From 6140bd212d4245093a7fd091f7e5dca5805f5790 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 18 Apr 2024 18:26:24 +0000 Subject: [PATCH] address PR comments --- examples/llama_1.1b/ex_config_quantization.py | 6 +++-- .../llama_1.1b/ex_sparseml_quantization.py | 4 +++- .../compressors/sparse_bitmask.py | 2 +- src/compressed_tensors/config/base.py | 2 +- .../quantization/lifecycle/apply.py | 4 +++- .../quantization/lifecycle/initialize.py | 4 +++- .../quantization/observers/helpers.py | 3 ++- .../quantization/observers/min_max.py | 23 ++----------------- .../quantization/quant_config.py | 2 +- .../quantization/quant_scheme.py | 2 +- tests/quantization/test_quant_args.py | 2 +- tests/quantization/test_quant_config.py | 2 +- tests/quantization/test_quant_scheme.py | 2 +- tests/test_bitmask.py | 2 +- 14 files changed, 25 insertions(+), 35 deletions(-) diff --git a/examples/llama_1.1b/ex_config_quantization.py b/examples/llama_1.1b/ex_config_quantization.py index 131ba8ba..4dfbc951 100644 --- a/examples/llama_1.1b/ex_config_quantization.py +++ b/examples/llama_1.1b/ex_config_quantization.py @@ -25,6 +25,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator from torch.utils.data import DataLoader from sparseml.pytorch.utils import tensors_to_device +import torch config_file = "example_quant_config.json" model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" @@ -34,8 +35,9 @@ max_seq_length = 1024 pad_to_max_length = False output_dir = "./llama1.1b_new_quant_out" +device = "cuda:0" if torch.cuda_is_available() else "cpu" -model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0") +model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device) model.eval() # no grad or updates needed for base model config = QuantizationConfig.parse_file(config_file) @@ -80,4 +82,4 @@ # SparseML in order to save the config from sparseml.transformers.compression import modify_save_pretrained modify_save_pretrained(model) -model.save_pretrained(output_dir) \ No newline at end of file +model.save_pretrained(output_dir) diff --git a/examples/llama_1.1b/ex_sparseml_quantization.py b/examples/llama_1.1b/ex_sparseml_quantization.py index 655ea0f2..5ef492b1 100644 --- a/examples/llama_1.1b/ex_sparseml_quantization.py +++ b/examples/llama_1.1b/ex_sparseml_quantization.py @@ -16,6 +16,7 @@ from sparseml.transformers.finetune.data.data_args import DataTrainingArguments from sparseml.transformers.finetune.data.base import TextGenerationDataset from transformers import AutoTokenizer +import torch recipe = "example_quant_recipe.yaml" model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" @@ -25,8 +26,9 @@ max_seq_length = 1024 pad_to_max_length = False output_dir = "./llama1.1b_old_quant_out" +device = "cuda:0" if torch.cuda_is_available() else "cpu" -model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0") +model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map=device) tokenizer = AutoTokenizer.from_pretrained(model_name) data_args = DataTrainingArguments( diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index 9edf0b69..f6f03f0b 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -17,9 +17,9 @@ import numpy import torch -from safetensors import safe_open from compressed_tensors.compressors import ModelCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open from torch import Tensor from tqdm import tqdm diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 552c10a1..f58b11f8 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -14,8 +14,8 @@ from typing import Optional -from pydantic import BaseModel from compressed_tensors.registry import RegistryMixin +from pydantic import BaseModel __all__ = ["CompressionConfig"] diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 77c5245b..08cb42f9 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -16,7 +16,9 @@ from collections import OrderedDict from typing import Iterable, Optional -from compressed_tensors.quantization.lifecycle.calibration import set_module_for_calibration +from compressed_tensors.quantization.lifecycle.calibration import ( + set_module_for_calibration, +) from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d78997c1..4ef6379b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,7 +17,9 @@ from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import wrap_module_forward_quantized +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 3fd2f4b6..bc43bbec 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -34,6 +34,7 @@ def calculate_qparams( :return: tuple of the calculated scale(s) and zero point(s) """ bit_range = 2**quantization_args.num_bits - 1 + bit_min = -(bit_range + 1) / 2 if quantization_args.symmetric: symmetric_range = 2 * max(min_vals.abs(), max_vals.abs()) scales = symmetric_range / bit_range @@ -46,6 +47,6 @@ def calculate_qparams( # scales from a 0 range should be set to 1 scales[observed_range == 0] = 1 - zero_points = ((0 - min_vals) / scales).to(torch.int8) + zero_points = ((0 - min_vals) / scales + bit_min).to(torch.int8) return scales, zero_points diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 4d477fcd..3496bb77 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -43,10 +43,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: :param observed: observed tensor to calculate quantization parameters for :return: tuple of scale and zero point derived from the observed tensor """ - # TODO: Add support for full range of quantization Args, only supports 8bit - # per tensor - bit_min = -128 - bit_max = 127 + min_val = torch.tensor([observed.min()]) max_val = torch.tensor([observed.max()]) @@ -63,20 +60,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) self.counter += 1 - - if self.quantization_args.symmetric: - symmetric_range = 2 * max(min_val.abs(), max_val.abs()) - scale = symmetric_range / (bit_max - bit_min) - zero_point = torch.tensor(0).to(torch.int8) - else: - # non-symmetric - observed_range = max_val - min_val - quantized_range = bit_max - bit_min - scale = observed_range / (quantized_range) - - # scales from a 0 range should be set to 1 - scale[observed_range == 0] = 1 - - zero_point = ((0 - min_val) / scale + bit_min).to(torch.int8) - - return scale, zero_point + return calculate_qparams(min_val, max_val, self.quantization_args) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 7214bc83..a62a79bd 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -15,7 +15,6 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, Field from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( calculate_compression_ratio, @@ -23,6 +22,7 @@ iter_named_leaf_modules, module_type, ) +from pydantic import BaseModel, Field from torch.nn import Module diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index c083fc55..ed0f8245 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,8 +14,8 @@ from typing import List, Optional -from pydantic import BaseModel from compressed_tensors.quantization.quant_args import QuantizationArgs +from pydantic import BaseModel __all__ = ["QuantizationScheme"] diff --git a/tests/quantization/test_quant_args.py b/tests/quantization/test_quant_args.py index 46a4cc49..c1c84be6 100644 --- a/tests/quantization/test_quant_args.py +++ b/tests/quantization/test_quant_args.py @@ -13,12 +13,12 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, QuantizationType, ) +from pydantic import ValidationError def test_defaults(): diff --git a/tests/quantization/test_quant_config.py b/tests/quantization/test_quant_config.py index 68688c36..091be723 100644 --- a/tests/quantization/test_quant_config.py +++ b/tests/quantization/test_quant_config.py @@ -14,12 +14,12 @@ import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationConfig, QuantizationScheme, QuantizationStatus, ) +from pydantic import ValidationError def test_basic_config(): diff --git a/tests/quantization/test_quant_scheme.py b/tests/quantization/test_quant_scheme.py index 1c198812..14ba9f7e 100644 --- a/tests/quantization/test_quant_scheme.py +++ b/tests/quantization/test_quant_scheme.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from pydantic import ValidationError def test_basic_scheme(): diff --git a/tests/test_bitmask.py b/tests/test_bitmask.py index 28d29ed0..248580bc 100644 --- a/tests/test_bitmask.py +++ b/tests/test_bitmask.py @@ -17,8 +17,8 @@ import pytest import torch -from safetensors.torch import save_file from compressed_tensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor +from safetensors.torch import save_file @pytest.mark.parametrize(