Skip to content

Commit b06e632

Browse files
committed
Optimized lowering and decomposition to benchmark quantization again
1 parent 044acdf commit b06e632

File tree

9 files changed

+527
-152
lines changed

9 files changed

+527
-152
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+32
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,38 @@ def aten_ops_quantize_op(
619619
)
620620

621621

622+
def attention_validator(
623+
node: Node, settings: Optional[CompilationSettings] = None
624+
) -> bool:
625+
# Currently, `attn_mask` is not supported
626+
return args_bounds_check(node.args, 3) is None
627+
628+
629+
@dynamo_tensorrt_converter(
630+
torch.nn.functional.scaled_dot_product_attention,
631+
capability_validator=attention_validator,
632+
supports_dynamic_shapes=True,
633+
)
634+
def tensorrt_scaled_dot_product_attention(
635+
ctx: ConversionContext,
636+
target: Target,
637+
args: Tuple[Argument, ...],
638+
kwargs: Dict[str, Argument],
639+
name: str,
640+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
641+
return impl.attention.scaled_dot_product_attention(
642+
ctx,
643+
target,
644+
SourceIR.TORCHTRT_LOWERED,
645+
name,
646+
args[0],
647+
args[1],
648+
args[2],
649+
args_bounds_check(args, 5, False),
650+
kwargs.get("scale", None),
651+
)
652+
653+
622654
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
623655
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
624656
def aten_ops_squeeze(

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
activation,
33
addmm,
44
arange,
5+
attention,
56
cast,
67
cat,
78
condition,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import math
2+
from typing import Optional, Union
3+
4+
import numpy as np
5+
import tensorrt as trt
6+
from torch.fx.node import Target
7+
from torch_tensorrt._enums import dtype
8+
from torch_tensorrt.dynamo.conversion import impl
9+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
SourceIR,
12+
cast_trt_tensor,
13+
get_trt_tensor,
14+
)
15+
from torch_tensorrt.fx.types import TRTTensor
16+
17+
18+
def tril(
19+
ctx: ConversionContext,
20+
target: Union[Target, str],
21+
source_ir: Optional[SourceIR],
22+
name: str,
23+
input: TRTTensor,
24+
) -> TRTTensor:
25+
# the lower triangle of the tensor means the rows greater than and equal to the cols
26+
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
27+
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
28+
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
29+
arange_tensor = impl.arange.arange(
30+
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
31+
)
32+
# get the rows
33+
row_tensor = impl.elementwise.trunc_div(
34+
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
35+
)
36+
# get the cols
37+
col_tensor = impl.elementwise.fmod(
38+
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
39+
)
40+
cond = impl.elementwise.ge(
41+
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
42+
)
43+
return impl.shuffle.reshape(
44+
ctx, target, source_ir, name + "_reshape", cond, [row, col]
45+
)
46+
47+
48+
def scaled_dot_product_attention(
49+
ctx: ConversionContext,
50+
target: Union[Target, str],
51+
source_ir: Optional[SourceIR],
52+
name: str,
53+
query: TRTTensor,
54+
key: TRTTensor,
55+
value: TRTTensor,
56+
is_causal: bool,
57+
scale: Optional[float],
58+
) -> TRTTensor:
59+
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
60+
mm = impl.matmul.matrix_multiply(
61+
ctx,
62+
target,
63+
source_ir,
64+
name + "_mm",
65+
query,
66+
key,
67+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
68+
)
69+
if scale is None:
70+
scale = query.shape[-1]
71+
if scale < 0:
72+
# dynamic shape
73+
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
74+
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
75+
else:
76+
# static shape
77+
sqrt_scaled = math.sqrt(scale)
78+
scaled = impl.elementwise.div(
79+
ctx,
80+
target,
81+
source_ir,
82+
name + "_scale",
83+
mm,
84+
sqrt_scaled,
85+
)
86+
else:
87+
scaled = impl.elementwise.mul(
88+
ctx,
89+
target,
90+
source_ir,
91+
name + "_scale",
92+
mm,
93+
scale,
94+
)
95+
96+
if is_causal:
97+
L, S = query.shape[-2], key.shape[-2]
98+
if L >= 0 and S >= 0:
99+
# static shape
100+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
101+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
102+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
103+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
104+
else:
105+
# if any of the L or S is dynamic shape
106+
if L < 0:
107+
L = impl.shape.shape(
108+
ctx, target, source_ir, name + "_shape_0", query, -2
109+
)
110+
if S < 0:
111+
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)
112+
113+
LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)
114+
115+
# this is to generate a tensor which has shape (L, S), type is int32
116+
arange_tensor = impl.arange.arange(
117+
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
118+
)
119+
shape_tensor = impl.shuffle.reshape(
120+
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
121+
)
122+
123+
# since we want our attn_bias to be in float32, so cast it to float32
124+
shape_tensor = cast_trt_tensor(
125+
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
126+
)
127+
128+
# initialize the attn_bias as the zeros tensor
129+
attn_bias = impl.elementwise.mul(
130+
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
131+
)
132+
133+
# generate the mask tensor
134+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
135+
temp_mask = impl.unary.logical_not(
136+
ctx, target, source_ir, name + "_logical_not", tril_tensor
137+
)
138+
inf_tensor = impl.elementwise.mul(
139+
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
140+
)
141+
cond = impl.elementwise.eq(
142+
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
143+
)
144+
# mask out the certain part of the attn_bias
145+
attn_bias = impl.condition.select(
146+
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
147+
)
148+
149+
scaled = impl.elementwise.add(
150+
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
151+
)
152+
153+
softmax = impl.normalization.softmax(
154+
ctx, target, source_ir, name + "_softmax", scaled, -1, False
155+
)
156+
out = impl.matmul.matrix_multiply(
157+
ctx,
158+
target,
159+
source_ir,
160+
name + "_out",
161+
softmax,
162+
value,
163+
)
164+
165+
return out

0 commit comments

Comments
 (0)