diff --git a/src/frontends/pytorch/src/op/quantile.cpp b/src/frontends/pytorch/src/op/quantile.cpp new file mode 100644 index 00000000000000..724e12e39e6eff --- /dev/null +++ b/src/frontends/pytorch/src/op/quantile.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/floor.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/minimum.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_quantile(const NodeContext& context) { + num_inputs_check(context, 2, 4); + + auto input = context.get_input(0); + auto quantiles = context.get_input(1); + + auto dim = context.input_is_none(2) ? -1 : context.get_input(2); + auto keepdim = context.input_is_none(3) ? false : context.get_input(3); + + if (dim == -1) { + input = context.mark_node(std::make_shared( + input, context.mark_node(v0::Constant::create(element::i64, {1}, {-1})), true)); + dim = 0; + } + + + auto sort_result = context.mark_node(std::make_shared(input, dim, true)); + auto sorted_tensor = sort_result->output(0); + + + auto input_shape = context.mark_node(std::make_shared(input)); + auto dim_size = context.mark_node(std::make_shared( + input_shape, context.mark_node(v0::Constant::create(element::i64, {}, {dim})), + v0::Constant::create(element::i64, {}, {0}))); + + auto scaled_q = context.mark_node(std::make_shared( + quantiles, context.mark_node(std::make_shared( + dim_size, v0::Constant::create(element::i64, {}, {1}))))); + auto lower_indices = context.mark_node(std::make_shared(scaled_q)); + auto upper_indices = context.mark_node(std::make_shared( + lower_indices, v0::Constant::create(element::i64, {}, {1}))); + + lower_indices = context.mark_node(std::make_shared( + lower_indices, v0::Constant::create(element::i64, {}, {0}))); + upper_indices = context.mark_node(std::make_shared( + upper_indices, context.mark_node(std::make_shared( + dim_size, v0::Constant::create(element::i64, {}, {1}))))); + + + auto lower_values = context.mark_node(std::make_shared(sorted_tensor, lower_indices, dim)); + auto upper_values = context.mark_node(std::make_shared(sorted_tensor, upper_indices, dim)); + + auto weights = context.mark_node(std::make_shared(scaled_q, lower_indices)); + + + auto result = context.mark_node(std::make_shared( + lower_values, context.mark_node(std::make_shared(weights, context.mark_node(std::make_shared(upper_values, lower_values)))))); + + if (!keepdim) { + auto input_shape = context.mark_node(std::make_shared(input)); + auto output_shape = context.mark_node(std::make_shared( + input_shape, + context.mark_node(v0::Constant::create(element::i64, {1}, {dim})), + v0::Constant::create(element::i64, {}, {0}))); + result = context.mark_node(std::make_shared(result, output_shape, true)); + } + + return {result}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 00e3a55b0bc327..9d47fc313d8106 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -190,6 +190,7 @@ OP_CONVERTER(translate_quantized_add); OP_CONVERTER(translate_quantized_add_relu); OP_CONVERTER(translate_quantized_hardswish); OP_CONVERTER(translate_quantized_mul); +OP_CONVERTER(translate_quantile); OP_CONVERTER(translate_range_length); OP_CONVERTER(translate_rand); OP_CONVERTER(translate_randn); @@ -745,6 +746,7 @@ const std::unordered_map get_supported_ops_ts() { {"quantized::hardswish", op::translate_quantized_hardswish}, {"quantized::linear", op::translate_quantized_linear}, {"quantized::mul", op::translate_quantized_mul}, + {"aten::quantile", op::translate_quantile}, {"torchvision::deform_conv2d", op::translate_deform_conv}, {"torchvision::nms", op::translate_nms}, {"torchvision::roi_align", op::translate_roi_align}, diff --git a/tests/layer_tests/pytorch_tests/test_quantile.py b/tests/layer_tests/pytorch_tests/test_quantile.py new file mode 100644 index 00000000000000..4dff96d4c96de2 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantile.py @@ -0,0 +1,36 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np +import torch +from pytorch_layer_test_class import PytorchLayerTest + + +class TestQuantile(PytorchLayerTest): + def _prepare_input(self): + input_tensor = np.random.randn(1, 3, 224, 224).astype(np.float32) + quantile = np.array(0.5, dtype=np.float32) + return (input_tensor, quantile) + + def create_model(self, dim=None, keepdim=False): + class aten_quantile(torch.nn.Module): + def __init__(self, dim, keepdim): + super(aten_quantile, self).__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x, q): + return torch.quantile(x, q, dim=self.dim, keepdim=self.keepdim) + + ref_net = None + + return aten_quantile(dim, keepdim), ref_net, "aten::quantile" + + @pytest.mark.parametrize("dim", [None, 0, 1, 2, 3, -1, -2, -3]) + @pytest.mark.parametrize("keepdim", [True, False]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_quantile(self, dim, keepdim, ie_device, precision, ir_version): + self._test(*self.create_model(dim, keepdim), ie_device, precision, ir_version) +