Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchfx weight int4pack mm initial support #28391

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def get_aot_decomposition_list():


def get_inf_decomposition_list():
return [torch.ops.aten.nll_loss_forward.default]
return [
torch.ops.aten._unsafe_index,
torch.ops.aten.nll_loss_forward.default,
]


def get_export_decomposition_list():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, options):
"torch.ops.aten._softmax.default": None,
"torch.ops.aten._to_copy.default": None,
"torch.ops.aten._unsafe_view.default": None,
"torch.ops.aten._weight_int4pack_mm.default": None,
"torch.ops.aten.abs.default": None,
"torch.ops.aten.acos.default": None,
"torch.ops.aten.acosh.default": None,
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "transforms/softmax_reshape_elimination.hpp"
#include "transforms/string_equality_replacer.hpp"
#include "transforms/torchfx_gptq_pattern_replacer.hpp"
#include "transforms/torchfx_torchao_pattern_replacer.hpp"
#include "transforms/tuple_unpack_replacer.hpp"
#include "transforms/u4_block_repack.hpp"
#include "translate_session.hpp"
Expand Down Expand Up @@ -258,6 +259,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
ov::pass::Manager manager("Frontend:Pytorch:normalize::fx_gptq");
manager.register_pass<ov::frontend::pytorch::pass::GPTQDecompressionReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::GPTQMultPatternReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::WeightINT4PackMMReplacer>();
manager.run_passes(model);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "torchfx_torchao_pattern_replacer.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
#include "utils_quantize.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

using namespace ov::op;
using namespace ov::pass::pattern;

WeightINT4PackMMReplacer::WeightINT4PackMMReplacer() {
const auto& const_1 = wrap_type<v0::Constant>();
const auto& const_2 = wrap_type<v0::Constant>();
const auto& const_3 = wrap_type<v0::Constant>();

const auto& weight_int4pack_mm = wrap_type<ov::op::util::FrameworkNode>({any_input(), const_1, const_2, const_3});
cavusmustafa marked this conversation as resolved.
Show resolved Hide resolved

ov::matcher_pass_callback callback = [=](Matcher& m) {
auto weight_int4pack_mm = m.get_match_root();
if (!weight_int4pack_mm) {
return false;
}
if (!(weight_int4pack_mm = cast_fw_node(m.get_match_root(), "aten._weight_int4pack_mm.default"))) {
return false;
}
auto wt_const = std::dynamic_pointer_cast<v0::Constant>(weight_int4pack_mm->get_input_node_shared_ptr(1));
auto wt_const_flat = std::make_shared<v0::Constant>(wt_const->get_element_type(),
Shape({shape_size(wt_const->get_shape()), 1}),
wt_const->get_data_ptr<uint8_t>());
std::vector<uint64_t> broadcast_shape_vec(wt_const_flat->get_shape());
broadcast_shape_vec[1] = 8;
auto broadcast_shape_const = std::make_shared<v0::Constant>(element::i32, Shape({2}), broadcast_shape_vec);

auto broadcast = std::make_shared<v3::Broadcast>(wt_const_flat, broadcast_shape_const);
OutputVector outputs_broadcast(broadcast->get_output_size());
OutputVector broadcast_inputs(2);
broadcast_inputs[0] = wt_const_flat->outputs()[0];
broadcast_inputs[1] = broadcast_shape_const->outputs()[0];
if (!broadcast->constant_fold(outputs_broadcast, broadcast_inputs)) {
return false;
}

auto broadcast_out_const = std::dynamic_pointer_cast<v0::Constant>(outputs_broadcast[0].get_node_shared_ptr());
auto broadcast_out_ptr =
const_cast<int32_t*>(reinterpret_cast<const int32_t*>(broadcast_out_const->get_data_ptr()));
for (size_t k = 0; k < shape_size(outputs_broadcast[0].get_shape()); ++k) {
int32_t shift_val = (k % 8) * 4;
broadcast_out_ptr[k] = (broadcast_out_ptr[k] >> shift_val) & 15;
}
std::vector<uint64_t> wt_ordered_shape(2);
wt_ordered_shape[0] = wt_const->get_shape()[0] * 8;
wt_ordered_shape[1] = wt_const->get_shape()[1] * wt_const->get_shape()[2] * wt_const->get_shape()[3];
auto wt_const_ordered = std::make_shared<v0::Constant>(wt_const->get_element_type(), Shape(wt_ordered_shape));
auto wt_ordered_ptr = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(wt_const_ordered->get_data_ptr()));
for (uint64_t b = 0; b < wt_ordered_shape[0] / 64; b++) {
for (uint64_t j = 0; j < wt_ordered_shape[1]; j++) {
for (uint64_t i = 0; i < 32; i++) {
uint64_t l = 0;
uint64_t m = (i * 2);
l = b * 64 * (broadcast_out_const->get_shape()[0] / wt_ordered_shape[0]) + j * 8 +
(m / broadcast_out_const->get_shape()[1]);
m = m % broadcast_out_const->get_shape()[1];
wt_ordered_ptr[(b * 64 + i) * wt_ordered_shape[1] + j] = broadcast_out_ptr[l * 8 + m];
wt_ordered_ptr[(b * 64 + i + 32) * wt_ordered_shape[1] + j] = broadcast_out_ptr[l * 8 + m + 1];
}
}
}

auto transpose_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
auto transpose = std::make_shared<v1::Transpose>(wt_const_ordered, transpose_order);
OutputVector outputs_transpose(transpose->get_output_size());
OutputVector transpose_inputs(2);
transpose_inputs[0] = wt_const_ordered->outputs()[0];
transpose_inputs[1] = transpose_order->outputs()[0];
if (!transpose->constant_fold(outputs_transpose, transpose_inputs)) {
return false;
}
auto g_const = std::dynamic_pointer_cast<v0::Constant>(weight_int4pack_mm->get_input_node_shared_ptr(2));
auto g_ptr = const_cast<int64_t*>(reinterpret_cast<const int64_t*>(g_const->get_data_ptr()));
uint64_t g = (uint64_t)(g_ptr[0]);
if (g > outputs_transpose[0].get_shape()[1]) {
g = outputs_transpose[0].get_shape()[1];
}
auto wt_i32 = std::make_shared<v0::Constant>(
outputs_transpose[0].get_node_shared_ptr()->get_element_type(),
Shape({static_cast<unsigned long>(outputs_transpose[0].get_node_shared_ptr()->get_shape()[0] / g),
static_cast<unsigned long>(g),
outputs_transpose[0].get_node_shared_ptr()->get_shape()[1]}),
std::dynamic_pointer_cast<v0::Constant>(outputs_transpose[0].get_node_shared_ptr())
->get_data_ptr<uint8_t>());

auto convert_to_u4 = std::make_shared<v0::Convert>(wt_i32, element::u4);
OutputVector outputs_to_u4(convert_to_u4->get_output_size());
if (!convert_to_u4->constant_fold(outputs_to_u4, wt_i32->outputs())) {
return false;
}

auto sz_const = std::dynamic_pointer_cast<v0::Constant>(weight_int4pack_mm->get_input_node_shared_ptr(3));
const auto two = v0::Constant::create(element::i32, Shape{}, {2});
const auto split = std::make_shared<v1::Split>(sz_const, two, 2);
OutputVector outputs_split(split->get_output_size());
OutputVector split_inputs(2);
split_inputs[0] = sz_const->outputs()[0];
split_inputs[1] = two->outputs()[0];
if (!split->constant_fold(outputs_split, split_inputs)) {
return false;
}

const auto divide = std::make_shared<v1::Divide>(outputs_split[1], outputs_split[0]);
OutputVector outputs_divide(divide->get_output_size());
OutputVector divide_inputs(2);
divide_inputs[0] = outputs_split[1];
divide_inputs[1] = outputs_split[0];
if (!divide->constant_fold(outputs_divide, divide_inputs)) {
return false;
}

const auto eight = v0::Constant::create(sz_const->get_element_type(), Shape{}, {8.0});
const auto subtract = std::make_shared<v1::Subtract>(eight, outputs_divide[0]);
OutputVector outputs_subtract(subtract->get_output_size());
OutputVector subtract_inputs(2);
subtract_inputs[0] = eight->outputs()[0];
subtract_inputs[1] = outputs_divide[0];
if (!subtract->constant_fold(outputs_subtract, subtract_inputs)) {
return false;
}

auto new_scales = std::make_shared<v0::Constant>(
outputs_split[0].get_element_type(),
Shape({outputs_split[0].get_shape()[0], 1, outputs_split[0].get_shape()[1]}),
std::dynamic_pointer_cast<v0::Constant>(outputs_split[0].get_node_shared_ptr())->get_data_ptr<uint8_t>());
auto new_zeros = std::make_shared<v0::Constant>(
outputs_subtract[0].get_element_type(),
Shape({outputs_subtract[0].get_shape()[0], 1, outputs_subtract[0].get_shape()[1]}),
std::dynamic_pointer_cast<v0::Constant>(outputs_subtract[0].get_node_shared_ptr())
->get_data_ptr<uint8_t>());

auto convert_scales_to_float = std::make_shared<v0::Convert>(new_scales, element::f32);
OutputVector outputs_scales_to_float(convert_scales_to_float->get_output_size());
if (!convert_scales_to_float->constant_fold(outputs_scales_to_float, new_scales->outputs())) {
return false;
}

auto convert_zeros_to_float = std::make_shared<v0::Convert>(new_zeros, element::f32);
OutputVector outputs_zeros_to_float(convert_zeros_to_float->get_output_size());
if (!convert_zeros_to_float->constant_fold(outputs_zeros_to_float, new_zeros->outputs())) {
return false;
}

auto new_convert =
std::make_shared<v0::Convert>(outputs_to_u4[0].get_node_shared_ptr(), new_zeros->get_element_type());
auto new_subtract = std::make_shared<v1::Subtract>(new_convert, new_zeros);
auto new_mult = std::make_shared<v1::Multiply>(new_subtract, new_scales);
auto new_shape = v0::Constant::create(element::i32,
Shape{outputs_transpose[0].get_shape().size()},
outputs_transpose[0].get_shape());
auto new_reshape = std::make_shared<v1::Reshape>(new_mult, new_shape, false);
auto new_matmul = std::make_shared<v0::MatMul>(weight_int4pack_mm->get_input_node_shared_ptr(0), new_reshape);

replace_node(weight_int4pack_mm, new_matmul);

return true;
};

auto m = std::make_shared<Matcher>(weight_int4pack_mm, "ov::frontend::pytorch::pass::WeightINT4PackMMReplacer");
this->register_matcher(m, callback);
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

// This transformation replaces aten._weight_int4pack_mm op with a decompression
// pattern which can be captured by OpenVINO device plugins
class WeightINT4PackMMReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::WeightINT4PackMMReplacer");
WeightINT4PackMMReplacer();
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
58 changes: 58 additions & 0 deletions tests/layer_tests/pytorch_tests/test_weight_int4pack_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import pytest
import torch
from packaging import version
from pytorch_layer_test_class import PytorchLayerTest

class TestWeightInt4PackMMOperation(PytorchLayerTest):
def _prepare_input(self, input_size):
return (torch.rand(1,input_size,dtype=torch.float32),)

def create_model(self, input_size, group_size, inner_k_tiles):
class CustomWeightInt4PackMMOperation(torch.nn.Module):
def __init__(self):
super(CustomWeightInt4PackMMOperation, self).__init__()
# TODO: Generating random quantized weights, scales, and zero points is not ideal.
# An actual quantizing algorithm could be implemented for more accurate testing.
self.gs = group_size
w = torch.randint(low=0,high=15,size=[input_size,input_size], dtype=torch.int32)
self.wq = torch.ops.aten._convert_weight_to_int4pack(w, inner_k_tiles)
scales = torch.randint(low=1,high=100,size=[int(input_size/self.gs),input_size,1], dtype=torch.int32).to(dtype=torch.bfloat16)/10.0
zeros = torch.ones(int(input_size/self.gs),input_size,1, dtype=torch.bfloat16)*8.0
self.sz = torch.cat((scales,zeros), 2)
def forward(self, x):
return torch.ops.aten._weight_int4pack_mm(x.to(dtype=torch.bfloat16), self.wq, self.gs, self.sz).to(dtype=torch.float32)

model_class = CustomWeightInt4PackMMOperation()
ref_net = None
return model_class, ref_net, "aten._weight_int4pack_mm.default"

@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("input_size, group_size, inner_k_tiles, dtype", [
(1024, 32, 2, torch.float32),
(1024, 32, 4, torch.float32),
(1024, 32, 8, torch.float32),
(4096, 32, 2, torch.float32),
(4096, 32, 4, torch.float32),
(4096, 32, 8, torch.float32),
(4096, 64, 2, torch.float32),
(4096, 64, 4, torch.float32),
(4096, 64, 8, torch.float32),
])
def test_weight_int4pack_mm_operation(self, input_size, group_size, inner_k_tiles, dtype, ie_device, precision, ir_version):
# TODO: Input requirements for aten._convert_weight_to_int4pack changed after PyTorch 2.4 which was used to prepare
# weight input for this test. Disabling the test for PyTorch versions later than 2.4 until this issue is resolved.
if version.parse(torch.__version__) >= version.parse("2.5"):
pytest.skip("Current test is not supported PyTorch versions later than 2.4 due to weight handling updates in aten._convert_weight_to_int4pack.")

# Due to precision errors, the output accuracy may change based on the system this test it running on.
# The eps is adjusted accordingly, but overall model accuracy should be observed in full model tests as well.
self._test(
*self.create_model(input_size, group_size, inner_k_tiles),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"input_size": input_size},
aot_autograd=True,
custom_eps=128.0
)
Loading