From 2575cb7966109c965963f53faee3f4867f79fa2e Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 11 Oct 2024 14:07:43 +0000 Subject: [PATCH] clean-up --- src/compressed_tensors/quantization/lifecycle/apply.py | 1 - src/compressed_tensors/quantization/lifecycle/frozen.py | 6 +----- src/compressed_tensors/quantization/quant_args.py | 2 +- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index e4d4a95f..09281528 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -246,7 +246,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): if current_status < status >= QuantizationStatus.CALIBRATION > current_status: # only quantize weights up front when our end goal state is calibration, # weight quantization parameters are already loaded for frozen/compressed - # TODO: to be removed from compressed-tensors quantize_weights_upfront = status == QuantizationStatus.CALIBRATION model.apply( lambda module: set_module_for_calibration( diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py index 1ef51d81..24e51822 100644 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ b/src/compressed_tensors/quantization/lifecycle/frozen.py @@ -45,11 +45,7 @@ def freeze_module_quantization(module: Module): delattr(module, "input_observer") if hasattr(module, "weight_observer") and not scheme.weights.dynamic: delattr(module, "weight_observer") - if ( - hasattr(module, "output_observer") - and not is_kv_cache_quant_scheme(scheme) - and not scheme.output_activations.dynamic - ): + if hasattr(module, "output_observer") and not scheme.output_activations.dynamic: delattr(module, "output_observer") module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index b43ffc75..0863637d 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -120,7 +120,7 @@ def get_observer(self): # keeps state across samples for dynamic self.observer = "memoryless" - return self.observer + return Observer.load_from_registry(self.observer, quantization_args=self) # TODO: update to be removed into llm-compressor def get_kv_cache(self):