Skip to content

Commit

Permalink
Apply code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikolajcz committed Oct 23, 2024
1 parent ef19e35 commit 40bf304
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,26 @@
ov::pass::ConvertSliceScatter::ConvertSliceScatter() {
MATCHER_SCOPE(ConvertSliceScatter);

auto slicescatter = pattern::wrap_type<ov::op::v15::SliceScatter>();
const auto& slicescatter = pattern::wrap_type<ov::op::v15::SliceScatter>();

matcher_pass_callback callback = [](pattern::Matcher& m) {
auto slice_node = ov::as_type_ptr<ov::op::v15::SliceScatter>(m.get_match_root());
if (!slice_node) {
const matcher_pass_callback callback = [this](pattern::Matcher& m) {
const auto& slice_node = ov::as_type_ptr<ov::op::v15::SliceScatter>(m.get_match_root());
if (!slice_node || transformation_callback(slice_node)) {
return false;
}
NodeRegistry node_registry;
auto const_0 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 0);
auto const_1 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 1);
auto const_1d_neg_1 =
const auto& const_0 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 0);
const auto& const_1 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 1);
const auto& const_1d_neg_1 =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{1}, std::vector<int64_t>{-1});
auto const_scatter_indices_shape =
const auto& const_scatter_indices_shape =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{-1, 1});
auto data_shape = node_registry.make<ov::op::v3::ShapeOf>(slice_node->input_value(0), ov::element::i64);
auto num_elements_data = node_registry.make<ov::op::v1::ReduceProd>(data_shape, const_0, false);
auto data_indices_flatten =
const auto& data_shape = node_registry.make<ov::op::v3::ShapeOf>(slice_node->input_value(0), ov::element::i64);
const auto& num_elements_data = node_registry.make<ov::op::v1::ReduceProd>(data_shape, const_0, false);
const auto& data_indices_flatten =
node_registry.make<ov::op::v4::Range>(const_0, num_elements_data, const_1, ov::element::i64);
auto full_data_indices = node_registry.make<ov::op::v1::Reshape>(data_indices_flatten, data_shape, false);
const auto& full_data_indices =
node_registry.make<ov::op::v1::Reshape>(data_indices_flatten, data_shape, false);
std::shared_ptr<ov::op::v8::Slice> slice_indices;
if (slice_node->get_input_size() == 5) {
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices,
Expand All @@ -54,23 +55,23 @@ ov::pass::ConvertSliceScatter::ConvertSliceScatter() {
slice_node->input_value(4),
slice_node->input_value(5));
}
auto slice_indices_flatten =
const auto& slice_indices_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_indices, const_scatter_indices_shape, false);
auto updates_flatten =
const auto& updates_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(1), const_1d_neg_1, false);
auto data_flatten = node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(0), const_1d_neg_1, false);
auto output_flatten =
const auto& data_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(0), const_1d_neg_1, false);
const auto& output_flatten =
node_registry.make<ov::op::v3::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
auto output = node_registry.make<ov::op::v1::Reshape>(output_flatten, data_shape, false);
const auto& output = node_registry.make<ov::op::v1::Reshape>(output_flatten, data_shape, false);

output->set_friendly_name(slice_node->get_friendly_name());
copy_runtime_info(slice_node, node_registry.get());
replace_node(slice_node, output);
slice_node->clear_control_dependencies();

return true;
};

auto m = std::make_shared<pattern::Matcher>(slicescatter, matcher_name);
const auto& m = std::make_shared<pattern::Matcher>(slicescatter, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
#include "transformations/utils/utils.hpp"
namespace {
std::shared_ptr<ov::Model> create_v15_model(ov::NodeVector inputs) {
const auto data = inputs.at(0);
const auto updates = inputs.at(1);
const auto start = inputs.at(2);
const auto stop = inputs.at(3);
const auto step = inputs.at(4);
const auto& data = inputs.at(0);
const auto& updates = inputs.at(1);
const auto& start = inputs.at(2);
const auto& stop = inputs.at(3);
const auto& step = inputs.at(4);
ov::ParameterVector params{};
for (auto inp : inputs) {
const auto param = ov::as_type_ptr<ov::op::v0::Parameter>(inp);
for (const auto& inp : inputs) {
const auto& param = ov::as_type_ptr<ov::op::v0::Parameter>(inp);
if (param) {
params.push_back(param);
}
Expand All @@ -38,40 +38,40 @@ std::shared_ptr<ov::Model> create_v15_model(ov::NodeVector inputs) {
}

std::shared_ptr<ov::Model> create_decomposed_model(ov::NodeVector inputs) {
const auto data = inputs.at(0);
const auto updates = inputs.at(1);
const auto start = inputs.at(2);
const auto stop = inputs.at(3);
const auto step = inputs.at(4);
const auto& data = inputs.at(0);
const auto& updates = inputs.at(1);
const auto& start = inputs.at(2);
const auto& stop = inputs.at(3);
const auto& step = inputs.at(4);
ov::ParameterVector params{};
for (auto inp : inputs) {
const auto param = ov::as_type_ptr<ov::op::v0::Parameter>(inp);
for (const auto& inp : inputs) {
const auto& param = ov::as_type_ptr<ov::op::v0::Parameter>(inp);
if (param) {
params.push_back(param);
}
}
auto const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
auto const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1});
auto const_1d_neg_1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto const_scatter_indices_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {-1, 1});
auto data_shape = std::make_shared<ov::opset8::ShapeOf>(data, ov::element::i64);
auto num_elements_data = std::make_shared<ov::opset8::ReduceProd>(data_shape, const_0, false);
auto data_indices_flatten =
const auto& const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
const auto& const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1});
const auto& const_1d_neg_1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
const auto& const_scatter_indices_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {-1, 1});
const auto& data_shape = std::make_shared<ov::opset8::ShapeOf>(data, ov::element::i64);
const auto& num_elements_data = std::make_shared<ov::opset8::ReduceProd>(data_shape, const_0, false);
const auto& data_indices_flatten =
std::make_shared<ov::opset8::Range>(const_0, num_elements_data, const_1, ov::element::i64);
auto full_data_indices = std::make_shared<ov::opset8::Reshape>(data_indices_flatten, data_shape, false);
const auto& full_data_indices = std::make_shared<ov::opset8::Reshape>(data_indices_flatten, data_shape, false);
std::shared_ptr<ov::opset8::Slice> slice_indices;
if (inputs.size() == 5) {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step);
} else {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step, inputs.at(5));
}
auto slice_indices_flatten =
const auto& slice_indices_flatten =
std::make_shared<ov::opset8::Reshape>(slice_indices, const_scatter_indices_shape, false);
auto updates_flatten = std::make_shared<ov::opset8::Reshape>(updates, const_1d_neg_1, false);
auto data_flatten = std::make_shared<ov::opset8::Reshape>(data, const_1d_neg_1, false);
auto output_flatten =
const auto& updates_flatten = std::make_shared<ov::opset8::Reshape>(updates, const_1d_neg_1, false);
const auto& data_flatten = std::make_shared<ov::opset8::Reshape>(data, const_1d_neg_1, false);
const auto& output_flatten =
std::make_shared<ov::opset8::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
auto slicescatter = std::make_shared<ov::opset8::Reshape>(output_flatten, data_shape, false);
const auto& slicescatter = std::make_shared<ov::opset8::Reshape>(output_flatten, data_shape, false);
slicescatter->set_friendly_name("slicescatter15");
return std::make_shared<ov::Model>(slicescatter->outputs(), params);
}
Expand Down

0 comments on commit 40bf304

Please sign in to comment.