Skip to content

Commit

Permalink
Revert changes for gpt-j
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jan 15, 2025
1 parent c8ef5ba commit 97c1d0f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,9 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
varsplit->set_output_size(2);
// Reshape or UnSqueeze should both be support
auto dim0 = ov::gen_pattern::Symbol("dim0");
auto dim1 = ov::gen_pattern::Symbol("dim1");
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {dim0, dim1, 1, 32}}) |
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {dim0, dim1, 1, 32}}) |
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(1), 2});
// repeate cos/sin table
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
Expand All @@ -421,17 +419,10 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto neg_Multiply_1177 = makePattern<opset1::Multiply>({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_65524 = makePattern<opset1::Unsqueeze>({neg_Multiply_1177, -1});
auto head_num = ov::gen_pattern::Symbol("head_num");
auto Unsqueeze_28998 =
makePattern<opset1::Reshape>({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});

auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3);
auto Unsqueeze_65525 = makePattern<opset1::Unsqueeze>({slice_Slice_1168, -1});
auto Unsqueeze_28999 =
makePattern<opset1::Reshape>({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});
auto stack_1182 =
makePattern<opset1::Concat>({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999},
{{"axis", -1}});
auto stack_1182 = makePattern<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});

auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
Expand All @@ -456,7 +447,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
makePattern<opset1::Concat>({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}});
auto permute_Transpose_1213 = makePattern<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});

auto result = cat_Concat_1211 | permute_Transpose_1213;
auto result = permute_Transpose_1213;

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand All @@ -470,8 +461,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
OutputVector new_args;
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);

if (pattern_map.count(permute_Transpose_1213))
config.output_trans0213 = true;
config.output_trans0213 = true;
config.is_interleaved = true;

// input is [B,L,H,S]
Expand All @@ -488,11 +478,14 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr()},
pattern_map.at(cat_Concat_1211).get_node_shared_ptr(),
pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()},
new_node);
ov::replace_node(old_node, new_node);
// shapeof may be moved up from transpose to add,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1215,87 +1215,4 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin});
}
}

TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {
disable_rt_info_check();
const int batch = -1;
const int num_heads = 16;
const int ndims = 256;
const int rotary_ndims = 64;
using namespace ov;
{
std::vector<int32_t> rpi_idx(rotary_ndims);
for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) {
rpi_idx[i] = index;
rpi_idx[i + 1] = index;
}
auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx);

auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims});
auto aten_gather_GatherElements =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims});

auto prim_ListUnpack_VariadicSplit =
makeOP<opset1::VariadicSplit>({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}});
auto aten_unsqueeze_Unsqueeze_1 =
makeOP<opset1::Reshape>({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}},
{{"special_zero", false}});
auto aten_repeat_interleave_Gather_1 =
makeOP<opset8::Gather>({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}});

auto aten_unsqueeze_Unsqueeze_2 =
makeOP<opset1::Reshape>({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}},
{{"special_zero", false}});
auto aten_repeat_interleave_Gather_3 =
makeOP<opset8::Gather>({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}});

auto VariadicSplit_32371 = makeOP<opset1::VariadicSplit>({input, 3, {rotary_ndims, ndims - rotary_ndims}});
auto aten_mul_Multiply =
makeOP<opset1::Multiply>({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1},
{{"auto_broadcast", "numpy"}});
auto aten_slice_Slice_10 = makeOP<opset8::Slice>({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}});
auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f});
auto aten_neg_Multiply =
makeOP<opset1::Multiply>({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_28998 = makeOP<opset1::Reshape>({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_slice_Slice_14 = makeOP<opset8::Slice>({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}});
auto Unsqueeze_28999 = makeOP<opset1::Reshape>({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_stack = makeOP<opset1::Concat>({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}});
auto aten_flatten_Reshape =
makeOP<opset1::Reshape>({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}});
auto aten_mul_Multiply_1 = makeOP<opset1::Multiply>({aten_flatten_Reshape, aten_repeat_interleave_Gather_3},
{{"auto_broadcast", "numpy"}});
auto aten_add_Add =
makeOP<opset1::Add>({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}});
auto aten_cat_Concat_1 = makeOP<opset1::Concat>({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}});

model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat_1},
ov::ParameterVector{input, aten_gather_GatherElements});
}
manager.register_pass<ov::pass::RoPEFusion>(false);
{
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims});
auto aten_gather_GatherElements =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 64});
auto rope = makeOP<ov::op::internal::RoPE>({input, aten_gather_GatherElements, aten_gather_GatherElements},
{{"config.slice_start", 0},
{"config.slice_stop", 0},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", true},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", false},
{"config.support_2d_rope", false},
{"config.is_qwen", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.gather_position_arg_id", 0}});
model_ref =
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements});
}
}

0 comments on commit 97c1d0f

Please sign in to comment.