Skip to content

Commit

Permalink
Update RopeFusion to support gpt-j model after SDPA to PA conversion (#…
Browse files Browse the repository at this point in the history
…28512)

### Details:
Update RopeFusion for gpt-j models

### Tickets:
 - *CVS-161066*
  • Loading branch information
itikhono authored Jan 24, 2025
1 parent 174869c commit c421837
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
varsplit->set_output_size(2);
// Reshape or UnSqueeze should both be support
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}}) |
auto dim0 = ov::gen_pattern::Symbol("dim0");
auto dim1 = ov::gen_pattern::Symbol("dim1");
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}}) |
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(1), 2});
// repeate cos/sin table
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
Expand All @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto neg_Multiply_1177 = makePattern<opset1::Multiply>({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_65524 = makePattern<opset1::Unsqueeze>({neg_Multiply_1177, -1});
auto head_num = ov::gen_pattern::Symbol("head_num");
auto Unsqueeze_28998 =
makePattern<opset1::Reshape>({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});

auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3);
auto Unsqueeze_65525 = makePattern<opset1::Unsqueeze>({slice_Slice_1168, -1});
auto stack_1182 = makePattern<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});
auto Unsqueeze_28999 =
makePattern<opset1::Reshape>({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});
auto stack_1182 =
makePattern<opset1::Concat>({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999},
{{"axis", -1}});

auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
Expand All @@ -445,9 +454,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto slice_Slice_971 = GenSlice(view_Reshape, ndims, int32_max, 1, 3);
auto cat_Concat_1211 =
makePattern<opset1::Concat>({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}});
auto permute_Transpose_1213 = makePattern<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});

auto result = permute_Transpose_1213;
auto result = cat_Concat_1211;

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand All @@ -459,9 +467,31 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

op::internal::RoPE::Config config;
OutputVector new_args;
NodeVector rt_from = {pattern_map.at(varsplit).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr()};
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);

config.output_trans0213 = true;
// Fuse output transpose to Rope.
auto root_target_inputs = root->output(0).get_target_inputs();
if (root_target_inputs.size() == 1) {
auto target_node = root_target_inputs.begin()->get_node()->shared_from_this();
if (auto transpose = ov::as_type_ptr<op::v1::Transpose>(target_node)) {
auto axes = transpose->input_value(1).get_node_shared_ptr();
auto axes_const = ov::as_type_ptr<op::v0::Constant>(axes);
if (axes_const && axes_const->cast_vector<int64_t>() == std::vector<int64_t>{0, 2, 1, 3}) {
config.output_trans0213 = true;
rt_from.push_back(target_node);
root = target_node;
}
}
}
config.is_interleaved = true;

// input is [B,L,H,S]
Expand All @@ -474,19 +504,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto new_node = std::make_shared<op::internal::RoPE>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::copy_runtime_info({pattern_map.at(varsplit).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr(),
pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()},
new_node);
ov::copy_runtime_info(rt_from, new_node);
ov::replace_node(old_node, new_node);
// shapeof may be moved up from transpose to add,
// After RoPE fusion, shapeof must be moved to the data input of RoPE otherwise extra subgraph exists
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1215,4 +1215,87 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin});
}
}
}

TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {
disable_rt_info_check();
const int batch = -1;
const int num_heads = 16;
const int ndims = 256;
const int rotary_ndims = 64;
using namespace ov;
{
std::vector<int32_t> rpi_idx(rotary_ndims);
for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) {
rpi_idx[i] = index;
rpi_idx[i + 1] = index;
}
auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx);

auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims});
auto aten_gather_GatherElements =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims});

auto prim_ListUnpack_VariadicSplit =
makeOP<opset1::VariadicSplit>({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}});
auto aten_unsqueeze_Unsqueeze_1 =
makeOP<opset1::Reshape>({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}},
{{"special_zero", false}});
auto aten_repeat_interleave_Gather_1 =
makeOP<opset8::Gather>({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}});

auto aten_unsqueeze_Unsqueeze_2 =
makeOP<opset1::Reshape>({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}},
{{"special_zero", false}});
auto aten_repeat_interleave_Gather_3 =
makeOP<opset8::Gather>({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}});

auto VariadicSplit_32371 = makeOP<opset1::VariadicSplit>({input, 3, {rotary_ndims, ndims - rotary_ndims}});
auto aten_mul_Multiply =
makeOP<opset1::Multiply>({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1},
{{"auto_broadcast", "numpy"}});
auto aten_slice_Slice_10 = makeOP<opset8::Slice>({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}});
auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f});
auto aten_neg_Multiply =
makeOP<opset1::Multiply>({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_28998 = makeOP<opset1::Reshape>({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_slice_Slice_14 = makeOP<opset8::Slice>({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}});
auto Unsqueeze_28999 = makeOP<opset1::Reshape>({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_stack = makeOP<opset1::Concat>({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}});
auto aten_flatten_Reshape =
makeOP<opset1::Reshape>({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}});
auto aten_mul_Multiply_1 = makeOP<opset1::Multiply>({aten_flatten_Reshape, aten_repeat_interleave_Gather_3},
{{"auto_broadcast", "numpy"}});
auto aten_add_Add =
makeOP<opset1::Add>({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}});
auto aten_cat_Concat_1 = makeOP<opset1::Concat>({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}});

model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat_1},
ov::ParameterVector{input, aten_gather_GatherElements});
}
manager.register_pass<ov::pass::RoPEFusion>(false);
{
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims});
auto aten_gather_GatherElements =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 64});
auto rope = makeOP<ov::op::internal::RoPE>({input, aten_gather_GatherElements, aten_gather_GatherElements},
{{"config.slice_start", 0},
{"config.slice_stop", 0},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", true},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", false},
{"config.support_2d_rope", false},
{"config.is_qwen", false},
{"config.head_cnt", 0},
{"config.head_size", 0},
{"config.gather_position_arg_id", 0}});
model_ref =
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements});
}
}
5 changes: 2 additions & 3 deletions src/plugins/intel_cpu/src/nodes/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor {
auto* x = t_src.ptr<T>(b, p, h);
float* sin = &t_sin_cos.at<float>({b, p, 0}, true);
float* cos = &t_sin_cos.at<float>({b, p, half_rotary_dims}, true);
auto* dst = t_dst.ptr<T>(b, h, p);
auto* dst = m_config.output_trans0213 ? t_dst.ptr<T>(b, h, p) : t_dst.ptr<T>(b, p, h);

if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, x, dst, cos, sin);
Expand Down Expand Up @@ -397,8 +397,7 @@ void RoPE::initSupportedPrimitiveDescriptors() {
m_executor = std::make_shared<RoPEExecutorChatGLM<float>>(m_config);
rtPrecision = ov::element::f32;
}
} else if (m_config.is_interleaved && m_config.output_trans0213) {
OPENVINO_ASSERT(m_config.input_trans0213 == false);
} else if (m_config.is_interleaved) {
OPENVINO_ASSERT(m_config.slice_start == 0);
OPENVINO_ASSERT(m_config.slice_stop == 0);
OPENVINO_ASSERT(m_config.gather_position_arg_id == 0);
Expand Down

0 comments on commit c421837

Please sign in to comment.