Skip to content

Commit

Permalink
Feature custom qat module (#999)
Browse files Browse the repository at this point in the history
* Extend quantizable layer types (#881)

* Creates support to adding modules to the list of quantizable modules via a modifier flag.

* Passed layer_class_argument to recursive call.

* Changed flag name to have better contrast to exclude_module_types

* Style and quality fixes

* Update src/sparseml/pytorch/sparsification/quantization/helpers.py

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>

* Dea-activate conversions that remove Q/DQ after matmul

* Remove unused code

* Use public property name instead of alias

* Replace conditional by dictionary

* Style and quality fixes

* Changed flag name to custom_quantizable_module_types to make it a bit less confusing when contrasted w/ existing flag exclude_module_types (sometimes the same module needs to be listed in both flags).

* Fix of calls to set _exclude_module_types

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>

* Revert dtype conversion. Dictionary was breaking in some cases (#989)

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
3 people authored Aug 18, 2022
1 parent e7c5009 commit 1e5b27b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 35 deletions.
23 changes: 16 additions & 7 deletions src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,20 +481,25 @@ def configure_module_default_qconfigs(module: Module):
submodule.configure_qconfig()


def add_quant_dequant(module, name=None, parent_module=None):
def add_quant_dequant(
module: torch.nn.Module, name=None, parent_module=None, layer_class_names=None
):
"""
Wraps all Conv and Linear submodule with a qconfig with a QuantWrapper
:param module: the module to modify
:param name: name of the module to modify; default to None
:param parent_module: parent module containing the module to modify; default to None
:param layer_class_names: list of module class names to be added to the
list of quantizable modules
:return: the modified module
"""
named_children = module.named_children()
if (
type(module) in _QUANTIZABLE_MODULE_TYPES
and hasattr(module, "qconfig")
and module.qconfig
):
is_quantizable = type(module) in _QUANTIZABLE_MODULE_TYPES
if layer_class_names:
is_quantizable = (
is_quantizable or module.__class__.__name__ in layer_class_names
)
if is_quantizable and hasattr(module, "qconfig") and module.qconfig:
module = torch_quantization.QuantWrapper(module)
if parent_module is not None and len(list(named_children)) <= 0:
if "." in name:
Expand All @@ -508,7 +513,11 @@ def add_quant_dequant(module, name=None, parent_module=None):
setattr(parent_module, name, module)
else:
for name, child in named_children:
setattr(module, name, add_quant_dequant(child))
setattr(
module,
name,
add_quant_dequant(child, layer_class_names=layer_class_names),
)
return module


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class QuantizationModifier(ScheduledModifier):
batch-normalization modules
:param exclude_module_types: optional list of module class names
to not propagate quantization configs to. Default is None
:param custom_quantizable_module_types: optional list of module class names
to be added to the list of quantizable modules. Default is None
:param activation_qconfig_kwargs: Additional kwargs for quantization of
activations.
:param weight_qconfig_kwargs: Additional kwargs for quantization of
Expand Down Expand Up @@ -162,6 +164,7 @@ def __init__(
num_calibration_steps: Optional[int] = None,
exclude_batchnorm: bool = True,
exclude_module_types: Optional[List[str]] = None,
custom_quantizable_module_types: Optional[List[str]] = None,
activation_qconfig_kwargs: Optional[Dict[str, Any]] = None,
weight_qconfig_kwargs: Optional[Dict[str, Any]] = None,
tensorrt: bool = False,
Expand Down Expand Up @@ -195,6 +198,7 @@ def __init__(
self._weight_bits = weight_bits
self._exclude_batchnorm = exclude_batchnorm
self._exclude_module_types = exclude_module_types
self._custom_quantizable_module_types = custom_quantizable_module_types

self._modules_to_quantize = None
self._qat_enabled = False
Expand Down Expand Up @@ -389,6 +393,14 @@ def quantize_embedding_activations(self) -> bool:
else:
return self._quantize_embedding_activations

@ModifierProp()
def custom_quantizable_module_types(self) -> Union[List[str], None]:
"""
:return: optional list of module class names to be included
in list of quantizable modules. Default is None
"""
return self._custom_quantizable_module_types

@ModifierProp()
def exclude_module_types(self) -> Union[List[str], None]:
"""
Expand Down Expand Up @@ -651,8 +663,9 @@ def _enable_module_qat(self, module: Module):
# wrap all conv / linear blocks in with quantization observers
torch_quantization.propagate_qconfig_(quant_module)
configure_module_default_qconfigs(quant_module)

add_quant_dequant(quant_module, name, module)
add_quant_dequant(
quant_module, name, module, self.custom_quantizable_module_types
)

# Remove output quantization from appropriate modules
remove_activation_qat_by_layer_name(
Expand All @@ -661,16 +674,16 @@ def _enable_module_qat(self, module: Module):

# remove qconfigs for module types in exclude_module_types
to_exclude = []
if self._exclude_module_types:
to_exclude.extend(self._exclude_module_types)
if self.exclude_module_types:
to_exclude.extend(self.exclude_module_types)

# if exclude_batchnorm flag is used, add batch norm layers to list of
# modules to exclude qconfig
if self._exclude_batchnorm:
if self.exclude_batchnorm:
to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"])

self._exclude_module_types = to_exclude
if self._exclude_module_types:
if self.exclude_module_types:
self._strip_excluded_module_qconfigs(module)

# set modules with proper qconfigs to QAT mode
Expand Down Expand Up @@ -753,9 +766,9 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
)

def _strip_excluded_module_qconfigs(self, module: Module):
if not self._exclude_module_types:
if not self.exclude_module_types:
return
excluded_classes = set(self._exclude_module_types)
excluded_classes = set(self.exclude_module_types)
for submodule in module.modules():
if submodule.__class__.__name__ in excluded_classes and hasattr(
submodule, "qconfig"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def _attribute_to_kwarg(attribute: onnx.AttributeProto):
def _quantize_array(
array: numpy.ndarray, scale: float, zero_point: int, dtype: Any = numpy.uint8
) -> numpy.ndarray:

if dtype == numpy.uint8:
tensor_dtype = torch.quint8
elif dtype == numpy.int8:
Expand Down Expand Up @@ -1060,25 +1061,8 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
if not bias_add_node or bias_add_node.op_type != "Add":
continue

# Optionally find output QDQ block which will be deleted
output_quantize_node = graph.get_node_single_child(bias_add_node)
if (
not output_quantize_node
or output_quantize_node.op_type not in _QUANTIZE_OP_NAMES
):
output_quantize_node = None

output_dequantize_node = (
graph.get_node_single_child(output_quantize_node)
if output_quantize_node
else None
)
if (
not output_dequantize_node
or output_dequantize_node.op_type not in _QUANTIZE_OP_NAMES
):
output_quantize_node = None
output_dequantize_node = None
output_quantize_node = None
output_dequantize_node = None

input_quantize_params = get_quantization_params(
model, input_quantize_node, include_target=False
Expand Down Expand Up @@ -1587,7 +1571,6 @@ def quantize_torch_qat_export(
_convert_quantizable_gemm_no_activations(model)
quantize_resnet_identity_add_inputs(model)
_remove_duplicate_quantize_ops(model)
_cleanup_unused_quants(model)

graph = ONNXGraph(model)
graph.sort_nodes_topologically()
Expand Down

0 comments on commit 1e5b27b

Please sign in to comment.