Skip to content

Commit

Permalink
Move broadcasting out of reshape conditional
Browse files Browse the repository at this point in the history
Signed-off-by: MATEUSZ MIKOLAJCZYK <MATEUSZ.MIKOLAJCZYK@intel.com>
  • Loading branch information
mmikolajcz committed Dec 20, 2024
1 parent 50c98c1 commit d3eac20
Showing 1 changed file with 51 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr,

auto matmul_operand1 = input_node1;
auto matmul_operand2 = input_node2;
auto broadcasted_operand1 = input_node1;
auto broadcasted_operand2 = input_node2;

size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin,
separate1_dims_end;
Expand Down Expand Up @@ -803,65 +805,57 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr,
reduced_dims_end2,
is_separate_first2);

no_reshape_for_matmul1 = false;
no_reshape_for_matmul2 = false;
// // no_reshape_after_matmul = false;
ov::OutputVector common_sub_shape, separate1_sub_shape, separate2_sub_shape;

if (no_reshape_for_matmul1 == false || no_reshape_for_matmul2 == false) {
auto data_shape1 = std::make_shared<ov::op::v3::ShapeOf>(input_node1);
auto data_shape2 = std::make_shared<ov::op::v3::ShapeOf>(input_node2);
common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes);
auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes);
OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size());
common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes);
auto reduced_sub_shape_prod =
compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true);
auto reduced_sub_shape_prod2 =
compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true);
auto reduced_sub_shape =
compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false);
auto reduced_sub_shape2 =
compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false);

reduced_sub_shape_prod =
broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes);
reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes);
if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) {
separate1_sub_shape =
compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes);
auto broadcasted1 = broadcast_input(input_node1,
common_sub_shape,
separate1_sub_shape,
reduced_sub_shape,
is_separate_first1,
subgraph_nodes);
matmul_operand1 = reshape_input_for_matmul(broadcasted1,
common_sub_shape,
separate1_sub_shape,
reduced_sub_shape_prod,
is_separate_first1,
subgraph_nodes);
}

if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) {
separate2_sub_shape =
compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes);
auto broadcasted2 = broadcast_input(input_node2,
common_sub_shape,
separate2_sub_shape,
reduced_sub_shape,
is_separate_first2,
subgraph_nodes);
matmul_operand2 = reshape_input_for_matmul(broadcasted2,
common_sub_shape,
separate2_sub_shape,
reduced_sub_shape_prod,
is_separate_first2,
subgraph_nodes);
subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2});
}
subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1});
auto data_shape1 = std::make_shared<ov::op::v3::ShapeOf>(input_node1);
auto data_shape2 = std::make_shared<ov::op::v3::ShapeOf>(input_node2);
subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1});
subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2});
common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes);
auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes);
OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size());
common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes);
auto reduced_sub_shape_prod =
compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true);
auto reduced_sub_shape_prod2 =
compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true);
auto reduced_sub_shape =
compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false);
auto reduced_sub_shape2 =
compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false);

reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes);
reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes);
separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes);
broadcasted_operand1 = broadcast_input(input_node1,
common_sub_shape,
separate1_sub_shape,
reduced_sub_shape,
is_separate_first1,
subgraph_nodes);
separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes);
broadcasted_operand2 = broadcast_input(input_node2,
common_sub_shape,
separate2_sub_shape,
reduced_sub_shape,
is_separate_first2,
subgraph_nodes);
if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) {
matmul_operand1 = reshape_input_for_matmul(broadcasted_operand1,
common_sub_shape,
separate1_sub_shape,
reduced_sub_shape_prod,
is_separate_first1,
subgraph_nodes);
}

if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) {
matmul_operand2 = reshape_input_for_matmul(broadcasted_operand2,
common_sub_shape,
separate2_sub_shape,
reduced_sub_shape_prod,
is_separate_first2,
subgraph_nodes);
}

// step 3. apply MatMul operation for formatted inputs
Expand Down

0 comments on commit d3eac20

Please sign in to comment.