From 2eee35cd98f0bd78fd7dfc289bb254396380da28 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 30 Jan 2025 11:21:19 +0100 Subject: [PATCH] Fix shape_infer for reduced out ellipsis with dynamic rank inputs Signed-off-by: Mateusz Mikolajczyk --- .../include/einsum_shape_inference.hpp | 13 ++-- src/core/tests/type_prop/einsum.cpp | 60 +++++++++++++++++++ 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/core/shape_inference/include/einsum_shape_inference.hpp b/src/core/shape_inference/include/einsum_shape_inference.hpp index 1ee471117d6872..007b220ddda3d6 100644 --- a/src/core/shape_inference/include/einsum_shape_inference.hpp +++ b/src/core/shape_inference/include/einsum_shape_inference.hpp @@ -24,10 +24,13 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s input_subscripts.size() == input_shapes.size(), "Equation must contain a number of subscripts equal to a number of Einsum inputs."); + const auto output_labels = Einsum::extract_labels(output_subscript); + const auto has_out_ellipsis = std::any_of(output_labels.begin(), output_labels.end(), [](std::string label) { + return label == "..."; + }); // create a dictionary with dimension sizes (or ranges in case of dynamic shapes) for each label // and check their compatibility in case of repeating labels std::unordered_map label_to_shape; - for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) { const auto& pshape = input_shapes[input_idx]; const auto labels = Einsum::extract_labels(input_subscripts[input_idx]); @@ -78,16 +81,11 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s } } } else { - if (has_ellipsis) { + if (has_ellipsis && has_out_ellipsis) { // Shape has dynamic rank and ellipsis return {pshape}; } for (auto const& label : labels) { - NODE_VALIDATION_CHECK(op, - label != "...", - "The subscript corresponding to a dynamic rank input must " - "not contain ellipsis."); - if (label_to_shape.find(label) == label_to_shape.end()) { label_to_shape[label] = ov::PartialShape{Dimension::dynamic()}; } @@ -96,7 +94,6 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s } // compute the output shape - const auto output_labels = Einsum::extract_labels(output_subscript); auto output_shapes = std::vector(1); auto& output_shape = output_shapes[0]; diff --git a/src/core/tests/type_prop/einsum.cpp b/src/core/tests/type_prop/einsum.cpp index 4772393a89f497..5c96289af647ea 100644 --- a/src/core/tests/type_prop/einsum.cpp +++ b/src/core/tests/type_prop/einsum.cpp @@ -478,6 +478,66 @@ TEST_F(TypePropEinsumTest, all_dynamic_rank_ellipsis) { EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); } +TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis_reduced_out_ellipsis) { + const std::string equation = "a...b,b...->a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({ov::Dimension::dynamic()})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis_reduced_out_ellipsis) { + const std::string equation = "a...b,b...->a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({3})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + TEST_F(TypePropEinsumTest, broadcasting_same_symbol_common) { const std::string equation = "ab,ba->b"; constexpr auto et = element::i32;