Skip to content

Commit

Permalink
Refactor empty ellipsis handling in Einsum decomposition to improve c…
Browse files Browse the repository at this point in the history
…larity

Signed-off-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com>
  • Loading branch information
mmikolajcz committed Jan 23, 2025
1 parent f666700 commit 6347ed2
Showing 1 changed file with 31 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1085,8 +1085,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr,
/// Handle ellipses labels that do not represent any dimensions:
/// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript.
/// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts.
/// 3. If there is at least one ellipsis that does not represent any dimensions, unsqueeze the corresponding input at
/// ellipsis dimension.
/// 3. If there is at least one ellipsis that represents dimension, unsqueeze ellipses that do not represent any,
///
/// \param input_nodes A vector of input nodes for the Einsum operation.
/// \param input_subscripts A vector of input subscripts corresponding to the input nodes.
Expand All @@ -1096,40 +1095,43 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes,
std::vector<std::string>& input_subscripts,
std::string& output_subscript,
ov::NodeVector& subgraph_nodes) {
std::vector<bool> ellipsis_inputs(input_nodes.size(), false);
std::vector<bool> no_ellipsis_or_empty_inputs(input_nodes.size(), false);
static const std::string ellipsis = "...";
for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]);
ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end());
if (!ellipsis_inputs[inp_iter] || (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) {
no_ellipsis_or_empty_inputs[inp_iter] = true;
}
bool has_ellipsis = false;
bool all_no_ellipsis_or_empty = true;

for (size_t i = 0; i < input_nodes.size(); ++i) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]);
bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end();
has_ellipsis |= has_ellipsis_in_input;
all_no_ellipsis_or_empty &=
!has_ellipsis_in_input || (input_nodes[i].get_partial_shape().rank().get_length() ==
static_cast<ov::Dimension::value_type>(labels.size() - 1));
}
if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) {
return inp;
})) {
if (output_subscript.find("...") != std::string::npos) {
output_subscript.erase(output_subscript.find("..."), 3);

if (!has_ellipsis) {
if (output_subscript.find(ellipsis) != std::string::npos) {
output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size());
}
} else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) {
return inp;
})) {
for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) {
if (input_subscripts[inp_iter].find("...") != std::string::npos) {
input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3);
} else if (all_no_ellipsis_or_empty) {
for (auto& subscript : input_subscripts) {
if (subscript.find(ellipsis) != std::string::npos) {
subscript.erase(subscript.find(ellipsis), ellipsis.size());
}
}
if (output_subscript.find("...") != std::string::npos) {
output_subscript.erase(output_subscript.find("..."), 3);
if (output_subscript.find(ellipsis) != std::string::npos) {
output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size());
}
} else {
for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) {
if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) {
auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]);
auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "...");
std::vector<int64_t> ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)};
input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes);
for (size_t i = 0; i < input_nodes.size(); ++i) {
const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]);
if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() &&
input_nodes[i].get_partial_shape().rank().get_length() ==
static_cast<ov::Dimension::value_type>(labels.size() - 1)) {
input_nodes[i] = unsqueeze_input(
input_nodes[i],
{static_cast<int64_t>(
std::distance(labels.begin(), std::find(labels.begin(), labels.end(), ellipsis)))},
subgraph_nodes);
}
}
}
Expand Down

0 comments on commit 6347ed2

Please sign in to comment.