diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 7d93928243b78d..7ae13f270a03c5 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -772,6 +772,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; + auto broadcasted_operand1 = input_node1; + auto broadcasted_operand2 = input_node2; size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin, separate1_dims_end; @@ -803,65 +805,57 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, reduced_dims_end2, is_separate_first2); - no_reshape_for_matmul1 = false; - no_reshape_for_matmul2 = false; - // // no_reshape_after_matmul = false; ov::OutputVector common_sub_shape, separate1_sub_shape, separate2_sub_shape; - if (no_reshape_for_matmul1 == false || no_reshape_for_matmul2 == false) { - auto data_shape1 = std::make_shared(input_node1); - auto data_shape2 = std::make_shared(input_node2); - common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); - auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); - OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); - common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); - auto reduced_sub_shape_prod = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); - auto reduced_sub_shape_prod2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); - auto reduced_sub_shape = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); - auto reduced_sub_shape2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); - - reduced_sub_shape_prod = - broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); - reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); - if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { - separate1_sub_shape = - compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); - auto broadcasted1 = broadcast_input(input_node1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape, - is_separate_first1, - subgraph_nodes); - matmul_operand1 = reshape_input_for_matmul(broadcasted1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape_prod, - is_separate_first1, - subgraph_nodes); - } - - if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { - separate2_sub_shape = - compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); - auto broadcasted2 = broadcast_input(input_node2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape, - is_separate_first2, - subgraph_nodes); - matmul_operand2 = reshape_input_for_matmul(broadcasted2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape_prod, - is_separate_first2, - subgraph_nodes); - subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2}); - } - subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1}); + auto data_shape1 = std::make_shared(input_node1); + auto data_shape2 = std::make_shared(input_node2); + subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1}); + subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2}); + common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); + auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); + OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); + common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); + auto reduced_sub_shape_prod = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); + auto reduced_sub_shape_prod2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); + auto reduced_sub_shape = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); + auto reduced_sub_shape2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); + + reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); + reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); + separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); + broadcasted_operand1 = broadcast_input(input_node1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape, + is_separate_first1, + subgraph_nodes); + separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); + broadcasted_operand2 = broadcast_input(input_node2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape, + is_separate_first2, + subgraph_nodes); + if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { + matmul_operand1 = reshape_input_for_matmul(broadcasted_operand1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape_prod, + is_separate_first1, + subgraph_nodes); + } + + if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { + matmul_operand2 = reshape_input_for_matmul(broadcasted_operand2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape_prod, + is_separate_first2, + subgraph_nodes); } // step 3. apply MatMul operation for formatted inputs