From dc4caa7a4e41d13b43f2845edbc2fa6a383e8151 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 6 Oct 2024 03:33:33 +0000 Subject: [PATCH] fix multi-step + rocm_flash_attn support --- vllm/attention/backends/rocm_flash_attn.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index fb5cd11ec033..7456aab8b8d2 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -191,12 +191,22 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int): + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries