Skip to content

Commit

Permalink
Refactor broadcast_merge_shapes to eliminate loop
Browse files Browse the repository at this point in the history
Signed-off-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com>
  • Loading branch information
mmikolajcz committed Jan 24, 2025
1 parent d380155 commit 1918991
Showing 1 changed file with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,22 +322,24 @@ ov::Output<ov::Node> unsqueeze_input(const ov::Output<ov::Node>& input_node,
ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs,
ov::OutputVector& shapes_rhs,
ov::NodeVector& subgraph_nodes) {
// TODO - Refactor func to remove loop and duplicated Broadcast.
OPENVINO_ASSERT(shapes_lhs.size() == shapes_rhs.size());
ov::OutputVector broadcasted_shape_nodes{shapes_lhs.size()};

for (size_t shp_i = 0; shp_i < shapes_lhs.size(); shp_i++) {
ov::OutputVector broadcasted_shape_nodes{};
// OutputVector is either empty or contains a single shape
if (shapes_lhs.size() == 1 && shapes_rhs.size() == 1) {
auto const_1 = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {1});
auto tmp_const_of_lhs_shp =
std::make_shared<ov::op::v3::Broadcast>(const_1, shapes_lhs[shp_i], ov::op::BroadcastType::NUMPY);
std::make_shared<ov::op::v3::Broadcast>(const_1, shapes_lhs[0], ov::op::BroadcastType::NUMPY);
auto tmp_const_of_broadcasted_shp =
std::make_shared<ov::op::v3::Broadcast>(tmp_const_of_lhs_shp,
shapes_rhs[shp_i],
shapes_rhs[0],
ov::op::BroadcastType::BIDIRECTIONAL);
auto broadcasted_shape = std::make_shared<ov::op::v3::ShapeOf>(tmp_const_of_broadcasted_shp);
broadcasted_shape_nodes[shp_i] = broadcasted_shape;
broadcasted_shape_nodes.push_back(broadcasted_shape->output(0));
subgraph_nodes.insert(subgraph_nodes.end(),
{const_1, tmp_const_of_lhs_shp, tmp_const_of_broadcasted_shp, broadcasted_shape});
} else if (shapes_lhs.size() == 0 && shapes_rhs.size() == 1) {
return shapes_rhs;
} else if (shapes_lhs.size() == 1 && shapes_rhs.size() == 0) {
return shapes_lhs;
}
return broadcasted_shape_nodes;
}
Expand Down

0 comments on commit 1918991

Please sign in to comment.