Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Einsum Core and common transformation to support broadcasting, repeated labels and ellipsis #28151

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
637d510
Einsum core improvements
mmikolajcz Dec 18, 2024
50c98c1
Einsum decomposition broadcasting + ellipsis support
mmikolajcz Dec 19, 2024
d3eac20
Move broadcasting out of reshape conditional
mmikolajcz Dec 20, 2024
46098cc
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 2, 2025
0ec7974
Initial support for repeated labels
mmikolajcz Jan 7, 2025
be601ca
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 8, 2025
fa041ca
Remove xfail for onnx einsum test
mmikolajcz Jan 8, 2025
6796536
Remove Einsum xfail for torch HF tests
mmikolajcz Jan 8, 2025
be8400c
Update transpose reshape elimination for MatMul to handle broadcast f…
mmikolajcz Jan 16, 2025
d8147d5
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 17, 2025
81b5d39
Initial Einsum update to handle ellipsis label without dimensions
mmikolajcz Jan 17, 2025
28b579a
Update reduce_input in einsum common decomposition
mmikolajcz Jan 20, 2025
33acf2e
Fix broadcasting of reduced part for reshape
mmikolajcz Jan 20, 2025
29b1072
Extend Einsum reference test cases
mmikolajcz Jan 21, 2025
9e749f9
FIx divide by 0 and handling 2+ repeated label types for einsum decom…
mmikolajcz Jan 23, 2025
50b6d3e
Move fix_inputs_with_0d_ellipsis to separate function
mmikolajcz Jan 23, 2025
f666700
Modify reshape_input_for_matmul reduced prod to match ne for separate
mmikolajcz Jan 23, 2025
6347ed2
Refactor empty ellipsis handling in Einsum decomposition to improve c…
mmikolajcz Jan 23, 2025
6f1732f
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 23, 2025
d380155
Refactor handling of 0-dimensional ellipsis in Einsum operations for …
mmikolajcz Jan 24, 2025
1918991
Refactor broadcast_merge_shapes to eliminate loop
mmikolajcz Jan 24, 2025
2eee35c
Fix shape_infer for reduced out ellipsis with dynamic rank inputs
mmikolajcz Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

