Skip to content

Commit

Permalink
prefill only attention
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Oct 7, 2024
1 parent f19da64 commit cb4946d
Show file tree
Hide file tree
Showing 14 changed files with 983 additions and 9 deletions.
Empty file added tests/attention/__init__.py
Empty file.
Empty file.
89 changes: 89 additions & 0 deletions tests/attention/prefill_only/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import itertools as it

import pytest
import torch
import torch.nn.functional as F

from vllm.attention.layer import Attention
from vllm.attention.prefill_only.abstract import AttentionType
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend,
_Backend)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE


def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities


SEQ_LENS = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29]


@pytest.mark.parametrize("head_dim", [64])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4, 8])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("n_seqs", list(range(1, len(SEQ_LENS))))
def test_basic_correctness(head_dim: int, num_heads: int, num_kv_heads: int,
attn_type: str, dtype: str, n_seqs: int):
assert num_heads % num_kv_heads == 0

torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

attention_impls = AttentionImpls[dtype]

seq_lens = SEQ_LENS[:n_seqs]
batchsize = sum(seq_lens)

query = torch.rand((batchsize, num_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
key = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
value = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))

impl_outputs_list = []

for attention_impl in attention_impls:
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
scaling = head_dim**-0.5

attn = Attention(num_heads,
head_dim,
scale=scaling,
num_kv_heads=num_kv_heads,
attn_backend=attn_backend)

metadata_builder = attn_backend.make_metadata_builder()
attn_metadata = metadata_builder(seq_lens=seq_lens)
attn_metadata = attn_metadata.to("cuda:0")

outputs = attn.forward(query,
key,
value,
kv_cache=None,
attn_metadata=attn_metadata)

impl_outputs_list.append((attention_impl, outputs))

tolerance = 1e-2
for a, b in it.combinations(impl_outputs_list, 2):
similarities = compare_embeddings(a[1], b[1])
all_similarities = torch.stack(similarities)

assert torch.all(
(all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0"
54 changes: 54 additions & 0 deletions tests/attention/prefill_only/test_enum_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from vllm.attention.prefill_only.abstract import (AttentionType,
PrefillOnlyAttentionBackend)
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend,
_Backend)


def get_attn_backend(attention_impl: str, attn_type: str):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
return attn_backend


@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_backend(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
attn_backend = get_attn_backend(attention_impl, attn_type)

assert isinstance(attn_backend, PrefillOnlyAttentionBackend)


@pytest.mark.parametrize("attn_type", ["ENCODER_DECODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_ENCODER_DECODER_not_supported(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
with pytest.raises(NotImplementedError):
get_attn_backend(attention_impl, attn_type)


def test_not_supported_backend():
attention_impls = ["not_supported_backend", 0, 1.0]

for attention_impl in attention_impls:
with pytest.raises(ValueError):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
AttnBackend.get_backend_cls(selected_backend)


def test_not_supported_attn_type():
attn_types = ["not_supported_attn_type", 0, 1.0]

for attn_type in attn_types:
with pytest.raises(ValueError):
AttentionType.attn_type_name_to_enum(attn_type)
13 changes: 13 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ class AttentionType(Enum):
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V

@staticmethod
def attn_type_name_to_enum(attn_type: str) -> "AttentionType":
assert attn_type is not None

attn_type_members = AttentionType.__members__
if attn_type not in attn_type_members:
raise ValueError(
f"Invalid attn_type '{attn_type}'. "
f"Available backends: {', '.join(attn_type_members)} "
"(case-sensitive).")

return AttentionType[attn_type]


class AttentionBackend(ABC):
"""Abstract class for attention backends."""
Expand Down
28 changes: 19 additions & 9 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
prefix: str = "",
attn_backend: Optional[AttentionBackend] = None,
) -> None:
super().__init__()
if cache_config is not None:
Expand Down Expand Up @@ -73,14 +74,18 @@ def __init__(
self.quant_method = quant_method
self.quant_method.create_weights(self)

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size, blocksparse_params
is not None)
impl_cls = attn_backend.get_impl_cls()
if attn_backend is None:
# During model initialization, the default dtype is set as the model
# weight and activation dtype.

dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
num_heads, head_size, num_kv_heads, sliding_window, dtype,
kv_cache_dtype, block_size, blocksparse_params is not None)()
else:
self.attn_backend = attn_backend

impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
Expand All @@ -94,6 +99,11 @@ def forward(
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
if hasattr(self.attn_backend, "attn_type"):
return self.impl.forward(query, key, value, kv_cache,
attn_metadata, self._k_scale,
self._v_scale,
self.attn_backend.attn_type)

return self.impl.forward(query,
key,
Expand Down
Empty file.
125 changes: 125 additions & 0 deletions vllm/attention/prefill_only/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.utils import is_pin_memory_available

pin_memory = is_pin_memory_available()


class PrefillOnlyAttentionBackend(ABC):

def __init__(self, attn_type: AttentionType):
if attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError("Encoder/decoder cross-attention "
"are not implemented for "
"PrefillOnlyAttentionBackend")

self._attn_type = attn_type

@property
def attn_type(self) -> AttentionType:
return self._attn_type

@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["PrefillOnlyAttentionImpl"]:
raise NotImplementedError

@staticmethod
def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]:
return PrefillOnlyAttentionMetadata

@classmethod
def make_metadata(cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]:
return PrefillOnlyAttentionMetadataBuilder

@classmethod
def make_metadata_builder(
cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)


@dataclass
class PrefillOnlyAttentionMetadata:
max_seq_len: int
seq_lens: List[int]

# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]

def to(self, device, non_blocking=False):
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.to(device, non_blocking=non_blocking)

return self


T = TypeVar("T", bound=PrefillOnlyAttentionMetadata)


class PrefillOnlyAttentionMetadataBuilder(Generic[T]):

def __call__(self, seq_lens: List[int]):
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
pin_memory=pin_memory,
device="cpu")
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])

return PrefillOnlyAttentionMetadata(seq_lens=seq_lens,
max_seq_len=max(seq_lens),
seq_start_loc=seq_start_loc)


class PrefillOnlyAttentionImpl(ABC):

@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
raise NotImplementedError

@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
raise NotImplementedError
Loading

0 comments on commit cb4946d

Please sign in to comment.