From 1229c5a344310d69e0173316cc5199bf4e3df294 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 3 May 2024 15:36:37 +0000 Subject: [PATCH] pass test_quant_args --- src/compressed_tensors/quantization/quant_args.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 140cfdb8..f8c82d8a 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -64,8 +64,8 @@ class QuantizationArgs(BaseModel): num_bits: int = 8 type: QuantizationType = QuantizationType.INT symmetric: bool = True - strategy: Optional[QuantizationStrategy] = None group_size: Optional[int] = None + strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False observer: str = Field( @@ -96,7 +96,7 @@ def get_observer(self): return Observer.load_from_registry(self.observer, quantization_args=self) - @validator("strategy", pre=True) + @validator("strategy", pre=True, always=True) def validate_strategy(cls, value, values): group_size = values.get("group_size") @@ -114,8 +114,7 @@ def validate_strategy(cls, value, values): "group_size > 0 for strategy='group' and " "group_size = -1 for 'channel'" ) - # breakpoint() - group_size = 128 + if value == QuantizationStrategy.GROUP: if group_size is None: raise ValueError(f"strategy {value} requires group_size to be set.")