namespace {
Expand Down Expand Up @@ -124,9 +126,16 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa
auto transpose_before_pattern =
ov::pass::pattern::wrap_type<ov::op::v1::Transpose>({input_2_pattern, const_transpose_before_pattern});

auto const_optional_broadcast_before_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto optional_broadcast_before_pattern = ov::pass::pattern::wrap_type<ov::op::v3::Broadcast>(
{transpose_before_pattern, const_optional_broadcast_before_pattern});

auto transpose_or_transpose_broadcast = std::make_shared<ov::pass::pattern::op::Or>(
OutputVector{transpose_before_pattern, optional_broadcast_before_pattern});

auto const_reshape_before_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto reshape_before_pattern =
ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({transpose_before_pattern, const_reshape_before_pattern});
auto reshape_before_pattern = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>(
{transpose_or_transpose_broadcast, const_reshape_before_pattern});

auto matmul_pattern = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({input_1_pattern, reshape_before_pattern});

Expand Down Expand Up @@ -181,8 +190,37 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa
// transposes
if (!check_transposes(transpose_before_order, transpose_after_order, transposed_b))
return false;

const auto new_matmul = std::make_shared<ov::op::v0::MatMul>(input_1, input_2, transposed_a, false);
auto matmul_2_input = input_2;
// for einsum decomposition, check if broadcast exist and if so, reorder target shape based on transpose
if (pattern_value_map.count(optional_broadcast_before_pattern)) {
auto broadcast_before = ov::as_type_ptr<ov::op::v3::Broadcast>(
pattern_value_map.at(optional_broadcast_before_pattern).get_node_shared_ptr());
if (!broadcast_before) {
return false;
}
auto broadcast_before_constant =
ov::as_type_ptr<ov::op::v0::Constant>(broadcast_before->get_input_node_shared_ptr(1));
if (!broadcast_before_constant) {
return false;
}
auto broadcast_shape_after_transpose = broadcast_before_constant->cast_vector<int64_t>();
if (broadcast_shape_after_transpose.size() != transpose_before_order.size()) {
return false;
}
std::vector<int64_t> broadcast_shape_no_transpose;
broadcast_shape_no_transpose.reserve(broadcast_shape_after_transpose.size());
for (auto idx : transpose_before_order) {
broadcast_shape_no_transpose.push_back(broadcast_shape_after_transpose[idx]);
}
auto broadcast_shape_no_transpose_constant =
ov::op::v0::Constant::create(element::i64,
broadcast_before_constant->get_shape(),
broadcast_shape_no_transpose);
matmul_2_input = broadcast_before->clone_with_new_inputs({input_2, broadcast_shape_no_transpose_constant});
copy_runtime_info(broadcast_before, matmul_2_input.get_node_shared_ptr());
}

const auto new_matmul = std::make_shared<ov::op::v0::MatMul>(input_1, matmul_2_input, transposed_a, false);
new_matmul->set_friendly_name(transpose_after->get_friendly_name());
copy_runtime_info({transpose_before, reshape_before, matmul, reshape_after, transpose_after}, new_matmul);
replace_node(transpose_after, new_matmul);
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,21 @@ TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_Einsum) {
{
auto data_1 = std::make_shared<ov::op::v0::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<ov::op::v0::Parameter>(element::f32, data_shape_2);
auto broadcast_shape_constant_1 =
std::make_shared<ov::op::v0::Constant>(element::i64, Shape{data_shape_1.size()}, data_shape_1);
auto broadcast_shape_constant_2 =
std::make_shared<ov::op::v0::Constant>(element::i64, Shape{data_shape_2.size()}, data_shape_2);
auto broadcast_1 = std::make_shared<ov::op::v3::Broadcast>(data_1,
broadcast_shape_constant_1,
ov::op::BroadcastType::BIDIRECTIONAL);
auto broadcast_2 = std::make_shared<ov::op::v3::Broadcast>(data_2,
broadcast_shape_constant_2,
ov::op::BroadcastType::BIDIRECTIONAL);
// for some cases Reshape may be first input for Matmul
auto shape_constant =
std::make_shared<ov::op::v0::Constant>(element::i64, Shape{data_shape_1.size()}, data_shape_1);
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_1, shape_constant, false);
auto matmul = std::make_shared<ov::op::v0::MatMul>(reshape, data_2, false, false);
auto reshape = std::make_shared<ov::op::v1::Reshape>(broadcast_1, shape_constant, false);
auto matmul = std::make_shared<ov::op::v0::MatMul>(reshape, broadcast_2, false, false);
model_ref = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{data_1, data_2});
}
}
75 changes: 66 additions & 9 deletions src/core/reference/src/op/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ void broadcast_input(ov::TensorVector& inputs,
OPENVINO_ASSERT(input_ind < inputs.size());
ov::Tensor& input = inputs[input_ind];
const Shape old_shape = input.get_shape();
Shape new_shape;
PartialShape new_shape;
new_shape.insert(new_shape.end(), new_common_shape.begin(), new_common_shape.end());
if (is_separate_first) {
new_shape.insert(new_shape.end(), separate_shape.begin(), separate_shape.end());
Expand All @@ -435,15 +435,15 @@ void broadcast_input(ov::TensorVector& inputs,
new_shape.insert(new_shape.end(), separate_shape.begin(), separate_shape.end());
}

if (input.get_shape() == new_shape) {
if (input.get_shape() == new_shape.to_shape()) {
return;
}
OPENVINO_ASSERT(old_shape.size() <= new_shape.size());

auto output = ov::Tensor(input.get_element_type(), new_shape);

std::vector<size_t> broadcast_axes(old_shape.size());
std::iota(broadcast_axes.begin(), broadcast_axes.end(), new_shape.size() - old_shape.size());
OPENVINO_ASSERT(PartialShape::broadcast_merge_into(new_shape, old_shape, ov::op::AutoBroadcastType::NUMPY));
auto output = ov::Tensor(input.get_element_type(), new_shape.to_shape());

reference::broadcast(reinterpret_cast<const char*>(input.data<T>()),
reinterpret_cast<char*>(output.data<T>()),
Expand Down Expand Up @@ -853,34 +853,37 @@ void contract_two_inputs(ov::TensorVector& inputs,
PartialShape common_sub_shape1 = compute_sub_shape(input_shape1, common_dims_begin, common_dims_end);
PartialShape common_sub_shape2 = compute_sub_shape(input_shape2, common_dims_begin2, common_dims_end2);

Shape reduced_sub_shape_prod = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end, true);
Shape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end);
PartialShape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end);
Shape reduced_sub_shape2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2);
Shape separate1_sub_shape = compute_sub_shape(input_shape1, separate1_dims_begin, separate1_dims_end);
Shape separate2_sub_shape = compute_sub_shape(input_shape2, separate2_dims_begin, separate2_dims_end);

// broadcast both inputs to have common sub-shape broadcasted that is needed
// in case of ellipsis among the common labels
// reference::broadcast()
PartialShape::broadcast_merge_into(common_sub_shape1, common_sub_shape2, op::AutoBroadcastType::NUMPY);
PartialShape::broadcast_merge_into(reduced_sub_shape, reduced_sub_shape2, op::AutoBroadcastType::NUMPY);
Shape reduced_sub_shape_prod = {shape_size(reduced_sub_shape.get_shape())};
Shape common_sub_shape = common_sub_shape1.get_shape();
broadcast_input<T>(inputs,
input_ind1,
common_sub_shape,
separate1_sub_shape,
reduced_sub_shape,
reduced_sub_shape.get_shape(),
is_separate_first1);
broadcast_input<T>(inputs,
input_ind2,
common_sub_shape,
separate2_sub_shape,
reduced_sub_shape,
reduced_sub_shape.get_shape(),
is_separate_first2);

ov::Tensor matmul_operand1 = reshape_input_for_matmul<T>(input1,
common_sub_shape,
separate1_sub_shape,
reduced_sub_shape_prod,
is_separate_first1);

ov::Tensor matmul_operand2 = reshape_input_for_matmul<T>(input2,
common_sub_shape,
separate2_sub_shape,
Expand Down Expand Up @@ -924,6 +927,58 @@ void contract_two_inputs(ov::TensorVector& inputs,
update_operands(inputs, input_subscripts, input_ind1, input_ind2, contract_output, resultant_subscript);
}

/// \brief Adjusts input subscripts and nodes to handle 0-dimensional ellipsis in Einsum operations.
///
/// Handle ellipses labels that do not represent any dimensions:
/// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript.
/// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts.
/// 3. If there is at least one ellipsis that represents dimension, unsqueeze ellipses that do not represent any,
///
/// \param input_nodes A vector of input tensors for the Einsum operation.
/// \param input_subscripts A vector of input subscripts corresponding to the input nodes.
/// \param output_subscript The output subscript for the Einsum operation.
template <typename T>
void fix_inputs_with_0d_ellipsis(ov::TensorVector& input_nodes,
std::vector<std::string>& input_subscripts,
std::string& output_subscript) {
static const std::string ellipsis = "...";
bool has_ellipsis = false;
bool all_no_ellipsis_or_empty = true;

for (size_t i = 0; i < input_nodes.size(); ++i) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]);
bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end();
has_ellipsis |= has_ellipsis_in_input;
all_no_ellipsis_or_empty &=
!has_ellipsis_in_input || (input_nodes[i].get_shape().size() == (labels.size() - 1));
}

