Skip to content

Commit

Permalink
fix multi-step + rocm_flash_attn support
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Oct 6, 2024
1 parent 6cad135 commit dc4caa7
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,22 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""

assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with rocm_flash_attn yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")

# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
Expand Down

0 comments on commit dc4caa7

Please sign in to comment.