forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SliceScatter15 decomposition transformation
- Loading branch information
1 parent
9c432a3
commit af1198e
Showing
5 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
22 changes: 22 additions & 0 deletions
22
src/common/transformations/include/transformations/op_conversions/convert_slicescatter.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ConvertSliceScatter; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::ConvertSliceScatter : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConvertSliceScatter", "0"); | ||
ConvertSliceScatter(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
src/common/transformations/src/transformations/op_conversions/convert_slicescatter.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/op_conversions/convert_slicescatter.hpp" | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/range.hpp" | ||
#include "openvino/op/reduce_prod.hpp" | ||
#include "openvino/op/reshape.hpp" | ||
#include "openvino/op/scatter_nd_update.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
#include "openvino/op/slice.hpp" | ||
#include "openvino/op/slice_scatter.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
|
||
ov::pass::ConvertSliceScatter::ConvertSliceScatter() { | ||
MATCHER_SCOPE(ConvertSliceScatter); | ||
|
||
auto slicescatter = pattern::wrap_type<ov::op::v15::SliceScatter>(); | ||
|
||
matcher_pass_callback callback = [](pattern::Matcher& m) { | ||
auto slice_node = ov::as_type_ptr<ov::op::v15::SliceScatter>(m.get_match_root()); | ||
if (!slice_node) { | ||
return false; | ||
} | ||
NodeRegistry node_registry; | ||
auto const_0 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 0); | ||
auto const_1 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 1); | ||
auto const_1d_neg_1 = | ||
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{1}, std::vector<int64_t>{-1}); | ||
auto const_scatter_indices_shape = | ||
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{-1, 1}); | ||
auto data_shape = node_registry.make<ov::op::v3::ShapeOf>(slice_node->input_value(0), ov::element::i64); | ||
auto num_elements_data = node_registry.make<ov::op::v1::ReduceProd>(data_shape, const_0, false); | ||
auto data_indices_flatten = | ||
node_registry.make<ov::op::v4::Range>(const_0, num_elements_data, const_1, ov::element::i64); | ||
auto full_data_indices = node_registry.make<ov::op::v1::Reshape>(data_indices_flatten, data_shape, false); | ||
std::shared_ptr<ov::op::v8::Slice> slice_indices; | ||
if (slice_node->get_input_size() == 5) { | ||
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices, | ||
slice_node->input_value(2), | ||
slice_node->input_value(3), | ||
slice_node->input_value(4)); | ||
} else { | ||
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices, | ||
slice_node->input_value(2), | ||
slice_node->input_value(3), | ||
slice_node->input_value(4), | ||
slice_node->input_value(5)); | ||
} | ||
auto slice_indices_flatten = | ||
node_registry.make<ov::op::v1::Reshape>(slice_indices, const_scatter_indices_shape, false); | ||
auto updates_flatten = | ||
node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(1), const_1d_neg_1, false); | ||
auto data_flatten = node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(0), const_1d_neg_1, false); | ||
auto output_flatten = | ||
node_registry.make<ov::op::v3::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten); | ||
auto output = node_registry.make<ov::op::v1::Reshape>(output_flatten, data_shape, false); | ||
|
||
output->set_friendly_name(slice_node->get_friendly_name()); | ||
copy_runtime_info(slice_node, node_registry.get()); | ||
replace_node(slice_node, output); | ||
slice_node->clear_control_dependencies(); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<pattern::Matcher>(slicescatter, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
85 changes: 85 additions & 0 deletions
85
src/common/transformations/tests/op_conversions/convert_slicescatter_decomposition_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <memory> | ||
|
||
#include "common_test_utils/ov_test_utils.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/opsets/opset15.hpp" | ||
#include "openvino/opsets/opset8.hpp" | ||
#include "openvino/pass/manager.hpp" | ||
#include "transformations/op_conversions/convert_slicescatter.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
using namespace testing; | ||
|
||
namespace { | ||
|
||
std::shared_ptr<ov::Model> create_v15_model(bool with_axes) { | ||
const auto data = std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{256, 10, 15}); | ||
const auto updates = std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{4, 7, 2}); | ||
const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0}); | ||
const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2}); | ||
const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1}); | ||
const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2}); | ||
std::shared_ptr<ov::opset15::SliceScatter> slicescatter; | ||
if (!with_axes) { | ||
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step); | ||
} else { | ||
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step, axes); | ||
} | ||
slicescatter->set_friendly_name("slicescatter15"); | ||
return std::make_shared<ov::Model>(slicescatter->outputs(), ov::ParameterVector{data, updates}); | ||
} | ||
|
||
std::shared_ptr<ov::Model> create_decomposed_model(bool with_axes) { | ||
const auto data = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::Shape{256, 10, 15}); | ||
const auto updates = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::Shape{4, 7, 2}); | ||
const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0}); | ||
const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2}); | ||
const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1}); | ||
const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2}); | ||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); | ||
auto one = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); | ||
auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); | ||
auto scatter_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {-1, 1}); | ||
auto data_shape = std::make_shared<ov::opset8::ShapeOf>(data, ov::element::i64); | ||
auto num_elements_data = std::make_shared<ov::opset8::ReduceProd>(data_shape, zero, false); | ||
auto data_indices_flattened = std::make_shared<ov::opset8::Range>(zero, num_elements_data, one, ov::element::i64); | ||
auto full_data_indices = std::make_shared<ov::opset8::Reshape>(data_indices_flattened, data_shape, false); | ||
std::shared_ptr<ov::opset8::Slice> slice_indices; | ||
if (!with_axes) { | ||
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step); | ||
} else { | ||
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step, axes); | ||
} | ||
auto slice_indices_flatten = std::make_shared<ov::opset8::Reshape>(slice_indices, scatter_shape, false); | ||
auto updates_flatten = std::make_shared<ov::opset8::Reshape>(updates, neg_one_1d, false); | ||
auto data_flatten = std::make_shared<ov::opset8::Reshape>(data, neg_one_1d, false); | ||
auto output_flatten = | ||
std::make_shared<ov::opset8::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten); | ||
auto slicescatter = std::make_shared<ov::opset8::Reshape>(output_flatten, data_shape, false); | ||
slicescatter->set_friendly_name("slicescatter15"); | ||
|
||
return std::make_shared<ov::Model>(slicescatter->outputs(), ov::ParameterVector{data, updates}); | ||
} | ||
|
||
} // namespace | ||
|
||
TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_axes) { | ||
manager.register_pass<ov::pass::ConvertSliceScatter>(); | ||
model = create_v15_model(true); | ||
model_ref = create_decomposed_model(true); | ||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); | ||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); | ||
} | ||
|
||
TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_no_axes) { | ||
manager.register_pass<ov::pass::ConvertSliceScatter>(); | ||
model = create_v15_model(false); | ||
model_ref = create_decomposed_model(false); | ||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); | ||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters