Skip to content

Commit

Permalink
prefill only attention
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Oct 7, 2024
1 parent 1fbeec4 commit c0b0d3c
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 119 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self.quant_method = quant_method
self.quant_method.create_weights(self)

self.attn_backend: Union[AttentionBackend, type(AttentionBackend)]
self.attn_backend: Union[AttentionBackend, type[AttentionBackend]]
if attn_backend is None:
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
Expand Down
6 changes: 2 additions & 4 deletions vllm/attention/prefill_only/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,16 @@ def get_impl_cls() -> Type["PrefillOnlyAttentionImpl"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]:
raise NotImplementedError
return PrefillOnlyAttentionMetadata

@classmethod
def make_metadata(cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]:
raise NotImplementedError
return PrefillOnlyAttentionMetadataBuilder

@classmethod
def make_metadata_builder(
Expand Down
27 changes: 4 additions & 23 deletions vllm/attention/prefill_only/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type

import torch

from vllm.attention.prefill_only.abstract import (
AttentionType, PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder)
from vllm.attention.prefill_only.abstract import (AttentionType,
PrefillOnlyAttentionBackend,
PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata)


class PrefillOnlyFlashAttentionBackend(PrefillOnlyAttentionBackend):
Expand All @@ -22,25 +22,6 @@ def get_name() -> str:
def get_impl_cls() -> Type["PrefillOnlyFlashAttentionImpl"]:
return PrefillOnlyFlashAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["PrefillOnlyFlashAttentionMetadata"]:
return PrefillOnlyFlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]:
return PrefillOnlyFlashAttentionMetadataBuilder


@dataclass
class PrefillOnlyFlashAttentionMetadata(PrefillOnlyAttentionMetadata):
pass


class PrefillOnlyFlashAttentionMetadataBuilder(
PrefillOnlyAttentionMetadataBuilder[PrefillOnlyFlashAttentionMetadata]
):
pass


class PrefillOnlyFlashAttentionImpl(PrefillOnlyAttentionImpl):

Expand Down
23 changes: 1 addition & 22 deletions vllm/attention/prefill_only/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from dataclasses import dataclass
from typing import Type

from vllm.attention.prefill_only.flash_attn import (
PrefillOnlyFlashAttentionBackend, PrefillOnlyFlashAttentionImpl,
PrefillOnlyFlashAttentionMetadata,
PrefillOnlyFlashAttentionMetadataBuilder)
PrefillOnlyFlashAttentionBackend, PrefillOnlyFlashAttentionImpl)


class PrefillOnlyFlashInferBackend(PrefillOnlyFlashAttentionBackend):
Expand All @@ -17,24 +14,6 @@ def get_name() -> str:
def get_impl_cls() -> Type["PrefillOnlyFlashInferImpl"]:
return PrefillOnlyFlashInferImpl

@staticmethod
def get_metadata_cls() -> Type["PrefillOnlyFlashInferMetadata"]:
return PrefillOnlyFlashInferMetadata

@staticmethod
def get_builder_cls() -> Type["PrefillOnlyFlashInferMetadataBuilder"]:
return PrefillOnlyFlashInferMetadataBuilder


@dataclass
class PrefillOnlyFlashInferMetadata(PrefillOnlyFlashAttentionMetadata):
pass


class PrefillOnlyFlashInferMetadataBuilder(
PrefillOnlyFlashAttentionMetadataBuilder):
pass


class PrefillOnlyFlashInferImpl(PrefillOnlyFlashAttentionImpl):
# Because prefill only models do not involve kv cache,
Expand Down
28 changes: 5 additions & 23 deletions vllm/attention/prefill_only/torch_naive.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type

import torch

from vllm.attention.prefill_only.abstract import (
AttentionType, PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder)
from vllm.attention.prefill_only.abstract import (AttentionType,
PrefillOnlyAttentionBackend,
PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata)
from vllm.utils import is_pin_memory_available

