|
| 1 | +import math |
| 2 | +from typing import Optional, Union |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import tensorrt as trt |
| 6 | +from torch.fx.node import Target |
| 7 | +from torch_tensorrt._enums import dtype |
| 8 | +from torch_tensorrt.dynamo.conversion import impl |
| 9 | +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
| 10 | +from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 11 | + SourceIR, |
| 12 | + cast_trt_tensor, |
| 13 | + get_trt_tensor, |
| 14 | +) |
| 15 | +from torch_tensorrt.fx.types import TRTTensor |
| 16 | + |
| 17 | + |
| 18 | +def tril( |
| 19 | + ctx: ConversionContext, |
| 20 | + target: Union[Target, str], |
| 21 | + source_ir: Optional[SourceIR], |
| 22 | + name: str, |
| 23 | + input: TRTTensor, |
| 24 | +) -> TRTTensor: |
| 25 | + # the lower triangle of the tensor means the rows greater than and equal to the cols |
| 26 | + row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) |
| 27 | + col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) |
| 28 | + rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) |
| 29 | + arange_tensor = impl.arange.arange( |
| 30 | + ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 |
| 31 | + ) |
| 32 | + # get the rows |
| 33 | + row_tensor = impl.elementwise.trunc_div( |
| 34 | + ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col |
| 35 | + ) |
| 36 | + # get the cols |
| 37 | + col_tensor = impl.elementwise.fmod( |
| 38 | + ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col |
| 39 | + ) |
| 40 | + cond = impl.elementwise.ge( |
| 41 | + ctx, target, source_ir, name + "_ge", row_tensor, col_tensor |
| 42 | + ) |
| 43 | + return impl.shuffle.reshape( |
| 44 | + ctx, target, source_ir, name + "_reshape", cond, [row, col] |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def scaled_dot_product_attention( |
| 49 | + ctx: ConversionContext, |
| 50 | + target: Union[Target, str], |
| 51 | + source_ir: Optional[SourceIR], |
| 52 | + name: str, |
| 53 | + query: TRTTensor, |
| 54 | + key: TRTTensor, |
| 55 | + value: TRTTensor, |
| 56 | + is_causal: bool, |
| 57 | + scale: Optional[float], |
| 58 | +) -> TRTTensor: |
| 59 | + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html |
| 60 | + mm = impl.matmul.matrix_multiply( |
| 61 | + ctx, |
| 62 | + target, |
| 63 | + source_ir, |
| 64 | + name + "_mm", |
| 65 | + query, |
| 66 | + key, |
| 67 | + other_matrix_op=trt.MatrixOperation.TRANSPOSE, |
| 68 | + ) |
| 69 | + if scale is None: |
| 70 | + scale = query.shape[-1] |
| 71 | + if scale < 0: |
| 72 | + # dynamic shape |
| 73 | + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) |
| 74 | + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) |
| 75 | + else: |
| 76 | + # static shape |
| 77 | + sqrt_scaled = math.sqrt(scale) |
| 78 | + scaled = impl.elementwise.div( |
| 79 | + ctx, |
| 80 | + target, |
| 81 | + source_ir, |
| 82 | + name + "_scale", |
| 83 | + mm, |
| 84 | + sqrt_scaled, |
| 85 | + ) |
| 86 | + else: |
| 87 | + scaled = impl.elementwise.mul( |
| 88 | + ctx, |
| 89 | + target, |
| 90 | + source_ir, |
| 91 | + name + "_scale", |
| 92 | + mm, |
| 93 | + scale, |
| 94 | + ) |
| 95 | + |
| 96 | + if is_causal: |
| 97 | + L, S = query.shape[-2], key.shape[-2] |
| 98 | + if L >= 0 and S >= 0: |
| 99 | + # static shape |
| 100 | + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) |
| 101 | + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) |
| 102 | + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) |
| 103 | + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") |
| 104 | + else: |
| 105 | + # if any of the L or S is dynamic shape |
| 106 | + if L < 0: |
| 107 | + L = impl.shape.shape( |
| 108 | + ctx, target, source_ir, name + "_shape_0", query, -2 |
| 109 | + ) |
| 110 | + if S < 0: |
| 111 | + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) |
| 112 | + |
| 113 | + LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) |
| 114 | + |
| 115 | + # this is to generate a tensor which has shape (L, S), type is int32 |
| 116 | + arange_tensor = impl.arange.arange( |
| 117 | + ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 |
| 118 | + ) |
| 119 | + shape_tensor = impl.shuffle.reshape( |
| 120 | + ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] |
| 121 | + ) |
| 122 | + |
| 123 | + # since we want our attn_bias to be in float32, so cast it to float32 |
| 124 | + shape_tensor = cast_trt_tensor( |
| 125 | + ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir |
| 126 | + ) |
| 127 | + |
| 128 | + # initialize the attn_bias as the zeros tensor |
| 129 | + attn_bias = impl.elementwise.mul( |
| 130 | + ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 |
| 131 | + ) |
| 132 | + |
| 133 | + # generate the mask tensor |
| 134 | + tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) |
| 135 | + temp_mask = impl.unary.logical_not( |
| 136 | + ctx, target, source_ir, name + "_logical_not", tril_tensor |
| 137 | + ) |
| 138 | + inf_tensor = impl.elementwise.mul( |
| 139 | + ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") |
| 140 | + ) |
| 141 | + cond = impl.elementwise.eq( |
| 142 | + ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) |
| 143 | + ) |
| 144 | + # mask out the certain part of the attn_bias |
| 145 | + attn_bias = impl.condition.select( |
| 146 | + ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond |
| 147 | + ) |
| 148 | + |
| 149 | + scaled = impl.elementwise.add( |
| 150 | + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias |
| 151 | + ) |
| 152 | + |
| 153 | + softmax = impl.normalization.softmax( |
| 154 | + ctx, target, source_ir, name + "_softmax", scaled, -1, False |
| 155 | + ) |
| 156 | + out = impl.matmul.matrix_multiply( |
| 157 | + ctx, |
| 158 | + target, |
| 159 | + source_ir, |
| 160 | + name + "_out", |
| 161 | + softmax, |
| 162 | + value, |
| 163 | + ) |
| 164 | + |
| 165 | + return out |
0 commit comments