From 8b145324a9ee861c4e63e958c8285ac151786f8f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 2 Oct 2024 17:11:56 -0400 Subject: [PATCH] Fix 2/4 GPTQ Model Tests (#769) * fix 2 / 4 failing bugs * commented code --- .../transformers/gptq/test_oneshot.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/llmcompressor/transformers/gptq/test_oneshot.py b/tests/llmcompressor/transformers/gptq/test_oneshot.py index d2dba3815..017679fa5 100644 --- a/tests/llmcompressor/transformers/gptq/test_oneshot.py +++ b/tests/llmcompressor/transformers/gptq/test_oneshot.py @@ -79,15 +79,23 @@ def test_oneshot_application(self): model_loaded = SparseAutoModelForCausalLM.from_pretrained(self.output) # Check that the model is quantized - assert model_loaded.quantization_config is not None + # for compression_config - decompress() will attach a quantization_config + # to the model as we decompress right away + # for quantization_config - we have CompressedLinear which will only + # decompress on the forward pass and does not call decompress(). Results + # in a slightly different parameter tree to access the quant config + quantization_config = ( + model_loaded.config.quantization_config.quantization_config + ) + assert quantization_config is not None # check config is set properly - assert model_loaded.quantization_config.ignore == ["lm_head"] - assert len(model_loaded.quantization_config.config_groups) == 1 - quant_scheme = model_loaded.quantization_config.config_groups["group_0"] + assert quantization_config.ignore == ["lm_head"] + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] assert isinstance(quant_scheme, QuantizationScheme) assert quant_scheme.targets == ["Linear"] - weight_args = model_loaded.quantization_config.config_groups["group_0"].weights + weight_args = quantization_config.config_groups["group_0"].weights assert isinstance(weight_args, QuantizationArgs) assert weight_args.num_bits == 4