Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Apr 18, 2024
1 parent 20986a6 commit 6140bd2
Show file tree
Hide file tree
Showing 14 changed files with 25 additions and 35 deletions.
6 changes: 4 additions & 2 deletions examples/llama_1.1b/ex_config_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from torch.utils.data import DataLoader
from sparseml.pytorch.utils import tensors_to_device
import torch

config_file = "example_quant_config.json"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
Expand All @@ -34,8 +35,9 @@
max_seq_length = 1024
pad_to_max_length = False
output_dir = "./llama1.1b_new_quant_out"
device = "cuda:0" if torch.cuda_is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
model.eval() # no grad or updates needed for base model
config = QuantizationConfig.parse_file(config_file)

Expand Down Expand Up @@ -80,4 +82,4 @@
# SparseML in order to save the config
from sparseml.transformers.compression import modify_save_pretrained
modify_save_pretrained(model)
model.save_pretrained(output_dir)
model.save_pretrained(output_dir)
4 changes: 3 additions & 1 deletion examples/llama_1.1b/ex_sparseml_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoTokenizer
import torch

recipe = "example_quant_recipe.yaml"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
Expand All @@ -25,8 +26,9 @@
max_seq_length = 1024
pad_to_max_length = False
output_dir = "./llama1.1b_old_quant_out"
device = "cuda:0" if torch.cuda_is_available() else "cpu"

model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0")
model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map=device)

tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/compressors/sparse_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import numpy
import torch
from safetensors import safe_open
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm

Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from typing import Optional

from pydantic import BaseModel
from compressed_tensors.registry import RegistryMixin
from pydantic import BaseModel


__all__ = ["CompressionConfig"]
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from collections import OrderedDict
from typing import Iterable, Optional

from compressed_tensors.quantization.lifecycle.calibration import set_module_for_calibration
from compressed_tensors.quantization.lifecycle.calibration import (
set_module_for_calibration,
)
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from typing import Optional

import torch
from compressed_tensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def calculate_qparams(
:return: tuple of the calculated scale(s) and zero point(s)
"""
bit_range = 2**quantization_args.num_bits - 1
bit_min = -(bit_range + 1) / 2
if quantization_args.symmetric:
symmetric_range = 2 * max(min_vals.abs(), max_vals.abs())
scales = symmetric_range / bit_range
Expand All @@ -46,6 +47,6 @@ def calculate_qparams(
# scales from a 0 range should be set to 1
scales[observed_range == 0] = 1

zero_points = ((0 - min_vals) / scales).to(torch.int8)
zero_points = ((0 - min_vals) / scales + bit_min).to(torch.int8)

return scales, zero_points
23 changes: 2 additions & 21 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
:param observed: observed tensor to calculate quantization parameters for
:return: tuple of scale and zero point derived from the observed tensor
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor
bit_min = -128
bit_max = 127

min_val = torch.tensor([observed.min()])
max_val = torch.tensor([observed.max()])

Expand All @@ -63,20 +60,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
max_val = torch.max(self.max_val, torch.zeros_like(self.max_val))

self.counter += 1

if self.quantization_args.symmetric:
symmetric_range = 2 * max(min_val.abs(), max_val.abs())
scale = symmetric_range / (bit_max - bit_min)
zero_point = torch.tensor(0).to(torch.int8)
else:
# non-symmetric
observed_range = max_val - min_val
quantized_range = bit_max - bit_min
scale = observed_range / (quantized_range)

# scales from a 0 range should be set to 1
scale[observed_range == 0] = 1

zero_point = ((0 - min_val) / scale + bit_min).to(torch.int8)

return scale, zero_point
return calculate_qparams(min_val, max_val, self.quantization_args)
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel, Field
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
iter_named_leaf_modules,
module_type,
)
from pydantic import BaseModel, Field
from torch.nn import Module


Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from typing import List, Optional

from pydantic import BaseModel
from compressed_tensors.quantization.quant_args import QuantizationArgs
from pydantic import BaseModel


__all__ = ["QuantizationScheme"]
Expand Down
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import pytest
from pydantic import ValidationError
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from pydantic import ValidationError


def test_defaults():
Expand Down
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@


import pytest
from pydantic import ValidationError
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
)
from pydantic import ValidationError


def test_basic_config():
Expand Down
2 changes: 1 addition & 1 deletion tests/quantization/test_quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

import pytest
from pydantic import ValidationError
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from pydantic import ValidationError


def test_basic_scheme():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import pytest
import torch
from safetensors.torch import save_file
from compressed_tensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor
from safetensors.torch import save_file


@pytest.mark.parametrize(
Expand Down

0 comments on commit 6140bd2

Please sign in to comment.