diff --git a/src/common/transformations/include/transformations/op_conversions/convert_slicescatter.hpp b/src/common/transformations/include/transformations/op_conversions/convert_slicescatter.hpp new file mode 100644 index 00000000000000..020b4e236fcac5 --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/convert_slicescatter.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API ConvertSliceScatter; + +} // namespace pass +} // namespace ov + +class ov::pass::ConvertSliceScatter : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ConvertSliceScatter", "0"); + ConvertSliceScatter(); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 500d003bd4642e..9d46b583a828f2 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -94,6 +94,7 @@ #include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp" #include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp" #include "transformations/op_conversions/convert_slice_to_strided_slice.hpp" +#include "transformations/op_conversions/convert_slicescatter.hpp" #include "transformations/op_conversions/convert_softmax_downgrade.hpp" #include "transformations/op_conversions/convert_softmax_upgrade.hpp" #include "transformations/op_conversions/convert_space_to_depth.hpp" @@ -233,6 +234,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr(); ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion) diff --git a/src/common/transformations/src/transformations/op_conversions/convert_slicescatter.cpp b/src/common/transformations/src/transformations/op_conversions/convert_slicescatter.cpp new file mode 100644 index 00000000000000..391e1c54903b41 --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/convert_slicescatter.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_slicescatter.hpp" + +#include +#include + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reduce_prod.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_nd_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/slice_scatter.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +ov::pass::ConvertSliceScatter::ConvertSliceScatter() { + MATCHER_SCOPE(ConvertSliceScatter); + + auto slicescatter = pattern::wrap_type(); + + matcher_pass_callback callback = [](pattern::Matcher& m) { + auto slice_node = ov::as_type_ptr(m.get_match_root()); + if (!slice_node) { + return false; + } + NodeRegistry node_registry; + auto const_0 = node_registry.make(ov::element::i64, Shape{}, 0); + auto const_1 = node_registry.make(ov::element::i64, Shape{}, 1); + auto const_1d_neg_1 = + node_registry.make(ov::element::i64, Shape{1}, std::vector{-1}); + auto const_scatter_indices_shape = + node_registry.make(ov::element::i64, Shape{2}, std::vector{-1, 1}); + auto data_shape = node_registry.make(slice_node->input_value(0), ov::element::i64); + auto num_elements_data = node_registry.make(data_shape, const_0, false); + auto data_indices_flatten = + node_registry.make(const_0, num_elements_data, const_1, ov::element::i64); + auto full_data_indices = node_registry.make(data_indices_flatten, data_shape, false); + std::shared_ptr slice_indices; + if (slice_node->get_input_size() == 5) { + slice_indices = node_registry.make(full_data_indices, + slice_node->input_value(2), + slice_node->input_value(3), + slice_node->input_value(4)); + } else { + slice_indices = node_registry.make(full_data_indices, + slice_node->input_value(2), + slice_node->input_value(3), + slice_node->input_value(4), + slice_node->input_value(5)); + } + auto slice_indices_flatten = + node_registry.make(slice_indices, const_scatter_indices_shape, false); + auto updates_flatten = + node_registry.make(slice_node->input_value(1), const_1d_neg_1, false); + auto data_flatten = node_registry.make(slice_node->input_value(0), const_1d_neg_1, false); + auto output_flatten = + node_registry.make(data_flatten, slice_indices_flatten, updates_flatten); + auto output = node_registry.make(output_flatten, data_shape, false); + + output->set_friendly_name(slice_node->get_friendly_name()); + copy_runtime_info(slice_node, node_registry.get()); + replace_node(slice_node, output); + slice_node->clear_control_dependencies(); + + return true; + }; + + auto m = std::make_shared(slicescatter, matcher_name); + this->register_matcher(m, callback); +} diff --git a/src/common/transformations/tests/op_conversions/convert_slicescatter_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/convert_slicescatter_decomposition_test.cpp new file mode 100644 index 00000000000000..4aa94c4afb66ca --- /dev/null +++ b/src/common/transformations/tests/op_conversions/convert_slicescatter_decomposition_test.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/opsets/opset15.hpp" +#include "openvino/opsets/opset8.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/op_conversions/convert_slicescatter.hpp" +#include "transformations/utils/utils.hpp" +using namespace testing; + +namespace { + +std::shared_ptr create_v15_model(bool with_axes) { + const auto data = std::make_shared(ov::element::f32, ov::Shape{256, 10, 15}); + const auto updates = std::make_shared(ov::element::f32, ov::Shape{4, 7, 2}); + const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0}); + const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2}); + const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1}); + const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2}); + std::shared_ptr slicescatter; + if (!with_axes) { + slicescatter = std::make_shared(data, updates, start, stop, step); + } else { + slicescatter = std::make_shared(data, updates, start, stop, step, axes); + } + slicescatter->set_friendly_name("slicescatter15"); + return std::make_shared(slicescatter->outputs(), ov::ParameterVector{data, updates}); +} + +std::shared_ptr create_decomposed_model(bool with_axes) { + const auto data = std::make_shared(ov::element::f32, ov::Shape{256, 10, 15}); + const auto updates = std::make_shared(ov::element::f32, ov::Shape{4, 7, 2}); + const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0}); + const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2}); + const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1}); + const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2}); + auto zero = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); + auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto scatter_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {-1, 1}); + auto data_shape = std::make_shared(data, ov::element::i64); + auto num_elements_data = std::make_shared(data_shape, zero, false); + auto data_indices_flattened = std::make_shared(zero, num_elements_data, one, ov::element::i64); + auto full_data_indices = std::make_shared(data_indices_flattened, data_shape, false); + std::shared_ptr slice_indices; + if (!with_axes) { + slice_indices = std::make_shared(full_data_indices, start, stop, step); + } else { + slice_indices = std::make_shared(full_data_indices, start, stop, step, axes); + } + auto slice_indices_flatten = std::make_shared(slice_indices, scatter_shape, false); + auto updates_flatten = std::make_shared(updates, neg_one_1d, false); + auto data_flatten = std::make_shared(data, neg_one_1d, false); + auto output_flatten = + std::make_shared(data_flatten, slice_indices_flatten, updates_flatten); + auto slicescatter = std::make_shared(output_flatten, data_shape, false); + slicescatter->set_friendly_name("slicescatter15"); + + return std::make_shared(slicescatter->outputs(), ov::ParameterVector{data, updates}); +} + +} // namespace + +TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_axes) { + manager.register_pass(); + model = create_v15_model(true); + model_ref = create_decomposed_model(true); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_no_axes) { + manager.register_pass(); + model = create_v15_model(false); + model_ref = create_decomposed_model(false); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index abf1ad8f283205..3941f800ab6ba0 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -63,6 +63,7 @@ #include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp" #include "transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp" #include "transformations/op_conversions/convert_shuffle_channels3.hpp" +#include "transformations/op_conversions/convert_slicescatter.hpp" #include "transformations/op_conversions/convert_slice_to_strided_slice.hpp" #include "transformations/op_conversions/convert_space_to_batch.hpp" #include "transformations/op_conversions/convert_space_to_depth.hpp" @@ -644,6 +645,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition); CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction); CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertScatterNDUpdate15ToScatterNDUpdate3); + CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertSliceScatter); CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition); CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);