Skip to content

Commit

Permalink
Replace scaled_dot_product_attention lowering pass with decomposition (
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu authored Dec 16, 2024
1 parent bed5d37 commit 7d0d06d
Show file tree
Hide file tree
Showing 8 changed files with 477 additions and 622 deletions.
32 changes: 0 additions & 32 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,38 +2750,6 @@ def aten_ops_max_pool(
)


def attention_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
supports_dynamic_shapes=True,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
activation,
addmm,
arange,
attention,
cast,
cat,
condition,
Expand Down
165 changes: 0 additions & 165 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py

This file was deleted.

128 changes: 127 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torch._decomp import register_decomposition
Expand Down Expand Up @@ -423,6 +423,132 @@ def instance_norm_decomposition(
)


@register_torch_trt_decomposition(
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
device = query.device
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)

if is_causal:
assert attn_mask is None, "attn_mask must be None when is_causal=True"
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias

if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale

attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value


@register_torch_trt_decomposition(
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_flash_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, None, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_efficient_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_cudnn_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_detach import remove_detach
Expand All @@ -23,7 +22,6 @@
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
lower_scaled_dot_product_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
Expand Down
Loading

0 comments on commit 7d0d06d

Please sign in to comment.