Skip to content

Commit

Permalink
Update RopeFusion for qwen, gpt-j models
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jan 17, 2025
1 parent 5ec8343 commit 5baf699
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
varsplit->set_output_size(2);
// Reshape or UnSqueeze should both be support
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}}) |
auto dim0 = ov::gen_pattern::Symbol("dim0");
auto dim1 = ov::gen_pattern::Symbol("dim1");
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}}) |
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(1), 2});
// repeate cos/sin table
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
Expand All @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto neg_Multiply_1177 = makePattern<opset1::Multiply>({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_65524 = makePattern<opset1::Unsqueeze>({neg_Multiply_1177, -1});
auto head_num = ov::gen_pattern::Symbol("head_num");
auto Unsqueeze_28998 =
makePattern<opset1::Reshape>({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});

auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3);
auto Unsqueeze_65525 = makePattern<opset1::Unsqueeze>({slice_Slice_1168, -1});
auto stack_1182 = makePattern<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});
auto Unsqueeze_28999 =
makePattern<opset1::Reshape>({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});
auto stack_1182 =
makePattern<opset1::Concat>({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999},
{{"axis", -1}});

auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
Expand All @@ -447,7 +456,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
makePattern<opset1::Concat>({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}});
auto permute_Transpose_1213 = makePattern<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});

auto result = permute_Transpose_1213;
auto result = cat_Concat_1211 | permute_Transpose_1213;

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand All @@ -461,7 +470,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
OutputVector new_args;
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);

config.output_trans0213 = true;
if (pattern_map.count(permute_Transpose_1213))
config.output_trans0213 = true;
config.is_interleaved = true;

// input is [B,L,H,S]
Expand All @@ -478,14 +488,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr(),
pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()},
pattern_map.at(cat_Concat_1211).get_node_shared_ptr()},
new_node);
ov::replace_node(old_node, new_node);
// shapeof may be moved up from transpose to add,
Expand Down Expand Up @@ -705,6 +712,7 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288]
auto position_ids = makePattern();

auto head_cnt = ov::gen_pattern::Symbol("head_cnt");
auto head_size = ov::gen_pattern::Symbol("head_size");
Expand All @@ -731,14 +739,19 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto ScatterUpdate_463814 = makePattern<opset3::ScatterUpdate>({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}});
auto slice_Slice_446 =
makePattern<ov::opset8::Slice>({rotary_emb_cos, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});

auto gather_cos_by_pos_ids = makePattern<opset8::Gather>({rotary_emb_cos, position_ids, 1}, {{"batch_dims", 0}});
auto reshape_cos_to_expected_layout =
makePattern<opset8::Reshape>({gather_cos_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}});

