From 62e6f39133e60621e40278d52394beaa5ebec3b9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 31 Oct 2024 20:27:39 +0000 Subject: [PATCH 1/2] clean up observer defaulting logic, better error message Signed-off-by: Kyle Sayers --- .../quantization/quant_args.py | 16 +++++----------- src/compressed_tensors/registry/registry.py | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 3259976c..bfd027f3 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -114,12 +114,6 @@ def get_observer(self): """ :return: torch quantization FakeQuantize built based on these QuantizationArgs """ - - # No observer required for the dynamic case - if self.dynamic: - self.observer = None - return self.observer - return self.observer @field_validator("type", mode="before") @@ -217,15 +211,15 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: warnings.warn( "No observer is used for dynamic quantization, setting to None" ) - model.observer = None + observer = None - # if we have not set an observer and we - # are running static quantization, use minmax - if not observer and not dynamic: - model.observer = "minmax" + elif observer is None: + # default to minmax for non-dynamic cases + observer = "minmax" # write back modified values model.strategy = strategy + model.observer = observer return model def pytorch_dtype(self) -> torch.dtype: diff --git a/src/compressed_tensors/registry/registry.py b/src/compressed_tensors/registry/registry.py index d8d8bc6d..76026313 100644 --- a/src/compressed_tensors/registry/registry.py +++ b/src/compressed_tensors/registry/registry.py @@ -258,7 +258,7 @@ def get_from_registry( retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: # look up name in alias registry - name = _ALIAS_REGISTRY[parent_class].get(name) + name = _ALIAS_REGISTRY[parent_class].get(name, name) # look up name in registry retrieved_value = _REGISTRY[parent_class].get(name) if retrieved_value is None: From a953056a070656a7243074f3b369030641fab734 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 1 Nov 2024 15:59:08 +0000 Subject: [PATCH 2/2] avoid annoying users with old configs Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_args.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index bfd027f3..4619d581 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -197,6 +197,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: "activation ordering" ) + # infer observer w.r.t. dynamic if dynamic: if strategy not in ( QuantizationStrategy.TOKEN, @@ -208,9 +209,10 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: "quantization", ) if observer is not None: - warnings.warn( - "No observer is used for dynamic quantization, setting to None" - ) + if observer != "memoryless": # avoid annoying users with old configs + warnings.warn( + "No observer is used for dynamic quantization, setting to None" + ) observer = None elif observer is None: