Skip to content

Commit

Permalink
Refactor handling of 0-dimensional ellipsis in Einsum operations for …
Browse files Browse the repository at this point in the history
…improved clarity

Signed-off-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com>
  • Loading branch information
mmikolajcz committed Jan 24, 2025
1 parent 6f1732f commit d380155
Showing 1 changed file with 55 additions and 38 deletions.
93 changes: 55 additions & 38 deletions src/core/reference/src/op/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,55 +927,72 @@ void contract_two_inputs(ov::TensorVector& inputs,
update_operands(inputs, input_subscripts, input_ind1, input_ind2, contract_output, resultant_subscript);
}

/// \brief Adjusts input subscripts and nodes to handle 0-dimensional ellipsis in Einsum operations.
///
/// Handle ellipses labels that do not represent any dimensions:
/// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript.
/// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts.
/// 3. If there is at least one ellipsis that represents dimension, unsqueeze ellipses that do not represent any,
///
/// \param input_nodes A vector of input tensors for the Einsum operation.
/// \param input_subscripts A vector of input subscripts corresponding to the input nodes.
/// \param output_subscript The output subscript for the Einsum operation.
template <typename T>
void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, const std::string& equation) {
std::vector<std::string> input_subscripts;
std::string output_subscript;
ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript);

// compute einsum path that is used to contract a pair of operands
// in more optimal order
size_t num_inputs = inputs.size();
auto einsum_path = compute_einsum_path(num_inputs);

ov::TensorVector int_inputs = inputs;
std::vector<bool> ellipsis_inputs(inputs.size(), false);
std::vector<bool> no_ellipsis_or_empty_inputs(inputs.size(), false);
void fix_inputs_with_0d_ellipsis(ov::TensorVector& input_nodes,
std::vector<std::string>& input_subscripts,
std::string& output_subscript) {
static const std::string ellipsis = "...";
for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]);
ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end());
if (!ellipsis_inputs[inp_iter] || (inputs[inp_iter].get_shape().size() == (labels.size() - 1))) {
no_ellipsis_or_empty_inputs[inp_iter] = true;
}
bool has_ellipsis = false;
bool all_no_ellipsis_or_empty = true;

for (size_t i = 0; i < input_nodes.size(); ++i) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]);
bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end();
has_ellipsis |= has_ellipsis_in_input;
all_no_ellipsis_or_empty &=
!has_ellipsis_in_input || (input_nodes[i].get_shape().size() == (labels.size() - 1));
}
if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) {
return inp;
})) {
if (output_subscript.find("...") != std::string::npos) {
output_subscript.erase(output_subscript.find("..."), 3);

if (!has_ellipsis) {
if (output_subscript.find(ellipsis) != std::string::npos) {
output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size());
}
} else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) {
return inp;
})) {
for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) {
if (input_subscripts[inp_iter].find("...") != std::string::npos) {
input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3);
} else if (all_no_ellipsis_or_empty) {
for (auto& subscript : input_subscripts) {
if (subscript.find(ellipsis) != std::string::npos) {
subscript.erase(subscript.find(ellipsis), ellipsis.size());
}
}
if (output_subscript.find("...") != std::string::npos) {
output_subscript.erase(output_subscript.find("..."), 3);
if (output_subscript.find(ellipsis) != std::string::npos) {
output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size());
}
} else {
for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) {
if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) {
auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]);
auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "...");
std::vector<int64_t> ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)};
int_inputs[inp_iter] = unsqueeze_input<T>(inputs[inp_iter], ellipsis_idx);
for (size_t i = 0; i < input_nodes.size(); ++i) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]);
if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() &&
input_nodes[i].get_shape().size() == (labels.size() - 1)) {
std::vector<int64_t> ellipsis_idx{
std::distance(labels.begin(), std::find(labels.begin(), labels.end(), ellipsis))};
input_nodes[i] = unsqueeze_input<T>(input_nodes[i], ellipsis_idx);
}
}
}
}

template <typename T>
void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, const std::string& equation) {
std::vector<std::string> input_subscripts;
std::string output_subscript;
ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript);

// compute einsum path that is used to contract a pair of operands
// in more optimal order
size_t num_inputs = inputs.size();
auto einsum_path = compute_einsum_path(num_inputs);
ov::TensorVector int_inputs = inputs;

// fix inputs where ellipsis does not contain any dimensions
fix_inputs_with_0d_ellipsis<T>(int_inputs, input_subscripts, output_subscript);

// contract inputs by Einsum until just one is remained
for (auto const& inds_pair : einsum_path) {
Expand Down

0 comments on commit d380155

Please sign in to comment.