diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py index 8508615bb44173..49509f55a64c31 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py @@ -118,7 +118,10 @@ def get_aot_decomposition_list(): def get_inf_decomposition_list(): - return [torch.ops.aten.nll_loss_forward.default] + return [ + torch.ops.aten._unsafe_index, + torch.ops.aten.nll_loss_forward.default, + ] def get_export_decomposition_list(): diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 8ca3b7b489f665..d5287315a0d14b 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -51,6 +51,7 @@ def __init__(self, options): "torch.ops.aten._softmax.default": None, "torch.ops.aten._to_copy.default": None, "torch.ops.aten._unsafe_view.default": None, + "torch.ops.aten._weight_int4pack_mm.default": None, "torch.ops.aten.abs.default": None, "torch.ops.aten.acos.default": None, "torch.ops.aten.acosh.default": None, diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index bb69e8fa313130..200611ef6cb536 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -42,6 +42,7 @@ #include "transforms/softmax_reshape_elimination.hpp" #include "transforms/string_equality_replacer.hpp" #include "transforms/torchfx_gptq_pattern_replacer.hpp" +#include "transforms/torchfx_torchao_pattern_replacer.hpp" #include "transforms/tuple_unpack_replacer.hpp" #include "transforms/u4_block_repack.hpp" #include "translate_session.hpp" @@ -261,6 +262,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { ov::pass::Manager manager("Frontend:Pytorch:normalize::fx_gptq"); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.run_passes(model); } diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp new file mode 100644 index 00000000000000..2dd00a6135329d --- /dev/null +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp @@ -0,0 +1,192 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "torchfx_torchao_pattern_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::op; +using namespace ov::pass::pattern; + +WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { + const auto& const_1 = wrap_type(); + const auto& const_2 = wrap_type(); + const auto& const_3 = wrap_type(); + + const auto& weight_int4pack_mm = wrap_type({any_input(), const_1, const_2, const_3}); + + ov::matcher_pass_callback callback = [=](Matcher& m) { + auto weight_int4pack_mm = m.get_match_root(); + if (!weight_int4pack_mm) { + return false; + } + if (!(weight_int4pack_mm = cast_fw_node(m.get_match_root(), "aten._weight_int4pack_mm.default"))) { + return false; + } + auto wt_const = std::dynamic_pointer_cast(weight_int4pack_mm->get_input_node_shared_ptr(1)); + auto wt_const_flat = std::make_shared(wt_const->get_element_type(), + Shape({shape_size(wt_const->get_shape()), 1}), + wt_const->get_data_ptr()); + std::vector broadcast_shape_vec(wt_const_flat->get_shape()); + broadcast_shape_vec[1] = 8; + auto broadcast_shape_const = std::make_shared(element::i32, Shape({2}), broadcast_shape_vec); + + auto broadcast = std::make_shared(wt_const_flat, broadcast_shape_const); + OutputVector outputs_broadcast(broadcast->get_output_size()); + OutputVector broadcast_inputs(2); + broadcast_inputs[0] = wt_const_flat->outputs()[0]; + broadcast_inputs[1] = broadcast_shape_const->outputs()[0]; + if (!broadcast->constant_fold(outputs_broadcast, broadcast_inputs)) { + return false; + } + + auto broadcast_out_const = std::dynamic_pointer_cast(outputs_broadcast[0].get_node_shared_ptr()); + auto broadcast_out_ptr = + const_cast(reinterpret_cast(broadcast_out_const->get_data_ptr())); + for (size_t k = 0; k < shape_size(outputs_broadcast[0].get_shape()); ++k) { + int32_t shift_val = (k % 8) * 4; + broadcast_out_ptr[k] = (broadcast_out_ptr[k] >> shift_val) & 15; + } + std::vector wt_ordered_shape(2); + wt_ordered_shape[0] = wt_const->get_shape()[0] * 8; + wt_ordered_shape[1] = wt_const->get_shape()[1] * wt_const->get_shape()[2] * wt_const->get_shape()[3]; + auto wt_const_ordered = std::make_shared(wt_const->get_element_type(), Shape(wt_ordered_shape)); + auto wt_ordered_ptr = const_cast(reinterpret_cast(wt_const_ordered->get_data_ptr())); + for (uint64_t b = 0; b < wt_ordered_shape[0] / 64; b++) { + for (uint64_t j = 0; j < wt_ordered_shape[1]; j++) { + for (uint64_t i = 0; i < 32; i++) { + uint64_t l = 0; + uint64_t m = (i * 2); + l = b * 64 * (broadcast_out_const->get_shape()[0] / wt_ordered_shape[0]) + j * 8 + + (m / broadcast_out_const->get_shape()[1]); + m = m % broadcast_out_const->get_shape()[1]; + wt_ordered_ptr[(b * 64 + i) * wt_ordered_shape[1] + j] = broadcast_out_ptr[l * 8 + m]; + wt_ordered_ptr[(b * 64 + i + 32) * wt_ordered_shape[1] + j] = broadcast_out_ptr[l * 8 + m + 1]; + } + } + } + + auto transpose_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); + auto transpose = std::make_shared(wt_const_ordered, transpose_order); + OutputVector outputs_transpose(transpose->get_output_size()); + OutputVector transpose_inputs(2); + transpose_inputs[0] = wt_const_ordered->outputs()[0]; + transpose_inputs[1] = transpose_order->outputs()[0]; + if (!transpose->constant_fold(outputs_transpose, transpose_inputs)) { + return false; + } + auto g_const = std::dynamic_pointer_cast(weight_int4pack_mm->get_input_node_shared_ptr(2)); + auto g_ptr = const_cast(reinterpret_cast(g_const->get_data_ptr())); + uint64_t g = (uint64_t)(g_ptr[0]); + if (g > outputs_transpose[0].get_shape()[1]) { + g = outputs_transpose[0].get_shape()[1]; + } + auto wt_i32 = std::make_shared( + outputs_transpose[0].get_node_shared_ptr()->get_element_type(), + Shape({static_cast(outputs_transpose[0].get_node_shared_ptr()->get_shape()[0] / g), + static_cast(g), + outputs_transpose[0].get_node_shared_ptr()->get_shape()[1]}), + std::dynamic_pointer_cast(outputs_transpose[0].get_node_shared_ptr()) + ->get_data_ptr()); + + auto convert_to_u4 = std::make_shared(wt_i32, element::u4); + OutputVector outputs_to_u4(convert_to_u4->get_output_size()); + if (!convert_to_u4->constant_fold(outputs_to_u4, wt_i32->outputs())) { + return false; + } + + auto sz_const = std::dynamic_pointer_cast(weight_int4pack_mm->get_input_node_shared_ptr(3)); + const auto two = v0::Constant::create(element::i32, Shape{}, {2}); + const auto split = std::make_shared(sz_const, two, 2); + OutputVector outputs_split(split->get_output_size()); + OutputVector split_inputs(2); + split_inputs[0] = sz_const->outputs()[0]; + split_inputs[1] = two->outputs()[0]; + if (!split->constant_fold(outputs_split, split_inputs)) { + return false; + } + + const auto divide = std::make_shared(outputs_split[1], outputs_split[0]); + OutputVector outputs_divide(divide->get_output_size()); + OutputVector divide_inputs(2); + divide_inputs[0] = outputs_split[1]; + divide_inputs[1] = outputs_split[0]; + if (!divide->constant_fold(outputs_divide, divide_inputs)) { + return false; + } + + const auto eight = v0::Constant::create(sz_const->get_element_type(), Shape{}, {8.0}); + const auto subtract = std::make_shared(eight, outputs_divide[0]); + OutputVector outputs_subtract(subtract->get_output_size()); + OutputVector subtract_inputs(2); + subtract_inputs[0] = eight->outputs()[0]; + subtract_inputs[1] = outputs_divide[0]; + if (!subtract->constant_fold(outputs_subtract, subtract_inputs)) { + return false; + } + + auto new_scales = std::make_shared( + outputs_split[0].get_element_type(), + Shape({outputs_split[0].get_shape()[0], 1, outputs_split[0].get_shape()[1]}), + std::dynamic_pointer_cast(outputs_split[0].get_node_shared_ptr())->get_data_ptr()); + auto new_zeros = std::make_shared( + outputs_subtract[0].get_element_type(), + Shape({outputs_subtract[0].get_shape()[0], 1, outputs_subtract[0].get_shape()[1]}), + std::dynamic_pointer_cast(outputs_subtract[0].get_node_shared_ptr()) + ->get_data_ptr()); + + auto convert_scales_to_float = std::make_shared(new_scales, element::f32); + OutputVector outputs_scales_to_float(convert_scales_to_float->get_output_size()); + if (!convert_scales_to_float->constant_fold(outputs_scales_to_float, new_scales->outputs())) { + return false; + } + + auto convert_zeros_to_float = std::make_shared(new_zeros, element::f32); + OutputVector outputs_zeros_to_float(convert_zeros_to_float->get_output_size()); + if (!convert_zeros_to_float->constant_fold(outputs_zeros_to_float, new_zeros->outputs())) { + return false; + } + + auto new_convert = + std::make_shared(outputs_to_u4[0].get_node_shared_ptr(), new_zeros->get_element_type()); + auto new_subtract = std::make_shared(new_convert, new_zeros); + auto new_mult = std::make_shared(new_subtract, new_scales); + auto new_shape = v0::Constant::create(element::i32, + Shape{outputs_transpose[0].get_shape().size()}, + outputs_transpose[0].get_shape()); + auto new_reshape = std::make_shared(new_mult, new_shape, false); + auto new_matmul = std::make_shared(weight_int4pack_mm->get_input_node_shared_ptr(0), new_reshape); + + replace_node(weight_int4pack_mm, new_matmul); + + return true; + }; + + auto m = std::make_shared(weight_int4pack_mm, "ov::frontend::pytorch::pass::WeightINT4PackMMReplacer"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.hpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.hpp new file mode 100644 index 00000000000000..e6bdb9453627f2 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +// This transformation replaces aten._weight_int4pack_mm op with a decompression +// pattern which can be captured by OpenVINO device plugins +class WeightINT4PackMMReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::WeightINT4PackMMReplacer"); + WeightINT4PackMMReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py b/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py new file mode 100644 index 00000000000000..bb68f306f33396 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest +import torch +from packaging import version +from pytorch_layer_test_class import PytorchLayerTest + +class TestWeightInt4PackMMOperation(PytorchLayerTest): + def _prepare_input(self, input_size): + return (torch.rand(1,input_size,dtype=torch.float32),) + + def create_model(self, input_size, group_size, inner_k_tiles): + class CustomWeightInt4PackMMOperation(torch.nn.Module): + def __init__(self): + super(CustomWeightInt4PackMMOperation, self).__init__() + # TODO: Generating random quantized weights, scales, and zero points is not ideal. + # An actual quantizing algorithm could be implemented for more accurate testing. + self.gs = group_size + w = torch.randint(low=0,high=15,size=[input_size,input_size], dtype=torch.int32) + self.wq = torch.ops.aten._convert_weight_to_int4pack(w, inner_k_tiles) + scales = torch.randint(low=1,high=100,size=[int(input_size/self.gs),input_size,1], dtype=torch.int32).to(dtype=torch.bfloat16)/10.0 + zeros = torch.ones(int(input_size/self.gs),input_size,1, dtype=torch.bfloat16)*8.0 + self.sz = torch.cat((scales,zeros), 2) + def forward(self, x): + return torch.ops.aten._weight_int4pack_mm(x.to(dtype=torch.bfloat16), self.wq, self.gs, self.sz).to(dtype=torch.float32) + + model_class = CustomWeightInt4PackMMOperation() + ref_net = None + return model_class, ref_net, "aten._weight_int4pack_mm.default" + + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize("input_size, group_size, inner_k_tiles, dtype", [ + (1024, 32, 2, torch.float32), + (1024, 32, 4, torch.float32), + (1024, 32, 8, torch.float32), + (4096, 32, 2, torch.float32), + (4096, 32, 4, torch.float32), + (4096, 32, 8, torch.float32), + (4096, 64, 2, torch.float32), + (4096, 64, 4, torch.float32), + (4096, 64, 8, torch.float32), + ]) + def test_weight_int4pack_mm_operation(self, input_size, group_size, inner_k_tiles, dtype, ie_device, precision, ir_version): + # TODO: Input requirements for aten._convert_weight_to_int4pack changed after PyTorch 2.4 which was used to prepare + # weight input for this test. Disabling the test for PyTorch versions later than 2.4 until this issue is resolved. + if version.parse(torch.__version__) >= version.parse("2.5"): + pytest.skip("Current test is not supported PyTorch versions later than 2.4 due to weight handling updates in aten._convert_weight_to_int4pack.") + + # Due to precision errors, the output accuracy may change based on the system this test it running on. + # The eps is adjusted accordingly, but overall model accuracy should be observed in full model tests as well. + self._test( + *self.create_model(input_size, group_size, inner_k_tiles), + ie_device, + precision, + ir_version, + kwargs_to_prepare_input={"input_size": input_size}, + aot_autograd=True, + custom_eps=128.0 + )