diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e9db69ae..43266503 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -112,18 +112,22 @@ def _maybe_calibrate_or_quantize( }: return value - device = next(module.parameters()).device - scale = getattr(module, f"{base_name}_scale") - # zero_point = getattr(module, f"{base_name}_zero_point").data - zero_point = getattr(module, f"{base_name}_zero_point") - - if module.quantization_status == QuantizationStatus.CALIBRATION: - # get observer and get new quant params from observation - observer = getattr(module, f"{base_name}_observer") - updated_scale, updated_zero_point = observer(value) - - # update scale and zero point - scale.data = updated_scale.to(device) - zero_point.data = updated_zero_point.to(device) + observer = getattr(module, f"{base_name}_observer") + if observer.DYNAMIC: + # dynamic quantization - get scale and zero point directly from observer + scale, zero_point = observer(value) + else: + # static quantization - get previous scale and zero point from layer + scale = getattr(module, f"{base_name}_scale") + zero_point = getattr(module, f"{base_name}_zero_point") + + if module.quantization_status == QuantizationStatus.CALIBRATION: + # calibration mode - get new quant params from observer + updated_scale, updated_zero_point = observer(value) + + # update scale and zero point + device = next(module.parameters()).device + scale.data = updated_scale.to(device) + zero_point.data = updated_zero_point.to(device) return fake_quantize(value, scale, zero_point, args) diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py index 9df8ec46..47129ca3 100644 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ b/src/compressed_tensors/quantization/lifecycle/frozen.py @@ -36,9 +36,12 @@ def freeze_module_quantization(module: Module): # delete observers from module observer_names = [] - for submodule_name, _ in module.named_modules(): + for submodule_name, submodule in module.named_modules(): if "." not in submodule_name and submodule_name.endswith("_observer"): - # delete any observers that belong directly to this module + if getattr(submodule, "DYNAMIC", False): + continue # do not delete dynamic observers + + # delete any non-dynamic observers that belong directly to this module observer_names.append(submodule_name) for observer_name in observer_names: delattr(module, observer_name) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d78997c1..f1471cb0 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -78,6 +78,13 @@ def initialize_module_for_quantization( def _initialize_scale_zero_point_observer( module: Module, base_name: str, quantization_args: QuantizationArgs ): + # initialize observer module and attach as submodule + observer = quantization_args.get_observer() + module.register_module(f"{base_name}_observer", observer) + + if observer.DYNAMIC: + return # no need to register a scale and zero point for a dynamic observer + device = next(module.parameters()).device # initializes empty scale and zero point parameters for the module @@ -88,7 +95,3 @@ def _initialize_scale_zero_point_observer( torch.empty(0, device=device, dtype=int), requires_grad=False ) module.register_parameter(f"{base_name}_zero_point", init_zero_point) - - # initialize observer module and attach as submodule - observer = quantization_args.get_observer() - module.register_module(f"{base_name}_observer", observer) diff --git a/src/compressed_tensors/quantization/observers/__init__.py b/src/compressed_tensors/quantization/observers/__init__.py index 7e7ea908..1ae2c2f7 100644 --- a/src/compressed_tensors/quantization/observers/__init__.py +++ b/src/compressed_tensors/quantization/observers/__init__.py @@ -19,3 +19,4 @@ from .base import * from .memoryless import * from .min_max import * +from .dynamic import * diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 96fe1049..023cb12f 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -30,6 +30,9 @@ class Observer(Module, RegistryMixin): pair """ + # child classes should set to True if they are meant to be used as dynamic + DYNAMIC = False + def __init__(self, quantization_args: QuantizationArgs): self.quantization_args: QuantizationArgs = quantization_args super().__init__() diff --git a/src/compressed_tensors/quantization/observers/dynamic.py b/src/compressed_tensors/quantization/observers/dynamic.py new file mode 100644 index 00000000..a2e9bf5f --- /dev/null +++ b/src/compressed_tensors/quantization/observers/dynamic.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from sparsetensors.quantization.observers.base import Observer +from sparsetensors.quantization.observers.memoryless import MemorylessObserver + + +__all__ = ["DynamicObserver"] + + +@Observer.register("dynamic") +class DynamicObserver(MemorylessObserver): + """ + Values targted for a dyanmic observer do not require calibration, + this observer will persist in the model through the lifecycle, calculating + the quantization parameters on the fly for each observed Tensor. + + This base dynamic observer uses the `calculate_qparams` from MemorylessObserver + where each scale and zero point is based solely on the currently observed + Tensor. + """ + + DYNAMIC = False