From 637d510c436efc9bf771a7ec480b74ba8bef5dfc Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Wed, 18 Dec 2024 11:06:23 +0000 Subject: [PATCH 01/18] Einsum core improvements Signed-off-by: MATEUSZ MIKOLAJCZYK --- src/core/reference/src/op/einsum.cpp | 25 ++- .../include/einsum_shape_inference.hpp | 21 +- src/core/src/op/einsum.cpp | 5 - src/core/tests/type_prop/einsum.cpp | 208 +++++++++++++++++- 4 files changed, 233 insertions(+), 26 deletions(-) diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index b8b23964346225..3b26491caa7b7b 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -425,7 +425,7 @@ void broadcast_input(ov::TensorVector& inputs, OPENVINO_ASSERT(input_ind < inputs.size()); ov::Tensor& input = inputs[input_ind]; const Shape old_shape = input.get_shape(); - Shape new_shape; + PartialShape new_shape; new_shape.insert(new_shape.end(), new_common_shape.begin(), new_common_shape.end()); if (is_separate_first) { new_shape.insert(new_shape.end(), separate_shape.begin(), separate_shape.end()); @@ -435,15 +435,15 @@ void broadcast_input(ov::TensorVector& inputs, new_shape.insert(new_shape.end(), separate_shape.begin(), separate_shape.end()); } - if (input.get_shape() == new_shape) { + if (input.get_shape() == new_shape.to_shape()) { return; } OPENVINO_ASSERT(old_shape.size() <= new_shape.size()); - auto output = ov::Tensor(input.get_element_type(), new_shape); - std::vector broadcast_axes(old_shape.size()); std::iota(broadcast_axes.begin(), broadcast_axes.end(), new_shape.size() - old_shape.size()); + OPENVINO_ASSERT(PartialShape::broadcast_merge_into(new_shape, old_shape, ov::op::AutoBroadcastType::NUMPY)); + auto output = ov::Tensor(input.get_element_type(), new_shape.to_shape()); reference::broadcast(reinterpret_cast(input.data()), reinterpret_cast(output.data()), @@ -853,8 +853,10 @@ void contract_two_inputs(ov::TensorVector& inputs, PartialShape common_sub_shape1 = compute_sub_shape(input_shape1, common_dims_begin, common_dims_end); PartialShape common_sub_shape2 = compute_sub_shape(input_shape2, common_dims_begin2, common_dims_end2); - Shape reduced_sub_shape_prod = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end, true); - Shape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end); + PartialShape reduced_sub_shape_prod = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end, true); + PartialShape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end); + Shape reduced_sub_shape_prod2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2, true); + Shape reduced_sub_shape2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2); Shape separate1_sub_shape = compute_sub_shape(input_shape1, separate1_dims_begin, separate1_dims_end); Shape separate2_sub_shape = compute_sub_shape(input_shape2, separate2_dims_begin, separate2_dims_end); @@ -862,29 +864,32 @@ void contract_two_inputs(ov::TensorVector& inputs, // in case of ellipsis among the common labels // reference::broadcast() PartialShape::broadcast_merge_into(common_sub_shape1, common_sub_shape2, op::AutoBroadcastType::NUMPY); + PartialShape::broadcast_merge_into(reduced_sub_shape, reduced_sub_shape2, op::AutoBroadcastType::NUMPY); + PartialShape::broadcast_merge_into(reduced_sub_shape_prod, reduced_sub_shape_prod2, op::AutoBroadcastType::NUMPY); Shape common_sub_shape = common_sub_shape1.get_shape(); broadcast_input(inputs, input_ind1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape, + reduced_sub_shape.get_shape(), is_separate_first1); broadcast_input(inputs, input_ind2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape, + reduced_sub_shape.get_shape(), is_separate_first2); ov::Tensor matmul_operand1 = reshape_input_for_matmul(input1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape_prod.get_shape(), is_separate_first1); + ov::Tensor matmul_operand2 = reshape_input_for_matmul(input2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape_prod.get_shape(), is_separate_first2); // step 3. apply MatMul operation for formatted inputs diff --git a/src/core/shape_inference/include/einsum_shape_inference.hpp b/src/core/shape_inference/include/einsum_shape_inference.hpp index eb84482af0f052..5de11922f894a4 100644 --- a/src/core/shape_inference/include/einsum_shape_inference.hpp +++ b/src/core/shape_inference/include/einsum_shape_inference.hpp @@ -31,15 +31,19 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) { const auto& pshape = input_shapes[input_idx]; const auto labels = Einsum::extract_labels(input_subscripts[input_idx]); + const auto has_ellipsis = std::any_of(labels.begin(), labels.end(), [](std::string label) { + return label == "..."; + }); if (pshape.rank().is_static()) { size_t input_rank = pshape.size(); // check that a rank is greater or equal to a number of labels // these numbers are always equal if there is no ellipsis in the subscript - NODE_VALIDATION_CHECK(op, - input_rank >= labels.size(), - "Input rank must be greater or equal to a number of labels in the " - "corresponding input subscript."); + NODE_VALIDATION_CHECK( + op, + (input_rank >= (labels.size() - 1) && has_ellipsis) || (input_rank == labels.size() && !has_ellipsis), + "Input rank must be greater or equal to a number of labels in the " + "corresponding input subscript."); for (size_t label_ind = 0, dim_ind = 0; label_ind < labels.size() && dim_ind < input_rank; ++label_ind) { auto const& label = labels[label_ind]; @@ -64,15 +68,20 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s label_to_shape[label] = TRShape{pshape[dim_ind]}; } else { NODE_VALIDATION_CHECK(op, - label_to_shape[label].compatible(TRShape{pshape[label_ind]}), + TRShape::broadcast_merge_into(label_to_shape[label], + TRShape{pshape[dim_ind]}, + op::AutoBroadcastType::NUMPY), "Different input dimensions indicated by the same labels for Einsum " "must be compatible."); - OPENVINO_ASSERT(TRShape::merge_into(label_to_shape[label], TRShape{pshape[dim_ind]})); } ++dim_ind; } } } else { + if (has_ellipsis) { + // Shape has dynamic rank and ellipsis + return {pshape}; + } for (auto const& label : labels) { NODE_VALIDATION_CHECK(op, label != "...", diff --git a/src/core/src/op/einsum.cpp b/src/core/src/op/einsum.cpp index 281dc58d07684e..8c6e6b34040760 100644 --- a/src/core/src/op/einsum.cpp +++ b/src/core/src/op/einsum.cpp @@ -139,11 +139,6 @@ void op::v7::Einsum::parse_equation(const std::string& equation, OPENVINO_ASSERT(is_subscript_correct(output_subscript, output_is_ellipsis_met), "Output subscript of Einsum equation must consist of either only " "alphabetic letters or alphabetic letters with one ellipsis."); - - // if the ellipsis is met in input subscripts, one ellipsis must be in the output subscript - OPENVINO_ASSERT(is_ellipsis_met == output_is_ellipsis_met, - "Output subscript of Einsum equation must contain one ellipsis if " - "ellipsis is met in any input subscript."); } } diff --git a/src/core/tests/type_prop/einsum.cpp b/src/core/tests/type_prop/einsum.cpp index 9fbb04fcc1b610..455c2840f432ac 100644 --- a/src/core/tests/type_prop/einsum.cpp +++ b/src/core/tests/type_prop/einsum.cpp @@ -177,7 +177,7 @@ TEST_F(TypePropEinsumTest, dynamic_shape_diag_extraction) { EXPECT_EQ(o->get_element_type(), et); EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({{3, 5}, 3, 4})); - EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), ElementsAre(symbols[0], symbols[1], symbols[2])); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), ElementsAre(symbols[0], symbols[4], symbols[2])); } TEST_F(TypePropEinsumTest, dynamic_shape_ellipsis) { @@ -372,14 +372,212 @@ TEST_F(TypePropEinsumTest, incorrect_equation_not_broadcastable_shapes) { HasSubstr("Input dimensions labeled with ellipsis for Einsum must be broadcastable.")); } -TEST_F(TypePropEinsumTest, incorrect_equation_missed_ellipsis) { +TEST_F(TypePropEinsumTest, missed_out_ellipsis) { const std::string equation = "a...b,b...->a"; - const auto input_shapes = Shapes{{11, 1, 4, 3}, {3, 11, 7, 5}}; + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_rhs_out_ellipsis) { + const std::string equation = "a...b,b->a"; + + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_lhs_out_ellipsis) { + const std::string equation = "ab,b...->a"; + + const auto input_shapes = Shapes{{11, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_rhs_ellipsis) { + const std::string equation = "a...b,b->a..."; + + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11, 1, 4})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_lhs_ellipsis) { + const std::string equation = "ab,b...->a..."; + + const auto input_shapes = Shapes{{11, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11, 11, 7, 4})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_rhs_ellipsis_implicit) { + const std::string equation = "a...b,b"; + + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({1, 4, 11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_lhs_ellipsis_implicit) { + const std::string equation = "ab,b..."; + + const auto input_shapes = Shapes{{11, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11, 7, 4, 11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, all_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes(2, PartialShape::dynamic()); + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, broadcasting_same_symbol_common) { + const std::string equation = "ab,ba->b"; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{7, 5}, {1, 7}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, broadcasting_same_symbol_reduced) { + const std::string equation = "ab,ba->b"; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{1, 5}, {5, 7}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, broadcasting_same_symbol) { + const std::string equation = "ab,ba->b"; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{7, 1}, {5, 1}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, ellipsis_no_dimension) { + const std::string equation = "...ab,ba...->b..."; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{5, 1}, {5, 5}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, ellipsis_dynamic_shape) { + const std::string equation = "...ab,ba...->b..."; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{-1, 57, 5, 5}, {5, 5}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5, -1, 57})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, input_rank_incompatible_with_equation) { + const std::string equation = "ab,bc->ac"; + + const auto input_shapes = Shapes{{2, 2, 10}, {3, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + + OV_EXPECT_THROW(auto o = make_op(inputs, equation), + AssertFailure, + HasSubstr("Input rank must be greater or equal to a number of labels in the " + "corresponding input subscript.")); +} + +TEST_F(TypePropEinsumTest, input_rank_incompatible_with_equation_single_input) { + const std::string equation = "ab->ba"; + + const auto input_shapes = Shapes{{3, 5, 7}}; const auto inputs = make_inputs(element::f32, input_shapes); OV_EXPECT_THROW(auto o = make_op(inputs, equation), AssertFailure, - HasSubstr("Output subscript of Einsum equation must contain one " - "ellipsis if ellipsis is met in any input subscript.")); + HasSubstr("Input rank must be greater or equal to a number of labels in the " + "corresponding input subscript.")); } From 50c98c189d4e7204a2afab03f36c0d3791005b42 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Thu, 19 Dec 2024 18:07:55 +0000 Subject: [PATCH 02/18] Einsum decomposition broadcasting + ellipsis support Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 300 +++++++++++++++--- 1 file changed, 264 insertions(+), 36 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 7955e37cfcda14..7d93928243b78d 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -9,6 +9,7 @@ #include "itt.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/core/validation_util.hpp" +#include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/einsum.hpp" @@ -26,7 +27,7 @@ namespace { /// \brief Check if the EinsumDecomposition transformation is applicable to a given Einsum. -/// The transformation is applicable if input subscript does not have repeated labels and ellipsis. +/// The transformation is applicable if input subscript does not have repeated labels. /// /// \param subscript A subscript to check its format /// @@ -35,7 +36,7 @@ namespace { bool is_subscript_applicable(const std::string& subscript) { auto labels = ov::op::v7::Einsum::extract_labels(subscript); auto unique_labels = std::unordered_set(labels.begin(), labels.end()); - return std::find(labels.begin(), labels.end(), "...") == labels.end() && unique_labels.size() == labels.size(); + return unique_labels.size() == labels.size(); } /// \brief Compute einsum_path for a given Einsum node meaning that the (pseudo-)optimal @@ -174,7 +175,88 @@ void update_operands(ov::OutputVector& input_nodes, input_subscripts.erase(input_subscripts.begin() + input_ind1); input_subscripts.push_back(new_subscript); } +using LabelDimMap = std::unordered_map>; + +LabelDimMap compute_label_dim_map(const ov::Rank& input_rank, const std::string& input_subscript) { + static const std::string ellipsis = "..."; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); + const auto static_input_rank = input_rank.is_static(); + OPENVINO_ASSERT(static_input_rank || (std::find(labels.begin(), labels.end(), ellipsis) == labels.end()), + "Input rank cannot be dynamic in case of ellipsis in input subscript"); + const size_t input_rank_length = static_input_rank ? input_rank.get_length() : labels.size(); + OPENVINO_ASSERT(input_rank_length >= labels.size()); + const size_t num_broadcasted_dims = input_rank_length - labels.size() + 1; + OPENVINO_ASSERT(num_broadcasted_dims > 0); + + LabelDimMap resulted_map; + size_t current_dim = 0; + for (const auto& label : labels) { + if (label == ellipsis) { + std::vector label_dims(num_broadcasted_dims); + std::iota(label_dims.begin(), label_dims.end(), current_dim); + resulted_map[label] = label_dims; + current_dim += num_broadcasted_dims; + } else if (resulted_map.find(label) != resulted_map.end()) { + resulted_map[label].push_back(current_dim); + ++current_dim; + } else { + std::vector label_dims; + label_dims.push_back(current_dim); + resulted_map[label] = label_dims; + ++current_dim; + } + } + + return resulted_map; +} + +void compute_ranges(const ov::Rank& input_rank, + const std::string& input_subscript, + const std::vector& common_labels, + const std::vector& sep_labels, + const std::vector& reduced_labels, + size_t& common_begin, + size_t& common_end, + size_t& sep_begin, + size_t& sep_end, + size_t& reduced_begin, + size_t& reduced_end, + bool is_separated_first) { + auto label_to_dim_map = compute_label_dim_map(input_rank, input_subscript); + static const std::string ellipsis = "..."; + + size_t common_rank = common_labels.size(); + if (std::find(common_labels.begin(), common_labels.end(), ellipsis) != common_labels.end()) { + OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); + common_rank += label_to_dim_map[ellipsis].size() - 1; + } + + size_t sep_rank = sep_labels.size(); + if (std::find(sep_labels.begin(), sep_labels.end(), ellipsis) != sep_labels.end()) { + OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); + sep_rank += label_to_dim_map[ellipsis].size() - 1; + } + size_t reduced_rank = reduced_labels.size(); + if (std::find(reduced_labels.begin(), reduced_labels.end(), ellipsis) != reduced_labels.end()) { + OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); + reduced_rank += label_to_dim_map[ellipsis].size() - 1; + } + + common_begin = 0; + common_end = common_begin + common_rank; + if (is_separated_first) { + sep_begin = common_end; + sep_end = sep_begin + sep_rank; + reduced_begin = sep_end; + reduced_end = reduced_begin + reduced_rank; + } else { + reduced_begin = common_end; + reduced_end = reduced_begin + reduced_rank; + sep_begin = reduced_end; + sep_end = sep_begin + sep_rank; + } +} /// \brief Return input node with computed sub-shape defined by a range [s_begin;s_end) /// /// \param data_shape Input node that contains some tensor shape @@ -243,6 +325,84 @@ ov::Output unsqueeze_input(const ov::Output& input_node, return unsqueeze->output(0); } +ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs, + ov::OutputVector& shapes_rhs, + ov::NodeVector& subgraph_nodes) { + // TODO - Refactor func to remove loop and duplicated Broadcast. + OPENVINO_ASSERT(shapes_lhs.size() == shapes_rhs.size()); + ov::OutputVector broadcasted_shape_nodes{shapes_lhs.size()}; + + for (size_t shp_i = 0; shp_i < shapes_lhs.size(); shp_i++) { + auto const_1 = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {1}); + auto tmp_const_of_lhs_shp = + std::make_shared(const_1, shapes_lhs[shp_i], ov::op::BroadcastType::NUMPY); + auto tmp_const_of_broadcasted_shp = + std::make_shared(tmp_const_of_lhs_shp, + shapes_rhs[shp_i], + ov::op::BroadcastType::BIDIRECTIONAL); + auto broadcasted_shape = std::make_shared(tmp_const_of_broadcasted_shp); + broadcasted_shape_nodes[shp_i] = broadcasted_shape; + subgraph_nodes.insert(subgraph_nodes.end(), + {const_1, tmp_const_of_lhs_shp, tmp_const_of_broadcasted_shp, broadcasted_shape}); + } + return broadcasted_shape_nodes; +} + +/// \brief Broadcast input node to the new shape specified by broadcasted sub-shapes of the common, +/// separate and reduced dimensions so that the broadcasted input has a format acceptable by Reshape MatMul +/// +/// \param input_node Input node to reshape +/// \param common_sub_shape A sub-shape corresponding to the broadcasted common dimensions +/// \param separate_sub_shape A sub-shape corresponding to the broadcasted separate dimensions +/// \param reduced_sub_shape_prod A product of the broadcasted separate dimensions sizes +/// \param is_separate_first true - the separate dimensions placed before reduced +/// dimensions, otherwise, it is after them +/// \param subgraph_nodes A vector of operation nodes that is included into +/// a sub-graph decomposing Einsum that is needed for copy_runtime_info +/// +/// \return Broadcasted input node +/// +ov::Output broadcast_input(const ov::Output& input_node, + const ov::OutputVector& common_sub_shape, + const ov::OutputVector& separate_sub_shape, + const ov::OutputVector& reduced_sub_shape, + bool is_separate_first, + ov::NodeVector& subgraph_nodes) { + ov::OutputVector new_shape_parts; + new_shape_parts.insert(new_shape_parts.end(), common_sub_shape.begin(), common_sub_shape.end()); + // form a new shape for input so that collapsed dimensions corresponding + // to the common, separate and reduced dimensions are placed in the correct order + if (is_separate_first) { + new_shape_parts.insert(new_shape_parts.end(), separate_sub_shape.begin(), separate_sub_shape.end()); + new_shape_parts.insert(new_shape_parts.end(), reduced_sub_shape.begin(), reduced_sub_shape.end()); + } else { + new_shape_parts.insert(new_shape_parts.end(), reduced_sub_shape.begin(), reduced_sub_shape.end()); + new_shape_parts.insert(new_shape_parts.end(), separate_sub_shape.begin(), separate_sub_shape.end()); + } + + // in case of scalar reshape is not needed + if (new_shape_parts.size() == 0) { + return input_node; + } + auto new_shape_op = std::make_shared(new_shape_parts, 0); + // if new shape is possible to compute on the shape infer stage, insert Constant node immediately + // in order to prevent repeated computing during constant-folding pass + std::shared_ptr reshaped_input_op; + if (auto new_shape_const = ov::util::get_constant_from_source(new_shape_op)) { + reshaped_input_op = + std::make_shared(input_node, new_shape_const, ov::op::BroadcastType::BIDIRECTIONAL); + subgraph_nodes.insert(subgraph_nodes.end(), {new_shape_const}); + } else { + reshaped_input_op = std::make_shared(input_node, + new_shape_op->output(0), + ov::op::BroadcastType::BIDIRECTIONAL); + subgraph_nodes.insert(subgraph_nodes.end(), {new_shape_op}); + } + + subgraph_nodes.insert(subgraph_nodes.end(), {reshaped_input_op}); + return reshaped_input_op->output(0); +} + /// \brief Reshape input node to the new shape specified by sub-shapes of the common, /// separate and reduced dimensions so that the reshaped input has a format acceptable by MatMul /// @@ -334,7 +494,7 @@ void transpose_input(ov::OutputVector& input_nodes, size_t input_ind, ov::NodeVector& subgraph_nodes) { // perform sanity check for arguments - auto num_inputs = input_nodes.size(); + const auto num_inputs = input_nodes.size(); OPENVINO_ASSERT(num_inputs == input_subscripts.size(), "Each input must have own subscript."); OPENVINO_ASSERT(input_ind < num_inputs, "Input index is out of range."); @@ -350,21 +510,22 @@ void transpose_input(ov::OutputVector& input_nodes, // find permutation that establishes bijection between the input subscript // and the required one - auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); - auto required_labels = ov::op::v7::Einsum::extract_labels(required_subscript); + const auto& input_node = input_nodes[input_ind]; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); + const auto required_labels = ov::op::v7::Einsum::extract_labels(required_subscript); OPENVINO_ASSERT(labels.size() == required_labels.size()); + const auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); for (const auto& required_label : required_labels) { - auto it = std::find(labels.begin(), labels.end(), required_label); - OPENVINO_ASSERT(it != labels.end()); - int64_t found_index = static_cast(it - labels.begin()); - permutation.push_back(found_index); + const auto label_dims_it = label_dim_map.find(required_label); + OPENVINO_ASSERT(label_dims_it != label_dim_map.end()); + const auto& label_dims = label_dims_it->second; + permutation.insert(permutation.end(), label_dims.begin(), label_dims.end()); } // create a sub-graph for transposing into the required layout - const auto& input_node = input_nodes[input_ind]; - auto permutation_const = + const auto permutation_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{permutation.size()}, permutation); - auto transpose = std::make_shared(input_node, permutation_const); + const auto transpose = std::make_shared(input_node, permutation_const); // update a vector of inputs and input subscripts input_nodes[input_ind] = transpose->output(0); @@ -468,6 +629,11 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, const auto& input_node1 = input_nodes[input_ind1]; const auto& input_node2 = input_nodes[input_ind2]; + // extract diagonals in case repeated labels in the corresponding input subscripts + // TODO + // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); + // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); + // reduce dimensions for input operands if possible reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind1, subgraph_nodes); reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind2, subgraph_nodes); @@ -491,6 +657,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, std::vector common_labels_inds1, common_labels_inds2; std::vector separate_labels_inds1, separate_labels_inds2; std::vector reduced_labels_inds1, reduced_labels_inds2; + std::vector common_labels, sep_labels1, sep_labels2, reduced_labels; // +++++ for (size_t label_ind = 0; label_ind < labels1.size(); ++label_ind) { const auto& label = labels1[label_ind]; auto iter = std::find(labels2.begin(), labels2.end(), label); @@ -501,13 +668,16 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, if (is_dim_reduced) { reduced_labels_inds1.push_back(static_cast(label_ind)); reduced_labels_inds2.push_back(static_cast(iter - labels2.begin())); + reduced_labels.push_back(label); } else { common_labels_inds1.push_back(static_cast(label_ind)); common_labels_inds2.push_back(static_cast(iter - labels2.begin())); + common_labels.push_back(label); } } else { separate_part1 += label; separate_labels_inds1.push_back(static_cast(label_ind)); + sep_labels1.push_back(label); } } for (size_t label_ind = 0; label_ind < labels2.size(); ++label_ind) { @@ -516,6 +686,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, if (iter == labels1.end()) { separate_part2 += label; separate_labels_inds2.push_back(static_cast(label_ind)); + sep_labels2.push_back(label); } } @@ -601,26 +772,71 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; - int64_t common_dims_begin = 0; - int64_t common_dims_end = common_labels_inds1.size(); + + size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin, + separate1_dims_end; + compute_ranges(input_node1.get_partial_shape().rank(), + input_subscript1, + common_labels, + sep_labels1, + reduced_labels, + common_dims_begin, + common_dims_end, + separate1_dims_begin, + separate1_dims_end, + reduced_dims_begin, + reduced_dims_end, + is_separate_first1); + + size_t common_dims_begin2, common_dims_end2, reduced_dims_begin2, reduced_dims_end2, separate2_dims_begin, + separate2_dims_end; + compute_ranges(input_node2.get_partial_shape().rank(), + input_subscript2, + common_labels, + sep_labels2, + reduced_labels, + common_dims_begin2, + common_dims_end2, + separate2_dims_begin, + separate2_dims_end, + reduced_dims_begin2, + reduced_dims_end2, + is_separate_first2); + + no_reshape_for_matmul1 = false; + no_reshape_for_matmul2 = false; + // // no_reshape_after_matmul = false; ov::OutputVector common_sub_shape, separate1_sub_shape, separate2_sub_shape; + if (no_reshape_for_matmul1 == false || no_reshape_for_matmul2 == false) { auto data_shape1 = std::make_shared(input_node1); + auto data_shape2 = std::make_shared(input_node2); common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); - int64_t reduced_dims_begin = (is_separate_first1 ? common_labels_inds1.size() + separate_labels_inds1.size() - : common_labels_inds1.size()); - int64_t reduced_dims_end = reduced_dims_begin + reduced_labels_inds1.size(); + auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); + OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); + common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); auto reduced_sub_shape_prod = compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); - + auto reduced_sub_shape_prod2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); + auto reduced_sub_shape = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); + auto reduced_sub_shape2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); + + reduced_sub_shape_prod = + broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); + reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { - int64_t separate1_dims_begin = - (is_separate_first1 ? common_labels_inds1.size() - : common_labels_inds1.size() + reduced_labels_inds1.size()); - int64_t separate1_dims_end = separate1_dims_begin + separate_labels_inds1.size(); separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); - matmul_operand1 = reshape_input_for_matmul(input_node1, + auto broadcasted1 = broadcast_input(input_node1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape, + is_separate_first1, + subgraph_nodes); + matmul_operand1 = reshape_input_for_matmul(broadcasted1, common_sub_shape, separate1_sub_shape, reduced_sub_shape_prod, @@ -629,14 +845,15 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, } if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { - auto data_shape2 = std::make_shared(input_node2); - int64_t separate2_dims_begin = - (is_separate_first2 ? common_labels_inds2.size() - : common_labels_inds2.size() + reduced_labels_inds2.size()); - int64_t separate2_dims_end = separate2_dims_begin + separate_labels_inds2.size(); separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); - matmul_operand2 = reshape_input_for_matmul(input_node2, + auto broadcasted2 = broadcast_input(input_node2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape, + is_separate_first2, + subgraph_nodes); + matmul_operand2 = reshape_input_for_matmul(broadcasted2, common_sub_shape, separate2_sub_shape, reduced_sub_shape_prod, @@ -654,8 +871,11 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, // step 4. reshape back by unrolling dimensions corresponding to separate labels if needed // now dimensions corresponding to reduced labels are reduced by the MatMul operation - std::string resultant_subscript = - input_subscript1.substr(common_dims_begin, common_dims_end) + separate_part1 + separate_part2; + common_part = ""; + for (const auto& common_label : common_labels) { + common_part += common_label; + } + const std::string resultant_subscript = common_part + separate_part1 + separate_part2; if (no_reshape_after_matmul) { // this is a case when Reshape is not needed after MatMul operation // since there are no collapsed (or auxiliary added) separated dimensions @@ -667,12 +887,12 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, new_shape.insert(new_shape.end(), separate2_sub_shape.begin(), separate2_sub_shape.end()); auto result_shape_op = std::make_shared(new_shape, 0); - // if new shape is possible to compute on the shape infer stage, insert Constant node immediatelly + // if new shape is possible to compute on the shape infer stage, insert Constant node immediately // in order to prevent repeated computing during constant-folding pass std::shared_ptr result_op; if (auto new_shape_const = ov::util::get_constant_from_source(result_shape_op)) { result_op = std::make_shared(matmul->output(0), new_shape_const, false); - subgraph_nodes.insert(subgraph_nodes.end(), {new_shape_const}); + subgraph_nodes.insert(subgraph_nodes.end(), {result_shape_op, new_shape_const}); } else { result_op = std::make_shared(matmul->output(0), result_shape_op->output(0), false); subgraph_nodes.insert(subgraph_nodes.end(), {result_shape_op}); @@ -723,6 +943,12 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { // and a vector of sub-graph nodes for copy_runtime_info ov::OutputVector input_nodes = einsum_node->input_values(); ov::NodeVector subgraph_nodes; + // check that the transformation is applicable + if (std::any_of(input_nodes.cbegin(), input_nodes.cend(), [](ov::Output node) { + return node.get_partial_shape().rank().is_dynamic(); + })) { + return false; + } // compute einsum path that is used to contract a pair of operands // in more optimal order @@ -739,13 +965,15 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { subgraph_nodes); } - // reduce dimensions for the remained input node OPENVINO_ASSERT(input_nodes.size() == 1); - reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); + // extract diagonal for the single operand + // TODO + // extract_diagonal(this, input_nodes, input_subscripts, 0, subgraph_nodes); + // reduce dimensions for the remained input node + reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); // transpose dimensions to layout required by the output subscript transpose_input(input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); - // replace the original Einsum node with the last node from decomposing sub-graph // preserve the original node name auto last_node = input_nodes[0].get_node_shared_ptr(); From d3eac209cd424d07f17fe894ef8c986f760bf0bb Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Fri, 20 Dec 2024 11:21:32 +0000 Subject: [PATCH 03/18] Move broadcasting out of reshape conditional Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 108 +++++++++--------- 1 file changed, 51 insertions(+), 57 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 7d93928243b78d..7ae13f270a03c5 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -772,6 +772,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; + auto broadcasted_operand1 = input_node1; + auto broadcasted_operand2 = input_node2; size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin, separate1_dims_end; @@ -803,65 +805,57 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, reduced_dims_end2, is_separate_first2); - no_reshape_for_matmul1 = false; - no_reshape_for_matmul2 = false; - // // no_reshape_after_matmul = false; ov::OutputVector common_sub_shape, separate1_sub_shape, separate2_sub_shape; - if (no_reshape_for_matmul1 == false || no_reshape_for_matmul2 == false) { - auto data_shape1 = std::make_shared(input_node1); - auto data_shape2 = std::make_shared(input_node2); - common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); - auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); - OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); - common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); - auto reduced_sub_shape_prod = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); - auto reduced_sub_shape_prod2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); - auto reduced_sub_shape = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); - auto reduced_sub_shape2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); - - reduced_sub_shape_prod = - broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); - reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); - if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { - separate1_sub_shape = - compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); - auto broadcasted1 = broadcast_input(input_node1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape, - is_separate_first1, - subgraph_nodes); - matmul_operand1 = reshape_input_for_matmul(broadcasted1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape_prod, - is_separate_first1, - subgraph_nodes); - } - - if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { - separate2_sub_shape = - compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); - auto broadcasted2 = broadcast_input(input_node2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape, - is_separate_first2, - subgraph_nodes); - matmul_operand2 = reshape_input_for_matmul(broadcasted2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape_prod, - is_separate_first2, - subgraph_nodes); - subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2}); - } - subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1}); + auto data_shape1 = std::make_shared(input_node1); + auto data_shape2 = std::make_shared(input_node2); + subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1}); + subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2}); + common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); + auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); + OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); + common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); + auto reduced_sub_shape_prod = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); + auto reduced_sub_shape_prod2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); + auto reduced_sub_shape = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); + auto reduced_sub_shape2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); + + reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); + reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); + separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); + broadcasted_operand1 = broadcast_input(input_node1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape, + is_separate_first1, + subgraph_nodes); + separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); + broadcasted_operand2 = broadcast_input(input_node2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape, + is_separate_first2, + subgraph_nodes); + if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { + matmul_operand1 = reshape_input_for_matmul(broadcasted_operand1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape_prod, + is_separate_first1, + subgraph_nodes); + } + + if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { + matmul_operand2 = reshape_input_for_matmul(broadcasted_operand2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape_prod, + is_separate_first2, + subgraph_nodes); } // step 3. apply MatMul operation for formatted inputs From 0ec79742ec4f12fba96ddb6470cee59a8ded7487 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Tue, 7 Jan 2025 11:44:24 +0000 Subject: [PATCH 04/18] Initial support for repeated labels Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 224 ++++++++++++++---- 1 file changed, 183 insertions(+), 41 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 7ae13f270a03c5..067e73e3a0326c 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -12,33 +12,25 @@ #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/divide.hpp" #include "openvino/op/einsum.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" #include "openvino/op/reduce_prod.hpp" #include "openvino/op/reduce_sum.hpp" #include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/strided_slice.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" namespace { -/// \brief Check if the EinsumDecomposition transformation is applicable to a given Einsum. -/// The transformation is applicable if input subscript does not have repeated labels. -/// -/// \param subscript A subscript to check its format -/// -/// \return true - applicable, false - not applicable -/// -bool is_subscript_applicable(const std::string& subscript) { - auto labels = ov::op::v7::Einsum::extract_labels(subscript); - auto unique_labels = std::unordered_set(labels.begin(), labels.end()); - return unique_labels.size() == labels.size(); -} - /// \brief Compute einsum_path for a given Einsum node meaning that the (pseudo-)optimal /// order of operands contraction in terms of performance and memory consumption /// @@ -595,6 +587,167 @@ void reduce_input(ov::pass::EinsumDecomposition* einsum_decompose_ptr, subgraph_nodes.insert(subgraph_nodes.end(), {axes_const, reduce_sum}); } +ov::Output build_identity(const ov::Output& input_node, + const std::vector& repeated_label_dims, + ov::NodeVector& subgraph_nodes) { + OPENVINO_ASSERT(repeated_label_dims.size() > 1); + const auto input_shape = std::make_shared(input_node); + const auto repeated_label_indices = + ov::op::v0::Constant::create(ov::element::i64, {repeated_label_dims.size()}, repeated_label_dims); + const auto const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); + const auto const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); + const auto repeated_dimensions = std::make_shared(input_shape, repeated_label_indices, const_0); + const auto reduced_dimension = std::make_shared(repeated_dimensions, const_0, const_0); + const auto reduced_dimension_min_1 = std::make_shared(reduced_dimension, const_1); + + const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); + const auto reduced_size_min_1 = std::make_shared(reduced_size, const_1); + const auto step_size = std::make_shared(reduced_size_min_1, reduced_dimension_min_1); + const auto range = std::make_shared(const_0, reduced_dimension, const_1, ov::element::i64); + const auto steps = std::make_shared(range, step_size); + const auto zeros = std::make_shared(const_0, reduced_size); + const auto reduced_dimension_1d = std::make_shared(reduced_dimension, const_0); + const auto ones = std::make_shared(const_1, reduced_dimension_1d); + const auto eye_flattened = std::make_shared(zeros, steps, ones, const_0); + + const auto identity_rank = std::make_shared(input_shape); + const auto ones_of_input_shape_rank = std::make_shared(const_1, identity_rank); + const auto identity_shape = std::make_shared(ones_of_input_shape_rank, + repeated_label_indices, + repeated_dimensions, + const_0); + const auto identity = std::make_shared(eye_flattened, identity_shape, false); + const auto identity_cvt = std::make_shared(identity, input_node.get_element_type()); + subgraph_nodes.insert(subgraph_nodes.end(), + {input_shape, + repeated_label_indices, + const_0, + const_1, + repeated_dimensions, + reduced_dimension, + reduced_dimension_min_1, + reduced_size, + reduced_size_min_1, + step_size, + range, + steps, + zeros, + reduced_dimension_1d, + ones, + eye_flattened, + identity_rank, + ones_of_input_shape_rank, + identity_shape, + identity, + identity_cvt}); + return subgraph_nodes.back(); +} + +ov::Output build_multi_identity(ov::pass::EinsumDecomposition* einsum_decompose_ptr, + const ov::Output& input_node, + const std::vector& repeated_labels, + const LabelDimMap& label_dim_map, + ov::NodeVector& subgraph_nodes) { + OPENVINO_ASSERT(repeated_labels.size() > 0); + + const auto get_identity = [&](size_t idx) { + const auto repeated_label_dims = label_dim_map.find(repeated_labels[idx]); + OPENVINO_ASSERT(repeated_label_dims != label_dim_map.end()); + return build_identity(input_node, repeated_label_dims->second, subgraph_nodes); + }; + + // initially set multi-identity with identity for the first repeated label + const auto multi_identity = get_identity(0); + for (size_t label_ind = 1; label_ind < repeated_labels.size(); ++label_ind) { + const auto identity = get_identity(label_ind); + const auto mul = + std::make_shared(multi_identity, identity, ov::op::AutoBroadcastType::NUMPY); + subgraph_nodes.insert(subgraph_nodes.end(), {mul}); + } + + return subgraph_nodes.back(); +} + +/// \brief Helper function to fill in the data needed for diagonal extraction - result shape +/// and subscript, repeated labels, axes to reduce. +/// +void prepare_diagonal_extraction_data(const std::string& input_subscript, + const LabelDimMap& label_dim_map, + std::string& resultant_subscript, + std::vector& repeated_labels, + ov::AxisSet& reduced_axes) { + static const std::string ellipsis = "..."; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); + + for (const auto& label : labels) { + if (resultant_subscript.find(label) != std::string::npos) { + continue; + } + + const auto dims_it = label_dim_map.find(label); + OPENVINO_ASSERT(dims_it != label_dim_map.end()); + + auto dims = dims_it->second; + const auto dims_size = dims.size(); + OPENVINO_ASSERT(dims_size > 0); + + if (label != ellipsis && dims_size > 1) { + // repeated label is found + for (size_t dim_ind = 1; dim_ind < dims_size; ++dim_ind) { + reduced_axes.insert(dims[dim_ind]); + } + // save only the first dimension corresponding to the repeated label + dims = {dims[0]}; + repeated_labels.push_back(label); + } + resultant_subscript += label; + } +} + +void extract_diagonal(ov::pass::EinsumDecomposition* einsum_decompose_ptr, + ov::OutputVector& inputs, + std::vector& input_subscripts, + size_t input_ind, + ov::NodeVector& subgraph_nodes) { + // perform sanity check for arguments + const auto num_inputs = inputs.size(); + OPENVINO_ASSERT(num_inputs == input_subscripts.size(), "Each input must have own subscript."); + OPENVINO_ASSERT(input_ind < num_inputs, "Input index is out of range."); + + const auto& input_node = inputs[input_ind]; + const auto& input_subscript = input_subscripts[input_ind]; + + const auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); + std::string resultant_subscript; + std::vector repeated_labels; + ov::AxisSet reduced_axes; + prepare_diagonal_extraction_data(input_subscript, + label_dim_map, + resultant_subscript, + repeated_labels, + reduced_axes); + + if (repeated_labels.size() == 0) { + return; + } + const auto multi_identity = + build_multi_identity(einsum_decompose_ptr, input_node, repeated_labels, label_dim_map, subgraph_nodes); + + // multiply both operands with broadcasting + const auto mul = + std::make_shared(input_node, multi_identity, ov::op::AutoBroadcastType::NUMPY); + subgraph_nodes.insert(subgraph_nodes.end(), {mul}); + + const std::vector reduced_axes_vec{reduced_axes.cbegin(), reduced_axes.cend()}; + const auto axes_const = + ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes_vec); + const auto reduce_sum = std::make_shared(mul->output(0), axes_const, false); + subgraph_nodes.insert(subgraph_nodes.end(), {axes_const, reduce_sum}); + + inputs[input_ind] = reduce_sum->output(0); + input_subscripts[input_ind] = resultant_subscript; +} + /// \brief Contract two inputs of Einsum operation according to equation. /// The result of the contraction is appended into input_nodes along with its subscript. /// The input nodes for these two operands are removed from input_nodes along with their input @@ -630,9 +783,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, const auto& input_node2 = input_nodes[input_ind2]; // extract diagonals in case repeated labels in the corresponding input subscripts - // TODO - // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); - // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); + extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); + extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); // reduce dimensions for input operands if possible reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind1, subgraph_nodes); @@ -772,8 +924,6 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; - auto broadcasted_operand1 = input_node1; - auto broadcasted_operand2 = input_node2; size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin, separate1_dims_end; @@ -827,21 +977,21 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); - broadcasted_operand1 = broadcast_input(input_node1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape, - is_separate_first1, - subgraph_nodes); + matmul_operand1 = broadcast_input(input_node1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape, + is_separate_first1, + subgraph_nodes); separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); - broadcasted_operand2 = broadcast_input(input_node2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape, - is_separate_first2, - subgraph_nodes); + matmul_operand2 = broadcast_input(input_node2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape, + is_separate_first2, + subgraph_nodes); if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { - matmul_operand1 = reshape_input_for_matmul(broadcasted_operand1, + matmul_operand1 = reshape_input_for_matmul(matmul_operand1, common_sub_shape, separate1_sub_shape, reduced_sub_shape_prod, @@ -850,7 +1000,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, } if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { - matmul_operand2 = reshape_input_for_matmul(broadcasted_operand2, + matmul_operand2 = reshape_input_for_matmul(matmul_operand2, common_sub_shape, separate2_sub_shape, reduced_sub_shape_prod, @@ -926,13 +1076,6 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { std::string output_subscript; ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript); - // check that the transformation is applicable - if (std::any_of(input_subscripts.cbegin(), input_subscripts.cend(), [](const std::string& subscript) { - return is_subscript_applicable(subscript) == false; - })) { - return false; - } - // create a list of input nodes with preserving their order // and a vector of sub-graph nodes for copy_runtime_info ov::OutputVector input_nodes = einsum_node->input_values(); @@ -962,8 +1105,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { OPENVINO_ASSERT(input_nodes.size() == 1); // extract diagonal for the single operand - // TODO - // extract_diagonal(this, input_nodes, input_subscripts, 0, subgraph_nodes); + extract_diagonal(this, input_nodes, input_subscripts, 0, subgraph_nodes); // reduce dimensions for the remained input node reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); // transpose dimensions to layout required by the output subscript From fa041ca75dfa456ecb71e72115ad257beb3896c7 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Wed, 8 Jan 2025 11:11:18 +0000 Subject: [PATCH 05/18] Remove xfail for onnx einsum test Signed-off-by: MATEUSZ MIKOLAJCZYK --- src/frontends/onnx/tests/__init__.py | 1 - src/frontends/onnx/tests/tests_python/test_backend.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/frontends/onnx/tests/__init__.py b/src/frontends/onnx/tests/__init__.py index fdf1295dfd1dbe..bd5477dbcf8a13 100644 --- a/src/frontends/onnx/tests/__init__.py +++ b/src/frontends/onnx/tests/__init__.py @@ -120,7 +120,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE") xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadcast_cpu - " "Not equal to tolerance") -xfail_issue_58033 = xfail_test(reason="Einsum operation misses support for complex ellipsis equations") xfail_issue_58676 = xfail_test(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07") skip_issue_58676 = pytest.mark.skip(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07") xfail_issue_onnx_models_140 = xfail_test(reason="https://github.com/onnx/models/issues/140") diff --git a/src/frontends/onnx/tests/tests_python/test_backend.py b/src/frontends/onnx/tests/tests_python/test_backend.py index 39b9788d720af3..487454675ac50e 100644 --- a/src/frontends/onnx/tests/tests_python/test_backend.py +++ b/src/frontends/onnx/tests/tests_python/test_backend.py @@ -32,7 +32,6 @@ xfail_issue_73538, xfail_issue_48052, xfail_issue_52463, - xfail_issue_58033, xfail_issue_63033, xfail_issue_63036, xfail_issue_63043, @@ -292,7 +291,6 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu", "OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu", ), - (xfail_issue_58033, "OnnxBackendNodeModelTest.test_einsum_batch_diagonal_cpu"), ( xfail_issue_63033, "OnnxBackendNodeModelTest.test_batchnorm_epsilon_training_mode_cpu", From 6796536d7dadca42999a6018366f56fbf71a0ec7 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Wed, 8 Jan 2025 13:27:41 +0000 Subject: [PATCH 06/18] Remove Einsum xfail for torch HF tests Signed-off-by: MATEUSZ MIKOLAJCZYK --- tests/model_hub_tests/pytorch/hf_transformers_models | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/model_hub_tests/pytorch/hf_transformers_models b/tests/model_hub_tests/pytorch/hf_transformers_models index 3d05a430a16671..c861c0cb6c64ee 100644 --- a/tests/model_hub_tests/pytorch/hf_transformers_models +++ b/tests/model_hub_tests/pytorch/hf_transformers_models @@ -88,7 +88,7 @@ hifigan,microsoft/speecht5_hifigan,xfail,Load error: The size of tensor a (100) hubert,facebook/hubert-large-ls960-ft hybridbert,gokuls/bert_12_layer_model_v1 ibert,DunnBC22/ibert-roberta-base-Abusive_Or_Threatening_Speech -idefics,HuggingFaceM4/tiny-random-idefics,xfail,aten::einsum Different input dimensions indicated by the same labels for Einsum must be compatible +idefics,HuggingFaceM4/tiny-random-idefics imagegpt,openai/imagegpt-small informer,huggingface/informer-tourism-monthly,xfail,Load error: mat1 and mat2 shapes cannot be multiplied instructblip,Salesforce/instructblip-vicuna-7b @@ -106,7 +106,7 @@ levit,facebook/levit-128S,xfail,Trace error: Cannot insert a Tensor that require lilt,nielsr/lilt-xlm-roberta-base llama_with_landmark,Leooyii/Landmark_512_Slimpajama_1B longformer,allenai/longformer-base-4096 -longt5,pszemraj/long-t5-tglobal-base-16384-book-summary,xfail,(CVS-148676) Compile error: unsupported Einsum +longt5,pszemraj/long-t5-tglobal-base-16384-book-summary luke,oshizo/sbert-jsnli-luke-japanese-base-lite lxmert,unc-nlp/lxmert-base-uncased m2m_100,facebook/nllb-200-distilled-600M @@ -119,7 +119,7 @@ mbart,facebook/mbart-large-50-many-to-many-mmt mctct,speechbrain/m-ctc-t-large mega,Bingsu/mega-150m-arch,xfail,Trace error: Cannot insert a Tensor that requires grad as a constant megatron-bert,UFNLP/gatortron-base -mgp-str,alibaba-damo/mgp-str-base,xfail,(CVS-148676) Compile error: unsupported Einsum +mgp-str,alibaba-damo/mgp-str-base mobilebert,google/mobilebert-uncased mobilenet_v1,google/mobilenet_v1_0.75_192 mobilenet_v2,google/mobilenet_v2_1.0_224 From be8400c7f95f8a8aa3fc216882ccd0a0e60d38d2 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Thu, 16 Jan 2025 12:42:26 +0000 Subject: [PATCH 07/18] Update transpose reshape elimination for MatMul to handle broadcast from Einsum Signed-off-by: MATEUSZ MIKOLAJCZYK --- ...anspose_reshape_elimination_for_matmul.cpp | 46 +++++++++++++++++-- ...anspose_reshape_elimination_for_matmul.cpp | 14 +++++- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp index d3eff542d6b7af..caac91de147ab6 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp @@ -9,10 +9,12 @@ #include "itt.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/op/broadcast.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/transpose.hpp" +#include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" namespace { @@ -124,9 +126,16 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa auto transpose_before_pattern = ov::pass::pattern::wrap_type({input_2_pattern, const_transpose_before_pattern}); + auto const_optional_broadcast_before_pattern = ov::pass::pattern::wrap_type(); + auto optional_broadcast_before_pattern = ov::pass::pattern::wrap_type( + {transpose_before_pattern, const_optional_broadcast_before_pattern}); + + auto transpose_or_transpose_broadcast = std::make_shared( + OutputVector{transpose_before_pattern, optional_broadcast_before_pattern}); + auto const_reshape_before_pattern = ov::pass::pattern::wrap_type(); - auto reshape_before_pattern = - ov::pass::pattern::wrap_type({transpose_before_pattern, const_reshape_before_pattern}); + auto reshape_before_pattern = ov::pass::pattern::wrap_type( + {transpose_or_transpose_broadcast, const_reshape_before_pattern}); auto matmul_pattern = ov::pass::pattern::wrap_type({input_1_pattern, reshape_before_pattern}); @@ -181,8 +190,37 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa // transposes if (!check_transposes(transpose_before_order, transpose_after_order, transposed_b)) return false; - - const auto new_matmul = std::make_shared(input_1, input_2, transposed_a, false); + auto matmul_2_input = input_2; + // for einsum decomposition, check if broadcast exist and if so, reorder target shape based on transpose + if (pattern_value_map.count(optional_broadcast_before_pattern)) { + auto broadcast_before = ov::as_type_ptr( + pattern_value_map.at(optional_broadcast_before_pattern).get_node_shared_ptr()); + if (!broadcast_before) { + return false; + } + auto broadcast_before_constant = + ov::as_type_ptr(broadcast_before->get_input_node_shared_ptr(1)); + if (!broadcast_before_constant) { + return false; + } + auto broadcast_shape_after_transpose = broadcast_before_constant->cast_vector(); + if (broadcast_shape_after_transpose.size() != transpose_before_order.size()) { + return false; + } + std::vector broadcast_shape_no_transpose; + broadcast_shape_no_transpose.reserve(broadcast_shape_after_transpose.size()); + for (auto idx : transpose_before_order) { + broadcast_shape_no_transpose.push_back(broadcast_shape_after_transpose[idx]); + } + auto broadcast_shape_no_transpose_constant = + ov::op::v0::Constant::create(element::i64, + broadcast_before_constant->get_shape(), + broadcast_shape_no_transpose); + matmul_2_input = broadcast_before->clone_with_new_inputs({input_2, broadcast_shape_no_transpose_constant}); + copy_runtime_info(broadcast_before, matmul_2_input.get_node_shared_ptr()); + } + + const auto new_matmul = std::make_shared(input_1, matmul_2_input, transposed_a, false); new_matmul->set_friendly_name(transpose_after->get_friendly_name()); copy_runtime_info({transpose_before, reshape_before, matmul, reshape_after, transpose_after}, new_matmul); replace_node(transpose_after, new_matmul); diff --git a/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp b/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp index ea57598a16c653..1f8d376f86800d 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp +++ b/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp @@ -138,11 +138,21 @@ TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_Einsum) { { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); + auto broadcast_shape_constant_1 = + std::make_shared(element::i64, Shape{data_shape_1.size()}, data_shape_1); + auto broadcast_shape_constant_2 = + std::make_shared(element::i64, Shape{data_shape_2.size()}, data_shape_2); + auto broadcast_1 = std::make_shared(data_1, + broadcast_shape_constant_1, + ov::op::BroadcastType::BIDIRECTIONAL); + auto broadcast_2 = std::make_shared(data_2, + broadcast_shape_constant_2, + ov::op::BroadcastType::BIDIRECTIONAL); // for some cases Reshape may be first input for Matmul auto shape_constant = std::make_shared(element::i64, Shape{data_shape_1.size()}, data_shape_1); - auto reshape = std::make_shared(data_1, shape_constant, false); - auto matmul = std::make_shared(reshape, data_2, false, false); + auto reshape = std::make_shared(broadcast_1, shape_constant, false); + auto matmul = std::make_shared(reshape, broadcast_2, false, false); model_ref = std::make_shared(NodeVector{matmul}, ParameterVector{data_1, data_2}); } } From 81b5d3907e601c42a8b5b629b95daf9cba7128a8 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Fri, 17 Jan 2025 18:19:41 +0000 Subject: [PATCH 08/18] Initial Einsum update to handle ellipsis label without dimensions Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 94 ++++++++++++++----- src/core/reference/src/op/einsum.cpp | 45 ++++++++- .../include/einsum_shape_inference.hpp | 4 + 3 files changed, 112 insertions(+), 31 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 3a8f9b4c0c5ffb..0aea695fab1fb1 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -896,32 +896,6 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, is_separate_first2); transpose_input(input_nodes, input_subscripts, int_subscript2, input_ind2, subgraph_nodes); - // step 2. reshape both operands so that separate labels and reduced labels are represented - // with just one dimension this is needed by MatMul operation requirement to operands - // format. For example, the shape must be in a format [B1, ..., Bm, X1, Y] or [B1, ..., Bm, - // Y, X2], where B1, ..., Bm are common dimensions, X1 and X2 are collapsed dimensions - // for separate labels and Y is collapsed dimension for reduced labels - // this step is not needed for the operand if it satisfies to one of the requirements: - // 1. there is just one separate dimension and just one reduced dimension - // 2. there is no separate dimension, no common dimensions, and just one reduced dimension - bool no_reshape_for_matmul1 = - (reduced_labels_inds1.size() == 1 && separate_labels_inds1.size() == 1) || - (reduced_labels_inds1.size() == 1 && common_labels_inds1.size() == 0 && separate_labels_inds1.size() == 0); - bool no_reshape_for_matmul2 = - (reduced_labels_inds2.size() == 1 && separate_labels_inds2.size() == 1) || - (reduced_labels_inds2.size() == 1 && common_labels_inds2.size() == 0 && separate_labels_inds2.size() == 0); - // reshape back after MatMul is not needed if one of two requrements satisfies for both operands: - // 1. there is just one separate dimension - // 2. there is no separate dimension and no common dimensions present. - // If there is no separate dimension and common dimensions present, reshape is needed - // because auxiliary separate dimension has been added by Unsqueeze operation - // in the purpose for MatMul - bool no_reshape_back1 = - (separate_labels_inds1.size() == 1) || (common_labels_inds1.size() == 0 && separate_labels_inds1.size() == 0); - bool no_reshape_back2 = - (separate_labels_inds2.size() == 1) || (common_labels_inds2.size() == 0 && separate_labels_inds2.size() == 0); - bool no_reshape_after_matmul = no_reshape_back1 && no_reshape_back2; - auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; @@ -990,6 +964,34 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, reduced_sub_shape, is_separate_first2, subgraph_nodes); + + // step 2. reshape both operands so that separate labels and reduced labels are represented + // with just one dimension this is needed by MatMul operation requirement to operands + // format. For example, the shape must be in a format [B1, ..., Bm, X1, Y] or [B1, ..., Bm, + // Y, X2], where B1, ..., Bm are common dimensions, X1 and X2 are collapsed dimensions + // for separate labels and Y is collapsed dimension for reduced labels + // this step is not needed for the operand if it satisfies to one of the requirements: + // 1. there is just one separate dimension and just one reduced dimension + // 2. there is no separate dimension, no common dimensions, and just one reduced dimension + const auto common_labels1_size = common_dims_end - common_dims_begin; + const auto common_labels2_size = common_dims_end2 - common_dims_begin2; + const auto reduced_labels1_size = reduced_dims_end - reduced_dims_begin; + const auto reduced_labels2_size = reduced_dims_end2 - reduced_dims_begin2; + const auto separate_labels1_size = separate1_dims_end - separate1_dims_begin; + const auto separate_labels2_size = separate2_dims_end - separate2_dims_begin; + bool no_reshape_for_matmul1 = (reduced_labels1_size == 1 && separate_labels1_size == 1) || + (reduced_labels1_size == 1 && common_labels1_size == 0 && separate_labels1_size == 0); + bool no_reshape_for_matmul2 = (reduced_labels2_size == 1 && separate_labels2_size == 1) || + (reduced_labels2_size == 1 && common_labels2_size == 0 && separate_labels2_size == 0); + // reshape back after MatMul is not needed if one of two requirements satisfies for both operands: + // 1. there is just one separate dimension + // 2. there is no separate dimension and no common dimensions present. + // If there is no separate dimension and common dimensions present, reshape is needed + // because auxiliary separate dimension has been added by Unsqueeze operation + // in the purpose for MatMul + bool no_reshape_back1 = (separate_labels1_size == 1) || (common_labels1_size == 0 && separate_labels1_size == 0); + bool no_reshape_back2 = (separate_labels2_size == 1) || (common_labels2_size == 0 && separate_labels2_size == 0); + bool no_reshape_after_matmul = no_reshape_back1 && no_reshape_back2; if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { matmul_operand1 = reshape_input_for_matmul(matmul_operand1, common_sub_shape, @@ -1091,6 +1093,46 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { // in more optimal order auto einsum_path = compute_einsum_path(einsum_node); + // fix inputs where ellipsis does not contain any dimensions + std::vector ellipsis_inputs(input_nodes.size(), false); + std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); + static const std::string ellipsis = "..."; + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); + if (!ellipsis_inputs[inp_iter] || + (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { + no_ellipsis_or_empty_inputs[inp_iter] = true; + } + } + if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { + return inp; + })) { + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { + return inp; + })) { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (input_subscripts[inp_iter].find("...") != std::string::npos) { + input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } + } + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { + auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); + std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; + input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); + } + } + } + // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { contract_two_inputs(this, diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index d16e500b40e2fe..4457b628f670be 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -853,9 +853,7 @@ void contract_two_inputs(ov::TensorVector& inputs, PartialShape common_sub_shape1 = compute_sub_shape(input_shape1, common_dims_begin, common_dims_end); PartialShape common_sub_shape2 = compute_sub_shape(input_shape2, common_dims_begin2, common_dims_end2); - PartialShape reduced_sub_shape_prod = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end, true); PartialShape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end); - Shape reduced_sub_shape_prod2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2, true); Shape reduced_sub_shape2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2); Shape separate1_sub_shape = compute_sub_shape(input_shape1, separate1_dims_begin, separate1_dims_end); Shape separate2_sub_shape = compute_sub_shape(input_shape2, separate2_dims_begin, separate2_dims_end); @@ -865,7 +863,7 @@ void contract_two_inputs(ov::TensorVector& inputs, // reference::broadcast() PartialShape::broadcast_merge_into(common_sub_shape1, common_sub_shape2, op::AutoBroadcastType::NUMPY); PartialShape::broadcast_merge_into(reduced_sub_shape, reduced_sub_shape2, op::AutoBroadcastType::NUMPY); - PartialShape::broadcast_merge_into(reduced_sub_shape_prod, reduced_sub_shape_prod2, op::AutoBroadcastType::NUMPY); + Shape reduced_sub_shape_prod = {shape_size(reduced_sub_shape.get_shape())}; Shape common_sub_shape = common_sub_shape1.get_shape(); broadcast_input(inputs, input_ind1, @@ -883,13 +881,13 @@ void contract_two_inputs(ov::TensorVector& inputs, ov::Tensor matmul_operand1 = reshape_input_for_matmul(input1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape_prod.get_shape(), + reduced_sub_shape_prod, is_separate_first1); ov::Tensor matmul_operand2 = reshape_input_for_matmul(input2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape_prod.get_shape(), + reduced_sub_shape_prod, is_separate_first2); // step 3. apply MatMul operation for formatted inputs @@ -941,6 +939,43 @@ void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, cons auto einsum_path = compute_einsum_path(num_inputs); ov::TensorVector int_inputs = inputs; + std::vector ellipsis_inputs(inputs.size(), false); + std::vector no_ellipsis_or_empty_inputs(inputs.size(), false); + static const std::string ellipsis = "..."; + for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); + if (!ellipsis_inputs[inp_iter] || (inputs[inp_iter].get_shape().size() == (labels.size() - 1))) { + no_ellipsis_or_empty_inputs[inp_iter] = true; + } + } + if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { + return inp; + })) { + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { + return inp; + })) { + for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { + if (input_subscripts[inp_iter].find("...") != std::string::npos) { + input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } + } + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else { + for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { + if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { + auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); + std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; + int_inputs[inp_iter] = unsqueeze_input(inputs[inp_iter], ellipsis_idx); + } + } + } // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { diff --git a/src/core/shape_inference/include/einsum_shape_inference.hpp b/src/core/shape_inference/include/einsum_shape_inference.hpp index 2a7cd60369261e..1ee471117d6872 100644 --- a/src/core/shape_inference/include/einsum_shape_inference.hpp +++ b/src/core/shape_inference/include/einsum_shape_inference.hpp @@ -101,6 +101,10 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s auto& output_shape = output_shapes[0]; for (auto const& output_label : output_labels) { + if (output_label == "..." && label_to_shape.find(output_label) == label_to_shape.end()) { + // Output labels may contain ellipsis that does not cover any dimensions. + continue; + } NODE_VALIDATION_CHECK(op, label_to_shape.find(output_label) != label_to_shape.end(), "Label in output subscript of Einsum equation must enter at least " From 28b579ac85d0fc0a8ce23e6fb1c37e1da09961a9 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Mon, 20 Jan 2025 15:35:11 +0000 Subject: [PATCH 09/18] Update reduce_input in einsum common decomposition Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 0aea695fab1fb1..8fde064b2214af 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -546,38 +546,46 @@ void reduce_input(ov::pass::EinsumDecomposition* einsum_decompose_ptr, size_t input_ind, ov::NodeVector& subgraph_nodes) { // perform sanity check for arguments - auto num_inputs = input_nodes.size(); + const auto num_inputs = input_nodes.size(); OPENVINO_ASSERT(num_inputs == input_subscripts.size(), "Each input must have own subscript."); OPENVINO_ASSERT(input_ind < num_inputs, "Input index is out of range."); - std::vector reduced_axes; - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[input_ind]); + const auto& input_node = input_nodes[input_ind]; + const auto& input_subscript = input_subscripts[input_ind]; + + // compute output shape and axes to reduce + std::set reduced_axes; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[input_ind]); + auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); std::string new_input_subscript = ""; - for (size_t dim_ind = 0; dim_ind < labels.size(); ++dim_ind) { - const auto& label = labels[dim_ind]; + for (const auto& label : labels) { // check if the current label is met in the other input subscripts // or the output subscript - bool is_dim_reduced = is_dimension_reduced(input_subscripts, output_subscript, label, {input_ind}); + const bool is_dim_reduced = is_dimension_reduced(input_subscripts, output_subscript, label, {input_ind}); + + OPENVINO_ASSERT(label_dim_map.find(label) != label_dim_map.end()); + const auto& label_dims = label_dim_map[label]; // if label is not met, dimension corresponding to the label is to reduce if (is_dim_reduced) { - reduced_axes.push_back(dim_ind); + reduced_axes.insert(label_dims.begin(), label_dims.end()); } else { new_input_subscript += label; } } - if (reduced_axes.size() == 0) { + if (reduced_axes.empty()) { // there is no axis to reduce return; } // reduce by summed up elements along dimension for which label is met just once - const auto& input_node = input_nodes[input_ind]; - auto axes_const = - ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes); - auto reduce_sum = einsum_decompose_ptr->register_new_node(input_node, axes_const, false); + const std::vector reduced_axes_vec{reduced_axes.cbegin(), reduced_axes.cend()}; + const auto axes_const = + ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes_vec); + const auto reduce_sum = + einsum_decompose_ptr->register_new_node(input_node, axes_const, false); // update a vector of inputs and input subscripts input_nodes[input_ind] = reduce_sum->output(0); From 33acf2ecedfe41ac75d09575f78434e2b0c3d34f Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Mon, 20 Jan 2025 16:26:50 +0000 Subject: [PATCH 10/18] Fix broadcasting of reduced part for reshape Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 8fde064b2214af..f21cfcf74aafc5 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -412,7 +412,7 @@ ov::Output broadcast_input(const ov::Output& input_node, ov::Output reshape_input_for_matmul(const ov::Output& input_node, const ov::OutputVector& common_sub_shape, const ov::OutputVector& separate_sub_shape, - const ov::OutputVector& reduced_sub_shape_prod, + const ov::OutputVector& reduced_sub_shape, bool is_separate_first, ov::NodeVector& subgraph_nodes) { ov::OutputVector new_shape_parts; @@ -436,9 +436,17 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ separate_parts.push_back(separate_shape_prod->output(0)); subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, separate_shape_prod}); } + ov::OutputVector reduced_sub_shape_prod; + auto const_0 = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); + for (auto sub_shape : reduced_sub_shape) { + auto product = std::make_shared(sub_shape, const_0, true); + subgraph_nodes.insert(subgraph_nodes.end(), {const_0, product}); + reduced_sub_shape_prod.push_back(product->output(0)); + } // form a new shape for input so that collapsed dimensions corresponding // to the common, separate and reduced dimensions are placed in the correct order + if (is_separate_first) { new_shape_parts.insert(new_shape_parts.end(), separate_parts.begin(), separate_parts.end()); new_shape_parts.insert(new_shape_parts.end(), reduced_sub_shape_prod.begin(), reduced_sub_shape_prod.end()); @@ -454,7 +462,7 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ auto new_shape_op = std::make_shared(new_shape_parts, 0); - // if new shape is possible to compute on the shape infer stage, insert Constant node immediatelly + // if new shape is possible to compute on the shape infer stage, insert Constant node immediately // in order to prevent repeated computing during constant-folding pass std::shared_ptr reshaped_input_op; if (auto new_shape_const = ov::util::get_constant_from_source(new_shape_op)) { @@ -947,17 +955,13 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); - auto reduced_sub_shape_prod = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); - auto reduced_sub_shape_prod2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); auto reduced_sub_shape = compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); auto reduced_sub_shape2 = compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); - reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); + separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); matmul_operand1 = broadcast_input(input_node1, common_sub_shape, @@ -1004,7 +1008,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, matmul_operand1 = reshape_input_for_matmul(matmul_operand1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape, is_separate_first1, subgraph_nodes); } @@ -1013,7 +1017,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, matmul_operand2 = reshape_input_for_matmul(matmul_operand2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape, is_separate_first2, subgraph_nodes); } From 29b1072bce9df6f14edeb68e5577e40b7f51f636 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Tue, 21 Jan 2025 16:16:33 +0000 Subject: [PATCH 11/18] Extend Einsum reference test cases Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../tests/functional/op_reference/einsum.cpp | 200 ++++++++++++++++++ 1 file changed, 200 insertions(+) diff --git a/src/plugins/template/tests/functional/op_reference/einsum.cpp b/src/plugins/template/tests/functional/op_reference/einsum.cpp index 2d3e7fb627305f..4dd8f46a405472 100644 --- a/src/plugins/template/tests/functional/op_reference/einsum.cpp +++ b/src/plugins/template/tests/functional/op_reference/einsum.cpp @@ -154,6 +154,205 @@ std::vector generateParams() { .equation("abbac,bad->ad") .expectedResult({ET, {2, 1}, std::vector{123, 129}}) .testcaseName("einsum_diagonal_with_matmul"), + + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("...->...") + .expectedResult({ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}) + .testcaseName("einsum_identity"), + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("i...->i") + .expectedResult({ET, {2}, std::vector{6, 15}}) + .testcaseName("einsum_reduce_ellipsis"), + Builder{} + .inputs({{ET, {3, 3, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}}}) + .equation("iii->") + .expectedResult({ET, {}, std::vector{42}}) + .testcaseName("einsum_trace"), + Builder{} + .inputs({{ET, {3, 3, 4}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}}}) + .equation("ii...->") + .expectedResult({ET, {}, std::vector{222}}) + .testcaseName("einsum_trace_ellipsis"), + Builder{} + .inputs({{ET, {3, 2, 1, 2, 1, 3, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}}}) + .equation("ijkjkik->ijk") + .expectedResult({ET, {3, 2, 1}, std::vector{1, 10, 14, 23, 27, 36}}) + .testcaseName("einsum_diagonal_mixed_order"), + Builder{} + .inputs({{ET, + {3, 3, 3, 3, 3}, + std::vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, + 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, + 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, + 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, + 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, + 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243}}}) + .equation("iiiii->i") + .expectedResult({ET, {3}, std::vector{1, 122, 243}}) + .testcaseName("einsum_5d_diagonal"), + Builder{} + .inputs({{ET, {2, 1}, std::vector{1, 2}}, + {ET, {4, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {3, 1, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9}}}) + .equation("ab,bcd,dbc->ca") + .expectedResult({ET, {3, 2}, std::vector{120, 240, 150, 300, 180, 360}}) + .testcaseName("einsum_3in_broadcast"), + Builder{} + .inputs({{ET, {2, 1}, std::vector{1, 2}}, {ET, {3, 2}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("ab,bc->ac") + .expectedResult({ET, {2, 2}, std::vector{9, 12, 18, 24}}) + .testcaseName("einsum_2in_broadcast_lhs_reduced"), + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}, {ET, {1, 2}, std::vector{1, 2}}}) + .equation("ab,bc->ac") + .expectedResult({ET, {2, 2}, std::vector{6, 12, 15, 30}}) + .testcaseName("einsum_2in_broadcast_rhs_reduced"), + Builder{} + .inputs({{ET, {2, 1}, std::vector{1, 2}}, {ET, {3, 2}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("ab,bc->bc") + .expectedResult({ET, {3, 2}, std::vector{3, 6, 9, 12, 15, 18}}) + .testcaseName("einsum_2in_broadcast_lhs_common"), + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}, {ET, {1, 2}, std::vector{1, 2}}}) + .equation("ab,bc->cb") + .expectedResult({ET, {2, 3}, std::vector{5, 7, 9, 10, 14, 18}}) + .testcaseName("einsum_2in_broadcast_rhs_common"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("aj,j...->a...") + .expectedResult({ET, {1, 4, 2, 1}, std::vector{70, 76, 82, 88, 94, 100, 106, 112}}) + .testcaseName("einsum_2in_only_rhs_out_ellipsis"), + Builder{} + .inputs({{ET, + {2, 7, 4, 3}, + std::vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, + 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, + 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, + 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j->a...") + .expectedResult( + {ET, {2, 7, 4}, std::vector{14, 32, 50, 68, 86, 104, 122, 140, 158, 176, 194, 212, 230, 248, + 266, 284, 302, 320, 338, 356, 374, 392, 410, 428, 446, 464, 482, 500, + 518, 536, 554, 572, 590, 608, 626, 644, 662, 680, 698, 716, 734, 752, + 770, 788, 806, 824, 842, 860, 878, 896, 914, 932, 950, 968, 986, 1004}}) + .testcaseName("einsum_2in_only_lhs_out_ellipsis"), + Builder{} + .inputs({{ET, + {2, 7, 4, 3}, + std::vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, + 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, + 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, + 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j->a") + .expectedResult({ET, {2}, std::vector{7196, 21308}}) + .testcaseName("einsum_2in_lhs_ellipsis_out_reduced"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("aj,j...->a") + .expectedResult({ET, {1}, std::vector{728}}) + .testcaseName("einsum_2in_rhs_ellipsis_out_reduced"), + Builder{} + .inputs({{ET, {1, 1, 4, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{8312}}) + .testcaseName("einsum_2in_broadcast_ellipsis_out_reduced"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("a...j,j...->a...") + .expectedResult({ET, {1, 4, 2, 1}, std::vector{70, 76, 82, 88, 94, 100, 106, 112}}) + .testcaseName("einsum_2in_unsqueeze_lhs_ellipsis"), + Builder{} + .inputs({{ET, {1, 1, 4, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a...") + .expectedResult({ET, {1, 1, 4}, std::vector{14, 32, 50, 68}}) + .testcaseName("einsum_2in_unsqueeze_rhs_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{728}}) + .testcaseName("einsum_2in_unsqueeze_lhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 1, 4, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{164}}) + .testcaseName("einsum_2in_unsqueeze_rhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_lhs_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("aj,j...->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_rhs_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("aj,j->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_all_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {1}, std::vector{1}}}) + .equation("a...j,j->a") + .expectedResult({ET, {1}, std::vector{6}}) + .testcaseName("einsum_2in_prune_lhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 1}, std::vector{1}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("aj,j...->a") + .expectedResult({ET, {1}, std::vector{6}}) + .testcaseName("einsum_2in_prune_rhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_inp_ellipsis_no_out_ellipsis") + }; return params; } @@ -161,6 +360,7 @@ std::vector generateParams() { std::vector generateCombinedParams() { const std::vector> generatedParams{ generateParams(), + generateParams(), generateParams(), }; std::vector combinedParams; From 9e749f9bd8690532f72c568c03114b2bb4377670 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 15:50:03 +0100 Subject: [PATCH 12/18] FIx divide by 0 and handling 2+ repeated label types for einsum decomposition Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index f21cfcf74aafc5..26f88a3df71f3e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -16,7 +16,9 @@ #include "openvino/op/einsum.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/maximum.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/power.hpp" #include "openvino/op/range.hpp" #include "openvino/op/reduce_prod.hpp" #include "openvino/op/reduce_sum.hpp" @@ -610,21 +612,25 @@ ov::Output build_identity(const ov::Output& input_node, const auto input_shape = std::make_shared(input_node); const auto repeated_label_indices = ov::op::v0::Constant::create(ov::element::i64, {repeated_label_dims.size()}, repeated_label_dims); + const auto repeated_label_indices_len = + ov::op::v0::Constant::create(ov::element::i64, {}, {repeated_label_dims.size()}); const auto const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); const auto const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); const auto repeated_dimensions = std::make_shared(input_shape, repeated_label_indices, const_0); const auto reduced_dimension = std::make_shared(repeated_dimensions, const_0, const_0); - const auto reduced_dimension_min_1 = std::make_shared(reduced_dimension, const_1); - - const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); - const auto reduced_size_min_1 = std::make_shared(reduced_size, const_1); - const auto step_size = std::make_shared(reduced_size_min_1, reduced_dimension_min_1); - const auto range = std::make_shared(const_0, reduced_dimension, const_1, ov::element::i64); - const auto steps = std::make_shared(range, step_size); - const auto zeros = std::make_shared(const_0, reduced_size); + const auto range_max_val = std::make_shared(reduced_dimension, repeated_label_indices_len); + const auto step_numerator = std::make_shared(range_max_val, const_1); + const auto step_denominator = std::make_shared(reduced_dimension, const_1); + const auto step_denominator_but_not_0 = std::make_shared(step_denominator, const_1); + const auto step_numerator_but_not_0 = std::make_shared(step_numerator, const_1); + const auto step = std::make_shared(step_numerator_but_not_0, step_denominator_but_not_0); + const auto eye_flattened_indices = std::make_shared(const_0, range_max_val, step); const auto reduced_dimension_1d = std::make_shared(reduced_dimension, const_0); const auto ones = std::make_shared(const_1, reduced_dimension_1d); - const auto eye_flattened = std::make_shared(zeros, steps, ones, const_0); + const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); + const auto zeros = std::make_shared(const_0, reduced_size); + const auto eye_flattened = + std::make_shared(zeros, eye_flattened_indices, ones, const_0); const auto identity_rank = std::make_shared(input_shape); const auto ones_of_input_shape_rank = std::make_shared(const_1, identity_rank); @@ -632,24 +638,28 @@ ov::Output build_identity(const ov::Output& input_node, repeated_label_indices, repeated_dimensions, const_0); + const auto identity = std::make_shared(eye_flattened, identity_shape, false); const auto identity_cvt = std::make_shared(identity, input_node.get_element_type()); subgraph_nodes.insert(subgraph_nodes.end(), {input_shape, repeated_label_indices, + repeated_label_indices_len, const_0, const_1, repeated_dimensions, reduced_dimension, - reduced_dimension_min_1, - reduced_size, - reduced_size_min_1, - step_size, - range, - steps, - zeros, + range_max_val, + step_numerator, + step_denominator, + step_denominator_but_not_0, + step_numerator_but_not_0, + step, + eye_flattened_indices, reduced_dimension_1d, ones, + reduced_size, + zeros, eye_flattened, identity_rank, ones_of_input_shape_rank, @@ -673,12 +683,12 @@ ov::Output build_multi_identity(ov::pass::EinsumDecomposition* einsum_ }; // initially set multi-identity with identity for the first repeated label - const auto multi_identity = get_identity(0); + auto multi_identity = get_identity(0).get_node_shared_ptr(); for (size_t label_ind = 1; label_ind < repeated_labels.size(); ++label_ind) { const auto identity = get_identity(label_ind); - const auto mul = + multi_identity = std::make_shared(multi_identity, identity, ov::op::AutoBroadcastType::NUMPY); - subgraph_nodes.insert(subgraph_nodes.end(), {mul}); + subgraph_nodes.insert(subgraph_nodes.end(), {multi_identity}); } return subgraph_nodes.back(); From 50b6d3ef81700e7b9a21048c78e253c6a72d2b48 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 17:45:54 +0100 Subject: [PATCH 13/18] Move fix_inputs_with_0d_ellipsis to separate function Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 94 +++++++++++-------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 26f88a3df71f3e..0d322c99af813e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -1079,6 +1079,61 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, // update a vector of nodes for copy_runtime_info subgraph_nodes.insert(subgraph_nodes.end(), {matmul}); } + +/// \brief Adjusts input subscripts and nodes to handle 0-dimensional ellipsis in Einsum operations. +/// +/// Handle ellipses labels that do not represent any dimensions: +/// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript. +/// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts. +/// 3. If there is at least one ellipsis that does not represent any dimensions, unsqueeze the corresponding input at +/// ellipsis dimension. +/// +/// \param input_nodes A vector of input nodes for the Einsum operation. +/// \param input_subscripts A vector of input subscripts corresponding to the input nodes. +/// \param output_subscript The output subscript for the Einsum operation. +/// \param subgraph_nodes A vector to store nodes created during the subgraph transformation. +void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, + std::vector& input_subscripts, + std::string& output_subscript, + ov::NodeVector& subgraph_nodes) { + std::vector ellipsis_inputs(input_nodes.size(), false); + std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); + static const std::string ellipsis = "..."; + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); + if (!ellipsis_inputs[inp_iter] || (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { + no_ellipsis_or_empty_inputs[inp_iter] = true; + } + } + if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { + return inp; + })) { + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { + return inp; + })) { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (input_subscripts[inp_iter].find("...") != std::string::npos) { + input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } + } + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { + auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); + std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; + input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); + } + } + } +} } // namespace ov::pass::EinsumDecomposition::EinsumDecomposition() { @@ -1116,44 +1171,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { auto einsum_path = compute_einsum_path(einsum_node); // fix inputs where ellipsis does not contain any dimensions - std::vector ellipsis_inputs(input_nodes.size(), false); - std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); - static const std::string ellipsis = "..."; - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); - if (!ellipsis_inputs[inp_iter] || - (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { - no_ellipsis_or_empty_inputs[inp_iter] = true; - } - } - if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { - return inp; - })) { - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); - } - } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { - return inp; - })) { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (input_subscripts[inp_iter].find("...") != std::string::npos) { - input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); - } - } - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); - } - } else { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); - std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; - input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); - } - } - } + fix_inputs_with_0d_ellipsis(input_nodes, input_subscripts, output_subscript, subgraph_nodes); // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { From f666700635e2f5836f3461855b98e1ab738dd1a0 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 18:07:58 +0100 Subject: [PATCH 14/18] Modify reshape_input_for_matmul reduced prod to match ne for separate Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 0d322c99af813e..a882c65f0b1dd5 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -422,6 +422,7 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ // compute a product of a sub-shape for separate labels ov::OutputVector separate_parts; + auto reduce_axis_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {0}); if (common_sub_shape.size() > 0 && separate_sub_shape.size() == 0) { // in this case new dimension corresponding to separate labels must be added // since MatMul operation is not possible to do without separate dimensions if the @@ -432,7 +433,6 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ } else if (separate_sub_shape.size() > 0) { // in this case compute a product of separate dimension sizes since they must be // presented with just one dimension for MatMul - auto reduce_axis_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {0}); auto separate_shape_prod = std::make_shared(separate_sub_shape[0], reduce_axis_const, true); separate_parts.push_back(separate_shape_prod->output(0)); @@ -440,9 +440,9 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ } ov::OutputVector reduced_sub_shape_prod; auto const_0 = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); - for (auto sub_shape : reduced_sub_shape) { - auto product = std::make_shared(sub_shape, const_0, true); - subgraph_nodes.insert(subgraph_nodes.end(), {const_0, product}); + if (reduced_sub_shape.size() > 0) { + auto product = std::make_shared(reduced_sub_shape[0], const_0, true); + subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, product}); reduced_sub_shape_prod.push_back(product->output(0)); } From 6347ed2d072f6b5eef71e2583debdb0a8caeede7 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 19:05:51 +0100 Subject: [PATCH 15/18] Refactor empty ellipsis handling in Einsum decomposition to improve clarity Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 60 ++++++++++--------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index a882c65f0b1dd5..b58b58b57ae5d2 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -1085,8 +1085,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, /// Handle ellipses labels that do not represent any dimensions: /// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript. /// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts. -/// 3. If there is at least one ellipsis that does not represent any dimensions, unsqueeze the corresponding input at -/// ellipsis dimension. +/// 3. If there is at least one ellipsis that represents dimension, unsqueeze ellipses that do not represent any, /// /// \param input_nodes A vector of input nodes for the Einsum operation. /// \param input_subscripts A vector of input subscripts corresponding to the input nodes. @@ -1096,40 +1095,43 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, std::vector& input_subscripts, std::string& output_subscript, ov::NodeVector& subgraph_nodes) { - std::vector ellipsis_inputs(input_nodes.size(), false); - std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); static const std::string ellipsis = "..."; - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); - if (!ellipsis_inputs[inp_iter] || (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { - no_ellipsis_or_empty_inputs[inp_iter] = true; - } + bool has_ellipsis = false; + bool all_no_ellipsis_or_empty = true; + + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end(); + has_ellipsis |= has_ellipsis_in_input; + all_no_ellipsis_or_empty &= + !has_ellipsis_in_input || (input_nodes[i].get_partial_shape().rank().get_length() == + static_cast(labels.size() - 1)); } - if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { - return inp; - })) { - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + + if (!has_ellipsis) { + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } - } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { - return inp; - })) { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (input_subscripts[inp_iter].find("...") != std::string::npos) { - input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } else if (all_no_ellipsis_or_empty) { + for (auto& subscript : input_subscripts) { + if (subscript.find(ellipsis) != std::string::npos) { + subscript.erase(subscript.find(ellipsis), ellipsis.size()); } } - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } } else { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); - std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; - input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() && + input_nodes[i].get_partial_shape().rank().get_length() == + static_cast(labels.size() - 1)) { + input_nodes[i] = unsqueeze_input( + input_nodes[i], + {static_cast( + std::distance(labels.begin(), std::find(labels.begin(), labels.end(), ellipsis)))}, + subgraph_nodes); } } } From d380155756cabcab19084663d54f9d4fde8c23de Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Fri, 24 Jan 2025 12:25:27 +0100 Subject: [PATCH 16/18] Refactor handling of 0-dimensional ellipsis in Einsum operations for improved clarity Signed-off-by: Mateusz Mikolajczyk --- src/core/reference/src/op/einsum.cpp | 93 ++++++++++++++++------------ 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index 4457b628f670be..ec0be118fc13e1 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -927,55 +927,72 @@ void contract_two_inputs(ov::TensorVector& inputs, update_operands(inputs, input_subscripts, input_ind1, input_ind2, contract_output, resultant_subscript); } +/// \brief Adjusts input subscripts and nodes to handle 0-dimensional ellipsis in Einsum operations. +/// +/// Handle ellipses labels that do not represent any dimensions: +/// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript. +/// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts. +/// 3. If there is at least one ellipsis that represents dimension, unsqueeze ellipses that do not represent any, +/// +/// \param input_nodes A vector of input tensors for the Einsum operation. +/// \param input_subscripts A vector of input subscripts corresponding to the input nodes. +/// \param output_subscript The output subscript for the Einsum operation. template -void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, const std::string& equation) { - std::vector input_subscripts; - std::string output_subscript; - ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript); - - // compute einsum path that is used to contract a pair of operands - // in more optimal order - size_t num_inputs = inputs.size(); - auto einsum_path = compute_einsum_path(num_inputs); - - ov::TensorVector int_inputs = inputs; - std::vector ellipsis_inputs(inputs.size(), false); - std::vector no_ellipsis_or_empty_inputs(inputs.size(), false); +void fix_inputs_with_0d_ellipsis(ov::TensorVector& input_nodes, + std::vector& input_subscripts, + std::string& output_subscript) { static const std::string ellipsis = "..."; - for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { - const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); - if (!ellipsis_inputs[inp_iter] || (inputs[inp_iter].get_shape().size() == (labels.size() - 1))) { - no_ellipsis_or_empty_inputs[inp_iter] = true; - } + bool has_ellipsis = false; + bool all_no_ellipsis_or_empty = true; + + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end(); + has_ellipsis |= has_ellipsis_in_input; + all_no_ellipsis_or_empty &= + !has_ellipsis_in_input || (input_nodes[i].get_shape().size() == (labels.size() - 1)); } - if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { - return inp; - })) { - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + + if (!has_ellipsis) { + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } - } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { - return inp; - })) { - for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { - if (input_subscripts[inp_iter].find("...") != std::string::npos) { - input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } else if (all_no_ellipsis_or_empty) { + for (auto& subscript : input_subscripts) { + if (subscript.find(ellipsis) != std::string::npos) { + subscript.erase(subscript.find(ellipsis), ellipsis.size()); } } - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } } else { - for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { - if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); - std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; - int_inputs[inp_iter] = unsqueeze_input(inputs[inp_iter], ellipsis_idx); + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() && + input_nodes[i].get_shape().size() == (labels.size() - 1)) { + std::vector ellipsis_idx{ + std::distance(labels.begin(), std::find(labels.begin(), labels.end(), ellipsis))}; + input_nodes[i] = unsqueeze_input(input_nodes[i], ellipsis_idx); } } } +} + +template +void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, const std::string& equation) { + std::vector input_subscripts; + std::string output_subscript; + ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript); + + // compute einsum path that is used to contract a pair of operands + // in more optimal order + size_t num_inputs = inputs.size(); + auto einsum_path = compute_einsum_path(num_inputs); + ov::TensorVector int_inputs = inputs; + + // fix inputs where ellipsis does not contain any dimensions + fix_inputs_with_0d_ellipsis(int_inputs, input_subscripts, output_subscript); // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { From 191899102de2eae9d1560999ea5b9d43bcd45f9e Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Fri, 24 Jan 2025 19:05:56 +0100 Subject: [PATCH 17/18] Refactor broadcast_merge_shapes to eliminate loop Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index b58b58b57ae5d2..35cc48253b0e1e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -322,22 +322,24 @@ ov::Output unsqueeze_input(const ov::Output& input_node, ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs, ov::OutputVector& shapes_rhs, ov::NodeVector& subgraph_nodes) { - // TODO - Refactor func to remove loop and duplicated Broadcast. - OPENVINO_ASSERT(shapes_lhs.size() == shapes_rhs.size()); - ov::OutputVector broadcasted_shape_nodes{shapes_lhs.size()}; - - for (size_t shp_i = 0; shp_i < shapes_lhs.size(); shp_i++) { + ov::OutputVector broadcasted_shape_nodes{}; + // OutputVector is either empty or contains a single shape + if (shapes_lhs.size() == 1 && shapes_rhs.size() == 1) { auto const_1 = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {1}); auto tmp_const_of_lhs_shp = - std::make_shared(const_1, shapes_lhs[shp_i], ov::op::BroadcastType::NUMPY); + std::make_shared(const_1, shapes_lhs[0], ov::op::BroadcastType::NUMPY); auto tmp_const_of_broadcasted_shp = std::make_shared(tmp_const_of_lhs_shp, - shapes_rhs[shp_i], + shapes_rhs[0], ov::op::BroadcastType::BIDIRECTIONAL); auto broadcasted_shape = std::make_shared(tmp_const_of_broadcasted_shp); - broadcasted_shape_nodes[shp_i] = broadcasted_shape; + broadcasted_shape_nodes.push_back(broadcasted_shape->output(0)); subgraph_nodes.insert(subgraph_nodes.end(), {const_1, tmp_const_of_lhs_shp, tmp_const_of_broadcasted_shp, broadcasted_shape}); + } else if (shapes_lhs.size() == 0 && shapes_rhs.size() == 1) { + return shapes_rhs; + } else if (shapes_lhs.size() == 1 && shapes_rhs.size() == 0) { + return shapes_lhs; } return broadcasted_shape_nodes; } From 2eee35cd98f0bd78fd7dfc289bb254396380da28 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 30 Jan 2025 11:21:19 +0100 Subject: [PATCH 18/18] Fix shape_infer for reduced out ellipsis with dynamic rank inputs Signed-off-by: Mateusz Mikolajczyk --- .../include/einsum_shape_inference.hpp | 13 ++-- src/core/tests/type_prop/einsum.cpp | 60 +++++++++++++++++++ 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/core/shape_inference/include/einsum_shape_inference.hpp b/src/core/shape_inference/include/einsum_shape_inference.hpp index 1ee471117d6872..007b220ddda3d6 100644 --- a/src/core/shape_inference/include/einsum_shape_inference.hpp +++ b/src/core/shape_inference/include/einsum_shape_inference.hpp @@ -24,10 +24,13 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s input_subscripts.size() == input_shapes.size(), "Equation must contain a number of subscripts equal to a number of Einsum inputs."); + const auto output_labels = Einsum::extract_labels(output_subscript); + const auto has_out_ellipsis = std::any_of(output_labels.begin(), output_labels.end(), [](std::string label) { + return label == "..."; + }); // create a dictionary with dimension sizes (or ranges in case of dynamic shapes) for each label // and check their compatibility in case of repeating labels std::unordered_map label_to_shape; - for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) { const auto& pshape = input_shapes[input_idx]; const auto labels = Einsum::extract_labels(input_subscripts[input_idx]); @@ -78,16 +81,11 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s } } } else { - if (has_ellipsis) { + if (has_ellipsis && has_out_ellipsis) { // Shape has dynamic rank and ellipsis return {pshape}; } for (auto const& label : labels) { - NODE_VALIDATION_CHECK(op, - label != "...", - "The subscript corresponding to a dynamic rank input must " - "not contain ellipsis."); - if (label_to_shape.find(label) == label_to_shape.end()) { label_to_shape[label] = ov::PartialShape{Dimension::dynamic()}; } @@ -96,7 +94,6 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s } // compute the output shape - const auto output_labels = Einsum::extract_labels(output_subscript); auto output_shapes = std::vector(1); auto& output_shape = output_shapes[0]; diff --git a/src/core/tests/type_prop/einsum.cpp b/src/core/tests/type_prop/einsum.cpp index 4772393a89f497..5c96289af647ea 100644 --- a/src/core/tests/type_prop/einsum.cpp +++ b/src/core/tests/type_prop/einsum.cpp @@ -478,6 +478,66 @@ TEST_F(TypePropEinsumTest, all_dynamic_rank_ellipsis) { EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); } +TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis_reduced_out_ellipsis) { + const std::string equation = "a...b,b...->a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({ov::Dimension::dynamic()})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis_reduced_out_ellipsis) { + const std::string equation = "a...b,b...->a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({3})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + TEST_F(TypePropEinsumTest, broadcasting_same_symbol_common) { const std::string equation = "ab,ba->b"; constexpr auto et = element::i32;