auto slice_StridedSlice_446 = GenStridedSlice(rotary_emb_cos,
ScatterUpdate_463814,
{0, INT_MAX},
{1, 1},
1); // tensor_array<f32[1,..4096,1,128]>
auto mul_Multiply_552 =
makePattern<opset1::Multiply>({slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
auto mul_Multiply_552 = makePattern<opset1::Multiply>(
{slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446 | reshape_cos_to_expected_layout},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>

auto reshape_opt1 = [&](std::shared_ptr<Node> input_BLHS) {
auto ShapeOf_485814 = makePattern<opset1::ShapeOf>({input_BLHS}, {});
Expand Down Expand Up @@ -772,18 +785,28 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
makePattern<opset1::Squeeze>({Multiply_567527, -2}); // tensor_array<f32[?,?,32,64]>
auto ListUnpack_586_Squeeze =
makePattern<opset1::Squeeze>({ListUnpack_586_Split->output(0), -2}); // tensor_array<f32[?,?,32,64]>
auto cat_Concat_593 = makePattern<opset1::Concat>({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze},
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>

auto ListUnpack_Squeeze_0_1 =
makePattern<opset1::Reshape>({Multiply_567527, {-1, 1, 32, 64}}, {{"special_zero", false}});
auto ListUnpack_Squeeze_1 =
makePattern<opset1::Reshape>({ListUnpack_586_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});

auto cat_Concat_593 = makePattern<opset1::Concat>(
{ListUnpack_586_Squeeze_0 | ListUnpack_Squeeze_0_1, ListUnpack_586_Squeeze | ListUnpack_Squeeze_1},
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>
auto slice_StridedSlice_470 = GenStridedSlice(rotary_emb_sin,
ScatterUpdate_463814,
{0, INT_MAX},
{1, 1},
1); // tensor_array<f32[1,..4096,1,128]>
auto slice_Slice_470 =
makePattern<opset8::Slice>({rotary_emb_sin, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});
auto mul_Multiply_594 =
makePattern<opset1::Multiply>({cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
auto gather_sin_by_pos_ids = makePattern<opset8::Gather>({rotary_emb_sin, position_ids, 1}, {{"batch_dims", 0}});
auto reshape_sin_to_expected_layout =
makePattern<opset8::Reshape>({gather_sin_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}});
auto mul_Multiply_594 = makePattern<opset1::Multiply>(
{cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470 | reshape_sin_to_expected_layout},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
auto add_Add_597 = makePattern<opset1::Add>({mul_Multiply_552, mul_Multiply_594},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>

Expand Down Expand Up @@ -844,8 +867,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto new_node = std::make_shared<op::internal::RoPE>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::copy_runtime_info({pattern_map.at(Multiply_567527).get_node_shared_ptr(),
pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(),
pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(),
// pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(),
// pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(),
pattern_map.at(cat_Concat_593).get_node_shared_ptr(),
pattern_map.at(mul_Multiply_594).get_node_shared_ptr(),
pattern_map.at(add_Add_597).get_node_shared_ptr()},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1215,4 +1215,87 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin});
}
}

TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {
disable_rt_info_check();
const int batch = -1;
const int num_heads = 16;
const int ndims = 256;
const int rotary_ndims = 64;
using namespace ov;
{
std::vector<int32_t> rpi_idx(rotary_ndims);
for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) {
rpi_idx[i] = index;
rpi_idx[i + 1] = index;
}
auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx);

auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims});
auto aten_gather_GatherElements =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims});

auto prim_ListUnpack_VariadicSplit =
makeOP<opset1::VariadicSplit>({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}});
auto aten_unsqueeze_Unsqueeze_1 =
makeOP<opset1::Reshape>({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}},
{{"special_zero", false}});
auto aten_repeat_interleave_Gather_1 =
makeOP<opset8::Gather>({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}});

auto aten_unsqueeze_Unsqueeze_2 =
makeOP<opset1::Reshape>({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}},
{{"special_zero", false}});
auto aten_repeat_interleave_Gather_3 =
makeOP<opset8::Gather>({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}});

auto VariadicSplit_32371 = makeOP<opset1::VariadicSplit>({input, 3, {rotary_ndims, ndims - rotary_ndims}});
auto aten_mul_Multiply =
makeOP<opset1::Multiply>({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1},
{{"auto_broadcast", "numpy"}});
auto aten_slice_Slice_10 = makeOP<opset8::Slice>({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}});
auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f});
auto aten_neg_Multiply =
makeOP<opset1::Multiply>({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_28998 = makeOP<opset1::Reshape>({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_slice_Slice_14 = makeOP<opset8::Slice>({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}});
auto Unsqueeze_28999 = makeOP<opset1::Reshape>({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_stack = makeOP<opset1::Concat>({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}});
auto aten_flatten_Reshape =
makeOP<opset1::Reshape>({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}});
auto aten_mul_Multiply_1 = makeOP<opset1::Multiply>({aten_flatten_Reshape, aten_repeat_interleave_Gather_3},
{{"auto_broadcast", "numpy"}});
auto aten_add_Add =
makeOP<opset1::Add>({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}});
auto aten_cat_Concat_1 = makeOP<opset1::Concat>({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}});

model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat_1},
ov::ParameterVector{input, aten_gather_GatherElements});
}
manager.register_pass<ov::pass::RoPEFusion>(false);
{
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims});
auto aten_gather_GatherElements =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 64});
auto rope = makeOP<ov::op::internal::RoPE>({input, aten_gather_GatherElements, aten_gather_GatherElements},
{{"config.slice_start", 0},
{"config.slice_stop", 0},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", true},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", false},
{"config.support_2d_rope", false},
{"config.is_qwen", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.gather_position_arg_id", 0}});
model_ref =
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements});
}
}

0 comments on commit 5baf699

Please sign in to comment.