Skip to content

Commit

Permalink
Added support for aten::quantile and its tests
Browse files Browse the repository at this point in the history
  • Loading branch information
geeky33 committed Jan 21, 2025
1 parent 54da6dd commit 7935cd9
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
89 changes: 89 additions & 0 deletions src/frontends/pytorch/src/op/quantile.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/floor.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/minimum.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_quantile(const NodeContext& context) {
num_inputs_check(context, 2, 4);

auto input = context.get_input(0);
auto quantiles = context.get_input(1);

auto dim = context.input_is_none(2) ? -1 : context.get_input<int64_t>(2);
auto keepdim = context.input_is_none(3) ? false : context.get_input<bool>(3);

if (dim == -1) {
input = context.mark_node(std::make_shared<v0::Reshape>(
input, context.mark_node(v0::Constant::create(element::i64, {1}, {-1})), true));
dim = 0;
}


auto sort_result = context.mark_node(std::make_shared<v0::Sort>(input, dim, true));
auto sorted_tensor = sort_result->output(0);


auto input_shape = context.mark_node(std::make_shared<v0::ShapeOf>(input));
auto dim_size = context.mark_node(std::make_shared<v0::Gather>(
input_shape, context.mark_node(v0::Constant::create(element::i64, {}, {dim})),
v0::Constant::create(element::i64, {}, {0})));

auto scaled_q = context.mark_node(std::make_shared<v1::Multiply>(
quantiles, context.mark_node(std::make_shared<v1::Subtract>(
dim_size, v0::Constant::create(element::i64, {}, {1})))));
auto lower_indices = context.mark_node(std::make_shared<v0::Floor>(scaled_q));
auto upper_indices = context.mark_node(std::make_shared<v1::Add>(
lower_indices, v0::Constant::create(element::i64, {}, {1})));

lower_indices = context.mark_node(std::make_shared<v1::Maximum>(
lower_indices, v0::Constant::create(element::i64, {}, {0})));
upper_indices = context.mark_node(std::make_shared<v1::Minimum>(
upper_indices, context.mark_node(std::make_shared<v1::Subtract>(
dim_size, v0::Constant::create(element::i64, {}, {1})))));


auto lower_values = context.mark_node(std::make_shared<v1::Gather>(sorted_tensor, lower_indices, dim));
auto upper_values = context.mark_node(std::make_shared<v1::Gather>(sorted_tensor, upper_indices, dim));

auto weights = context.mark_node(std::make_shared<v1::Subtract>(scaled_q, lower_indices));


auto result = context.mark_node(std::make_shared<v1::Add>(
lower_values, context.mark_node(std::make_shared<v1::Multiply>(weights, context.mark_node(std::make_shared<v1::Subtract>(upper_values, lower_values))))));

if (!keepdim) {
auto input_shape = context.mark_node(std::make_shared<v0::ShapeOf>(input));
auto output_shape = context.mark_node(std::make_shared<v1::Gather>(
input_shape,
context.mark_node(v0::Constant::create(element::i64, {1}, {dim})),
v0::Constant::create(element::i64, {}, {0})));
result = context.mark_node(std::make_shared<v0::Reshape>(result, output_shape, true));
}

return {result};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ OP_CONVERTER(translate_quantized_add);
OP_CONVERTER(translate_quantized_add_relu);
OP_CONVERTER(translate_quantized_hardswish);
OP_CONVERTER(translate_quantized_mul);
OP_CONVERTER(translate_quantile);
OP_CONVERTER(translate_range_length);
OP_CONVERTER(translate_rand);
OP_CONVERTER(translate_randn);
Expand Down Expand Up @@ -745,6 +746,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"quantized::hardswish", op::translate_quantized_hardswish},
{"quantized::linear", op::translate_quantized_linear},
{"quantized::mul", op::translate_quantized_mul},
{"aten::quantile", op::translate_quantile},
{"torchvision::deform_conv2d", op::translate_deform_conv},
{"torchvision::nms", op::translate_nms},
{"torchvision::roi_align", op::translate_roi_align},
Expand Down
36 changes: 36 additions & 0 deletions tests/layer_tests/pytorch_tests/test_quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import numpy as np
import torch
from pytorch_layer_test_class import PytorchLayerTest


class TestQuantile(PytorchLayerTest):
def _prepare_input(self):
input_tensor = np.random.randn(1, 3, 224, 224).astype(np.float32)
quantile = np.array(0.5, dtype=np.float32)
return (input_tensor, quantile)

def create_model(self, dim=None, keepdim=False):
class aten_quantile(torch.nn.Module):
def __init__(self, dim, keepdim):
super(aten_quantile, self).__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, x, q):
return torch.quantile(x, q, dim=self.dim, keepdim=self.keepdim)

ref_net = None

return aten_quantile(dim, keepdim), ref_net, "aten::quantile"

@pytest.mark.parametrize("dim", [None, 0, 1, 2, 3, -1, -2, -3])
@pytest.mark.parametrize("keepdim", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_quantile(self, dim, keepdim, ie_device, precision, ir_version):
self._test(*self.create_model(dim, keepdim), ie_device, precision, ir_version)

0 comments on commit 7935cd9

Please sign in to comment.