Skip to content

Commit

Permalink
[Bugfix] Fix try-catch conditions to import correct Flash Attention B…
Browse files Browse the repository at this point in the history
…ackend in Draft Model (#9101)
  • Loading branch information
tjtanaa authored Oct 6, 2024
1 parent f4dd830 commit 23fea87
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
from vllm.model_executor.layers.sampler import SamplerOutput

try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except (ModuleNotFoundError, ImportError):
# vllm_flash_attn is not installed, try the ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
Expand Down

0 comments on commit 23fea87

Please sign in to comment.