diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 51d019f155..91c9eb11b3 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -13,6 +13,7 @@ AddmmPattern, AddPattern, BmmPattern, + CatPattern, Conv1dPattern, Conv2dPattern, LayerNormPattern, @@ -246,6 +247,14 @@ def get_args_and_kwargs_matmul( return args, kwargs +def get_args_and_kwargs_cat(inputs_inputs: List[fx.Node], other_inputs: List[fx.Node], op_node: fx.Node) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: + args = tuple([inputs_inputs] + other_inputs) + dim = op_node.args[1] if len(op_node.args) > 1 else 0 + # pyre-fixme[6]: Incompatible parameter type + kwargs = {"dim": int(dim)} + return args, kwargs + + def get_args_and_kwargs_conv( graph_module: GraphModule, inputs_inputs: List[fx.Node], @@ -390,12 +399,17 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 self.mark_fused(p.nodes) dequants_inputs = [] - for node, idx in anchors.inputs: + for node, idx, *_spec in anchors.inputs: + arg = ( + node.args[idx] + if isinstance(idx, int) + else node.args[idx[0]][idx[1]] + ) if ( - node.args[idx].target + arg.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default ): - dequants_inputs.append(node.args[idx]) + dequants_inputs.append(arg) dequants_weights = [] for node, idx in anchors.weights: if ( @@ -434,6 +448,8 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_inputs, quant_node, ) + elif isinstance(pattern, CatPattern): + args, kwargs = get_args_and_kwargs_cat(inputs_inputs, other_inputs, op_node) elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)): args, kwargs = get_args_and_kwargs_conv( graph_module, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 0e907812b1..66f6772d94 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -33,7 +33,17 @@ class PartitionAnchors: is used for other types of input values as well as handling default parameters. """ - inputs: List[Tuple[fx.Node, int]] = field(default_factory=list) + # Inputs can share quantization parameters + inputs: List[ + Union[ + Tuple[fx.Node, Union[int, Tuple[int, int]]], + Tuple[ + fx.Node, + Union[int, Tuple[int, int]], + SharedQuantizationSpec, + ], + ] + ] = field(default_factory=list) weights: List[Tuple[fx.Node, int]] = field(default_factory=list) biases: List[ Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] @@ -155,6 +165,52 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_matmul.default +class CatPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.cat.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + cat_node = fused_partition[0].nodes[-1] + + # Create args. The first argument does not have quant spec and + # will inherit from the overall quant spec. All subsequent args + # will share that spec. + # Note that outpus also share that spec. + args: List[ + Union[ + Tuple[fx.Node, Union[int, Tuple[int, int]]], + Tuple[ + fx.Node, + Union[int, Tuple[int, int]], + SharedQuantizationSpec, + ], + ] + ] = [(cat_node, (0, 0))] + for i in range(1, len(cat_node.args[0])): + args.append( + ( + cat_node, + (0, i), + SharedQuantizationSpec((cat_node.args[0][0], cat_node)), + ) + ) + + return PartitionAnchors( + inputs=args, + weights=[], + biases=[], + output=[ + (cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node))) + ], + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.aten.cat.default + + class Conv1dPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.conv1d.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 42cc1a1df1..366426aa60 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -14,6 +14,7 @@ AddmmPattern, AddPattern, BmmPattern, + CatPattern, Conv1dPattern, Conv2dPattern, LayerNormPattern, @@ -144,17 +145,35 @@ def annotate_inputs( "quantization_annotation", QuantizationAnnotation(_annotated=True), ) + arg = ( + # pyre-ignore[16]: no attribute + node.args[idx] + if isinstance(idx, int) + # pyre-ignore[16]: no attribute + else node.args[idx[0]][idx[1]] + ) + annotation.input_qspec_map[arg] = ( + custom_spec[0] if custom_spec else spec + ) # pyre-ignore[16]: no attribute + node.meta["quantization_annotation"] = annotation + + def annotate_weights_or_biases(weights_or_biases: List[Tuple[fx.Node, int]], spec: Optional[QuantizationSpec]) -> None: + for node, idx, *custom_spec in weights_or_biases: + annotation = node.meta.get( + "quantization_annotation", + QuantizationAnnotation(_annotated=True), + ) annotation.input_qspec_map[node.args[idx]] = ( custom_spec[0] if custom_spec else spec ) - # pyre-ignore[16]: no attribute node.meta["quantization_annotation"] = annotation + # pyre-ignore[6]: incompatible parameter type annotate_inputs(anchors.inputs, input_act_qspec) - annotate_inputs(anchors.weights, weight_qspec) + annotate_weights_or_biases(anchors.weights, weight_qspec) # pyre-ignore[6]: incompatible parameter type - annotate_inputs(anchors.biases, bias_qspec) + annotate_weights_or_biases(anchors.biases, bias_qspec) return model def validate(self, model: fx.GraphModule) -> None: @@ -223,4 +242,5 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: if quantizers is None: quantizers = get_cadence_default_quantizers() quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u)) + quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8uW8u)) super().__init__(quantizers)