Skip to content

Commit

Permalink
Add SliceScatter15 decomposition transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikolajcz committed Oct 18, 2024
1 parent 9c432a3 commit af1198e
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertSliceScatter;

} // namespace pass
} // namespace ov

class ov::pass::ConvertSliceScatter : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertSliceScatter", "0");
ConvertSliceScatter();
};
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
#include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp"
#include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
Expand Down Expand Up @@ -233,6 +234,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertEmbeddingBagOffsets15ToEmbeddingBagOffsetsSum3)
REGISTER_PASS(manager, ConvertEmbeddingBagPacked15ToEmbeddingBagPackedSum3)
REGISTER_PASS(manager, ConvertScatterNDUpdate15ToScatterNDUpdate3)
REGISTER_PASS(manager, ConvertSliceScatter)

auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_slicescatter.hpp"

#include <memory>
#include <vector>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/slice_scatter.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::ConvertSliceScatter::ConvertSliceScatter() {
MATCHER_SCOPE(ConvertSliceScatter);

auto slicescatter = pattern::wrap_type<ov::op::v15::SliceScatter>();

matcher_pass_callback callback = [](pattern::Matcher& m) {
auto slice_node = ov::as_type_ptr<ov::op::v15::SliceScatter>(m.get_match_root());
if (!slice_node) {
return false;
}
NodeRegistry node_registry;
auto const_0 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 0);
auto const_1 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 1);
auto const_1d_neg_1 =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{1}, std::vector<int64_t>{-1});
auto const_scatter_indices_shape =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{-1, 1});
auto data_shape = node_registry.make<ov::op::v3::ShapeOf>(slice_node->input_value(0), ov::element::i64);
auto num_elements_data = node_registry.make<ov::op::v1::ReduceProd>(data_shape, const_0, false);
auto data_indices_flatten =
node_registry.make<ov::op::v4::Range>(const_0, num_elements_data, const_1, ov::element::i64);
auto full_data_indices = node_registry.make<ov::op::v1::Reshape>(data_indices_flatten, data_shape, false);
std::shared_ptr<ov::op::v8::Slice> slice_indices;
if (slice_node->get_input_size() == 5) {
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices,
slice_node->input_value(2),
slice_node->input_value(3),
slice_node->input_value(4));
} else {
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices,
slice_node->input_value(2),
slice_node->input_value(3),
slice_node->input_value(4),
slice_node->input_value(5));
}
auto slice_indices_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_indices, const_scatter_indices_shape, false);
auto updates_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(1), const_1d_neg_1, false);
auto data_flatten = node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(0), const_1d_neg_1, false);
auto output_flatten =
node_registry.make<ov::op::v3::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
auto output = node_registry.make<ov::op::v1::Reshape>(output_flatten, data_shape, false);

output->set_friendly_name(slice_node->get_friendly_name());
copy_runtime_info(slice_node, node_registry.get());
replace_node(slice_node, output);
slice_node->clear_control_dependencies();

return true;
};

auto m = std::make_shared<pattern::Matcher>(slicescatter, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <memory>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/opsets/opset15.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/utils/utils.hpp"
using namespace testing;

namespace {

std::shared_ptr<ov::Model> create_v15_model(bool with_axes) {
const auto data = std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{256, 10, 15});
const auto updates = std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{4, 7, 2});
const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0});
const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2});
const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1});
const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2});
std::shared_ptr<ov::opset15::SliceScatter> slicescatter;
if (!with_axes) {
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step);
} else {
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step, axes);
}
slicescatter->set_friendly_name("slicescatter15");
return std::make_shared<ov::Model>(slicescatter->outputs(), ov::ParameterVector{data, updates});
}

std::shared_ptr<ov::Model> create_decomposed_model(bool with_axes) {
const auto data = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::Shape{256, 10, 15});
const auto updates = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::Shape{4, 7, 2});
const auto start = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 0, 0});
const auto stop = ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, 2});
const auto step = ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, 1});
const auto axes = ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, 2});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {}, {1});
auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto scatter_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {-1, 1});
auto data_shape = std::make_shared<ov::opset8::ShapeOf>(data, ov::element::i64);
auto num_elements_data = std::make_shared<ov::opset8::ReduceProd>(data_shape, zero, false);
auto data_indices_flattened = std::make_shared<ov::opset8::Range>(zero, num_elements_data, one, ov::element::i64);
auto full_data_indices = std::make_shared<ov::opset8::Reshape>(data_indices_flattened, data_shape, false);
std::shared_ptr<ov::opset8::Slice> slice_indices;
if (!with_axes) {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step);
} else {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step, axes);
}
auto slice_indices_flatten = std::make_shared<ov::opset8::Reshape>(slice_indices, scatter_shape, false);
auto updates_flatten = std::make_shared<ov::opset8::Reshape>(updates, neg_one_1d, false);
auto data_flatten = std::make_shared<ov::opset8::Reshape>(data, neg_one_1d, false);
auto output_flatten =
std::make_shared<ov::opset8::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
auto slicescatter = std::make_shared<ov::opset8::Reshape>(output_flatten, data_shape, false);
slicescatter->set_friendly_name("slicescatter15");

return std::make_shared<ov::Model>(slicescatter->outputs(), ov::ParameterVector{data, updates});
}

} // namespace

TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_axes) {
manager.register_pass<ov::pass::ConvertSliceScatter>();
model = create_v15_model(true);
model_ref = create_decomposed_model(true);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

TEST_F(TransformationTestsF, ConvertSliceScatter15Decomposition_no_axes) {
manager.register_pass<ov::pass::ConvertSliceScatter>();
model = create_v15_model(false);
model_ref = create_decomposed_model(false);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp"
#include "transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp"
#include "transformations/op_conversions/convert_shuffle_channels3.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/op_conversions/convert_space_to_batch.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
Expand Down Expand Up @@ -644,6 +645,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertScatterNDUpdate15ToScatterNDUpdate3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertSliceScatter);
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down

0 comments on commit af1198e

Please sign in to comment.