From d3dea3ffd3f81d9840e99dddf242bdde6f22d6af Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 15:18:32 -0400 Subject: [PATCH] fix test which required accelerate, apply style (#194) --- src/compressed_tensors/quantization/lifecycle/frozen.py | 1 - src/compressed_tensors/quantization/observers/helpers.py | 2 +- src/compressed_tensors/quantization/observers/mse.py | 4 +++- tests/test_quantization/lifecycle/test_apply.py | 8 +------- tests/test_quantization/test_observers/test_mse.py | 2 +- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py index e4723431..4a65482c 100644 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ b/src/compressed_tensors/quantization/lifecycle/frozen.py @@ -14,7 +14,6 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from torch.nn import Module diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 875a05b3..ec474303 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import Counter -from typing import Optional, Tuple +from typing import Tuple import torch from compressed_tensors.quantization.quant_args import ( diff --git a/src/compressed_tensors/quantization/observers/mse.py b/src/compressed_tensors/quantization/observers/mse.py index 24d80584..739e921f 100644 --- a/src/compressed_tensors/quantization/observers/mse.py +++ b/src/compressed_tensors/quantization/observers/mse.py @@ -70,7 +70,9 @@ def calculate_mse_min_max( absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max) + best = torch.full_like( + absolute_min_val, torch.finfo(absolute_min_val.dtype).max + ) min_val = torch.ones_like(absolute_min_val) max_val = torch.zeros_like(absolute_max_val) for i in range(int(self.maxshrink * self.grid)): diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 5f0bd093..dcb980f2 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -236,14 +236,8 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): def test_apply_quantization_status(caplog, ignore, should_raise_warning): import logging - from transformers import AutoModelForCausalLM - # load a dense, unquantized tiny llama model - model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" - model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="cpu", torch_dtype="auto" - ) - + model = get_tinyllama_model() quantization_config_dict = { "quant_method": "sparseml", "format": "pack-quantized", diff --git a/tests/test_quantization/test_observers/test_mse.py b/tests/test_quantization/test_observers/test_mse.py index 659fc573..098551a2 100644 --- a/tests/test_quantization/test_observers/test_mse.py +++ b/tests/test_quantization/test_observers/test_mse.py @@ -27,7 +27,7 @@ ], ) def test_mse_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1., 1., 1., 1., 1.]) + tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse")