diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index b58b58b57ae5d2..35cc48253b0e1e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -322,22 +322,24 @@ ov::Output unsqueeze_input(const ov::Output& input_node, ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs, ov::OutputVector& shapes_rhs, ov::NodeVector& subgraph_nodes) { - // TODO - Refactor func to remove loop and duplicated Broadcast. - OPENVINO_ASSERT(shapes_lhs.size() == shapes_rhs.size()); - ov::OutputVector broadcasted_shape_nodes{shapes_lhs.size()}; - - for (size_t shp_i = 0; shp_i < shapes_lhs.size(); shp_i++) { + ov::OutputVector broadcasted_shape_nodes{}; + // OutputVector is either empty or contains a single shape + if (shapes_lhs.size() == 1 && shapes_rhs.size() == 1) { auto const_1 = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {1}); auto tmp_const_of_lhs_shp = - std::make_shared(const_1, shapes_lhs[shp_i], ov::op::BroadcastType::NUMPY); + std::make_shared(const_1, shapes_lhs[0], ov::op::BroadcastType::NUMPY); auto tmp_const_of_broadcasted_shp = std::make_shared(tmp_const_of_lhs_shp, - shapes_rhs[shp_i], + shapes_rhs[0], ov::op::BroadcastType::BIDIRECTIONAL); auto broadcasted_shape = std::make_shared(tmp_const_of_broadcasted_shp); - broadcasted_shape_nodes[shp_i] = broadcasted_shape; + broadcasted_shape_nodes.push_back(broadcasted_shape->output(0)); subgraph_nodes.insert(subgraph_nodes.end(), {const_1, tmp_const_of_lhs_shp, tmp_const_of_broadcasted_shp, broadcasted_shape}); + } else if (shapes_lhs.size() == 0 && shapes_rhs.size() == 1) { + return shapes_rhs; + } else if (shapes_lhs.size() == 1 && shapes_rhs.size() == 0) { + return shapes_lhs; } return broadcasted_shape_nodes; }