diff --git a/bin/quant.py b/bin/quant.py index 3d33191f..94bfc448 100644 --- a/bin/quant.py +++ b/bin/quant.py @@ -1,17 +1,18 @@ import torch from torch.nn import Linear -# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs + from sparsetensors.quantization.quant_args import QuantizationArgs from sparsetensors.quantization.quant_scheme import QuantizationScheme -from sparseml.modifiers.quantization.lifecycle.initialize import initialize_module_for_quantization -from sparseml.modifiers.quantization.lifecycle.calibration import set_module_for_calibration -from sparseml.modifiers.quantization.lifecycle.frozen import freeze_module_quantization +from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization +from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration +from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization num_bits = 8 scheme = QuantizationScheme( input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False), weights=QuantizationArgs(num_bits=num_bits, symmetric=True), output_activations=None, + targets = ["*"], ) layer = Linear(4, 4) @@ -31,25 +32,29 @@ layer(torch.randn(4,4)) print(dict(layer.named_parameters())) # scale and zero point should have updated values print(2) -for _ in range(10): +print("calib layers ") +for i in range(10): + print("iter", i) layer(torch.randn(4,4)) print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass print(3) -breakpoint() +# breakpoint() freeze_module_quantization(layer) -for _ in range(10): +print("freeze layers ") +for i in range(10): # do more forward passes but show args are frozen - layer(torch.random.randn(4,4)) + print("iter", i) + layer(torch.randn(4,4)) print(dict(layer.named_parameters())) # scale and zero point should not be updated now -# missing +# # missing -# correctness -# quantizing an entire model +# # correctness +# # quantizing an entire model diff --git a/setup.py b/setup.py index 12ef67a3..de506c99 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def _setup_install_requires() -> List: return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"] def _setup_extras() -> Dict: - return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]} + return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "sparsezoo"]} setup( name="sparsetensors", diff --git a/src/sparsetensors/quantization/lifecycle/__init__.py b/src/sparsetensors/quantization/lifecycle/__init__.py index d90b28a9..52b86440 100644 --- a/src/sparsetensors/quantization/lifecycle/__init__.py +++ b/src/sparsetensors/quantization/lifecycle/__init__.py @@ -19,4 +19,3 @@ from .frozen import * from .initialize import * from .status import * -from .initialize import * diff --git a/src/sparsetensors/quantization/lifecycle/calirbation.py b/src/sparsetensors/quantization/lifecycle/calibration.py similarity index 90% rename from src/sparsetensors/quantization/lifecycle/calirbation.py rename to src/sparsetensors/quantization/lifecycle/calibration.py index a4f4dfea..986b062a 100644 --- a/src/sparsetensors/quantization/lifecycle/calirbation.py +++ b/src/sparsetensors/quantization/lifecycle/calibration.py @@ -15,10 +15,9 @@ import logging +from sparsetensors.quantization.lifecycle.status import QuantizationStatus from torch.nn import Module -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus - __all__ = [ "set_module_for_calibration", @@ -41,4 +40,4 @@ def set_module_for_calibration(module: Module): "to re-calibrate a frozen module" ) - module.quantization_status = QuantizationStatus.CALIBRATION \ No newline at end of file + module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 4247e7c7..cbb27dea 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -15,11 +15,16 @@ from functools import wraps import torch +from sparsetensors.quantization.lifecycle.status import QuantizationStatus + +# from sparsetensors.quantization.utils.quantization_scheme import ( +# QuantizationArgs, +# QuantizationScheme, +# ) +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus - -from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs __all__ = ["wrap_module_forward_quantized"] @@ -34,8 +39,8 @@ def quantize( torch.round( x / scale + zero_point, ), - 0, - q_max, + 0, + q_max, ) @@ -83,7 +88,7 @@ def fake_quantize( # q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten() # Q1[:, i] = q # Q[:, i1:i2] = Q1 - Q = quantize(x, scale, zero_point, max_q) + Q = quantize(x, scale, zero_point, max_q) return dequantize(Q, scale, zero_point) @@ -138,7 +143,7 @@ def _maybe_calibrate_or_quantize( return value scale = getattr(module, f"{base_name}_scale") - # zero_point = getattr(module, f"{base_name}_zero_point").data + # zero_point = getattr(module, f"{base_name}_zero_point").data zero_point = getattr(module, f"{base_name}_zero_point") print(scale, zero_point) @@ -152,4 +157,4 @@ def _maybe_calibrate_or_quantize( scale.data = updated_scale zero_point.data = updated_zero_point - return fake_quantize(value, scale, zero_point, args) \ No newline at end of file + return fake_quantize(value, scale, zero_point, args) diff --git a/src/sparsetensors/quantization/lifecycle/frozen.py b/src/sparsetensors/quantization/lifecycle/frozen.py index d480465b..6b92eee7 100644 --- a/src/sparsetensors/quantization/lifecycle/frozen.py +++ b/src/sparsetensors/quantization/lifecycle/frozen.py @@ -13,10 +13,9 @@ # limitations under the License. +from sparsetensors.quantization.lifecycle.status import QuantizationStatus from torch.nn import Module -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus - __all__ = [ "freeze_module_quantization", diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py index cfa4aa77..6d23f4cc 100644 --- a/src/sparsetensors/quantization/lifecycle/initialize.py +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -16,17 +16,17 @@ import logging import torch +from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized +from sparsetensors.quantization.lifecycle.status import QuantizationStatus + +# from sparsetensors.quantization.utils.quantization_scheme import ( +# QuantizationArgs, +# QuantizationScheme, +# ) +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module, Parameter -from sparseml.modifiers.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus -from sparseml.modifiers.quantization.utils.quantization_scheme import ( - QuantizationArgs, - QuantizationScheme, -) - __all__ = [ "initialize_module_for_quantization", @@ -39,9 +39,7 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme): if scheme.input_activations is not None: - _initialize_scale_zero_point_observer( - module, "input", scheme.input_activations - ) + _initialize_scale_zero_point_observer(module, "input", scheme.input_activations) if scheme.weights is not None: if hasattr(module, "weight"): _initialize_scale_zero_point_observer(module, "weight", scheme.weights) @@ -52,7 +50,9 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationSchem f"for {type(module)}" ) if scheme.output_activations is not None: - _initialize_scale_zero_point_observer(module, "output", scheme.output_activations) + _initialize_scale_zero_point_observer( + module, "output", scheme.output_activations + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -61,7 +61,6 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationSchem wrap_module_forward_quantized(module, scheme) - def _initialize_scale_zero_point_observer( module: Module, base_name: str, quantization_args: QuantizationArgs ): diff --git a/src/sparsetensors/quantization/observers/__init__.py b/src/sparsetensors/quantization/observers/__init__.py index 1bec545d..d0362b8f 100644 --- a/src/sparsetensors/quantization/observers/__init__.py +++ b/src/sparsetensors/quantization/observers/__init__.py @@ -16,4 +16,4 @@ from .base import * from .memoryless import * -from .min_max import * \ No newline at end of file +from .min_max import * diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py index 44c8ec37..00cd7561 100644 --- a/src/sparsetensors/quantization/observers/base.py +++ b/src/sparsetensors/quantization/observers/base.py @@ -14,12 +14,12 @@ from typing import Optional, Tuple +# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsezoo.utils.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module -from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs -from sparsezoo.utils.registry import RegistryMixin - __all__ = ["Observer"] @@ -31,9 +31,7 @@ class Observer(Module, RegistryMixin): pair """ - def __init__(self, - quantization_args: QuantizationArgs - ): + def __init__(self, quantization_args: QuantizationArgs): self.quantization_args: QuantizationArgs = quantization_args super().__init__() self._scale = None @@ -69,4 +67,4 @@ def get_qparams( if observed is not None: # re-calcualte scale and zero point, update the stored value self._scale, self._zero_point = self.calculate_qparams(observed) - return self._scale, self._zero_point \ No newline at end of file + return self._scale, self._zero_point diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py index 5f74448b..faabbb5a 100644 --- a/src/sparsetensors/quantization/observers/memoryless.py +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -15,10 +15,11 @@ from typing import Tuple import torch +from sparsetensors.quantization.observers.base import Observer from torch import FloatTensor, IntTensor, Tensor -from sparseml.modifiers.quantization.observers.base import Observer -# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs + +# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs __all__ = ["MemorylessObserver"] @@ -60,4 +61,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: zero_point = (0 - min_val) / scale - return scale, zero_point \ No newline at end of file + return scale, zero_point diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py index c72eb1c0..40cde72c 100644 --- a/src/sparsetensors/quantization/observers/min_max.py +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -15,11 +15,10 @@ from typing import Tuple import torch +from sparsetensors.quantization.observers.base import Observer +from sparsetensors.quantization.quant_args import QuantizationArgs from torch import FloatTensor, IntTensor, Tensor -from sparseml.modifiers.quantization.observers.base import Observer -from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs - __all__ = ["MinMaxObserver"] @@ -77,4 +76,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: zero_point = (0 - self.min_val) / scale - return scale, zero_point \ No newline at end of file + return scale, zero_point diff --git a/src/sparsetensors/quantization/quant_args.py b/src/sparsetensors/quantization/quant_args.py index 89a2e3df..fb9e9b01 100644 --- a/src/sparsetensors/quantization/quant_args.py +++ b/src/sparsetensors/quantization/quant_args.py @@ -13,9 +13,9 @@ # limitations under the License. from enum import Enum -from typing import Optional +from typing import Any, Dict, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] @@ -61,3 +61,24 @@ class QuantizationArgs(BaseModel): strategy: QuantizationStrategy = QuantizationStrategy.TENSOR group_size: Optional[int] = None block_structure: Optional[str] = None + observer: str = Field( + default="minmax", + description=( + "The class to use to compute the quantization params - scale and zero-point'" + ), + ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + def get_observer(self): + """ + :return: torch quantization FakeQuantize built based on these QuantizationArgs + """ + from sparsetensors.quantization.observers.base import Observer + + return Observer.load_from_registry(self.observer, quantization_args=self) diff --git a/src/sparsetensors/quantization/utils/quantization_scheme.py b/src/sparsetensors/quantization/utils/quantization_scheme.py deleted file mode 100644 index 976b534e..00000000 --- a/src/sparsetensors/quantization/utils/quantization_scheme.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Schemas and types to support quantization -""" -from copy import deepcopy -from functools import partial -from typing import Any, Dict, Optional, Union - -import torch -from packaging import version -from pydantic import BaseModel, Field, validator -from torch.nn import Identity - - -try: - from torch import quantization as torch_quantization -except Exception: - torch_quantization = None - -from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper - - -__all__ = [ - "DictQuantizationArgs", - "DictQuantizationScheme", - "QuantizationArgs", - "QuantizationScheme", - "QuantizationSchemeLoadable", - "compute_range", - "get_observer", -] - - -_PARSED_TORCH_VERSION = version.parse(torch.__version__) -_TORCH_PRE_112 = _PARSED_TORCH_VERSION < version.parse("1.12.0") - - -""" -Type definition aliases for defining QuantizationArgs and QuantizationScheme -as dictionaries for YAML serialization -""" -DictQuantizationArgs = Dict[str, Union[int, bool, Dict[str, Any]]] -DictQuantizationScheme = Dict[str, DictQuantizationArgs] - -""" -Type definition for a type that is valid for loading a QuantizationScheme -using QuantizationScheme.load -""" -QuantizationSchemeLoadable = Union[ - "QuantizationScheme", - DictQuantizationScheme, - str, - None, -] - - -class QuantizationArgs(BaseModel): - """ - Class representing user facing arguments to define quantization Observers of - activations or weights in a network - """ - - num_bits: int = Field( - default=8, description="number of bits to target for quantization" - ) - symmetric: bool = Field( - default=False, - description="set True to use symmetric quantization. Default False", - ) - strategy: str = Field( - default="tensor", - description=( - "scope of the quantization to be applied. can be 'tensor' or 'channel'" - ), - ) - observer: str = Field( - default="minmax", - description=( - "The class to use to compute the quantization params - scale and zero-point'" - ), - - ) - - # kwargs: Dict[str, Any] = Field( - # default_factory=dict, - # description=( - # "optional dict of kwargs to be passed directly to torch quantization " - # "Observers constructor excluding quantization range or symmetry" - # ), - # ) - observer_kwargs: Dict[str, Any] = Field( - default_factory=dict, - description=( - "optional dict of kwargs to be passed directly to torch quantization " - "Observers constructor excluding quantization range or symmetry" - ), - ) - - @classmethod - def default_activation_args(cls): - """ - :return: default 8 bits asymmetric settings - """ - return cls(num_bits=8, symmetric=False) - - @classmethod - def default_weight_args(cls): - """ - :return: default 8 bits symmetric settings - """ - return cls(num_bits=8, symmetric=True) - - def get_observer(self) -> "torch.quantization.FakeQuantize": - """ - :return: torch quantization FakeQuantize built based on these QuantizationArgs - """ - from sparseml.modifiers.quantization.observers.base import Observer - return Observer.load_from_registry(self.observer, quantization_args=self) - - @validator("strategy") - def validate_strategy(cls, value): - valid_scopes = ["tensor", "channel"] - if value not in valid_scopes: - raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}") - return value - - -class QuantizationScheme(BaseModel): - """ - Class composed of QuantizationArgs to build QConfig and QuantWrapper objects for - quantizing models. Provides a simple user interface for defining how inputs, - weights, and outputs should be quantized - """ - - def __init__(self, *args, **kwargs): - # support for loading from yaml str - args = [arg if arg != "null" else None for arg in args] - for key, val in kwargs.items(): - if val == "null": - kwargs[key] = None - super().__init__(*args, **kwargs) - - input_activations: Optional[QuantizationArgs] = Field( - default_factory=QuantizationArgs.default_activation_args, - description=( - "target quantization setting for input activations. Set to None to " - "not quantize input activations. Default is 8 bits asymmetric" - ), - ) - weights: Optional[QuantizationArgs] = Field( - default_factory=QuantizationArgs.default_weight_args, - description=( - "target quantization setting for model weights. Set to None to " - "not quantize weights. Default is 8 bits symmetric" - ), - ) - output_activations: Optional[QuantizationArgs] = Field( - default=None, - description=( - "target quantization setting for output activations. Set to None to " - "not quantize output activations. Default is None" - ), - ) - target_hardware: Optional[str] = Field( - default=None, - description=( - "target deployment runtime/hardware name to be set by default " - "classmethods. Default is None" - ), - ) - - @classmethod - def load( - cls, - scheme: QuantizationSchemeLoadable, - default: Optional["QuantizationScheme"] = None, - ) -> "QuantizationScheme": - """ - :param scheme: QuantizationScheme, dict representation of scheme, - or string alias of a scheme to load. Valid strings: - ['default', 'deepsparse', 'tensorrt'] - :param default: default QuantizationScheme to override 'default' scheme - with - :return: constructed QuantizationScheme object from the given scheme; - if given a dict, returns QuantizationScheme.parse_obj(scheme), string - input will return the defualt QuantizationScheme if set to 'default'. - """ - if isinstance(scheme, cls): - return scheme - elif scheme is None or scheme == "default": - # if no default override, defaults to QuantizationScheme() - return deepcopy(default) or cls() - elif isinstance(scheme, str): - if scheme == "deepsparse": - return cls.deepsparse() - elif scheme == "tensorrt": - return cls.tensorrt() - raise ValueError( - f"Unrecognized QuantizationScheme string alias {scheme}. " - "Valid strings: ['default', 'deepsparse', 'tensorrt']" - ) - elif isinstance(scheme, dict): - # default to dict - scheme = {key: _parse_quantization_arg(arg) for key, arg in scheme.items()} - return cls.parse_obj(scheme) - else: - raise ValueError( - f"Unrecognized type {type(scheme)} for QuantizationScheme.load, " - "expected one of: [QuantizationScheme, Dict, str, None]" - ) - - @classmethod - def deepsparse(cls) -> "QuantizationScheme": - """ - :return: QuantizationScheme for deepsparse targeted deployments - - int8, symmetric weights, asymmetric inputs, no output quantization - """ - return cls( - input_activations=QuantizationArgs(num_bits=8, symmetric=False), - weights=QuantizationArgs(num_bits=8, symmetric=True), - output_activations=None, - target_hardware="deepsparse", - ) - - @classmethod - def tensorrt(cls) -> "QuantizationScheme": - """ - :return: QuantizationScheme for tensorrt targeted deployments - - compatibility with explict quantization as supported by TensorRT 8.2: - int8, symmetric for both weights and inputs, no output quantization - """ - return cls( - input_activations=QuantizationArgs(num_bits=8, symmetric=True), - weights=QuantizationArgs(num_bits=8, symmetric=True), - output_activations=None, - target_hardware="tensorrt", - ) - - def get_qconfig(self) -> "torch.quantization.QConfig": - """ - :return: QConfig for Modules (output activations used, - use QuantWrapper for inputs) - """ - qconfig = _get_qconfig(self.output_activations, self.weights) - # add reference to this quantization scheme for reference - qconfig.quantization_scheme = self - return qconfig - - def get_wrapper_qconfig(self) -> "torch.quantization.QConfig": - """ - :return: QConfig for QuantWrapper objects (input activations used) - """ - qconfig = _get_qconfig(self.input_activations, None) - # add reference to this quantization scheme for reference - qconfig.quantization_scheme = self - return qconfig - - def __str__(self) -> str: - """ - :return: YAML friendly string serialization - """ - dict_repr = self.dict() - dict_repr = { - key: val if val is not None else "null" for key, val in dict_repr.items() - } - return str(dict_repr) - - -def compute_range(dtype: torch.dtype, bits: int): - """ - compute quantization limits depending on data type and number of bits - - :param dtype: data type. - :param bits: number of bits. - :return: minimum limit, maximum limit, whether the range is customized - """ - bits = bits if bits else 8 - is_custom = bits != 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2**bits) - 1 - - return quant_min, quant_max, is_custom - - -# def get_observer( -# symmetric: bool, -# strategy: str, -# dtype: torch.dtype, -# bits: int, -# reduce_range: bool, -# qconfig_kwargs: Dict[str, Any], -# ): -# quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits) - -# if strategy == "channel": -# qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine -# observer_cls = torch_quantization.MovingAveragePerChannelMinMaxObserver -# observer_kwargs = dict( -# ch_axis=0, -# dtype=dtype, -# qscheme=qscheme, -# reduce_range=reduce_range, -# ) -# else: # default to tensor strategy -# qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine -# observer_cls = torch_quantization.MovingAverageMinMaxObserver -# observer_kwargs = dict( -# dtype=dtype, -# qscheme=qscheme, -# reduce_range=reduce_range, -# ) -# """ -# in torch 1.9.1, quant_min and quant_max are not passed to observer: -# https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 -# however in 1.12.0, this is fixed so both are passed to observer: -# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 - -# Passing quant_min/quant_max to observer means the observer will have -# `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. - -# For whatever reason, both versions calculate zero point for -# quint8 differently **if there is a customized_qrange** -# 1. customized qrange has zero point of 127 -# 2. non-customized has zero point of 128. -# source: -# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 - -# **we want to ensure that the zero point is 128** -# see https://github.com/neuralmagic/sparseml/pull/604 -# """ -# if is_custom_qrange: -# # for both versions we need to include the custom min/max values in kwargs -# observer_kwargs["quant_min"] = quant_min -# observer_kwargs["quant_max"] = quant_max -# if _TORCH_PRE_112: -# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, -# # so we patch them in to the constructor of the observer -# observer_cls = partial( -# observer_cls, quant_min=quant_min, quant_max=quant_max -# ) -# else: -# # if using a non custom qrange, we can rely on default values used by -# # the observers -# if _TORCH_PRE_112: -# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, -# # so we are safe to pass these to FakeQuantize -# observer_kwargs["quant_min"] = quant_min -# observer_kwargs["quant_max"] = quant_max -# else: -# # post 1.12 we cannot pass them to the observer since that will set -# # has_customized_qrange. instead we rely on the default values -# # being equal to the `quant_min` and `quant_max` here. -# pass - -# observer_kwargs["observer"] = observer_cls -# observer_kwargs.update(qconfig_kwargs or {}) -# observer = FakeQuantizeWrapper.with_args(**observer_kwargs) - -# return observer - - -def _get_qconfig( - activation_args: Optional[QuantizationArgs], weight_args: Optional[QuantizationArgs] -) -> "torch.quantization.QConfig": - return torch_quantization.QConfig( - activation=activation_args.get_observer() if activation_args else Identity, - weight=weight_args.get_observer() if weight_args else Identity, - ) - - -def _parse_quantization_arg(arg: Any): - if arg == "None": - return None - return arg \ No newline at end of file