From 6805e390315ea761feea486ecf503976d0f00b10 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Tue, 5 Nov 2024 23:13:17 +0800 Subject: [PATCH 1/2] Fix LayerNorm fp16 precision --- .../dynamo/conversion/aten_ops_converters.py | 12 +- .../conversion/impl/normalization/ops.py | 39 +++---- .../dynamo/conversion/test_layer_norm_aten.py | 105 +++++------------- 3 files changed, 48 insertions(+), 108 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 07c8c03697..aff5a1dfb3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -134,10 +134,6 @@ def aten_ops_batch_norm_legit_no_training( capability_validator=one_user_validator, supports_dynamic_shapes=True, ) -@dynamo_tensorrt_converter( - torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True -) -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -157,11 +153,9 @@ def aten_ops_layer_norm( name, input=args[0], normalized_shape=args[1], - weight=args_bounds_check(args, 2, 1.0), - bias=args_bounds_check(args, 3, 0.0), - eps=args_bounds_check(args, 4, 1e-05), - cudnn_enable=args_bounds_check(args, 5, True), - return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default), + weight=args_bounds_check(args, 2), + bias=args_bounds_check(args, 3), + eps=args[4], ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 4f39a6d5d9..9d69daa1e8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -159,16 +159,18 @@ def layer_norm( name: str, input: TRTTensor, normalized_shape: List[int], - weight: Optional[Union[torch.Tensor, np.ndarray]], - bias: Optional[Union[torch.Tensor, np.ndarray]], + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], eps: float, - cudnn_enable: bool, - return_mean_rstd: bool, -) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: +) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]: dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) axes = get_axes_for_reduce_op(dims) - weight = get_trt_tensor(ctx, weight, f"{name}_weight") - bias = get_trt_tensor(ctx, bias, f"{name}_bias") + + weight = get_trt_tensor( + ctx, weight if weight is not None else 1.0, f"{name}_weight" + ) + bias = get_trt_tensor(ctx, bias if bias is not None else 0.0, f"{name}_bias") + # Cast weight and bias to have same dtype as input weight = cast_trt_tensor( ctx, weight, input.dtype, f"{name}_weight_cast", target, source_ir @@ -176,32 +178,23 @@ def layer_norm( bias = cast_trt_tensor( ctx, bias, input.dtype, f"{name}_bias_cast", target, source_ir ) + if tuple(input.shape) != tuple(weight.shape): weight = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape ) + if tuple(input.shape) != tuple(bias.shape): bias = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape ) - strongly_typed_network = False - if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED): - weight = cast_trt_tensor(ctx, weight, input.dtype, name) - bias = cast_trt_tensor(ctx, bias, input.dtype, name) - strongly_typed_network = True - - layer_norm = ctx.net.add_normalization(input, weight, bias, axes) - layer_norm.epsilon = eps - # compute_precision ignored for strongly typed network. - if not strongly_typed_network: - layer_norm.compute_precision = input.dtype - set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir) - if return_mean_rstd: - # return fake mean and rstd for now - return layer_norm.get_output(0), None, None + layer = ctx.net.add_normalization(input, weight, bias, axes) + layer.epsilon = eps + set_layer_name(layer, target, name, source_ir) - return layer_norm.get_output(0) + # return fake mean and rstd for now + return layer.get_output(0), None, None def native_group_norm( diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index c6cfc430ba..90b568f775 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -6,78 +6,51 @@ from .harness import DispatchTestCase -class TestLayerNormConverter(DispatchTestCase): +class TestNativeLayerNormConverter(DispatchTestCase): @parameterized.expand( [ - ( - (5, 3, 2, 4), - [ - 4, - ], - ), - ((5, 3, 2, 4), [2, 4]), - ((5, 3, 2, 4), [3, 2, 4]), - ((5, 3, 2, 4), [5, 3, 2, 4]), + ((2, 4, 6), [6]), + ((2, 4, 6), [4, 6]), + ((2, 4, 6), [2, 4, 6]), ] ) - def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): + def test_layer_norm_1d(self, input_shape, normalized_shape): class LayerNorm(torch.nn.Module): def forward(self, x): - return torch.ops.aten.layer_norm.default( - x, - normalized_shape, - torch.randn(normalized_shape), - torch.randn(normalized_shape), - eps, - ) + return torch.ops.aten.native_layer_norm.default( + x, normalized_shape, None, None, 1e-05 + )[0] inputs = [torch.randn(input_shape)] - self.run_test( - LayerNorm(), - inputs, - ) - + self.run_test(LayerNorm(), inputs, use_dynamo_tracer=True) -class TestNativeLayerNormConverter(DispatchTestCase): @parameterized.expand( [ - ( - (5, 3, 2, 4), - [ - 4, - ], - ), + ((5, 3, 2, 4), [4]), ((5, 3, 2, 4), [2, 4]), ((5, 3, 2, 4), [3, 2, 4]), ((5, 3, 2, 4), [5, 3, 2, 4]), ] ) - def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): + def test_layer_norm_2d(self, input_shape, normalized_shape): class LayerNorm(torch.nn.Module): - def forward(self, x): + def forward(self, x, weight, bias): return torch.ops.aten.native_layer_norm.default( - x, - normalized_shape, - torch.randn(normalized_shape), - torch.randn(normalized_shape), - eps, + x, normalized_shape, weight, bias, 1e-05 )[0] - inputs = [torch.randn(input_shape)] - self.run_test( - LayerNorm(), - inputs, - ) + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + torch.randn(normalized_shape), + ] + self.run_test(LayerNorm(), inputs, use_dynamo_tracer=True) def test_layernorm_with_dynamic_shape(self): class LayerNorm(torch.nn.Module): - def forward(self, x): + def forward(self, x, weight, bias): return torch.ops.aten.native_layer_norm.default( - x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, + x, [3, 224, 224], weight, bias, 1e-05 )[0] input_specs = [ @@ -87,22 +60,19 @@ def forward(self, x): opt_shape=(5, 3, 224, 224), max_shape=(10, 3, 224, 224), ), + Input(dtype=torch.float32, shape=(3, 224, 224)), + Input(dtype=torch.float32, shape=(3, 224, 224)), ] self.run_test_with_dynamic_shape( - LayerNorm(), - input_specs, + LayerNorm(), input_specs, use_dynamo_tracer=True ) def test_layernorm_with_dynamic_shape_1(self): class LayerNorm(torch.nn.Module): - def forward(self, x): + def forward(self, x, weight, bias): return torch.ops.aten.native_layer_norm.default( - x, - torch.tensor([3]), - torch.ones((3)), - torch.zeros((3)), - 1e-05, + x, [3], weight, bias, 1e-05 )[0] input_specs = [ @@ -112,29 +82,12 @@ def forward(self, x): opt_shape=(3, 3, 3), max_shape=(4, 5, 3), ), + Input(dtype=torch.float32, shape=(3,)), + Input(dtype=torch.float32, shape=(3,)), ] self.run_test_with_dynamic_shape( - LayerNorm(), - input_specs, - ) - - @parameterized.expand([((5, 3, 2, 4), [2, 4])]) - def test_layer_norm_without_Scaling(self, input_shape, normalized_shape, eps=1e-05): - class LayerNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.native_layer_norm.default( - x, - normalized_shape, - None, - None, - eps, - )[0] - - inputs = [torch.randn(input_shape)] - self.run_test( - LayerNorm(), - inputs, + LayerNorm(), input_specs, use_dynamo_tracer=True ) From fda5d24ca0989ce8d51d4833214c66ee1862a7f5 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Wed, 13 Nov 2024 22:41:47 +0800 Subject: [PATCH 2/2] Keep function name the same as operator --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 4 ++-- py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index aff5a1dfb3..884c51e8ea 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -139,14 +139,14 @@ def aten_ops_batch_norm_legit_no_training( 0: (TRTTensor,), } ) -def aten_ops_layer_norm( +def aten_ops_native_layer_norm( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.layer_norm( + return impl.normalization.native_layer_norm( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 9d69daa1e8..b737fb7dbc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -152,7 +152,7 @@ def batch_norm( return output -def layer_norm( +def native_layer_norm( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR],