diff --git a/src/sparseml/onnx/utils/graph_optimizer.py b/src/sparseml/onnx/utils/graph_optimizer.py index d91f900c898..d0015a7bcdf 100644 --- a/src/sparseml/onnx/utils/graph_optimizer.py +++ b/src/sparseml/onnx/utils/graph_optimizer.py @@ -40,7 +40,6 @@ __all__ = [ "fold_conv_bns", "quantize_resnet_identity_add_inputs", - "quantized_residual_add_optim", ] @@ -202,91 +201,6 @@ def quantize_resnet_identity_add_inputs(quantized_model: onnx.ModelProto) -> boo return optimization_made -def quantized_residual_add_optim(quantized_model: onnx.ModelProto) -> bool: - """ - This optimization adds a quant/dequant block to the identity branch of a - residual whose non-identity branch is quantized. This enables the add at the - end of the residual to be fused at runtime. - - Function will match to any node who has two children nodes - one add node - and one quantize node whose branch eventually leads to the other add node. - - :param quantized_model: A loaded quantized model to perform this optimization on - :return: True if an in-place optimization was made - """ - graph = ONNXGraph(quantized_model) - optimization_made = False - for node in quantized_model.graph.node: - children_nodes = graph.get_node_children(node) - if len(children_nodes) != 2: - continue - - add_node = [node for node in children_nodes if node.op_type == "Add"] - quant_node = [ - node for node in children_nodes if node.op_type == "QuantizeLinear" - ] - if not add_node or not quant_node: - continue - add_node = add_node[0] - quant_node = quant_node[0] - - # verify that quant_node eventually leads to add_node - curr_node = [quant_node] - iter = 0 - max_iter = 20 # avoid cycles - while curr_node and curr_node[0] != add_node and iter < max_iter: - curr_node = graph.get_node_children(curr_node[0]) - iter += 1 - if curr_node[0] != add_node: - continue - - # create de-quantize node for identity - dequant_node = _make_dequant_node_for_quant(quant_node) - - # update graph - identity_edge_idx = 0 if add_node.input[0] == node.output[0] else 1 - graph.add_node(dequant_node) - graph.update_node_input(add_node, dequant_node.output[0], identity_edge_idx) - optimization_made = True - - # if any of the add children have are a quantize op while others aren't - # add a quant/dequant block to the non quantized paths to allow for fusion - # of the add - add_node_children = graph.get_node_children(add_node) - add_node_quant_child_idx = [ - idx - for idx, node in enumerate(add_node_children) - if node.op_type == "QuantizeLinear" - ] - if not add_node_quant_child_idx or all( - n.op_type == "Add" or n.op_type == "QuantizeLinear" - for n in add_node_children - ): - # no quant child node, or all child nodes are quant/add nodes - continue - - # make dequant pair node for quant child and add to graph - add_node_dequant_child = _make_dequant_node_for_quant( - add_node_children[add_node_quant_child_idx[0]] - ) - graph.add_node(add_node_dequant_child) - - # update all non quant node children to take the quant/dequant block as input - for add_child_node in add_node_children: - if add_child_node.op_type == "QuantizeLinear": - continue - add_node_id_idx = [ - idx - for idx, output_id in enumerate(add_child_node.input) - if output_id == add_node.output[0] - ][0] - graph.update_node_input( - add_child_node, add_node_dequant_child.output[0], add_node_id_idx - ) - - return optimization_made - - def _make_dequant_node_for_quant(quant_node: onnx.NodeProto) -> onnx.NodeProto: return onnx.helper.make_node( "DequantizeLinear", diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py index 6fb0ecce3b2..9219db829cd 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py @@ -25,6 +25,7 @@ import numpy import onnx +import torch from onnx import ModelProto, NodeProto, numpy_helper from sparseml.onnx.utils import ( @@ -34,7 +35,6 @@ get_node_attributes, get_node_output_nodes, quantize_resnet_identity_add_inputs, - quantized_residual_add_optim, remove_node_and_params_from_graph, swap_node_output, update_model_param, @@ -323,9 +323,21 @@ def _attribute_to_kwarg(attribute: onnx.AttributeProto): def _quantize_array( array: numpy.ndarray, scale: float, zero_point: int, dtype: Any = numpy.uint8 ) -> numpy.ndarray: - dmin = numpy.iinfo(dtype).min - dmax = numpy.iinfo(dtype).max - return ((array / scale).round() + zero_point).clip(dmin, dmax).astype(dtype) + if dtype == numpy.uint8: + tensor_dtype = torch.quint8 + elif dtype == numpy.int8: + tensor_dtype = torch.qint8 + elif dtype == numpy.int32: + tensor_dtype = torch.qint32 + + tensor = torch.Tensor(array).to(torch.float32) + if isinstance(scale, numpy.ndarray): + scale = scale.item() + if isinstance(zero_point, numpy.ndarray): + zero_point = zero_point.item() + + quant_tensor = torch.quantize_per_tensor(tensor, scale, zero_point, tensor_dtype) + return quant_tensor.int_repr().numpy() def _convert_quantizable_conv( @@ -450,6 +462,7 @@ def _convert_quantizable_gemm( weight_quantize_params.target, weight_quantize_params.scale, weight_quantize_params.zero_point, + weight_quantize_params.zero_point.dtype, ) quantized_weight = quantized_weight.transpose() # Gemm has implicit transpose quantized_weight_name = "{}.weight_quantized".format(gemm_node.name) @@ -732,6 +745,7 @@ def _add_quantized_conv_matmul_add_ops( weight_quantize_params.target, weight_quantize_params.scale, weight_quantize_params.zero_point, + weight_quantize_params.zero_point.dtype, ) if transpose_weight: quantized_weight = quantized_weight.transpose() @@ -1404,7 +1418,9 @@ def _quantize_qat_embedding(model: ModelProto): embedding = numpy_helper.to_array(embedding_initializer) scale = numpy_helper.to_array(scale_initializer) zero_point = numpy_helper.to_array(zp_initializer) - embedding_quant = _quantize_array(embedding, scale, zero_point) + embedding_quant = _quantize_array( + embedding, scale, zero_point, zero_point.dtype + ) embedding_quant_initializer = numpy_helper.from_array( embedding_quant, name=f"{embedding_initializer.name}_quant" ) @@ -1569,7 +1585,6 @@ def quantize_torch_qat_export( _convert_quantizable_gemm_no_activations(model) _quantize_qat_embedding(model) quantize_resnet_identity_add_inputs(model) - quantized_residual_add_optim(model) _remove_duplicate_quantize_ops(model) _cleanup_unused_quants(model) diff --git a/src/sparseml/version.py b/src/sparseml/version.py index 8b5831ba6fc..44bfac5dd78 100644 --- a/src/sparseml/version.py +++ b/src/sparseml/version.py @@ -19,7 +19,7 @@ from datetime import date -version_base = "1.0.0" +version_base = "1.0.1" is_release = False # change to True to set the generated version as a release version