diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 984747c53c6c..aaf6ec5f508c 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -6,11 +6,16 @@ from vllm.model_executor.layers.sampler import SamplerOutput try: - from vllm.attention.backends.flash_attn import FlashAttentionMetadata -except ModuleNotFoundError: - # vllm_flash_attn is not installed, use the identical ROCm FA metadata - from vllm.attention.backends.rocm_flash_attn import ( - ROCmFlashAttentionMetadata as FlashAttentionMetadata) + try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata + except (ModuleNotFoundError, ImportError): + # vllm_flash_attn is not installed, try the ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) +except (ModuleNotFoundError, ImportError) as err: + raise RuntimeError( + "Draft model speculative decoding currently only supports" + "CUDA and ROCm flash attention backend.") from err from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig,