Skip to content

Commit

Permalink
Fix shape_infer for reduced out ellipsis with dynamic rank inputs
Browse files Browse the repository at this point in the history
Signed-off-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com>
  • Loading branch information
mmikolajcz committed Jan 30, 2025
1 parent 1918991 commit 2eee35c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/core/shape_inference/include/einsum_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ std::vector<TRShape> shape_infer(const Einsum* op, const std::vector<T>& input_s
input_subscripts.size() == input_shapes.size(),
"Equation must contain a number of subscripts equal to a number of Einsum inputs.");

const auto output_labels = Einsum::extract_labels(output_subscript);
const auto has_out_ellipsis = std::any_of(output_labels.begin(), output_labels.end(), [](std::string label) {
return label == "...";
});
// create a dictionary with dimension sizes (or ranges in case of dynamic shapes) for each label
// and check their compatibility in case of repeating labels
std::unordered_map<std::string, TRShape> label_to_shape;

for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) {
const auto& pshape = input_shapes[input_idx];
const auto labels = Einsum::extract_labels(input_subscripts[input_idx]);
Expand Down Expand Up @@ -78,16 +81,11 @@ std::vector<TRShape> shape_infer(const Einsum* op, const std::vector<T>& input_s
}
}
} else {
if (has_ellipsis) {
if (has_ellipsis && has_out_ellipsis) {
// Shape has dynamic rank and ellipsis
return {pshape};
}
for (auto const& label : labels) {
NODE_VALIDATION_CHECK(op,
label != "...",
"The subscript corresponding to a dynamic rank input must "
"not contain ellipsis.");

if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = ov::PartialShape{Dimension::dynamic()};
}
Expand All @@ -96,7 +94,6 @@ std::vector<TRShape> shape_infer(const Einsum* op, const std::vector<T>& input_s
}

// compute the output shape
const auto output_labels = Einsum::extract_labels(output_subscript);
auto output_shapes = std::vector<TRShape>(1);
auto& output_shape = output_shapes[0];

Expand Down
60 changes: 60 additions & 0 deletions src/core/tests/type_prop/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,66 @@ TEST_F(TypePropEinsumTest, all_dynamic_rank_ellipsis) {
EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr));
}

TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis) {
const std::string equation = "a...b,b...->...a";
constexpr auto et = element::i32;

auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}};
const auto inputs = make_inputs(et, input_shapes);
const auto o = make_op(inputs, equation);

EXPECT_EQ(o->get_equation(), equation);
EXPECT_EQ(o->get_element_type(), et);
EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count);
EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic());
EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr));
}

TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis) {
const std::string equation = "a...b,b...->...a";
constexpr auto et = element::i32;

auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()};
const auto inputs = make_inputs(et, input_shapes);
const auto o = make_op(inputs, equation);

EXPECT_EQ(o->get_equation(), equation);
EXPECT_EQ(o->get_element_type(), et);
EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count);
EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic());
EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr));
}

TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis_reduced_out_ellipsis) {
const std::string equation = "a...b,b...->a";
constexpr auto et = element::i32;

auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}};
const auto inputs = make_inputs(et, input_shapes);
const auto o = make_op(inputs, equation);

EXPECT_EQ(o->get_equation(), equation);
EXPECT_EQ(o->get_element_type(), et);
EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count);
EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({ov::Dimension::dynamic()}));
EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr));
}

TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis_reduced_out_ellipsis) {
const std::string equation = "a...b,b...->a";
constexpr auto et = element::i32;

auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()};
const auto inputs = make_inputs(et, input_shapes);
const auto o = make_op(inputs, equation);

EXPECT_EQ(o->get_equation(), equation);
EXPECT_EQ(o->get_element_type(), et);
EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count);
EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({3}));
EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr));
}

TEST_F(TypePropEinsumTest, broadcasting_same_symbol_common) {
const std::string equation = "ab,ba->b";
constexpr auto et = element::i32;
Expand Down

0 comments on commit 2eee35c

Please sign in to comment.