Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use INormalizationLayer for GroupNorm #3273

Merged
merged 12 commits into from
Dec 12, 2024
38 changes: 2 additions & 36 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def aten_ops_native_group_norm(
SourceIR.ATEN,
name,
input=args[0],
weight=args[1],
bias=args[2],
weight=args_bounds_check(args, 1),
bias=args_bounds_check(args, 2),
N=args[3],
C=args[4],
HxW=args[5],
Expand All @@ -192,40 +192,6 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(
torch.ops.aten.group_norm.default,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.group_norm,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_group_norm(
HolyWu marked this conversation as resolved.
Show resolved Hide resolved
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.group_norm(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
num_groups=args[1],
weight=args_bounds_check(args, 2, None),
bias=args_bounds_check(args, 3, None),
eps=args_bounds_check(args, 4, 1e-05),
cudnn_enabled=args_bounds_check(args, 5, True),
)


@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
def aten_ops_cat(
ctx: ConversionContext,
Expand Down
261 changes: 49 additions & 212 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
Expand All @@ -16,7 +16,6 @@
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
Expand Down Expand Up @@ -203,234 +202,72 @@ def native_group_norm(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
N: int,
C: int,
HxW: int,
group: int,
eps: float,
return_mean_rstd: bool = True,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
# with INormalization Layer
assert (
len(input.shape) >= 3
), f"The input dimension should not be less than 3, got {len(input.shape)}!"

B = input.shape[0]
# if C is provided, it must be as same as the channel from the input shape,
# else if C is zero, we should get the channel from the input shape
if C == 0:
C = input.shape[1]
assert (
C == input.shape[1]
), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
HolyWu marked this conversation as resolved.
Show resolved Hide resolved
# Groups are a subdivision of the channel dimension.
assert (
C % group == 0
), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
input = get_trt_tensor(ctx, input, f"{name}_input")

shape = list(input.shape)

for i, s in enumerate(shape):
if i == 0 and s > 0:
shape[i] = B * group
elif i == 1:
shape[i] = C // group
elif i > 1 and s == -1:
shape[i] = 0

# Normalize every group.
reshaped_input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_input",
input,
shape,
)

if weight is None:
weight = to_numpy(1.0)

if bias is None:
bias = to_numpy(0.0)

weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)

dims = list(range(1, len(input.shape)))

# E[X]
mean_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean",
reshaped_input,
dims,
True,
)
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
rank = len(input.shape)

mean_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_mean_trt",
mean_trt,
reshaped_input.shape,
)
assert rank >= 3, f"Expected at least 3 dimensions for input tensor but got {rank}"

# X - E[X]
sub_trt = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_sub",
reshaped_input,
mean_trt,
)
assert (
C == input.shape[1]
), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})"

# variance
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
pow_var = impl.elementwise.pow(
ctx,
target,
source_ir,
f"{name}_pow",
sub_trt,
pow_trt,
)
weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype)
bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype)

var_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean_var",
pow_var,
dims,
True,
)
shape = [1, group] + [1] * (rank - 2)

var_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_var_trt",
var_trt,
reshaped_input.shape,
weight_one = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape
)

eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
add_trt = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add",
var_trt,
eps_trt,
bias_zero = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape
)

sqrt_trt = impl.unary.sqrt(
ctx,
target,
source_ir,
f"{name}_sqrt",
add_trt,
)
axes = get_axes_for_reduce_op([i for i in range(1 if group == 1 else 2, rank)])

# y = (X - E[X]) / sqrt((var + eps))
output = impl.elementwise.div(
ctx,
target,
source_ir,
f"{name}_div",
sub_trt,
sqrt_trt,
)
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
layer = ctx.net.add_normalization(input, weight_one, bias_zero, axes)
layer.epsilon = eps
layer.num_groups = group
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)

shape = list(output.shape)
for i, s in enumerate(shape):
if i == 0 and s > 0:
shape[i] = B
elif i == 1:
shape[i] = C
elif i > 1 and s == -1:
shape[i] = 0
shape[1] = C

reshaped_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_output", output, shape
)
reshaped_gamma = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_gamma",
weight,
weight_bias_shape,
)

reshaped_output = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_gamma",
reshaped_output,
reshaped_gamma,
)

reshaped_bias = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_beta",
bias,
weight_bias_shape,
)
reshaped_output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add_beta",
reshaped_output,
reshaped_bias,
)
if return_mean_rstd:
# return fake mean and rstd for now
return reshaped_output, None, None
return reshaped_output
if weight is not None:
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
weight = cast_trt_tensor(
ctx, weight, input.dtype, f"{name}_cast_weight", target, source_ir
)
weight = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_weight", weight, shape
)
output = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_mul_weight", output, weight
)

if bias is not None:
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
bias = cast_trt_tensor(
ctx, bias, input.dtype, f"{name}_cast_bias", target, source_ir
)
bias = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_bias", bias, shape
)
output = impl.elementwise.add(
ctx, target, source_ir, f"{name}_add_bias", output, bias
)

def group_norm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
num_groups: int,
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
eps: float,
cudnn_enabled: bool,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return native_group_norm(
ctx,
target,
source_ir,
name,
input,
weight,
bias,
0,
0,
0,
num_groups,
eps,
return_mean_rstd=False,
)
# return fake mean and rstd for now
return output, None, None

HolyWu marked this conversation as resolved.
Show resolved Hide resolved

def softmax(
Expand Down
Loading
Loading