diff --git a/src/frontends/pytorch/src/op/quantile.cpp b/src/frontends/pytorch/src/op/quantile.cpp new file mode 100644 index 00000000000000..2eefb34bbca029 --- /dev/null +++ b/src/frontends/pytorch/src/op/quantile.cpp @@ -0,0 +1,80 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/floor.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/minimum.hpp" +#include "utils.hpp" + +using namespace ov::op; + +OutputVector translate_quantile(const NodeContext& context) { + num_inputs_check(context, 2, 5); + + auto input = context.get_input(0); + auto q = context.get_input(1); // Quantile(s), can be float or tensor + + auto dim = context.input_is_none(2) ? -1 : context.get_input(2); + auto keepdim = context.input_is_none(3) ? false : context.get_input(3); + auto interpolation = context.input_is_none(4) ? "linear" : context.get_input(4); + + + if (dim == -1) { + input = context.mark_node(std::make_shared( + input, context.mark_node(std::make_shared(0, input.get_shape().size(), 1)), true)); + dim = 0; + } + + auto sorted = context.mark_node(std::make_shared(input, dim, true)); // Ascending order + + auto dim_size = input.get_shape()[dim]; + + auto indices = context.mark_node(std::make_shared(q, dim_size - 1)); + auto lower_indices = context.mark_node(std::make_shared(indices)); + auto upper_indices = context.mark_node(std::make_shared(lower_indices, 1)); + auto weights = context.mark_node(std::make_shared(indices, lower_indices)); + auto lower_values = context.mark_node(std::make_shared(sorted, lower_indices, dim)); + auto upper_values = context.mark_node(std::make_shared(sorted, upper_indices, dim)); + + Output result; + if (interpolation == "linear") { + result = context.mark_node(std::make_shared( + lower_values, context.mark_node(std::make_shared(weights, upper_values)))); + } else if (interpolation == "lower") { + result = lower_values; + } else if (interpolation == "higher") { + result = upper_values; + } else if (interpolation == "nearest") { + auto nearest_indices = context.mark_node(std::make_shared(indices)); + result = context.mark_node(std::make_shared(sorted, nearest_indices, dim)); + } else if (interpolation == "midpoint") { + result = context.mark_node(std::make_shared( + lower_values, context.mark_node(std::make_shared( + context.mark_node(std::make_shared(element::f32, Shape{}, 0.5)), + context.mark_node(std::make_shared(upper_values, lower_values)))))); + } else { + throw std::runtime_error("Unsupported interpolation method: " + interpolation); + } + if (!keepdim) { + auto reshape_dims = input.get_shape(); + reshape_dims.erase(reshape_dims.begin() + dim); + result = context.mark_node(std::make_shared(result, reshape_dims, true)); + } + + return {result}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ef75a253f7506a..a1e05bd0dfb799 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -191,6 +191,7 @@ OP_CONVERTER(translate_quantized_add); OP_CONVERTER(translate_quantized_add_relu); OP_CONVERTER(translate_quantized_hardswish); OP_CONVERTER(translate_quantized_mul); +OP_CONVERTER(translate_quantile); OP_CONVERTER(translate_range_length); OP_CONVERTER(translate_rand); OP_CONVERTER(translate_randn); @@ -747,6 +748,7 @@ const std::unordered_map get_supported_ops_ts() { {"quantized::hardswish", op::translate_quantized_hardswish}, {"quantized::linear", op::translate_quantized_linear}, {"quantized::mul", op::translate_quantized_mul}, + {"quantized::relu", op::translate_quantile}, {"torchvision::deform_conv2d", op::translate_deform_conv}, {"torchvision::nms", op::translate_nms}, {"torchvision::roi_align", op::translate_roi_align}, diff --git a/tests/layer_tests/pytorch_tests/test_quantile.py b/tests/layer_tests/pytorch_tests/test_quantile.py new file mode 100644 index 00000000000000..4dff96d4c96de2 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantile.py @@ -0,0 +1,36 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np +import torch +from pytorch_layer_test_class import PytorchLayerTest + + +class TestQuantile(PytorchLayerTest): + def _prepare_input(self): + input_tensor = np.random.randn(1, 3, 224, 224).astype(np.float32) + quantile = np.array(0.5, dtype=np.float32) + return (input_tensor, quantile) + + def create_model(self, dim=None, keepdim=False): + class aten_quantile(torch.nn.Module): + def __init__(self, dim, keepdim): + super(aten_quantile, self).__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x, q): + return torch.quantile(x, q, dim=self.dim, keepdim=self.keepdim) + + ref_net = None + + return aten_quantile(dim, keepdim), ref_net, "aten::quantile" + + @pytest.mark.parametrize("dim", [None, 0, 1, 2, 3, -1, -2, -3]) + @pytest.mark.parametrize("keepdim", [True, False]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_quantile(self, dim, keepdim, ie_device, precision, ir_version): + self._test(*self.create_model(dim, keepdim), ie_device, precision, ir_version) +