Skip to content

Commit

Permalink
Apply torch.compile to fused_moe/grouped_topk (vllm-project#12637)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Feb 1, 2025
1 parent 4f4d427 commit 3194039
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def fused_topk(


# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -566,8 +567,7 @@ def forward(
return hidden_states, residual


# TODO(simon): check whether we support torch compile for Deepseek V3
# @support_torch_compile
@support_torch_compile
class DeepseekV3Model(nn.Module):

fall_back_to_pt_during_load = False
Expand Down

0 comments on commit 3194039

Please sign in to comment.