Skip to content

Commit

Permalink
Enable quantized cat
Browse files Browse the repository at this point in the history
Summary: As titled. Use it only in the WakeWord quantizer for now, because it has implications on the numerics in general.

Reviewed By: zonglinpeng

Differential Revision: D69499329
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Feb 27, 2025
1 parent 6804284 commit 2e611a6
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 7 deletions.
22 changes: 19 additions & 3 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AddmmPattern,
AddPattern,
BmmPattern,
CatPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormPattern,
Expand Down Expand Up @@ -246,6 +247,14 @@ def get_args_and_kwargs_matmul(
return args, kwargs


def get_args_and_kwargs_cat(inputs_inputs: List[fx.Node], other_inputs: List[fx.Node], op_node: fx.Node) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
args = tuple([inputs_inputs] + other_inputs)
dim = op_node.args[1] if len(op_node.args) > 1 else 0
# pyre-fixme[6]: Incompatible parameter type
kwargs = {"dim": int(dim)}
return args, kwargs


def get_args_and_kwargs_conv(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
Expand Down Expand Up @@ -390,12 +399,17 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
self.mark_fused(p.nodes)

dequants_inputs = []
for node, idx in anchors.inputs:
for node, idx, *_spec in anchors.inputs:
arg = (
node.args[idx]
if isinstance(idx, int)
else node.args[idx[0]][idx[1]]
)
if (
node.args[idx].target
arg.target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
):
dequants_inputs.append(node.args[idx])
dequants_inputs.append(arg)
dequants_weights = []
for node, idx in anchors.weights:
if (
Expand Down Expand Up @@ -434,6 +448,8 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
dequants_inputs,
quant_node,
)
elif isinstance(pattern, CatPattern):
args, kwargs = get_args_and_kwargs_cat(inputs_inputs, other_inputs, op_node)
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
args, kwargs = get_args_and_kwargs_conv(
graph_module,
Expand Down
58 changes: 57 additions & 1 deletion backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,17 @@ class PartitionAnchors:
is used for other types of input values as well as handling default parameters.
"""

inputs: List[Tuple[fx.Node, int]] = field(default_factory=list)
# Inputs can share quantization parameters
inputs: List[
Union[
Tuple[fx.Node, Union[int, Tuple[int, int]]],
Tuple[
fx.Node,
Union[int, Tuple[int, int]],
SharedQuantizationSpec,
],
]
] = field(default_factory=list)
weights: List[Tuple[fx.Node, int]] = field(default_factory=list)
biases: List[
Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]]
Expand Down Expand Up @@ -155,6 +165,52 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


class CatPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.cat.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
cat_node = fused_partition[0].nodes[-1]

# Create args. The first argument does not have quant spec and
# will inherit from the overall quant spec. All subsequent args
# will share that spec.
# Note that outpus also share that spec.
args: List[
Union[
Tuple[fx.Node, Union[int, Tuple[int, int]]],
Tuple[
fx.Node,
Union[int, Tuple[int, int]],
SharedQuantizationSpec,
],
]
] = [(cat_node, (0, 0))]
for i in range(1, len(cat_node.args[0])):
args.append(
(
cat_node,
(0, i),
SharedQuantizationSpec((cat_node.args[0][0], cat_node)),
)
)

return PartitionAnchors(
inputs=args,
weights=[],
biases=[],
output=[
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
],
)

def replacement_op(self) -> OpOverload:
return torch.ops.aten.cat.default


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]
Expand Down
26 changes: 23 additions & 3 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AddmmPattern,
AddPattern,
BmmPattern,
CatPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormPattern,
Expand Down Expand Up @@ -144,17 +145,35 @@ def annotate_inputs(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
arg = (
# pyre-ignore[16]: no attribute
node.args[idx]
if isinstance(idx, int)
# pyre-ignore[16]: no attribute
else node.args[idx[0]][idx[1]]
)
annotation.input_qspec_map[arg] = (
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation

def annotate_weights_or_biases(weights_or_biases: List[Tuple[fx.Node, int]], spec: Optional[QuantizationSpec]) -> None:
for node, idx, *custom_spec in weights_or_biases:
annotation = node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation

# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.inputs, input_act_qspec)
annotate_inputs(anchors.weights, weight_qspec)
annotate_weights_or_biases(anchors.weights, weight_qspec)
# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.biases, bias_qspec)
annotate_weights_or_biases(anchors.biases, bias_qspec)
return model

def validate(self, model: fx.GraphModule) -> None:
Expand Down Expand Up @@ -223,4 +242,5 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
if quantizers is None:
quantizers = get_cadence_default_quantizers()
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u))
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8uW8u))
super().__init__(quantizers)

0 comments on commit 2e611a6

Please sign in to comment.