From b2a52fc9dfdea0dd3fbb04d329d51e72f5705ed6 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Sat, 11 Jan 2025 17:39:01 -0800 Subject: [PATCH 1/9] TorchFX: Decomposition for aten._unsafe_index --- .../src/openvino/frontend/pytorch/torchdynamo/decompositions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py index eb117f56ab167d..ddb7c5a9db519d 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py @@ -107,6 +107,7 @@ def get_aot_decomposition_list(): torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.aten._softmax.default, torch.ops.aten._softmax_backward_data.default, + torch.ops.aten._unsafe_index, torch.ops.aten.convolution_backward.default, torch.ops.aten.gelu_backward.default, torch.ops.aten.native_group_norm.default, From 31a11ee54a6ce97b3c862f41421df41b8971a1b9 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Sat, 11 Jan 2025 17:40:16 -0800 Subject: [PATCH 2/9] TorchFX: Initial support for aten._weight_int4pack_mm.default --- .../pytorch/torchdynamo/op_support.py | 1 + src/frontends/pytorch/src/frontend.cpp | 2 + .../torchfx_torchao_pattern_replacer.cpp | 199 ++++++++++++++++++ .../torchfx_torchao_pattern_replacer.hpp | 26 +++ 4 files changed, 228 insertions(+) create mode 100644 src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp create mode 100644 src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.hpp diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index b86ae857847a1f..12f8b497e48f2e 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -51,6 +51,7 @@ def __init__(self, options): "torch.ops.aten._softmax.default": None, "torch.ops.aten._to_copy.default": None, "torch.ops.aten._unsafe_view.default": None, + "torch.ops.aten._weight_int4pack_mm.default": None, "torch.ops.aten.abs.default": None, "torch.ops.aten.acos.default": None, "torch.ops.aten.acosh.default": None, diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 2b0ab6db9d3a09..a806bd44ef3464 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -44,6 +44,7 @@ #include "transforms/softmax_reshape_elimination.hpp" #include "transforms/string_equality_replacer.hpp" #include "transforms/torchfx_gptq_pattern_replacer.hpp" +#include "transforms/torchfx_torchao_pattern_replacer.hpp" #include "transforms/tuple_unpack_replacer.hpp" #include "transforms/u4_block_repack.hpp" #include "translate_session.hpp" @@ -258,6 +259,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { ov::pass::Manager manager("Frontend:Pytorch:normalize::fx_gptq"); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.run_passes(model); } diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp new file mode 100644 index 00000000000000..096a6db38fbd02 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp @@ -0,0 +1,199 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "torchfx_torchao_pattern_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::op; +using namespace ov::pass::pattern; + +WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { + const auto& const_1 = wrap_type(); + const auto& const_2 = wrap_type(); + const auto& const_3 = wrap_type(); + + const auto& weight_int4pack_mm = wrap_type({any_input(), const_1, const_2, const_3}); + + ov::matcher_pass_callback callback = [=](Matcher& m) { + auto weight_int4pack_mm = m.get_match_root(); + if (!weight_int4pack_mm) { + return false; + } + const auto& pattern_map = m.get_pattern_value_map(); + if (!(weight_int4pack_mm = cast_fw_node(m.get_match_root(), "aten._weight_int4pack_mm.default"))) { + return false; + } + auto wt_const = std::dynamic_pointer_cast(weight_int4pack_mm->get_input_node_shared_ptr(1)); + auto wt_const_flat = std::make_shared(wt_const->get_element_type(), + Shape({shape_size(wt_const->get_shape()), 1}), + wt_const->get_data_ptr()); + std::vector broadcast_shape_vec(wt_const_flat->get_shape()); + broadcast_shape_vec[1] = 8; + auto broadcast_shape_const = std::make_shared(element::i32, Shape({2}), broadcast_shape_vec); + + auto broadcast = std::make_shared(wt_const_flat, broadcast_shape_const); + OutputVector outputs_broadcast(broadcast->get_output_size()); + OutputVector broadcast_inputs(2); + broadcast_inputs[0] = wt_const_flat->outputs()[0]; + broadcast_inputs[1] = broadcast_shape_const->outputs()[0]; + if (!broadcast->constant_fold(outputs_broadcast, broadcast_inputs)) { + return false; + } + + auto broadcast_out_const = std::dynamic_pointer_cast(outputs_broadcast[0].get_node_shared_ptr()); + auto broadcast_out_ptr = + const_cast(reinterpret_cast(broadcast_out_const->get_data_ptr())); + for (size_t k = 0; k < shape_size(outputs_broadcast[0].get_shape()); ++k) { + int32_t shift_val = (k % 8) * 4; + broadcast_out_ptr[k] = (broadcast_out_ptr[k] >> shift_val) & 15; + } + std::vector wt_ordered_shape(2); + wt_ordered_shape[0] = wt_const->get_shape()[0] * 8; + wt_ordered_shape[1] = wt_const->get_shape()[1] * wt_const->get_shape()[2] * wt_const->get_shape()[3]; + auto wt_const_ordered = std::make_shared(wt_const->get_element_type(), Shape(wt_ordered_shape)); + auto wt_ordered_ptr = const_cast(reinterpret_cast(wt_const_ordered->get_data_ptr())); + for (uint64_t b = 0; b < wt_ordered_shape[0] / 64; b++) { + for (uint64_t j = 0; j < wt_ordered_shape[1]; j++) { + for (uint64_t i = 0; i < 32; i++) { + uint64_t l = 0; + uint64_t m = (i * 2); + l = b * 64 * (broadcast_out_const->get_shape()[0] / wt_ordered_shape[0]) + j * 8 + + (m / broadcast_out_const->get_shape()[1]); + m = m % broadcast_out_const->get_shape()[1]; + wt_ordered_ptr[(b * 64 + i) * wt_ordered_shape[1] + j] = broadcast_out_ptr[l * 8 + m]; + wt_ordered_ptr[(b * 64 + i + 32) * wt_ordered_shape[1] + j] = broadcast_out_ptr[l * 8 + m + 1]; + } + } + } + + auto transpose_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); + auto transpose = std::make_shared(wt_const_ordered, transpose_order); + OutputVector outputs_transpose(transpose->get_output_size()); + OutputVector transpose_inputs(2); + transpose_inputs[0] = wt_const_ordered->outputs()[0]; + transpose_inputs[1] = transpose_order->outputs()[0]; + if (!transpose->constant_fold(outputs_transpose, transpose_inputs)) { + return false; + } + auto g_const = std::dynamic_pointer_cast(weight_int4pack_mm->get_input_node_shared_ptr(2)); + auto g_ptr = const_cast(reinterpret_cast(g_const->get_data_ptr())); + uint64_t g = (uint64_t)(g_ptr[0]); + if (g > outputs_transpose[0].get_shape()[1]) { + g = outputs_transpose[0].get_shape()[1]; + } + auto wt_i32 = std::make_shared( + outputs_transpose[0].get_node_shared_ptr()->get_element_type(), + Shape({outputs_transpose[0].get_node_shared_ptr()->get_shape()[0] / g, + g, + outputs_transpose[0].get_node_shared_ptr()->get_shape()[1]}), + std::dynamic_pointer_cast(outputs_transpose[0].get_node_shared_ptr()) + ->get_data_ptr()); + + auto wt_i32_ptr = const_cast(reinterpret_cast(wt_i32->get_data_ptr())); + + auto convert_to_u4 = std::make_shared(wt_i32, element::u4); + OutputVector outputs_to_u4(convert_to_u4->get_output_size()); + if (!convert_to_u4->constant_fold(outputs_to_u4, wt_i32->outputs())) { + return false; + } + + auto sz_const = std::dynamic_pointer_cast(weight_int4pack_mm->get_input_node_shared_ptr(3)); + const auto two = v0::Constant::create(element::i32, Shape{}, {2}); + const auto split = std::make_shared(sz_const, two, 2); + OutputVector outputs_split(split->get_output_size()); + OutputVector split_inputs(2); + split_inputs[0] = sz_const->outputs()[0]; + split_inputs[1] = two->outputs()[0]; + if (!split->constant_fold(outputs_split, split_inputs)) { + return false; + } + + const auto divide = std::make_shared(outputs_split[1], outputs_split[0]); + OutputVector outputs_divide(divide->get_output_size()); + OutputVector divide_inputs(2); + divide_inputs[0] = outputs_split[1]; + divide_inputs[1] = outputs_split[0]; + if (!divide->constant_fold(outputs_divide, divide_inputs)) { + return false; + } + + const auto eight = v0::Constant::create(sz_const->get_element_type(), Shape{}, {8.0}); + const auto subtract = std::make_shared(eight, outputs_divide[0]); + OutputVector outputs_subtract(subtract->get_output_size()); + OutputVector subtract_inputs(2); + subtract_inputs[0] = eight->outputs()[0]; + subtract_inputs[1] = outputs_divide[0]; + if (!subtract->constant_fold(outputs_subtract, subtract_inputs)) { + return false; + } + + auto new_scales = std::make_shared( + outputs_split[0].get_element_type(), + Shape({outputs_split[0].get_shape()[0], 1, outputs_split[0].get_shape()[1]}), + std::dynamic_pointer_cast(outputs_split[0].get_node_shared_ptr())->get_data_ptr()); + auto new_zeros = std::make_shared( + outputs_subtract[0].get_element_type(), + Shape({outputs_subtract[0].get_shape()[0], 1, outputs_subtract[0].get_shape()[1]}), + std::dynamic_pointer_cast(outputs_subtract[0].get_node_shared_ptr()) + ->get_data_ptr()); + + auto convert_scales_to_float = std::make_shared(new_scales, element::f32); + OutputVector outputs_scales_to_float(convert_scales_to_float->get_output_size()); + if (!convert_scales_to_float->constant_fold(outputs_scales_to_float, new_scales->outputs())) { + return false; + } + auto new_scales_ptr = const_cast(reinterpret_cast( + std::dynamic_pointer_cast(outputs_scales_to_float[0].get_node_shared_ptr())->get_data_ptr())); + + auto convert_zeros_to_float = std::make_shared(new_zeros, element::f32); + OutputVector outputs_zeros_to_float(convert_zeros_to_float->get_output_size()); + if (!convert_zeros_to_float->constant_fold(outputs_zeros_to_float, new_zeros->outputs())) { + return false; + } + auto new_zeros_ptr = const_cast(reinterpret_cast( + std::dynamic_pointer_cast(outputs_zeros_to_float[0].get_node_shared_ptr())->get_data_ptr())); + + auto new_convert = + std::make_shared(outputs_to_u4[0].get_node_shared_ptr(), new_zeros->get_element_type()); + auto new_subtract = std::make_shared(new_convert, new_zeros); + auto new_mult = std::make_shared(new_subtract, new_scales); + auto new_shape = v0::Constant::create(element::i32, + Shape{outputs_transpose[0].get_shape().size()}, + outputs_transpose[0].get_shape()); + auto new_reshape = std::make_shared(new_mult, new_shape, false); + auto new_matmul = std::make_shared(weight_int4pack_mm->get_input_node_shared_ptr(0), new_reshape); + + replace_node(weight_int4pack_mm, new_matmul); + + return true; + }; + + auto m = std::make_shared(weight_int4pack_mm, "ov::frontend::pytorch::pass::WeightINT4PackMMReplacer"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.hpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.hpp new file mode 100644 index 00000000000000..e6bdb9453627f2 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +// This transformation replaces aten._weight_int4pack_mm op with a decompression +// pattern which can be captured by OpenVINO device plugins +class WeightINT4PackMMReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::WeightINT4PackMMReplacer"); + WeightINT4PackMMReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov From 91368bed1cfa314f456d017c9a1f25d3505dc6bd Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Sat, 11 Jan 2025 17:40:44 -0800 Subject: [PATCH 3/9] TorchFX: added layer test for aten._weight_int4pack_mm.default --- .../pytorch_tests/test_weight_int4pack_mm.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py diff --git a/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py b/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py new file mode 100644 index 00000000000000..ff9c2e1bd92d62 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest +import torch +from pytorch_layer_test_class import PytorchLayerTest + +class TestWeightInt4PackMMOperation(PytorchLayerTest): + def _prepare_input(self, input_size): + return (torch.rand(1,input_size,dtype=torch.float32),) + + def create_model(self, input_size, group_size, inner_k_tiles): + class CustomWeightInt4PackMMOperation(torch.nn.Module): + def __init__(self): + super(CustomWeightInt4PackMMOperation, self).__init__() + # TODO: Generating random quantized weights, scales, and zero points is not ideal. + # An actual quantizing algorithm could be implemented for more accurate testing. + self.gs = group_size + w = torch.randint(low=0,high=15,size=[input_size,input_size], dtype=torch.int32) + self.wq = torch.ops.aten._convert_weight_to_int4pack(w, inner_k_tiles) + scales = torch.randint(low=1,high=100,size=[int(input_size/self.gs),input_size,1], dtype=torch.int32).to(dtype=torch.bfloat16)/10.0 + zeros = torch.ones(int(input_size/self.gs),input_size,1, dtype=torch.bfloat16)*8.0 + self.sz = torch.cat((scales,zeros), 2) + def forward(self, x): + return torch.ops.aten._weight_int4pack_mm(x.to(dtype=torch.bfloat16), self.wq, self.gs, self.sz).to(dtype=torch.float32) + + model_class = CustomWeightInt4PackMMOperation() + ref_net = None + return model_class, ref_net, "aten._weight_int4pack_mm.default" + + @pytest.mark.precommit_fx_backend + @pytest.mark.parametrize("input_size, group_size, inner_k_tiles, dtype", [ + (1024, 32, 2, torch.float32), + (1024, 32, 4, torch.float32), + (1024, 32, 8, torch.float32), + (4096, 32, 2, torch.float32), + (4096, 32, 4, torch.float32), + (4096, 32, 8, torch.float32), + (4096, 64, 2, torch.float32), + (4096, 64, 4, torch.float32), + (4096, 64, 8, torch.float32), + ]) + def test_weight_int4pack_mm_operation(self, input_size, group_size, inner_k_tiles, dtype, ie_device, precision, ir_version): + # Due to precision errors, the output accuracy may change based on the system this test it running on. + # The eps is adjusted accordingly, but overall model accuracy should be observed in full model tests as well. + self._test( + *self.create_model(input_size, group_size, inner_k_tiles), + ie_device, + precision, + ir_version, + kwargs_to_prepare_input={"input_size": input_size}, + aot_autograd=True, + custom_eps=128.0 + ) From 5705512344c8a0269077c512e607ebee3086f083 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Sat, 11 Jan 2025 18:29:17 -0800 Subject: [PATCH 4/9] Unused variables removed --- .../src/transforms/torchfx_torchao_pattern_replacer.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp index 096a6db38fbd02..c3469b649821cc 100644 --- a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp @@ -40,7 +40,6 @@ WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { if (!weight_int4pack_mm) { return false; } - const auto& pattern_map = m.get_pattern_value_map(); if (!(weight_int4pack_mm = cast_fw_node(m.get_match_root(), "aten._weight_int4pack_mm.default"))) { return false; } @@ -110,8 +109,6 @@ WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { std::dynamic_pointer_cast(outputs_transpose[0].get_node_shared_ptr()) ->get_data_ptr()); - auto wt_i32_ptr = const_cast(reinterpret_cast(wt_i32->get_data_ptr())); - auto convert_to_u4 = std::make_shared(wt_i32, element::u4); OutputVector outputs_to_u4(convert_to_u4->get_output_size()); if (!convert_to_u4->constant_fold(outputs_to_u4, wt_i32->outputs())) { @@ -163,16 +160,12 @@ WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { if (!convert_scales_to_float->constant_fold(outputs_scales_to_float, new_scales->outputs())) { return false; } - auto new_scales_ptr = const_cast(reinterpret_cast( - std::dynamic_pointer_cast(outputs_scales_to_float[0].get_node_shared_ptr())->get_data_ptr())); auto convert_zeros_to_float = std::make_shared(new_zeros, element::f32); OutputVector outputs_zeros_to_float(convert_zeros_to_float->get_output_size()); if (!convert_zeros_to_float->constant_fold(outputs_zeros_to_float, new_zeros->outputs())) { return false; } - auto new_zeros_ptr = const_cast(reinterpret_cast( - std::dynamic_pointer_cast(outputs_zeros_to_float[0].get_node_shared_ptr())->get_data_ptr())); auto new_convert = std::make_shared(outputs_to_u4[0].get_node_shared_ptr(), new_zeros->get_element_type()); From a91da9e73e08b9f6c0c3f94db5cdfb2f19c13f8a Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Sat, 11 Jan 2025 22:20:21 -0800 Subject: [PATCH 5/9] Disable aten.weight_int4pack_mm test for PyTorch versions later than 2.4 --- tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py b/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py index ff9c2e1bd92d62..bb68f306f33396 100644 --- a/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py +++ b/tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +from packaging import version from pytorch_layer_test_class import PytorchLayerTest class TestWeightInt4PackMMOperation(PytorchLayerTest): @@ -39,6 +40,11 @@ def forward(self, x): (4096, 64, 8, torch.float32), ]) def test_weight_int4pack_mm_operation(self, input_size, group_size, inner_k_tiles, dtype, ie_device, precision, ir_version): + # TODO: Input requirements for aten._convert_weight_to_int4pack changed after PyTorch 2.4 which was used to prepare + # weight input for this test. Disabling the test for PyTorch versions later than 2.4 until this issue is resolved. + if version.parse(torch.__version__) >= version.parse("2.5"): + pytest.skip("Current test is not supported PyTorch versions later than 2.4 due to weight handling updates in aten._convert_weight_to_int4pack.") + # Due to precision errors, the output accuracy may change based on the system this test it running on. # The eps is adjusted accordingly, but overall model accuracy should be observed in full model tests as well. self._test( From 0de3fee040b428f1c2e6e4f14ce6830cdbe6945d Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Mon, 13 Jan 2025 13:29:36 -0800 Subject: [PATCH 6/9] Moved aten._unsafe_index decomposition into get_inf_decomposition_list --- .../openvino/frontend/pytorch/torchdynamo/decompositions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py index ddb7c5a9db519d..18d2f09ed1e7ee 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py @@ -107,7 +107,6 @@ def get_aot_decomposition_list(): torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.aten._softmax.default, torch.ops.aten._softmax_backward_data.default, - torch.ops.aten._unsafe_index, torch.ops.aten.convolution_backward.default, torch.ops.aten.gelu_backward.default, torch.ops.aten.native_group_norm.default, @@ -119,7 +118,10 @@ def get_aot_decomposition_list(): def get_inf_decomposition_list(): - return [torch.ops.aten.nll_loss_forward.default] + return [ + torch.ops.aten._unsafe_index, + torch.ops.aten.nll_loss_forward.default, + ] def get_export_decomposition_list(): From 641d1920d97617d1325c41372c84c24840241c43 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Mon, 13 Jan 2025 14:26:31 -0800 Subject: [PATCH 7/9] Type casting fix --- .../src/transforms/torchfx_torchao_pattern_replacer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp index c3469b649821cc..6bda18ff34e99f 100644 --- a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp @@ -103,8 +103,8 @@ WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { } auto wt_i32 = std::make_shared( outputs_transpose[0].get_node_shared_ptr()->get_element_type(), - Shape({outputs_transpose[0].get_node_shared_ptr()->get_shape()[0] / g, - g, + Shape({static_cast(outputs_transpose[0].get_node_shared_ptr()->get_shape()[0] / g), + static_cast(g), outputs_transpose[0].get_node_shared_ptr()->get_shape()[1]}), std::dynamic_pointer_cast(outputs_transpose[0].get_node_shared_ptr()) ->get_data_ptr()); From b765bb25a852eb3461c44787ea887521ea52b54d Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Fri, 17 Jan 2025 09:45:56 -0800 Subject: [PATCH 8/9] Data type update as some compilers fail building --- .../pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp index 6bda18ff34e99f..ff1b7f471bf6e5 100644 --- a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp @@ -47,7 +47,7 @@ WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { auto wt_const_flat = std::make_shared(wt_const->get_element_type(), Shape({shape_size(wt_const->get_shape()), 1}), wt_const->get_data_ptr()); - std::vector broadcast_shape_vec(wt_const_flat->get_shape()); + std::vector broadcast_shape_vec(wt_const_flat->get_shape()); broadcast_shape_vec[1] = 8; auto broadcast_shape_const = std::make_shared(element::i32, Shape({2}), broadcast_shape_vec); From a81b93f70819460e6d19f107efaaac94c14e1f8d Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Fri, 17 Jan 2025 10:10:52 -0800 Subject: [PATCH 9/9] Data type fix as some compiler fail building --- .../pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp index ff1b7f471bf6e5..2dd00a6135329d 100644 --- a/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/torchfx_torchao_pattern_replacer.cpp @@ -67,7 +67,7 @@ WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() { int32_t shift_val = (k % 8) * 4; broadcast_out_ptr[k] = (broadcast_out_ptr[k] >> shift_val) & 15; } - std::vector wt_ordered_shape(2); + std::vector wt_ordered_shape(2); wt_ordered_shape[0] = wt_const->get_shape()[0] * 8; wt_ordered_shape[1] = wt_const->get_shape()[1] * wt_const->get_shape()[2] * wt_const->get_shape()[3]; auto wt_const_ordered = std::make_shared(wt_const->get_element_type(), Shape(wt_ordered_shape));