diff --git a/tests/attention/__init__.py b/tests/attention/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/attention/prefill_only/__init__.py b/tests/attention/prefill_only/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/attention/prefill_only/test_basic_correctness.py b/tests/attention/prefill_only/test_basic_correctness.py new file mode 100644 index 000000000000..39e721ea8d6b --- /dev/null +++ b/tests/attention/prefill_only/test_basic_correctness.py @@ -0,0 +1,89 @@ +import itertools as it + +import pytest +import torch +import torch.nn.functional as F + +from vllm.attention.layer import Attention +from vllm.attention.prefill_only.abstract import AttentionType +from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend, + _Backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +SEQ_LENS = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + +@pytest.mark.parametrize("head_dim", [64]) +@pytest.mark.parametrize("num_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4, 8]) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"]) +@pytest.mark.parametrize("n_seqs", list(range(1, len(SEQ_LENS)))) +def test_basic_correctness(head_dim: int, num_heads: int, num_kv_heads: int, + attn_type: str, dtype: str, n_seqs: int): + assert num_heads % num_kv_heads == 0 + + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + attention_impls = AttentionImpls[dtype] + + seq_lens = SEQ_LENS[:n_seqs] + batchsize = sum(seq_lens) + + query = torch.rand((batchsize, num_heads, head_dim), + dtype=torch_dtype, + device="cuda:0").view((batchsize, -1)) + key = torch.rand((batchsize, num_kv_heads, head_dim), + dtype=torch_dtype, + device="cuda:0").view((batchsize, -1)) + value = torch.rand((batchsize, num_kv_heads, head_dim), + dtype=torch_dtype, + device="cuda:0").view((batchsize, -1)) + + impl_outputs_list = [] + + for attention_impl in attention_impls: + selected_backend = _Backend.backend_name_to_enum(attention_impl) + backend_cls = AttnBackend.get_backend_cls(selected_backend) + + attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type) + + attn_backend = backend_cls(attn_type_enum) + scaling = head_dim**-0.5 + + attn = Attention(num_heads, + head_dim, + scale=scaling, + num_kv_heads=num_kv_heads, + attn_backend=attn_backend) + + metadata_builder = attn_backend.make_metadata_builder() + attn_metadata = metadata_builder(seq_lens=seq_lens) + attn_metadata = attn_metadata.to("cuda:0") + + outputs = attn.forward(query, + key, + value, + kv_cache=None, + attn_metadata=attn_metadata) + + impl_outputs_list.append((attention_impl, outputs)) + + tolerance = 1e-2 + for a, b in it.combinations(impl_outputs_list, 2): + similarities = compare_embeddings(a[1], b[1]) + all_similarities = torch.stack(similarities) + + assert torch.all( + (all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0" diff --git a/tests/attention/prefill_only/test_enum_verify.py b/tests/attention/prefill_only/test_enum_verify.py new file mode 100644 index 000000000000..1996f4f42cbb --- /dev/null +++ b/tests/attention/prefill_only/test_enum_verify.py @@ -0,0 +1,54 @@ +import pytest + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend) +from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend, + _Backend) + + +def get_attn_backend(attention_impl: str, attn_type: str): + selected_backend = _Backend.backend_name_to_enum(attention_impl) + backend_cls = AttnBackend.get_backend_cls(selected_backend) + + attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type) + + attn_backend = backend_cls(attn_type_enum) + return attn_backend + + +@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"]) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +def test_backend(dtype: str, attn_type: str): + attention_impls = AttentionImpls[dtype] + + for attention_impl in attention_impls: + attn_backend = get_attn_backend(attention_impl, attn_type) + + assert isinstance(attn_backend, PrefillOnlyAttentionBackend) + + +@pytest.mark.parametrize("attn_type", ["ENCODER_DECODER"]) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +def test_ENCODER_DECODER_not_supported(dtype: str, attn_type: str): + attention_impls = AttentionImpls[dtype] + + for attention_impl in attention_impls: + with pytest.raises(NotImplementedError): + get_attn_backend(attention_impl, attn_type) + + +def test_not_supported_backend(): + attention_impls = ["not_supported_backend", 0, 1.0] + + for attention_impl in attention_impls: + with pytest.raises(ValueError): + selected_backend = _Backend.backend_name_to_enum(attention_impl) + AttnBackend.get_backend_cls(selected_backend) + + +def test_not_supported_attn_type(): + attn_types = ["not_supported_attn_type", 0, 1.0] + + for attn_type in attn_types: + with pytest.raises(ValueError): + AttentionType.attn_type_name_to_enum(attn_type) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2bc36ff18a96..c16e4126c016 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -18,6 +18,19 @@ class AttentionType(Enum): ENCODER = auto() # Encoder attention between previous layer Q/K/V ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + @staticmethod + def attn_type_name_to_enum(attn_type: str) -> "AttentionType": + assert attn_type is not None + + attn_type_members = AttentionType.__members__ + if attn_type not in attn_type_members: + raise ValueError( + f"Invalid attn_type '{attn_type}'. " + f"Available backends: {', '.join(attn_type_members)} " + "(case-sensitive).") + + return AttentionType[attn_type] + class AttentionBackend(ABC): """Abstract class for attention backends.""" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecf964fa49d9..4e1542770470 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, AttentionType +from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -36,6 +36,7 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, prefix: str = "", + attn_backend: Optional[AttentionBackend] = None, ) -> None: super().__init__() if cache_config is not None: @@ -73,14 +74,18 @@ def __init__( self.quant_method = quant_method self.quant_method.create_weights(self) - # During model initialization, the default dtype is set as the model - # weight and activation dtype. - dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size, blocksparse_params - is not None) - impl_cls = attn_backend.get_impl_cls() + if attn_backend is None: + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + + dtype = torch.get_default_dtype() + self.attn_backend = get_attn_backend( + num_heads, head_size, num_kv_heads, sliding_window, dtype, + kv_cache_dtype, block_size, blocksparse_params is not None)() + else: + self.attn_backend = attn_backend + + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) @@ -94,6 +99,11 @@ def forward( attn_metadata: AttentionMetadata, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: + if hasattr(self.attn_backend, "attn_type"): + return self.impl.forward(query, key, value, kv_cache, + attn_metadata, self._k_scale, + self._v_scale, + self.attn_backend.attn_type) return self.impl.forward(query, key, diff --git a/vllm/attention/prefill_only/__init__.py b/vllm/attention/prefill_only/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/attention/prefill_only/abstract.py b/vllm/attention/prefill_only/abstract.py new file mode 100644 index 000000000000..292b03b668af --- /dev/null +++ b/vllm/attention/prefill_only/abstract.py @@ -0,0 +1,125 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyAttentionBackend(ABC): + + def __init__(self, attn_type: AttentionType): + if attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyAttentionBackend") + + self._attn_type = attn_type + + @property + def attn_type(self) -> AttentionType: + return self._attn_type + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["PrefillOnlyAttentionImpl"]: + raise NotImplementedError + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]: + return PrefillOnlyAttentionMetadata + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyAttentionMetadataBuilder + + @classmethod + def make_metadata_builder( + cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + + +@dataclass +class PrefillOnlyAttentionMetadata: + max_seq_len: int + seq_lens: List[int] + + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + def to(self, device, non_blocking=False): + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + self.__dict__[k] = v.to(device, non_blocking=non_blocking) + + return self + + +T = TypeVar("T", bound=PrefillOnlyAttentionMetadata) + + +class PrefillOnlyAttentionMetadataBuilder(Generic[T]): + + def __call__(self, seq_lens: List[int]): + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + return PrefillOnlyAttentionMetadata(seq_lens=seq_lens, + max_seq_len=max(seq_lens), + seq_start_loc=seq_start_loc) + + +class PrefillOnlyAttentionImpl(ABC): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: T, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/prefill_only/flash_attn.py b/vllm/attention/prefill_only/flash_attn.py new file mode 100644 index 000000000000..63566c2bf629 --- /dev/null +++ b/vllm/attention/prefill_only/flash_attn.py @@ -0,0 +1,126 @@ +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) + + +class PrefillOnlyFlashAttentionBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flash_attn" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashAttentionImpl"]: + return PrefillOnlyFlashAttentionImpl + + +class PrefillOnlyFlashAttentionImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + PrefillOnlyFlashAttentionBackend.get_supported_head_sizes()) + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + from vllm_flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyFlashAttentionImpl") + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + attn_output = self.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_seq_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # Reshape the output tensor. + return attn_output.view(num_tokens, hidden_size) diff --git a/vllm/attention/prefill_only/flashinfer.py b/vllm/attention/prefill_only/flashinfer.py new file mode 100644 index 000000000000..846d74a27f2d --- /dev/null +++ b/vllm/attention/prefill_only/flashinfer.py @@ -0,0 +1,22 @@ +from typing import Type + +from vllm.attention.prefill_only.flash_attn import ( + PrefillOnlyFlashAttentionBackend, PrefillOnlyFlashAttentionImpl) + + +class PrefillOnlyFlashInferBackend(PrefillOnlyFlashAttentionBackend): + + @staticmethod + def get_name() -> str: + return "flashinfer" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashInferImpl"]: + return PrefillOnlyFlashInferImpl + + +class PrefillOnlyFlashInferImpl(PrefillOnlyFlashAttentionImpl): + # Because prefill only models do not involve kv cache, + # When using Flashinfer backend in prefill only models, + # you are actually using FLASH ATTN backend + pass diff --git a/vllm/attention/prefill_only/selector.py b/vllm/attention/prefill_only/selector.py new file mode 100644 index 000000000000..58c6d2662f62 --- /dev/null +++ b/vllm/attention/prefill_only/selector.py @@ -0,0 +1,140 @@ +import enum +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.attention.prefill_only.abstract import AttentionType +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + TORCH_NAIVE = enum.auto() + + @staticmethod + def backend_name_to_enum(backend_name: str) -> "_Backend": + assert backend_name is not None + + backend_members = _Backend.__members__ + if backend_name not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_name}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + + return _Backend[backend_name] + + +class AttnBackend: + + @classmethod + def from_engine(cls, engine): + model_config = engine.engine_config.model_config + num_heads = model_config.get_num_attention_heads() + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads() + sliding_window = model_config.get_sliding_window() + dtype = model_config.dtype + + backend = cls.which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype) + + backend_cls = cls.get_backend_cls(backend) + + attn_type = AttentionType.attn_type_name_to_enum( + engine.workflow.attn_type) + + return backend_cls(attn_type) + + @staticmethod + def get_backend_cls(backend): + if backend == _Backend.FLASH_ATTN: + logger.info("Using FLASH ATTN backend.") + from vllm.attention.prefill_only.flash_attn import ( # noqa: E501 + PrefillOnlyFlashAttentionBackend) + return PrefillOnlyFlashAttentionBackend + if backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + from vllm.attention.prefill_only.xformers import ( # noqa: E501 + PrefillOnlyXFormersBackend) + return PrefillOnlyXFormersBackend + elif backend == _Backend.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + from vllm.attention.prefill_only.torch_sdpa import ( # noqa: E501 + PrefillOnlyTorchSDPABackend) + return PrefillOnlyTorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.info("When using Flashinfer backend in encode only models, " + "you are actually using FLASH ATTN backend") + from vllm.attention.prefill_only.flashinfer import ( # noqa: E501 + PrefillOnlyFlashInferBackend) + return PrefillOnlyFlashInferBackend + elif backend == _Backend.TORCH_NAIVE: + logger.info("Using Torch naive backend.") + from vllm.attention.prefill_only.torch_naive import ( # noqa: E501 + PrefillOnlyTorchNAIVEBackend) + return PrefillOnlyTorchNAIVEBackend + else: + raise ValueError("Invalid attention backend.") + + @classmethod + def which_attn_to_use(cls, num_heads: int, head_size: int, + num_kv_heads: int, sliding_window: Optional[int], + dtype: torch.dtype): + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # get_env_variable_attn_backend + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = _Backend.backend_name_to_enum( + backend_by_env_var) + + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if current_platform.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window." + ) + selected_backend = _Backend.XFORMERS + + return selected_backend + + +AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"] +AttentionImpls_fp16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] +AttentionImpls_bf16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] + +AttentionImpls = { + "float": AttentionImpls_fp32, + "half": AttentionImpls_fp16, + "bfloat16": AttentionImpls_bf16, +} \ No newline at end of file diff --git a/vllm/attention/prefill_only/torch_naive.py b/vllm/attention/prefill_only/torch_naive.py new file mode 100644 index 000000000000..ac9df84a101b --- /dev/null +++ b/vllm/attention/prefill_only/torch_naive.py @@ -0,0 +1,150 @@ +import math +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchNAIVEBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_naive" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchNaiveBackendImpl"]: + return PrefillOnlyTorchNaiveBackendImpl + + +class PrefillOnlyTorchNaiveBackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch Naive does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch Naive does not support logits soft cap.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch Naive backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in Torch Naive.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchNaiveBackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def scaled_dot_product_attention(query, + key, + value, + attn_mask=None, + is_causal=False, + scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool, + device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value diff --git a/vllm/attention/prefill_only/torch_sdpa.py b/vllm/attention/prefill_only/torch_sdpa.py new file mode 100644 index 000000000000..8f0d806fd669 --- /dev/null +++ b/vllm/attention/prefill_only/torch_sdpa.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, List, Optional, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchSDPABackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_sdpa" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchSDPABackendImpl"]: + return PrefillOnlyTorchSDPABackendImpl + + +class PrefillOnlyTorchSDPABackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch SPDA does not support logits soft cap.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in TorchSDPA.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchSDPABackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + dropout_p=0.0, + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/prefill_only/xformers.py b/vllm/attention/prefill_only/xformers.py new file mode 100644 index 000000000000..4497afadc49e --- /dev/null +++ b/vllm/attention/prefill_only/xformers.py @@ -0,0 +1,121 @@ +from typing import Any, Dict, List, Optional, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + BlockDiagonalMask) + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class PrefillOnlyXFormersBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "xformers" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyXFormersImpl"]: + return PrefillOnlyXFormersImpl + + +class PrefillOnlyXFormersImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "XFormers does not support attention logits soft capping.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyXFormersImpl") + original_query = query + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + if causal: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) + else: + attn_bias = BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + out = xops.memory_efficient_attention_forward(query, + key, + value, + p=0.0, + attn_bias=attn_bias, + scale=self.scale) + return out.view_as(original_query)