if (!has_ellipsis) {
if (output_subscript.find(ellipsis) != std::string::npos) {
output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size());
}
} else if (all_no_ellipsis_or_empty) {
for (auto& subscript : input_subscripts) {
if (subscript.find(ellipsis) != std::string::npos) {
subscript.erase(subscript.find(ellipsis), ellipsis.size());
}
}
if (output_subscript.find(ellipsis) != std::string::npos) {
output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size());
}
} else {
for (size_t i = 0; i < input_nodes.size(); ++i) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]);
if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() &&
input_nodes[i].get_shape().size() == (labels.size() - 1)) {
std::vector<int64_t> ellipsis_idx{
std::distance(labels.begin(), std::find(labels.begin(), labels.end(), ellipsis))};
input_nodes[i] = unsqueeze_input<T>(input_nodes[i], ellipsis_idx);
}
}
}
}

template <typename T>
void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, const std::string& equation) {
std::vector<std::string> input_subscripts;
Expand All @@ -934,9 +989,11 @@ void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, cons
// in more optimal order
size_t num_inputs = inputs.size();
auto einsum_path = compute_einsum_path(num_inputs);

ov::TensorVector int_inputs = inputs;

// fix inputs where ellipsis does not contain any dimensions
fix_inputs_with_0d_ellipsis<T>(int_inputs, input_subscripts, output_subscript);

// contract inputs by Einsum until just one is remained
for (auto const& inds_pair : einsum_path) {
contract_two_inputs<T>(int_inputs, input_subscripts, output_subscript, inds_pair.first, inds_pair.second);
Expand Down
36 changes: 23 additions & 13 deletions src/core/shape_inference/include/einsum_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,29 @@ std::vector<TRShape> shape_infer(const Einsum* op, const std::vector<T>& input_s
input_subscripts.size() == input_shapes.size(),
"Equation must contain a number of subscripts equal to a number of Einsum inputs.");

const auto output_labels = Einsum::extract_labels(output_subscript);
const auto has_out_ellipsis = std::any_of(output_labels.begin(), output_labels.end(), [](std::string label) {
return label == "...";
});
// create a dictionary with dimension sizes (or ranges in case of dynamic shapes) for each label
// and check their compatibility in case of repeating labels
std::unordered_map<std::string, TRShape> label_to_shape;

for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) {
const auto& pshape = input_shapes[input_idx];
const auto labels = Einsum::extract_labels(input_subscripts[input_idx]);
const auto has_ellipsis = std::any_of(labels.begin(), labels.end(), [](std::string label) {
return label == "...";
});

if (pshape.rank().is_static()) {
size_t input_rank = pshape.size();
// check that a rank is greater or equal to a number of labels
// these numbers are always equal if there is no ellipsis in the subscript
NODE_VALIDATION_CHECK(op,
input_rank >= labels.size(),
"Input rank must be greater or equal to a number of labels in the "
"corresponding input subscript.");
NODE_VALIDATION_CHECK(
op,
(input_rank >= (labels.size() - 1) && has_ellipsis) || (input_rank == labels.size() && !has_ellipsis),
"Input rank must be greater or equal to a number of labels in the "
"corresponding input subscript.");

for (size_t label_ind = 0, dim_ind = 0; label_ind < labels.size() && dim_ind < input_rank; ++label_ind) {
auto const& label = labels[label_ind];
Expand All @@ -64,21 +71,21 @@ std::vector<TRShape> shape_infer(const Einsum* op, const std::vector<T>& input_s
label_to_shape[label] = TRShape{pshape[dim_ind]};
} else {
NODE_VALIDATION_CHECK(op,
label_to_shape[label].compatible(TRShape{pshape[label_ind]}),
TRShape::broadcast_merge_into(label_to_shape[label],
TRShape{pshape[dim_ind]},
op::AutoBroadcastType::NUMPY),
"Different input dimensions indicated by the same labels for Einsum "
"must be compatible.");
OPENVINO_ASSERT(TRShape::merge_into(label_to_shape[label], TRShape{pshape[dim_ind]}));
}
++dim_ind;
}
}
} else {
if (has_ellipsis && has_out_ellipsis) {
// Shape has dynamic rank and ellipsis
return {pshape};
}
for (auto const& label : labels) {
NODE_VALIDATION_CHECK(op,
label != "...",
"The subscript corresponding to a dynamic rank input must "
"not contain ellipsis.");

if (label_to_shape.find(label) == label_to_shape.end()) {
label_to_shape[label] = ov::PartialShape{Dimension::dynamic()};
}
Expand All @@ -87,11 +94,14 @@ std::vector<TRShape> shape_infer(const Einsum* op, const std::vector<T>& input_s
}

// compute the output shape
const auto output_labels = Einsum::extract_labels(output_subscript);
auto output_shapes = std::vector<TRShape>(1);
auto& output_shape = output_shapes[0];

for (auto const& output_label : output_labels) {
if (output_label == "..." && label_to_shape.find(output_label) == label_to_shape.end()) {
// Output labels may contain ellipsis that does not cover any dimensions.
continue;
}
NODE_VALIDATION_CHECK(op,
label_to_shape.find(output_label) != label_to_shape.end(),
"Label in output subscript of Einsum equation must enter at least "
Expand Down
5 changes: 0 additions & 5 deletions src/core/src/op/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,6 @@ void op::v7::Einsum::parse_equation(const std::string& equation,
OPENVINO_ASSERT(is_subscript_correct(output_subscript, output_is_ellipsis_met),
"Output subscript of Einsum equation must consist of either only "
"alphabetic letters or alphabetic letters with one ellipsis.");

// if the ellipsis is met in input subscripts, one ellipsis must be in the output subscript
OPENVINO_ASSERT(is_ellipsis_met == output_is_ellipsis_met,
"Output subscript of Einsum equation must contain one ellipsis if "
"ellipsis is met in any input subscript.");
}
}

Expand Down
Loading
Loading