Skip to content

Commit

Permalink
Modify reshape_input_for_matmul reduced prod to match ne for separate
Browse files Browse the repository at this point in the history
Signed-off-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com>
  • Loading branch information
mmikolajcz committed Jan 23, 2025
1 parent 50b6d3e commit f666700
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ ov::Output<ov::Node> reshape_input_for_matmul(const ov::Output<ov::Node>& input_

// compute a product of a sub-shape for separate labels
ov::OutputVector separate_parts;
auto reduce_axis_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {0});
if (common_sub_shape.size() > 0 && separate_sub_shape.size() == 0) {
// in this case new dimension corresponding to separate labels must be added
// since MatMul operation is not possible to do without separate dimensions if the
Expand All @@ -432,17 +433,16 @@ ov::Output<ov::Node> reshape_input_for_matmul(const ov::Output<ov::Node>& input_
} else if (separate_sub_shape.size() > 0) {
// in this case compute a product of separate dimension sizes since they must be
// presented with just one dimension for MatMul
auto reduce_axis_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {0});
auto separate_shape_prod =
std::make_shared<ov::op::v1::ReduceProd>(separate_sub_shape[0], reduce_axis_const, true);
separate_parts.push_back(separate_shape_prod->output(0));
subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, separate_shape_prod});
}
ov::OutputVector reduced_sub_shape_prod;
auto const_0 = ov::op::v0::Constant::create(ov::element::i32, {1}, {0});
for (auto sub_shape : reduced_sub_shape) {
auto product = std::make_shared<ov::op::v1::ReduceProd>(sub_shape, const_0, true);
subgraph_nodes.insert(subgraph_nodes.end(), {const_0, product});
if (reduced_sub_shape.size() > 0) {
auto product = std::make_shared<ov::op::v1::ReduceProd>(reduced_sub_shape[0], const_0, true);
subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, product});
reduced_sub_shape_prod.push_back(product->output(0));
}

Expand Down

0 comments on commit f666700

Please sign in to comment.