pin_memory = is_pin_memory_available()
Expand All @@ -22,24 +22,6 @@ def get_name() -> str:
def get_impl_cls() -> Type["PrefillOnlyTorchNaiveBackendImpl"]:
return PrefillOnlyTorchNaiveBackendImpl

@staticmethod
def get_metadata_cls() -> Type["PrefillOnlyTorchNaiveMetadata"]:
return PrefillOnlyTorchNaiveMetadata

@staticmethod
def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]:
return PrefillOnlyTorchNaiveMetadataBuilder


@dataclass
class PrefillOnlyTorchNaiveMetadata(PrefillOnlyAttentionMetadata):
pass


class PrefillOnlyTorchNaiveMetadataBuilder(
PrefillOnlyAttentionMetadataBuilder[PrefillOnlyTorchNaiveMetadata]):
pass


class PrefillOnlyTorchNaiveBackendImpl(PrefillOnlyAttentionImpl):

Expand Down Expand Up @@ -86,7 +68,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: PrefillOnlyTorchNaiveMetadata,
attn_metadata: PrefillOnlyAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
Expand Down
28 changes: 5 additions & 23 deletions vllm/attention/prefill_only/torch_sdpa.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type

import torch
from torch.nn.functional import scaled_dot_product_attention

from vllm.attention.prefill_only.abstract import (
AttentionType, PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder)
from vllm.attention.prefill_only.abstract import (AttentionType,
PrefillOnlyAttentionBackend,
PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata)
from vllm.utils import is_pin_memory_available

pin_memory = is_pin_memory_available()
Expand All @@ -22,24 +22,6 @@ def get_name() -> str:
def get_impl_cls() -> Type["PrefillOnlyTorchSDPABackendImpl"]:
return PrefillOnlyTorchSDPABackendImpl

@staticmethod
def get_metadata_cls() -> Type["PrefillOnlyTorchSDPAMetadata"]:
return PrefillOnlyTorchSDPAMetadata

@staticmethod
def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]:
return PrefillOnlyTorchSDPAMetadataBuilder


@dataclass
class PrefillOnlyTorchSDPAMetadata(PrefillOnlyAttentionMetadata):
pass


class PrefillOnlyTorchSDPAMetadataBuilder(
PrefillOnlyAttentionMetadataBuilder[PrefillOnlyTorchSDPAMetadata]):
pass


class PrefillOnlyTorchSDPABackendImpl(PrefillOnlyAttentionImpl):

Expand Down Expand Up @@ -86,7 +68,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: PrefillOnlyTorchSDPAMetadata,
attn_metadata: PrefillOnlyAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
Expand Down
28 changes: 5 additions & 23 deletions vllm/attention/prefill_only/xformers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type

import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
BlockDiagonalMask)

from vllm.attention.prefill_only.abstract import (
AttentionType, PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder)
from vllm.attention.prefill_only.abstract import (AttentionType,
PrefillOnlyAttentionBackend,
PrefillOnlyAttentionImpl,
PrefillOnlyAttentionMetadata)
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -24,24 +24,6 @@ def get_name() -> str:
def get_impl_cls() -> Type["PrefillOnlyXFormersImpl"]:
return PrefillOnlyXFormersImpl

@staticmethod
def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]:
return PrefillOnlyXFormersMetadata

@staticmethod
def get_builder_cls() -> Type["PrefillOnlyXFormersMetadataBuilder"]:
return PrefillOnlyXFormersMetadataBuilder


@dataclass
class PrefillOnlyXFormersMetadata(PrefillOnlyAttentionMetadata):
pass


class PrefillOnlyXFormersMetadataBuilder(
PrefillOnlyAttentionMetadataBuilder[PrefillOnlyXFormersMetadata]):
pass


class PrefillOnlyXFormersImpl(PrefillOnlyAttentionImpl):

Expand Down Expand Up @@ -82,7 +64,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: PrefillOnlyXFormersMetadata,
attn_metadata: PrefillOnlyAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
Expand Down

0 comments on commit c0b0d3c

Please sign in to comment.