diff --git a/bin/quant.py b/bin/quant.py new file mode 100644 index 00000000..3d33191f --- /dev/null +++ b/bin/quant.py @@ -0,0 +1,55 @@ +import torch +from torch.nn import Linear +# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from sparseml.modifiers.quantization.lifecycle.initialize import initialize_module_for_quantization +from sparseml.modifiers.quantization.lifecycle.calibration import set_module_for_calibration +from sparseml.modifiers.quantization.lifecycle.frozen import freeze_module_quantization +num_bits = 8 + +scheme = QuantizationScheme( + input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False), + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + output_activations=None, +) + +layer = Linear(4, 4) +print(layer) +print(dict(layer.named_parameters())) + + +initialize_module_for_quantization(layer, scheme) +print(layer) # should see observer under layer now +print(0) +print(dict(layer.named_parameters())) # should see empty tensors for scale and zero point now +print(1) + + +set_module_for_calibration(layer) +# do a calibration step +layer(torch.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should have updated values +print(2) +for _ in range(10): + layer(torch.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass + +print(3) +breakpoint() + + +freeze_module_quantization(layer) +for _ in range(10): + # do more forward passes but show args are frozen + layer(torch.random.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should not be updated now + + +# missing + +# correctness +# quantizing an entire model + + + diff --git a/src/sparsetensors/quantization/lifecycle/__init__.py b/src/sparsetensors/quantization/lifecycle/__init__.py new file mode 100644 index 00000000..d90b28a9 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .calibration import * +from .forward import * +from .frozen import * +from .initialize import * +from .status import * +from .initialize import * diff --git a/src/sparsetensors/quantization/lifecycle/calirbation.py b/src/sparsetensors/quantization/lifecycle/calirbation.py new file mode 100644 index 00000000..a4f4dfea --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/calirbation.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from torch.nn import Module + +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus + + +__all__ = [ + "set_module_for_calibration", +] + + +_LOGGER = logging.getLogger(__name__) + + +def set_module_for_calibration(module: Module): + if not getattr(module, "quantization_scheme", None): + # no quantization scheme nothing to do + return + status = getattr(module, "quantization_status", None) + if not status or status != QuantizationStatus.INITIALIZED: + raise _LOGGER.warning( + f"Attempting set module with status {status} to calibration mode. " + f"but status is not {QuantizationStatus.INITIALIZED} - you may " + "be calibrating an uninitialized module which may fail or attempting " + "to re-calibrate a frozen module" + ) + + module.quantization_status = QuantizationStatus.CALIBRATION \ No newline at end of file diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py new file mode 100644 index 00000000..4247e7c7 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -0,0 +1,155 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps + +import torch +from torch.nn import Module + +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus + +from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs + +__all__ = ["wrap_module_forward_quantized"] + + +def quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + q_max: torch.Tensor, +) -> torch.Tensor: + return torch.clamp( + torch.round( + x / scale + zero_point, + ), + 0, + q_max, + ) + + +def dequantize( + x_q: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, +) -> torch.Tensor: + return (x_q - zero_point) * scale + + +def fake_quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + args: QuantizationArgs, +) -> torch.Tensor: + max_q = torch.tensor(2**args.num_bits - 1) + columns = x.shape[1] + Q = torch.zeros_like(x) + # for i1 in range(0, columns, args.block_size): + # i2 = min(i1 + args.block_size, columns) + # count = i2 - i1 + + # W1 = x[:, i1:i2].clone() + # Q1 = torch.zeros_like(W1) + + # for i in range(count): + # w = W1[:, i] + # breakpoint() + # if args.group_size != -1: + # if (i1 + i) % args.group_size == 0: + # xmin, xmax = get_qparams( + # x[:, (i1 + i) : (i1 + i + args.group_size)], args.symmetric + # ) + # scale, zero = get_scale_zero_point( + # x[:, (i1 + i) : (i1 + i + args.group_size)], + # max_q, + # xmax, + # xmin, + # args.symmetric, + # args.group_size, + # ) + + # q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten() + # Q1[:, i] = q + # Q[:, i1:i2] = Q1 + Q = quantize(x, scale, zero_point, max_q) + return dequantize(Q, scale, zero_point) + + +def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): + # expects a module already initialized and injected with the parameters in + # initialize_module_for_quantization + forward_func_orig = module.forward.__func__ + + @wraps(forward_func_orig) # ensures docstring, names, etc are propagated + def wrapped_forward(self, *args, **kwargs): + input_ = args[0] + + if scheme.input_activations is not None: + # calibrate and (fake) quantize input activations when applicable + input_ = _maybe_calibrate_or_quantize( + module, input_, "input", scheme.input_activations + ) + + if scheme.weights is not None: + # calibrate and (fake) quantize weights when applicable + self.weight.data = _maybe_calibrate_or_quantize( + module, self.weight, "weight", scheme.weights + ) + + # perform wrapped forward call + output = forward_func_orig.__get__(module, module.__class__)( + input_, *args[1:], **kwargs + ) + + if scheme.output_activations is not None: + # calibrate and (fake) quantize output activations when applicable + output = _maybe_calibrate_or_quantize( + module, output, "output", scheme.output_activations + ) + + return output + + # bind wrapped forward to module class so reference to `self` is correct + bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) + # set forward to wrapped forward + setattr(module, "forward", bound_wrapped_forward) + + +def _maybe_calibrate_or_quantize( + module: Module, value: Module, base_name: str, args: "QuantizationArgs" +) -> torch.Tensor: + # only run quantized for the included stages + if module.quantization_status not in { + QuantizationStatus.CALIBRATION, + QuantizationStatus.FROZEN, + }: + return value + + scale = getattr(module, f"{base_name}_scale") + # zero_point = getattr(module, f"{base_name}_zero_point").data + zero_point = getattr(module, f"{base_name}_zero_point") + + print(scale, zero_point) + + if module.quantization_status == QuantizationStatus.CALIBRATION: + # get observer and get new quant params from observation + observer = getattr(module, f"{base_name}_observer") + updated_scale, updated_zero_point = observer(value) + + # update scale and zero point + scale.data = updated_scale + zero_point.data = updated_zero_point + + return fake_quantize(value, scale, zero_point, args) \ No newline at end of file diff --git a/src/sparsetensors/quantization/lifecycle/frozen.py b/src/sparsetensors/quantization/lifecycle/frozen.py new file mode 100644 index 00000000..d480465b --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/frozen.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from torch.nn import Module + +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus + + +__all__ = [ + "freeze_module_quantization", +] + + +def freeze_module_quantization(module: Module): + if not getattr(module, "quantization_scheme", None): + # no quantization scheme nothing to do + return + + # delete observers from module + for submodule_name, _ in module.named_modules(): + if "." not in submodule_name and submodule_name.endswith("_observer"): + # delete any observers that belong directly to this module + delattr(module, submodule_name) + + module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py new file mode 100644 index 00000000..cfa4aa77 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from torch.nn import Module, Parameter + +from sparseml.modifiers.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus +from sparseml.modifiers.quantization.utils.quantization_scheme import ( + QuantizationArgs, + QuantizationScheme, +) + + +__all__ = [ + "initialize_module_for_quantization", +] + + +_LOGGER = logging.getLogger(__name__) + + +def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme): + if scheme.input_activations is not None: + + _initialize_scale_zero_point_observer( + module, "input", scheme.input_activations + ) + if scheme.weights is not None: + if hasattr(module, "weight"): + _initialize_scale_zero_point_observer(module, "weight", scheme.weights) + else: + _LOGGER.warning( + f"module type {type(module)} targeted for weight quantization but " + "has no attribute weight, skipping weight quantization " + f"for {type(module)}" + ) + if scheme.output_activations is not None: + _initialize_scale_zero_point_observer(module, "output", scheme.output_activations) + + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + + # wrap forward call of module to perform quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) + + + +def _initialize_scale_zero_point_observer( + module: Module, base_name: str, quantization_args: QuantizationArgs +): + # initializes empty scale and zero point parameters for the module + init_scale = Parameter(torch.empty(0), requires_grad=False) + module.register_parameter(f"{base_name}_scale", init_scale) + + init_zero_point = Parameter(torch.empty(0, dtype=int), requires_grad=False) + module.register_parameter(f"{base_name}_zero_point", init_zero_point) + + # initialize observer module and attach as submodule + observer = quantization_args.get_observer() + module.register_module(f"{base_name}_observer", observer) diff --git a/src/sparsetensors/quantization/lifecycle/status.py b/src/sparsetensors/quantization/lifecycle/status.py new file mode 100644 index 00000000..3b6a441d --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/status.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +__all__ = [ + "QuantizationStatus", +] + + +class QuantizationStatus(Enum): + INITIALIZED = "INITIALIZED" + CALIBRATION = "CALIBRATION" + FROZEN = "FROZEN" diff --git a/src/sparsetensors/quantization/observers/__init__.py b/src/sparsetensors/quantization/observers/__init__.py new file mode 100644 index 00000000..1bec545d --- /dev/null +++ b/src/sparsetensors/quantization/observers/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * +from .memoryless import * +from .min_max import * \ No newline at end of file diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py new file mode 100644 index 00000000..44c8ec37 --- /dev/null +++ b/src/sparsetensors/quantization/observers/base.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +from torch import FloatTensor, IntTensor, Tensor +from torch.nn import Module + +from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs +from sparsezoo.utils.registry import RegistryMixin + + +__all__ = ["Observer"] + + +class Observer(Module, RegistryMixin): + """ + Base Observer class to be subclassed for specific implementation. + Subclasses should override `calculate_qparams` to return a scale, zero_point + pair + """ + + def __init__(self, + quantization_args: QuantizationArgs + ): + self.quantization_args: QuantizationArgs = quantization_args + super().__init__() + self._scale = None + self._zero_point = None + + def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + maps directly to get_qparams + :param observed: optional observed tensor to calculate quantization parameters + from + :return: tuple of scale and zero point based on last observed value + """ + return self.get_qparams(observed=observed) + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + + def get_qparams( + self, observed: Optional[Tensor] = None + ) -> Tuple[FloatTensor, IntTensor]: + """ + Convenience function to wrap overwritten calculate_qparams + adds support to make observed tensor optional and support for tracking latest + calculated scale and zero point + :param observed: optional observed tensor to calculate quantization parameters + from + :return: tuple of scale and zero point based on last observed value + """ + if observed is not None: + # re-calcualte scale and zero point, update the stored value + self._scale, self._zero_point = self.calculate_qparams(observed) + return self._scale, self._zero_point \ No newline at end of file diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py new file mode 100644 index 00000000..e69de29b diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py new file mode 100644 index 00000000..c72eb1c0 --- /dev/null +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from torch import FloatTensor, IntTensor, Tensor + +from sparseml.modifiers.quantization.observers.base import Observer +from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs + + +__all__ = ["MinMaxObserver"] + + +@Observer.register("minmax") +class MinMaxObserver(Observer): + """ + Implements a dynamic quantization observer that sets the scale and + zero point based on the latest observed value + """ + + def __init__(self, quantization_args: QuantizationArgs): + super().__init__(quantization_args=quantization_args) + + self.min_val = float("inf") + self.max_val = -float("inf") + self.counter = 0 + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + # TODO: Add support for full range of quantization Args, only supports 8bit + # per tensor + bit_range = 255 + min_val = torch.tensor([observed.min()]) + max_val = torch.tensor([observed.max()]) + + # running average + if self.counter > 0: + self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) + self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + else: + self.min_val = min_val + self.max_val = max_val + + # ensure that the zeros are in the range + self.min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + self.max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + + self.counter += 1 + + if self.quantization_args.symmetric: + symmetric_range = 2 * max(self.min_val.abs(), self.max_val.abs()) + scale = symmetric_range / bit_range + zero_point = torch.tensor(0).to(torch.int8) + else: + # non-symmetric + observed_range = self.max_val - self.min_val + scale = observed_range / bit_range + + # scales from a 0 range should be set to 1 + scale[observed_range == 0] = 1 + + zero_point = (0 - self.min_val) / scale + + return scale, zero_point \ No newline at end of file diff --git a/src/sparsetensors/quantization/utils/quantization_scheme.py b/src/sparsetensors/quantization/utils/quantization_scheme.py new file mode 100644 index 00000000..976b534e --- /dev/null +++ b/src/sparsetensors/quantization/utils/quantization_scheme.py @@ -0,0 +1,391 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Schemas and types to support quantization +""" +from copy import deepcopy +from functools import partial +from typing import Any, Dict, Optional, Union + +import torch +from packaging import version +from pydantic import BaseModel, Field, validator +from torch.nn import Identity + + +try: + from torch import quantization as torch_quantization +except Exception: + torch_quantization = None + +from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper + + +__all__ = [ + "DictQuantizationArgs", + "DictQuantizationScheme", + "QuantizationArgs", + "QuantizationScheme", + "QuantizationSchemeLoadable", + "compute_range", + "get_observer", +] + + +_PARSED_TORCH_VERSION = version.parse(torch.__version__) +_TORCH_PRE_112 = _PARSED_TORCH_VERSION < version.parse("1.12.0") + + +""" +Type definition aliases for defining QuantizationArgs and QuantizationScheme +as dictionaries for YAML serialization +""" +DictQuantizationArgs = Dict[str, Union[int, bool, Dict[str, Any]]] +DictQuantizationScheme = Dict[str, DictQuantizationArgs] + +""" +Type definition for a type that is valid for loading a QuantizationScheme +using QuantizationScheme.load +""" +QuantizationSchemeLoadable = Union[ + "QuantizationScheme", + DictQuantizationScheme, + str, + None, +] + + +class QuantizationArgs(BaseModel): + """ + Class representing user facing arguments to define quantization Observers of + activations or weights in a network + """ + + num_bits: int = Field( + default=8, description="number of bits to target for quantization" + ) + symmetric: bool = Field( + default=False, + description="set True to use symmetric quantization. Default False", + ) + strategy: str = Field( + default="tensor", + description=( + "scope of the quantization to be applied. can be 'tensor' or 'channel'" + ), + ) + observer: str = Field( + default="minmax", + description=( + "The class to use to compute the quantization params - scale and zero-point'" + ), + + ) + + # kwargs: Dict[str, Any] = Field( + # default_factory=dict, + # description=( + # "optional dict of kwargs to be passed directly to torch quantization " + # "Observers constructor excluding quantization range or symmetry" + # ), + # ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + @classmethod + def default_activation_args(cls): + """ + :return: default 8 bits asymmetric settings + """ + return cls(num_bits=8, symmetric=False) + + @classmethod + def default_weight_args(cls): + """ + :return: default 8 bits symmetric settings + """ + return cls(num_bits=8, symmetric=True) + + def get_observer(self) -> "torch.quantization.FakeQuantize": + """ + :return: torch quantization FakeQuantize built based on these QuantizationArgs + """ + from sparseml.modifiers.quantization.observers.base import Observer + return Observer.load_from_registry(self.observer, quantization_args=self) + + @validator("strategy") + def validate_strategy(cls, value): + valid_scopes = ["tensor", "channel"] + if value not in valid_scopes: + raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}") + return value + + +class QuantizationScheme(BaseModel): + """ + Class composed of QuantizationArgs to build QConfig and QuantWrapper objects for + quantizing models. Provides a simple user interface for defining how inputs, + weights, and outputs should be quantized + """ + + def __init__(self, *args, **kwargs): + # support for loading from yaml str + args = [arg if arg != "null" else None for arg in args] + for key, val in kwargs.items(): + if val == "null": + kwargs[key] = None + super().__init__(*args, **kwargs) + + input_activations: Optional[QuantizationArgs] = Field( + default_factory=QuantizationArgs.default_activation_args, + description=( + "target quantization setting for input activations. Set to None to " + "not quantize input activations. Default is 8 bits asymmetric" + ), + ) + weights: Optional[QuantizationArgs] = Field( + default_factory=QuantizationArgs.default_weight_args, + description=( + "target quantization setting for model weights. Set to None to " + "not quantize weights. Default is 8 bits symmetric" + ), + ) + output_activations: Optional[QuantizationArgs] = Field( + default=None, + description=( + "target quantization setting for output activations. Set to None to " + "not quantize output activations. Default is None" + ), + ) + target_hardware: Optional[str] = Field( + default=None, + description=( + "target deployment runtime/hardware name to be set by default " + "classmethods. Default is None" + ), + ) + + @classmethod + def load( + cls, + scheme: QuantizationSchemeLoadable, + default: Optional["QuantizationScheme"] = None, + ) -> "QuantizationScheme": + """ + :param scheme: QuantizationScheme, dict representation of scheme, + or string alias of a scheme to load. Valid strings: + ['default', 'deepsparse', 'tensorrt'] + :param default: default QuantizationScheme to override 'default' scheme + with + :return: constructed QuantizationScheme object from the given scheme; + if given a dict, returns QuantizationScheme.parse_obj(scheme), string + input will return the defualt QuantizationScheme if set to 'default'. + """ + if isinstance(scheme, cls): + return scheme + elif scheme is None or scheme == "default": + # if no default override, defaults to QuantizationScheme() + return deepcopy(default) or cls() + elif isinstance(scheme, str): + if scheme == "deepsparse": + return cls.deepsparse() + elif scheme == "tensorrt": + return cls.tensorrt() + raise ValueError( + f"Unrecognized QuantizationScheme string alias {scheme}. " + "Valid strings: ['default', 'deepsparse', 'tensorrt']" + ) + elif isinstance(scheme, dict): + # default to dict + scheme = {key: _parse_quantization_arg(arg) for key, arg in scheme.items()} + return cls.parse_obj(scheme) + else: + raise ValueError( + f"Unrecognized type {type(scheme)} for QuantizationScheme.load, " + "expected one of: [QuantizationScheme, Dict, str, None]" + ) + + @classmethod + def deepsparse(cls) -> "QuantizationScheme": + """ + :return: QuantizationScheme for deepsparse targeted deployments - + int8, symmetric weights, asymmetric inputs, no output quantization + """ + return cls( + input_activations=QuantizationArgs(num_bits=8, symmetric=False), + weights=QuantizationArgs(num_bits=8, symmetric=True), + output_activations=None, + target_hardware="deepsparse", + ) + + @classmethod + def tensorrt(cls) -> "QuantizationScheme": + """ + :return: QuantizationScheme for tensorrt targeted deployments - + compatibility with explict quantization as supported by TensorRT 8.2: + int8, symmetric for both weights and inputs, no output quantization + """ + return cls( + input_activations=QuantizationArgs(num_bits=8, symmetric=True), + weights=QuantizationArgs(num_bits=8, symmetric=True), + output_activations=None, + target_hardware="tensorrt", + ) + + def get_qconfig(self) -> "torch.quantization.QConfig": + """ + :return: QConfig for Modules (output activations used, + use QuantWrapper for inputs) + """ + qconfig = _get_qconfig(self.output_activations, self.weights) + # add reference to this quantization scheme for reference + qconfig.quantization_scheme = self + return qconfig + + def get_wrapper_qconfig(self) -> "torch.quantization.QConfig": + """ + :return: QConfig for QuantWrapper objects (input activations used) + """ + qconfig = _get_qconfig(self.input_activations, None) + # add reference to this quantization scheme for reference + qconfig.quantization_scheme = self + return qconfig + + def __str__(self) -> str: + """ + :return: YAML friendly string serialization + """ + dict_repr = self.dict() + dict_repr = { + key: val if val is not None else "null" for key, val in dict_repr.items() + } + return str(dict_repr) + + +def compute_range(dtype: torch.dtype, bits: int): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. + :param bits: number of bits. + :return: minimum limit, maximum limit, whether the range is customized + """ + bits = bits if bits else 8 + is_custom = bits != 8 + if dtype == torch.qint8: + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = (2**bits) - 1 + + return quant_min, quant_max, is_custom + + +# def get_observer( +# symmetric: bool, +# strategy: str, +# dtype: torch.dtype, +# bits: int, +# reduce_range: bool, +# qconfig_kwargs: Dict[str, Any], +# ): +# quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits) + +# if strategy == "channel": +# qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine +# observer_cls = torch_quantization.MovingAveragePerChannelMinMaxObserver +# observer_kwargs = dict( +# ch_axis=0, +# dtype=dtype, +# qscheme=qscheme, +# reduce_range=reduce_range, +# ) +# else: # default to tensor strategy +# qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine +# observer_cls = torch_quantization.MovingAverageMinMaxObserver +# observer_kwargs = dict( +# dtype=dtype, +# qscheme=qscheme, +# reduce_range=reduce_range, +# ) +# """ +# in torch 1.9.1, quant_min and quant_max are not passed to observer: +# https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 +# however in 1.12.0, this is fixed so both are passed to observer: +# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 + +# Passing quant_min/quant_max to observer means the observer will have +# `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. + +# For whatever reason, both versions calculate zero point for +# quint8 differently **if there is a customized_qrange** +# 1. customized qrange has zero point of 127 +# 2. non-customized has zero point of 128. +# source: +# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 + +# **we want to ensure that the zero point is 128** +# see https://github.com/neuralmagic/sparseml/pull/604 +# """ +# if is_custom_qrange: +# # for both versions we need to include the custom min/max values in kwargs +# observer_kwargs["quant_min"] = quant_min +# observer_kwargs["quant_max"] = quant_max +# if _TORCH_PRE_112: +# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, +# # so we patch them in to the constructor of the observer +# observer_cls = partial( +# observer_cls, quant_min=quant_min, quant_max=quant_max +# ) +# else: +# # if using a non custom qrange, we can rely on default values used by +# # the observers +# if _TORCH_PRE_112: +# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, +# # so we are safe to pass these to FakeQuantize +# observer_kwargs["quant_min"] = quant_min +# observer_kwargs["quant_max"] = quant_max +# else: +# # post 1.12 we cannot pass them to the observer since that will set +# # has_customized_qrange. instead we rely on the default values +# # being equal to the `quant_min` and `quant_max` here. +# pass + +# observer_kwargs["observer"] = observer_cls +# observer_kwargs.update(qconfig_kwargs or {}) +# observer = FakeQuantizeWrapper.with_args(**observer_kwargs) + +# return observer + + +def _get_qconfig( + activation_args: Optional[QuantizationArgs], weight_args: Optional[QuantizationArgs] +) -> "torch.quantization.QConfig": + return torch_quantization.QConfig( + activation=activation_args.get_observer() if activation_args else Identity, + weight=weight_args.get_observer() if weight_args else Identity, + ) + + +def _parse_quantization_arg(arg: Any): + if arg == "None": + return None + return arg \ No newline at end of file