Skip to content

Commit

Permalink
Change the order of the SDPA fusion transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Nov 8, 2024
1 parent 696bdde commit b217730
Showing 1 changed file with 4 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -618,14 +618,10 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertNMS9ToNMSIEInternal);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMulticlassNmsToMulticlassNmsIE);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMatrixNmsToMatrixNmsIE);
CPU_SET_CALLBACK_COMMON(manager,
[this](const_node_ptr &node) -> bool {
std::string errorMsg;
// Current SDPA impl is optimized only for LLM models, so we decompose it for others to avoid perf regression.
// Matching the pattern is a little complicated, so we just check if there is any state nodes.
return node::ScaledDotProductAttention::isSupportedOperation(node, errorMsg) && model->get_variables().size() > 0;
},
ov::pass::ScaledDotProductAttentionDecomposition);

CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion);
CPU_REGISTER_PASS_X64(manager, ov::intel_cpu::SDPAFuseTransposeReshape);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ScaledDotProductAttentionDecomposition);

// List of enabled/disabled transformations

Expand Down Expand Up @@ -882,8 +878,6 @@ void Transformations::PostLpt() {
#endif // OPENVINO_ARCH_X86_64

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::SDPAFuseTransposeReshape);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion, false);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm);
CPU_SET_CALLBACK_X64(postLPTPassManager,
Expand Down

0 comments on commit b217730

Please sign in to comment.