Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LayerNorm fp16 precision #3272

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,28 @@ def aten_ops_batch_norm_legit_no_training(
capability_validator=one_user_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_layer_norm(
def aten_ops_native_layer_norm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.layer_norm(
return impl.normalization.native_layer_norm(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
normalized_shape=args[1],
weight=args_bounds_check(args, 2, 1.0),
bias=args_bounds_check(args, 3, 0.0),
eps=args_bounds_check(args, 4, 1e-05),
cudnn_enable=args_bounds_check(args, 5, True),
return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default),
weight=args_bounds_check(args, 2),
bias=args_bounds_check(args, 3),
eps=args[4],
)


Expand Down
41 changes: 17 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,56 +152,49 @@ def batch_norm(
return output


def layer_norm(
def native_layer_norm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
normalized_shape: List[int],
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
eps: float,
cudnn_enable: bool,
return_mean_rstd: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
axes = get_axes_for_reduce_op(dims)
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")

weight = get_trt_tensor(
ctx, weight if weight is not None else 1.0, f"{name}_weight"
)
bias = get_trt_tensor(ctx, bias if bias is not None else 0.0, f"{name}_bias")

# Cast weight and bias to have same dtype as input
weight = cast_trt_tensor(
ctx, weight, input.dtype, f"{name}_weight_cast", target, source_ir
)
bias = cast_trt_tensor(
ctx, bias, input.dtype, f"{name}_bias_cast", target, source_ir
)

if tuple(input.shape) != tuple(weight.shape):
weight = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape
)

if tuple(input.shape) != tuple(bias.shape):
bias = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape
)
strongly_typed_network = False
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
weight = cast_trt_tensor(ctx, weight, input.dtype, name)
bias = cast_trt_tensor(ctx, bias, input.dtype, name)
strongly_typed_network = True

layer_norm = ctx.net.add_normalization(input, weight, bias, axes)
layer_norm.epsilon = eps
# compute_precision ignored for strongly typed network.
if not strongly_typed_network:
layer_norm.compute_precision = input.dtype
set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir)

if return_mean_rstd:
# return fake mean and rstd for now
return layer_norm.get_output(0), None, None
layer = ctx.net.add_normalization(input, weight, bias, axes)
layer.epsilon = eps
set_layer_name(layer, target, name, source_ir)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. Looks like you do not need to explicitly set the ILayer.precision or ILayer.set_output_type to set the output type of this layer with fp16 inputs

return layer_norm.get_output(0)
# return fake mean and rstd for now
return layer.get_output(0), None, None


def native_group_norm(
Expand Down
105 changes: 29 additions & 76 deletions tests/py/dynamo/conversion/test_layer_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,78 +6,51 @@
from .harness import DispatchTestCase


class TestLayerNormConverter(DispatchTestCase):
class TestNativeLayerNormConverter(DispatchTestCase):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
((2, 4, 6), [6]),
((2, 4, 6), [4, 6]),
((2, 4, 6), [2, 4, 6]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
def test_layer_norm_1d(self, input_shape, normalized_shape):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.layer_norm.default(
x,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
)
return torch.ops.aten.native_layer_norm.default(
x, normalized_shape, None, None, 1e-05
)[0]

inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
)

self.run_test(LayerNorm(), inputs, use_dynamo_tracer=True)

class TestNativeLayerNormConverter(DispatchTestCase):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [4]),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
def test_layer_norm_2d(self, input_shape, normalized_shape):
class LayerNorm(torch.nn.Module):
def forward(self, x):
def forward(self, x, weight, bias):
return torch.ops.aten.native_layer_norm.default(
x,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
x, normalized_shape, weight, bias, 1e-05
)[0]

inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
)
inputs = [
torch.randn(input_shape),
torch.randn(normalized_shape),
torch.randn(normalized_shape),
]
self.run_test(LayerNorm(), inputs, use_dynamo_tracer=True)

def test_layernorm_with_dynamic_shape(self):
class LayerNorm(torch.nn.Module):
def forward(self, x):
def forward(self, x, weight, bias):
return torch.ops.aten.native_layer_norm.default(
x,
torch.tensor([3, 224, 224]),
torch.ones((3, 224, 224)),
torch.zeros((3, 224, 224)),
1e-05,
x, [3, 224, 224], weight, bias, 1e-05
)[0]

input_specs = [
Expand All @@ -87,22 +60,19 @@ def forward(self, x):
opt_shape=(5, 3, 224, 224),
max_shape=(10, 3, 224, 224),
),
Input(dtype=torch.float32, shape=(3, 224, 224)),
Input(dtype=torch.float32, shape=(3, 224, 224)),
]

self.run_test_with_dynamic_shape(
LayerNorm(),
input_specs,
LayerNorm(), input_specs, use_dynamo_tracer=True
)

def test_layernorm_with_dynamic_shape_1(self):
class LayerNorm(torch.nn.Module):
def forward(self, x):
def forward(self, x, weight, bias):
return torch.ops.aten.native_layer_norm.default(
x,
torch.tensor([3]),
torch.ones((3)),
torch.zeros((3)),
1e-05,
x, [3], weight, bias, 1e-05
)[0]

input_specs = [
Expand All @@ -112,29 +82,12 @@ def forward(self, x):
opt_shape=(3, 3, 3),
max_shape=(4, 5, 3),
),
Input(dtype=torch.float32, shape=(3,)),
Input(dtype=torch.float32, shape=(3,)),
]

self.run_test_with_dynamic_shape(
LayerNorm(),
input_specs,
)

@parameterized.expand([((5, 3, 2, 4), [2, 4])])
def test_layer_norm_without_Scaling(self, input_shape, normalized_shape, eps=1e-05):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
normalized_shape,
None,
None,
eps,
)[0]

inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
LayerNorm(), input_specs, use_dynamo_tracer=True
)


Expand Down
Loading