Skip to content

Commit

Permalink
Add temporary convert for multinomial i64 support
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikolajcz committed Nov 28, 2023
1 parent 9c8cbd6 commit 2e7ebc7
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/frontends/tensorflow_common/src/op/multinomial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "openvino/op/multinomial.hpp"

#include "common_op_table.hpp"
#include "openvino/op/convert.hpp"

namespace ov {
namespace frontend {
Expand All @@ -20,9 +21,10 @@ OutputVector translate_multinomial_op(const NodeContext& node) {
auto output_type = node.get_attribute<ov::element::Type>("output_dtype");

auto res =
std::make_shared<ov::op::v13::Multinomial>(logits, num_samples, output_type, true, true, global_seed, op_seed);
set_node_name(node.get_name(), res);
return res->outputs();
std::make_shared<ov::op::v13::Multinomial>(logits, num_samples, element::i32, true, true, global_seed, op_seed);
auto converted_res = std::make_shared<ov::op::v0::Convert>(res, output_type);
set_node_name(node.get_name(), converted_res);
return converted_res->outputs();
}

} // namespace op
Expand Down

0 comments on commit 2e7ebc7

Please sign in to comment.