diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 6b12f56215ca83..06b86508ce1282 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); varsplit->set_output_size(2); // Reshape or UnSqueeze should both be support - auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}) | + auto dim0 = ov::gen_pattern::Symbol("dim0"); + auto dim1 = ov::gen_pattern::Symbol("dim1"); + auto unsqueeze_sin = makePattern({varsplit->output(0), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(0), 2}); - auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}) | + auto unsqueeze_cos = makePattern({varsplit->output(1), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(1), 2}); // repeate cos/sin table auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto neg_Multiply_1177 = makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); auto Unsqueeze_65524 = makePattern({neg_Multiply_1177, -1}); + auto head_num = ov::gen_pattern::Symbol("head_num"); + auto Unsqueeze_28998 = + makePattern({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3); auto Unsqueeze_65525 = makePattern({slice_Slice_1168, -1}); - auto stack_1182 = makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + auto Unsqueeze_28999 = + makePattern({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); + auto stack_1182 = + makePattern({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999}, + {{"axis", -1}}); auto ShapeOf_169068 = makePattern({stack_1182}); auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); @@ -445,9 +454,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto slice_Slice_971 = GenSlice(view_Reshape, ndims, int32_max, 1, 3); auto cat_Concat_1211 = makePattern({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}}); - auto permute_Transpose_1213 = makePattern({cat_Concat_1211, {0, 2, 1, 3}}); - auto result = permute_Transpose_1213; + auto result = cat_Concat_1211; matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); @@ -459,9 +467,31 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { op::internal::RoPE::Config config; OutputVector new_args; + NodeVector rt_from = {pattern_map.at(varsplit).get_node_shared_ptr(), + pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(), + pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(), + pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(), + pattern_map.at(stack_1182).get_node_shared_ptr(), + pattern_map.at(mul_cos).get_node_shared_ptr(), + pattern_map.at(mul_sin).get_node_shared_ptr(), + pattern_map.at(rotary_emb).get_node_shared_ptr(), + pattern_map.at(cat_Concat_1211).get_node_shared_ptr()}; config.rotary_ndims = static_cast(validator["ndims"]); - config.output_trans0213 = true; + // Fuse output transpose to Rope. + auto root_target_inputs = root->output(0).get_target_inputs(); + if (root_target_inputs.size() == 1) { + auto target_node = root_target_inputs.begin()->get_node()->shared_from_this(); + if (auto transpose = ov::as_type_ptr(target_node)) { + auto axes = transpose->input_value(1).get_node_shared_ptr(); + auto axes_const = ov::as_type_ptr(axes); + if (axes_const && axes_const->cast_vector() == std::vector{0, 2, 1, 3}) { + config.output_trans0213 = true; + rt_from.push_back(target_node); + root = target_node; + } + } + } config.is_interleaved = true; // input is [B,L,H,S] @@ -474,19 +504,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto new_node = std::make_shared(new_args, config); new_node->set_friendly_name(old_node->get_friendly_name()); - ov::copy_runtime_info({pattern_map.at(varsplit).get_node_shared_ptr(), - pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(), - pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(), - pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(), - pattern_map.at(stack_1182).get_node_shared_ptr(), - pattern_map.at(mul_cos).get_node_shared_ptr(), - pattern_map.at(mul_sin).get_node_shared_ptr(), - pattern_map.at(rotary_emb).get_node_shared_ptr(), - pattern_map.at(cat_Concat_1211).get_node_shared_ptr(), - pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()}, - new_node); + ov::copy_runtime_info(rt_from, new_node); ov::replace_node(old_node, new_node); // shapeof may be moved up from transpose to add, // After RoPE fusion, shapeof must be moved to the data input of RoPE otherwise extra subgraph exists diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 1b34e0c4423d3d..7ccddec3c5597a 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -1215,4 +1215,87 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) { {"config.gather_position_arg_id", 0}}); model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin}); } -} \ No newline at end of file +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) { + disable_rt_info_check(); + const int batch = -1; + const int num_heads = 16; + const int ndims = 256; + const int rotary_ndims = 64; + using namespace ov; + { + std::vector rpi_idx(rotary_ndims); + for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { + rpi_idx[i] = index; + rpi_idx[i + 1] = index; + } + auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); + + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims}); + + auto prim_ListUnpack_VariadicSplit = + makeOP({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}}); + auto aten_unsqueeze_Unsqueeze_1 = + makeOP({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_1 = + makeOP({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto aten_unsqueeze_Unsqueeze_2 = + makeOP({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_3 = + makeOP({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto VariadicSplit_32371 = makeOP({input, 3, {rotary_ndims, ndims - rotary_ndims}}); + auto aten_mul_Multiply = + makeOP({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1}, + {{"auto_broadcast", "numpy"}}); + auto aten_slice_Slice_10 = makeOP({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}}); + auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f}); + auto aten_neg_Multiply = + makeOP({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_28998 = makeOP({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_slice_Slice_14 = makeOP({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}}); + auto Unsqueeze_28999 = makeOP({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_stack = makeOP({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}}); + auto aten_flatten_Reshape = + makeOP({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}}); + auto aten_mul_Multiply_1 = makeOP({aten_flatten_Reshape, aten_repeat_interleave_Gather_3}, + {{"auto_broadcast", "numpy"}}); + auto aten_add_Add = + makeOP({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + auto aten_cat_Concat_1 = makeOP({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}}); + + model = std::make_shared(ov::NodeVector{aten_cat_Concat_1}, + ov::ParameterVector{input, aten_gather_GatherElements}); + } + manager.register_pass(false); + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 64}); + auto rope = makeOP({input, aten_gather_GatherElements, aten_gather_GatherElements}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", true}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", false}, + {"config.support_2d_rope", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.gather_position_arg_id", 0}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements}); + } +} diff --git a/src/plugins/intel_cpu/src/nodes/rope.cpp b/src/plugins/intel_cpu/src/nodes/rope.cpp index 984b35237d93ba..aaa060ff15afbc 100644 --- a/src/plugins/intel_cpu/src/nodes/rope.cpp +++ b/src/plugins/intel_cpu/src/nodes/rope.cpp @@ -207,7 +207,7 @@ struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor { auto* x = t_src.ptr(b, p, h); float* sin = &t_sin_cos.at({b, p, 0}, true); float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true); - auto* dst = t_dst.ptr(b, h, p); + auto* dst = m_config.output_trans0213 ? t_dst.ptr(b, h, p) : t_dst.ptr(b, p, h); if (m_rotaryKernel) { execJitKernel(m_rotaryKernel, x, dst, cos, sin); @@ -397,8 +397,7 @@ void RoPE::initSupportedPrimitiveDescriptors() { m_executor = std::make_shared>(m_config); rtPrecision = ov::element::f32; } - } else if (m_config.is_interleaved && m_config.output_trans0213) { - OPENVINO_ASSERT(m_config.input_trans0213 == false); + } else if (m_config.is_interleaved) { OPENVINO_ASSERT(m_config.slice_start == 0); OPENVINO_ASSERT(m_config.slice_stop == 0); OPENVINO_ASSERT(m_config.gather_position_arg_id == 0);