diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 40e804934cbd..ba9b2d043c64 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -26,6 +26,7 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention +from vllm.forward_context import get_forward_context from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -761,73 +762,132 @@ def forward( "encoder/decoder cross-attention " "are not implemented for " "FlashInferImpl") - num_tokens, hidden_size = query.shape - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - if attn_metadata.num_prefill_tokens > 0: - assert attn_metadata.num_decode_tokens == 0, ( - "Chunked prefill is not supported with flashinfer yet.") - if attn_metadata.num_decode_tokens > 0: - assert attn_metadata.num_prefill_tokens == 0, ( - "Chunked prefill is not supported with flashinfer yet.") - if kv_cache.numel() > 0: - # Use the same reshape and cache kernel as flash attention. - ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - k_scale, - v_scale, + return torch.ops.vllm.unified_flash_infer( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + ) + + +@torch.library.custom_op("vllm::unified_flash_infer", + mutates_args=["kv_cache"]) +def unified_flash_infer( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + + current_metadata = get_forward_context() + assert current_metadata is not None + assert isinstance(current_metadata, FlashInferMetadata) + attn_metadata: FlashInferMetadata = current_metadata + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if attn_metadata.num_prefill_tokens > 0: + assert attn_metadata.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if attn_metadata.num_decode_tokens > 0: + assert attn_metadata.num_prefill_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + query = query.contiguous() # Flashinfer requires query to be contiguous + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache.numel() == 0: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if self.kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype) - kv_cache = kv_cache.view(torch_dtype) - - query = query.contiguous( - ) # Flashinfer requires query to be contiguous - if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill - # when kv_cache is not provided. - # This happens when vllm runs the profiling to - # determine the number of blocks. - if kv_cache.numel() == 0: - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - else: - assert prefill_meta is not None - assert prefill_meta.prefill_wrapper is not None - output = prefill_meta.prefill_wrapper.forward( - query, - kv_cache, - logits_soft_cap=self.logits_soft_cap, - causal=True) else: - assert attn_metadata.decode_metadata is not None - assert attn_metadata.decode_metadata.decode_wrapper is not None - output = attn_metadata.decode_metadata.decode_wrapper.forward( - query, - kv_cache, - sm_scale=self.scale, - logits_soft_cap=self.logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale) - return output.view(num_tokens, hidden_size) + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + output = prefill_meta.prefill_wrapper.forward( + query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) + else: + assert attn_metadata.decode_metadata is not None + assert attn_metadata.decode_metadata.decode_wrapper is not None + output = attn_metadata.decode_metadata.decode_wrapper.forward( + query, + kv_cache, + sm_scale=softmax_scale, + logits_soft_cap=logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale) + return output.view(num_tokens, hidden_size) + + +@unified_flash_infer.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query).contiguous()