Skip to content

Commit

Permalink
fix test which required accelerate, apply style (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs authored Oct 23, 2024
1 parent d3216bc commit d3dea3f
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 11 deletions.
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from torch.nn import Module


Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import Counter
from typing import Optional, Tuple
from typing import Tuple

import torch
from compressed_tensors.quantization.quant_args import (
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/observers/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def calculate_mse_min_max(
absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max)
best = torch.full_like(
absolute_min_val, torch.finfo(absolute_min_val.dtype).max
)
min_val = torch.ones_like(absolute_min_val)
max_val = torch.zeros_like(absolute_max_val)
for i in range(int(self.maxshrink * self.grid)):
Expand Down
8 changes: 1 addition & 7 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,8 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
def test_apply_quantization_status(caplog, ignore, should_raise_warning):
import logging

from transformers import AutoModelForCausalLM

# load a dense, unquantized tiny llama model
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="cpu", torch_dtype="auto"
)

model = get_tinyllama_model()
quantization_config_dict = {
"quant_method": "sparseml",
"format": "pack-quantized",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization/test_observers/test_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
],
)
def test_mse_observer(symmetric, expected_scale, expected_zero_point):
tensor = torch.tensor([1., 1., 1., 1., 1.])
tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])
num_bits = 8
weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse")

Expand Down

0 comments on commit d3dea3f

Please sign in to comment.