-
-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
983 additions
and
9 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import itertools as it | ||
|
||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from vllm.attention.layer import Attention | ||
from vllm.attention.prefill_only.abstract import AttentionType | ||
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend, | ||
_Backend) | ||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE | ||
|
||
|
||
def compare_embeddings(embeddings1, embeddings2): | ||
similarities = [ | ||
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) | ||
for e1, e2 in zip(embeddings1, embeddings2) | ||
] | ||
return similarities | ||
|
||
|
||
SEQ_LENS = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29] | ||
|
||
|
||
@pytest.mark.parametrize("head_dim", [64]) | ||
@pytest.mark.parametrize("num_heads", [8, 16]) | ||
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4, 8]) | ||
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) | ||
@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"]) | ||
@pytest.mark.parametrize("n_seqs", list(range(1, len(SEQ_LENS)))) | ||
def test_basic_correctness(head_dim: int, num_heads: int, num_kv_heads: int, | ||
attn_type: str, dtype: str, n_seqs: int): | ||
assert num_heads % num_kv_heads == 0 | ||
|
||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] | ||
|
||
attention_impls = AttentionImpls[dtype] | ||
|
||
seq_lens = SEQ_LENS[:n_seqs] | ||
batchsize = sum(seq_lens) | ||
|
||
query = torch.rand((batchsize, num_heads, head_dim), | ||
dtype=torch_dtype, | ||
device="cuda:0").view((batchsize, -1)) | ||
key = torch.rand((batchsize, num_kv_heads, head_dim), | ||
dtype=torch_dtype, | ||
device="cuda:0").view((batchsize, -1)) | ||
value = torch.rand((batchsize, num_kv_heads, head_dim), | ||
dtype=torch_dtype, | ||
device="cuda:0").view((batchsize, -1)) | ||
|
||
impl_outputs_list = [] | ||
|
||
for attention_impl in attention_impls: | ||
selected_backend = _Backend.backend_name_to_enum(attention_impl) | ||
backend_cls = AttnBackend.get_backend_cls(selected_backend) | ||
|
||
attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type) | ||
|
||
attn_backend = backend_cls(attn_type_enum) | ||
scaling = head_dim**-0.5 | ||
|
||
attn = Attention(num_heads, | ||
head_dim, | ||
scale=scaling, | ||
num_kv_heads=num_kv_heads, | ||
attn_backend=attn_backend) | ||
|
||
metadata_builder = attn_backend.make_metadata_builder() | ||
attn_metadata = metadata_builder(seq_lens=seq_lens) | ||
attn_metadata = attn_metadata.to("cuda:0") | ||
|
||
outputs = attn.forward(query, | ||
key, | ||
value, | ||
kv_cache=None, | ||
attn_metadata=attn_metadata) | ||
|
||
impl_outputs_list.append((attention_impl, outputs)) | ||
|
||
tolerance = 1e-2 | ||
for a, b in it.combinations(impl_outputs_list, 2): | ||
similarities = compare_embeddings(a[1], b[1]) | ||
all_similarities = torch.stack(similarities) | ||
|
||
assert torch.all( | ||
(all_similarities <= 1.0 + tolerance) | ||
& (all_similarities >= 1.0 - tolerance) | ||
), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import pytest | ||
|
||
from vllm.attention.prefill_only.abstract import (AttentionType, | ||
PrefillOnlyAttentionBackend) | ||
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend, | ||
_Backend) | ||
|
||
|
||
def get_attn_backend(attention_impl: str, attn_type: str): | ||
selected_backend = _Backend.backend_name_to_enum(attention_impl) | ||
backend_cls = AttnBackend.get_backend_cls(selected_backend) | ||
|
||
attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type) | ||
|
||
attn_backend = backend_cls(attn_type_enum) | ||
return attn_backend | ||
|
||
|
||
@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"]) | ||
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) | ||
def test_backend(dtype: str, attn_type: str): | ||
attention_impls = AttentionImpls[dtype] | ||
|
||
for attention_impl in attention_impls: | ||
attn_backend = get_attn_backend(attention_impl, attn_type) | ||
|
||
assert isinstance(attn_backend, PrefillOnlyAttentionBackend) | ||
|
||
|
||
@pytest.mark.parametrize("attn_type", ["ENCODER_DECODER"]) | ||
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) | ||
def test_ENCODER_DECODER_not_supported(dtype: str, attn_type: str): | ||
attention_impls = AttentionImpls[dtype] | ||
|
||
for attention_impl in attention_impls: | ||
with pytest.raises(NotImplementedError): | ||
get_attn_backend(attention_impl, attn_type) | ||
|
||
|
||
def test_not_supported_backend(): | ||
attention_impls = ["not_supported_backend", 0, 1.0] | ||
|
||
for attention_impl in attention_impls: | ||
with pytest.raises(ValueError): | ||
selected_backend = _Backend.backend_name_to_enum(attention_impl) | ||
AttnBackend.get_backend_cls(selected_backend) | ||
|
||
|
||
def test_not_supported_attn_type(): | ||
attn_types = ["not_supported_attn_type", 0, 1.0] | ||
|
||
for attn_type in attn_types: | ||
with pytest.raises(ValueError): | ||
AttentionType.attn_type_name_to_enum(attn_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar | ||
|
||
import torch | ||
|
||
from vllm.attention.backends.abstract import AttentionType | ||
from vllm.utils import is_pin_memory_available | ||
|
||
pin_memory = is_pin_memory_available() | ||
|
||
|
||
class PrefillOnlyAttentionBackend(ABC): | ||
|
||
def __init__(self, attn_type: AttentionType): | ||
if attn_type == AttentionType.ENCODER_DECODER: | ||
raise NotImplementedError("Encoder/decoder cross-attention " | ||
"are not implemented for " | ||
"PrefillOnlyAttentionBackend") | ||
|
||
self._attn_type = attn_type | ||
|
||
@property | ||
def attn_type(self) -> AttentionType: | ||
return self._attn_type | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_name() -> str: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_impl_cls() -> Type["PrefillOnlyAttentionImpl"]: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]: | ||
return PrefillOnlyAttentionMetadata | ||
|
||
@classmethod | ||
def make_metadata(cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadata": | ||
return cls.get_metadata_cls()(*args, **kwargs) | ||
|
||
@staticmethod | ||
def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: | ||
return PrefillOnlyAttentionMetadataBuilder | ||
|
||
@classmethod | ||
def make_metadata_builder( | ||
cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadataBuilder": | ||
return cls.get_builder_cls()(*args, **kwargs) | ||
|
||
|
||
@dataclass | ||
class PrefillOnlyAttentionMetadata: | ||
max_seq_len: int | ||
seq_lens: List[int] | ||
|
||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in | ||
# the batch, used to index into sequence. E.g., if the sequence length is | ||
# [4, 6], it is [0, 4, 10]. | ||
seq_start_loc: Optional[torch.Tensor] | ||
|
||
def to(self, device, non_blocking=False): | ||
for k, v in self.__dict__.items(): | ||
if isinstance(v, torch.Tensor): | ||
self.__dict__[k] = v.to(device, non_blocking=non_blocking) | ||
|
||
return self | ||
|
||
|
||
T = TypeVar("T", bound=PrefillOnlyAttentionMetadata) | ||
|
||
|
||
class PrefillOnlyAttentionMetadataBuilder(Generic[T]): | ||
|
||
def __call__(self, seq_lens: List[int]): | ||
seq_lens_tensor = torch.tensor(seq_lens, | ||
dtype=torch.long, | ||
pin_memory=pin_memory, | ||
device="cpu") | ||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, | ||
dtype=torch.int32, | ||
device="cpu") | ||
torch.cumsum(seq_lens_tensor, | ||
dim=0, | ||
dtype=seq_start_loc.dtype, | ||
out=seq_start_loc[1:]) | ||
|
||
return PrefillOnlyAttentionMetadata(seq_lens=seq_lens, | ||
max_seq_len=max(seq_lens), | ||
seq_start_loc=seq_start_loc) | ||
|
||
|
||
class PrefillOnlyAttentionImpl(ABC): | ||
|
||
@abstractmethod | ||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
kv_cache_dtype: str = "auto", | ||
blocksparse_params: Optional[Dict[str, Any]] = None, | ||
logits_soft_cap: Optional[float] = None, | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: Optional[torch.Tensor], | ||
attn_metadata: T, | ||
k_scale: float = 1.0, | ||
v_scale: float = 1.0, | ||
attn_type: AttentionType = AttentionType.DECODER, | ||
) -> torch.Tensor: | ||
raise NotImplementedError |
Oops, something went wrong.