From f5eebb70664bf4dfdfd66db3da0dc1eea10dfd33 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 17 Apr 2024 18:04:07 +0000 Subject: [PATCH 01/10] testing fixes --- src/sparsetensors/quantization/lifecycle/forward.py | 1 - src/sparsetensors/quantization/observers/min_max.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 6416a10b..9388bd1a 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -114,7 +114,6 @@ def _maybe_calibrate_or_quantize( device = next(module.parameters()).device scale = getattr(module, f"{base_name}_scale") - # zero_point = getattr(module, f"{base_name}_zero_point").data zero_point = getattr(module, f"{base_name}_zero_point") if module.quantization_status == QuantizationStatus.CALIBRATION: diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py index e73805b4..6b012918 100644 --- a/src/sparsetensors/quantization/observers/min_max.py +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -50,8 +50,10 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: # update running average if self.counter > 0: - self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) - self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + # self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) + # self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + self.min_val = torch.min(min_val, self.min_val) + self.max_val = torch.max(max_val, self.max_val) else: self.min_val = min_val self.max_val = max_val From afbd617bd3505970b73e79715156fe0231c0476f Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 17 Apr 2024 19:04:19 +0000 Subject: [PATCH 02/10] fix clamp range --- src/sparsetensors/quantization/lifecycle/forward.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 9388bd1a..08aff33e 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -31,12 +31,13 @@ def quantize( zero_point: torch.Tensor, q_max: torch.Tensor, ) -> torch.Tensor: + #TODO: don't harcode these, will change for other bit-depths return torch.clamp( torch.round( x / scale + zero_point, ), - 0, - q_max, + -128, + 127, ) From 0cce0183ad8aa1a7fc65fb28de5e4b2c786d5044 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 17 Apr 2024 19:49:19 +0000 Subject: [PATCH 03/10] clean up fixes --- src/sparsetensors/quantization/lifecycle/forward.py | 11 ++++++----- src/sparsetensors/quantization/observers/min_max.py | 6 ++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 08aff33e..5a37b226 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -29,15 +29,15 @@ def quantize( x: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + q_min: torch.Tensor, q_max: torch.Tensor, ) -> torch.Tensor: - #TODO: don't harcode these, will change for other bit-depths return torch.clamp( torch.round( x / scale + zero_point, ), - -128, - 127, + q_min, + q_max, ) @@ -57,9 +57,10 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, ) -> torch.Tensor: - max_q = torch.tensor(2**args.num_bits - 1, device=x.device) + max_q = torch.tensor((2**args.num_bits) / 2 - 1, device=x.device) + min_q = torch.tensor(max_q - 2**args.num_bits, device=x.device) Q = torch.zeros_like(x) - Q = quantize(x, scale, zero_point, max_q) + Q = quantize(x, scale, zero_point, min_q, max_q) return dequantize(Q, scale, zero_point) diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py index 6b012918..2e9ca5d4 100644 --- a/src/sparsetensors/quantization/observers/min_max.py +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -27,7 +27,7 @@ class MinMaxObserver(Observer): """ Implements a dynamic quantization observer that sets the scale and - zero point based on the latest observed value + zero point based on the overall min and max value """ def __init__(self, quantization_args: QuantizationArgs): @@ -48,10 +48,8 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: min_val = torch.tensor([observed.min()]) max_val = torch.tensor([observed.max()]) - # update running average + # update global min and max if self.counter > 0: - # self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) - # self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) self.min_val = torch.min(min_val, self.min_val) self.max_val = torch.max(max_val, self.max_val) else: From c48c9c78e94cd7c989ad9547466047c4c90aa228 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 17 Apr 2024 20:36:41 +0000 Subject: [PATCH 04/10] fixing symmetry issue --- src/sparsetensors/quantization/lifecycle/forward.py | 5 +++-- src/sparsetensors/quantization/observers/min_max.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 5a37b226..23e0cb3e 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -57,8 +57,9 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, ) -> torch.Tensor: - max_q = torch.tensor((2**args.num_bits) / 2 - 1, device=x.device) - min_q = torch.tensor(max_q - 2**args.num_bits, device=x.device) + bit_range = 2**args.num_bits + max_q = torch.tensor(bit_range / 2 - 1, device=x.device) + min_q = torch.tensor(-bit_range / 2, device=x.device) Q = torch.zeros_like(x) Q = quantize(x, scale, zero_point, min_q, max_q) return dequantize(Q, scale, zero_point) diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py index 2e9ca5d4..fb1781e8 100644 --- a/src/sparsetensors/quantization/observers/min_max.py +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -44,7 +44,8 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ # TODO: Add support for full range of quantization Args, only supports 8bit # per tensor - bit_range = 255 + bit_min = -128 + bit_max = 127 min_val = torch.tensor([observed.min()]) max_val = torch.tensor([observed.max()]) @@ -64,16 +65,17 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: if self.quantization_args.symmetric: symmetric_range = 2 * max(min_val.abs(), max_val.abs()) - scale = symmetric_range / bit_range + scale = symmetric_range / (bit_max - bit_min) zero_point = torch.tensor(0).to(torch.int8) else: # non-symmetric observed_range = max_val - min_val - scale = observed_range / bit_range + quantized_range = bit_max - bit_min + scale = observed_range / (quantized_range) # scales from a 0 range should be set to 1 scale[observed_range == 0] = 1 - zero_point = ((0 - min_val) / scale).to(torch.int8) + zero_point = ((0 - min_val) / scale + bit_min).to(torch.int8) return scale, zero_point From a89efcc8687e48bd5f55b6bdae6fabab57cdfd08 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 17 Apr 2024 21:05:49 +0000 Subject: [PATCH 05/10] update examples --- examples/llama_1.1b/ex_config_quantization.py | 77 +++++++++++++++++++ .../llama_1.1b/ex_sparseml_quantization.py | 33 +++++--- examples/llama_1.1b/example_quant_config.json | 13 +--- examples/llama_1.1b/example_quant_recipe.yaml | 32 ++++++++ 4 files changed, 136 insertions(+), 19 deletions(-) create mode 100644 examples/llama_1.1b/ex_config_quantization.py diff --git a/examples/llama_1.1b/ex_config_quantization.py b/examples/llama_1.1b/ex_config_quantization.py new file mode 100644 index 00000000..175ca0b9 --- /dev/null +++ b/examples/llama_1.1b/ex_config_quantization.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import tqdm + +from compressed_tensors.quantization import ( + apply_quantization_config, + freeze_module_quantization, + QuantizationConfig, + QuantizationStatus, +) +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from sparseml.transformers.finetune.data.base import TextGenerationDataset +from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator +from torch.utils.data import DataLoader + + +config_file = "example_quant_config.json" +model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset_name = "open_platypus" +split = "train" +num_calibration_samples = 512 +max_seq_length = 1024 +pad_to_max_length = False +output_dir = "./llama1.1b_new_quant_out" + +model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0") +model.eval() # no grad or updates needed for base model +config = QuantizationConfig.parse_file(config_file) + +# set status to calibration +config.quantization_status = QuantizationStatus.CALIBRATION + +# initialize quantization +apply_quantization_config(model, config) + +# create dataset +tokenizer = AutoTokenizer.from_pretrained(model_name) +data_args = DataTrainingArguments( + dataset=dataset_name, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, +) +dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split=split, + tokenizer=tokenizer, +) +calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() +) +data_loader = DataLoader( + calib_dataset, batch_size=1, collate_fn=DefaultDataCollator() +) + +# run calibration +for idx, sample in tqdm(enumerate(data_loader)): + _ = model(**sample) + + if idx >= num_calibration_samples: + break + +# freeze params after calibration +model.apply(freeze_module_quantization) +model.save_pretrained(output_dir) \ No newline at end of file diff --git a/examples/llama_1.1b/ex_sparseml_quantization.py b/examples/llama_1.1b/ex_sparseml_quantization.py index 3c66f5e8..655ea0f2 100644 --- a/examples/llama_1.1b/ex_sparseml_quantization.py +++ b/examples/llama_1.1b/ex_sparseml_quantization.py @@ -13,28 +13,43 @@ # limitations under the License. from sparseml.transformers import oneshot, SparseAutoModelForCausalLM +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from sparseml.transformers.finetune.data.base import TextGenerationDataset +from transformers import AutoTokenizer -dataset_name = "open_platypus" -overwrite_output_dir = True -splits = {"calibration": "train"} -seed = 42 -output_dir = "./llama_1.1b_quant_mod_only" -num_calibration_samples = 1024 recipe = "example_quant_recipe.yaml" model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset_name = "open_platypus" +split = "train" +num_calibration_samples = 512 max_seq_length = 1024 pad_to_max_length = False +output_dir = "./llama1.1b_old_quant_out" model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0") +tokenizer = AutoTokenizer.from_pretrained(model_name) +data_args = DataTrainingArguments( + dataset=dataset_name, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, +) +dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split=split, + tokenizer=tokenizer, +) +calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() +) + oneshot( model=model_name, dataset=dataset_name, output_dir=output_dir, - overwrite_output_dir=overwrite_output_dir, - splits = splits, + overwrite_output_dir=True, max_seq_length = max_seq_length, - seed=seed, num_calibration_samples=num_calibration_samples, recipe=recipe, pad_to_max_length=pad_to_max_length diff --git a/examples/llama_1.1b/example_quant_config.json b/examples/llama_1.1b/example_quant_config.json index 65d15740..969cea76 100644 --- a/examples/llama_1.1b/example_quant_config.json +++ b/examples/llama_1.1b/example_quant_config.json @@ -1,7 +1,6 @@ { "quant_method": "sparseml", "format": "fakequant", - "quantization_status": "frozen", "global_compression_ratio": null, "config_groups": { "group_1": { @@ -14,7 +13,7 @@ "input_activations": { "num_bits": 8, "type": "int", - "symmetric": true, + "symmetric": false, "strategy": "tensor" }, "targets": ["Linear"] @@ -23,17 +22,11 @@ "weights": { "num_bits": 8, "type": "int", - "symmetric": false, + "symmetric": true, "strategy": "tensor" }, - "input_activations": null, "targets": ["Embedding"] } }, - "ignore": [ - "LlamaRotaryEmbedding", "LlamaRMSNorm", "SiLUActivation", - "model.layers.1.mlp.down_proj", "MatMulLeftInput_QK", "MatMulRightInput_QK", - "MatMulOutput_QK", "MatMulLeftInput_PV", "MatMulRightInput_PV", - "MatMulOutput_PV" - ] + "ignore": ["model.layers.0.mlp.down_proj"] } \ No newline at end of file diff --git a/examples/llama_1.1b/example_quant_recipe.yaml b/examples/llama_1.1b/example_quant_recipe.yaml index e69de29b..c94e9285 100644 --- a/examples/llama_1.1b/example_quant_recipe.yaml +++ b/examples/llama_1.1b/example_quant_recipe.yaml @@ -0,0 +1,32 @@ +test_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - model.layers.0.mlp.down_proj + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - MatMulLeftInput_QK + - MatMulRightInput_QK + - MatMulOutput_QK + - MatMulLeftInput_PV + - MatMulRightInput_PV + - MatMulOutput_PV + scheme_overrides: + Linear: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: + num_bits: 8 + symmetric: false + strategy: "tensor" + output_activations: null + Embedding: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null \ No newline at end of file From fd9545d294ff82fb9a3152dbfeef59deaf99bd90 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Wed, 17 Apr 2024 18:07:53 -0400 Subject: [PATCH 06/10] fix style post rename PR (#25) --- src/compressed_tensors/compressors/sparse_bitmask.py | 2 +- src/compressed_tensors/config/base.py | 2 +- src/compressed_tensors/quantization/lifecycle/apply.py | 4 +++- src/compressed_tensors/quantization/lifecycle/initialize.py | 4 +++- src/compressed_tensors/quantization/quant_config.py | 2 +- src/compressed_tensors/quantization/quant_scheme.py | 2 +- tests/quantization/test_quant_args.py | 2 +- tests/quantization/test_quant_config.py | 2 +- tests/quantization/test_quant_scheme.py | 2 +- tests/test_bitmask.py | 2 +- 10 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index 9edf0b69..f6f03f0b 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -17,9 +17,9 @@ import numpy import torch -from safetensors import safe_open from compressed_tensors.compressors import ModelCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open from torch import Tensor from tqdm import tqdm diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 552c10a1..f58b11f8 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -14,8 +14,8 @@ from typing import Optional -from pydantic import BaseModel from compressed_tensors.registry import RegistryMixin +from pydantic import BaseModel __all__ = ["CompressionConfig"] diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 77c5245b..08cb42f9 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -16,7 +16,9 @@ from collections import OrderedDict from typing import Iterable, Optional -from compressed_tensors.quantization.lifecycle.calibration import set_module_for_calibration +from compressed_tensors.quantization.lifecycle.calibration import ( + set_module_for_calibration, +) from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d78997c1..4ef6379b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,7 +17,9 @@ from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import wrap_module_forward_quantized +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 7214bc83..a62a79bd 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -15,7 +15,6 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, Field from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( calculate_compression_ratio, @@ -23,6 +22,7 @@ iter_named_leaf_modules, module_type, ) +from pydantic import BaseModel, Field from torch.nn import Module diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index c083fc55..ed0f8245 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,8 +14,8 @@ from typing import List, Optional -from pydantic import BaseModel from compressed_tensors.quantization.quant_args import QuantizationArgs +from pydantic import BaseModel __all__ = ["QuantizationScheme"] diff --git a/tests/quantization/test_quant_args.py b/tests/quantization/test_quant_args.py index 46a4cc49..c1c84be6 100644 --- a/tests/quantization/test_quant_args.py +++ b/tests/quantization/test_quant_args.py @@ -13,12 +13,12 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, QuantizationType, ) +from pydantic import ValidationError def test_defaults(): diff --git a/tests/quantization/test_quant_config.py b/tests/quantization/test_quant_config.py index 68688c36..091be723 100644 --- a/tests/quantization/test_quant_config.py +++ b/tests/quantization/test_quant_config.py @@ -14,12 +14,12 @@ import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationConfig, QuantizationScheme, QuantizationStatus, ) +from pydantic import ValidationError def test_basic_config(): diff --git a/tests/quantization/test_quant_scheme.py b/tests/quantization/test_quant_scheme.py index 1c198812..14ba9f7e 100644 --- a/tests/quantization/test_quant_scheme.py +++ b/tests/quantization/test_quant_scheme.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from pydantic import ValidationError def test_basic_scheme(): diff --git a/tests/test_bitmask.py b/tests/test_bitmask.py index 28d29ed0..248580bc 100644 --- a/tests/test_bitmask.py +++ b/tests/test_bitmask.py @@ -17,8 +17,8 @@ import pytest import torch -from safetensors.torch import save_file from compressed_tensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor +from safetensors.torch import save_file @pytest.mark.parametrize( From 20986a69544a4a7cb5fe4627a63cd43e7d68680f Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 18 Apr 2024 14:37:36 +0000 Subject: [PATCH 07/10] fix example --- examples/llama_1.1b/ex_config_quantization.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/llama_1.1b/ex_config_quantization.py b/examples/llama_1.1b/ex_config_quantization.py index 175ca0b9..131ba8ba 100644 --- a/examples/llama_1.1b/ex_config_quantization.py +++ b/examples/llama_1.1b/ex_config_quantization.py @@ -13,7 +13,7 @@ # limitations under the License. from tqdm import tqdm - +from torch.utils.data import RandomSampler from compressed_tensors.quantization import ( apply_quantization_config, freeze_module_quantization, @@ -24,7 +24,7 @@ from sparseml.transformers.finetune.data.base import TextGenerationDataset from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator from torch.utils.data import DataLoader - +from sparseml.pytorch.utils import tensors_to_device config_file = "example_quant_config.json" model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" @@ -62,11 +62,12 @@ dataset_manager.get_raw_dataset() ) data_loader = DataLoader( - calib_dataset, batch_size=1, collate_fn=DefaultDataCollator() + calib_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(calib_dataset) ) # run calibration -for idx, sample in tqdm(enumerate(data_loader)): +for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"): + sample = tensors_to_device(sample, "cuda:0") _ = model(**sample) if idx >= num_calibration_samples: @@ -74,4 +75,9 @@ # freeze params after calibration model.apply(freeze_module_quantization) + +# this functionality will move but for now we need to get the save override from +# SparseML in order to save the config +from sparseml.transformers.compression import modify_save_pretrained +modify_save_pretrained(model) model.save_pretrained(output_dir) \ No newline at end of file From 6140bd212d4245093a7fd091f7e5dca5805f5790 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 18 Apr 2024 18:26:24 +0000 Subject: [PATCH 08/10] address PR comments --- examples/llama_1.1b/ex_config_quantization.py | 6 +++-- .../llama_1.1b/ex_sparseml_quantization.py | 4 +++- .../compressors/sparse_bitmask.py | 2 +- src/compressed_tensors/config/base.py | 2 +- .../quantization/lifecycle/apply.py | 4 +++- .../quantization/lifecycle/initialize.py | 4 +++- .../quantization/observers/helpers.py | 3 ++- .../quantization/observers/min_max.py | 23 ++----------------- .../quantization/quant_config.py | 2 +- .../quantization/quant_scheme.py | 2 +- tests/quantization/test_quant_args.py | 2 +- tests/quantization/test_quant_config.py | 2 +- tests/quantization/test_quant_scheme.py | 2 +- tests/test_bitmask.py | 2 +- 14 files changed, 25 insertions(+), 35 deletions(-) diff --git a/examples/llama_1.1b/ex_config_quantization.py b/examples/llama_1.1b/ex_config_quantization.py index 131ba8ba..4dfbc951 100644 --- a/examples/llama_1.1b/ex_config_quantization.py +++ b/examples/llama_1.1b/ex_config_quantization.py @@ -25,6 +25,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator from torch.utils.data import DataLoader from sparseml.pytorch.utils import tensors_to_device +import torch config_file = "example_quant_config.json" model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" @@ -34,8 +35,9 @@ max_seq_length = 1024 pad_to_max_length = False output_dir = "./llama1.1b_new_quant_out" +device = "cuda:0" if torch.cuda_is_available() else "cpu" -model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0") +model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device) model.eval() # no grad or updates needed for base model config = QuantizationConfig.parse_file(config_file) @@ -80,4 +82,4 @@ # SparseML in order to save the config from sparseml.transformers.compression import modify_save_pretrained modify_save_pretrained(model) -model.save_pretrained(output_dir) \ No newline at end of file +model.save_pretrained(output_dir) diff --git a/examples/llama_1.1b/ex_sparseml_quantization.py b/examples/llama_1.1b/ex_sparseml_quantization.py index 655ea0f2..5ef492b1 100644 --- a/examples/llama_1.1b/ex_sparseml_quantization.py +++ b/examples/llama_1.1b/ex_sparseml_quantization.py @@ -16,6 +16,7 @@ from sparseml.transformers.finetune.data.data_args import DataTrainingArguments from sparseml.transformers.finetune.data.base import TextGenerationDataset from transformers import AutoTokenizer +import torch recipe = "example_quant_recipe.yaml" model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" @@ -25,8 +26,9 @@ max_seq_length = 1024 pad_to_max_length = False output_dir = "./llama1.1b_old_quant_out" +device = "cuda:0" if torch.cuda_is_available() else "cpu" -model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0") +model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map=device) tokenizer = AutoTokenizer.from_pretrained(model_name) data_args = DataTrainingArguments( diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index 9edf0b69..f6f03f0b 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -17,9 +17,9 @@ import numpy import torch -from safetensors import safe_open from compressed_tensors.compressors import ModelCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open from torch import Tensor from tqdm import tqdm diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 552c10a1..f58b11f8 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -14,8 +14,8 @@ from typing import Optional -from pydantic import BaseModel from compressed_tensors.registry import RegistryMixin +from pydantic import BaseModel __all__ = ["CompressionConfig"] diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 77c5245b..08cb42f9 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -16,7 +16,9 @@ from collections import OrderedDict from typing import Iterable, Optional -from compressed_tensors.quantization.lifecycle.calibration import set_module_for_calibration +from compressed_tensors.quantization.lifecycle.calibration import ( + set_module_for_calibration, +) from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d78997c1..4ef6379b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,7 +17,9 @@ from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import wrap_module_forward_quantized +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 3fd2f4b6..bc43bbec 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -34,6 +34,7 @@ def calculate_qparams( :return: tuple of the calculated scale(s) and zero point(s) """ bit_range = 2**quantization_args.num_bits - 1 + bit_min = -(bit_range + 1) / 2 if quantization_args.symmetric: symmetric_range = 2 * max(min_vals.abs(), max_vals.abs()) scales = symmetric_range / bit_range @@ -46,6 +47,6 @@ def calculate_qparams( # scales from a 0 range should be set to 1 scales[observed_range == 0] = 1 - zero_points = ((0 - min_vals) / scales).to(torch.int8) + zero_points = ((0 - min_vals) / scales + bit_min).to(torch.int8) return scales, zero_points diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 4d477fcd..3496bb77 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -43,10 +43,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: :param observed: observed tensor to calculate quantization parameters for :return: tuple of scale and zero point derived from the observed tensor """ - # TODO: Add support for full range of quantization Args, only supports 8bit - # per tensor - bit_min = -128 - bit_max = 127 + min_val = torch.tensor([observed.min()]) max_val = torch.tensor([observed.max()]) @@ -63,20 +60,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) self.counter += 1 - - if self.quantization_args.symmetric: - symmetric_range = 2 * max(min_val.abs(), max_val.abs()) - scale = symmetric_range / (bit_max - bit_min) - zero_point = torch.tensor(0).to(torch.int8) - else: - # non-symmetric - observed_range = max_val - min_val - quantized_range = bit_max - bit_min - scale = observed_range / (quantized_range) - - # scales from a 0 range should be set to 1 - scale[observed_range == 0] = 1 - - zero_point = ((0 - min_val) / scale + bit_min).to(torch.int8) - - return scale, zero_point + return calculate_qparams(min_val, max_val, self.quantization_args) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 7214bc83..a62a79bd 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -15,7 +15,6 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, Field from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( calculate_compression_ratio, @@ -23,6 +22,7 @@ iter_named_leaf_modules, module_type, ) +from pydantic import BaseModel, Field from torch.nn import Module diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index c083fc55..ed0f8245 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,8 +14,8 @@ from typing import List, Optional -from pydantic import BaseModel from compressed_tensors.quantization.quant_args import QuantizationArgs +from pydantic import BaseModel __all__ = ["QuantizationScheme"] diff --git a/tests/quantization/test_quant_args.py b/tests/quantization/test_quant_args.py index 46a4cc49..c1c84be6 100644 --- a/tests/quantization/test_quant_args.py +++ b/tests/quantization/test_quant_args.py @@ -13,12 +13,12 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, QuantizationType, ) +from pydantic import ValidationError def test_defaults(): diff --git a/tests/quantization/test_quant_config.py b/tests/quantization/test_quant_config.py index 68688c36..091be723 100644 --- a/tests/quantization/test_quant_config.py +++ b/tests/quantization/test_quant_config.py @@ -14,12 +14,12 @@ import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationConfig, QuantizationScheme, QuantizationStatus, ) +from pydantic import ValidationError def test_basic_config(): diff --git a/tests/quantization/test_quant_scheme.py b/tests/quantization/test_quant_scheme.py index 1c198812..14ba9f7e 100644 --- a/tests/quantization/test_quant_scheme.py +++ b/tests/quantization/test_quant_scheme.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from pydantic import ValidationError def test_basic_scheme(): diff --git a/tests/test_bitmask.py b/tests/test_bitmask.py index 28d29ed0..248580bc 100644 --- a/tests/test_bitmask.py +++ b/tests/test_bitmask.py @@ -17,8 +17,8 @@ import pytest import torch -from safetensors.torch import save_file from compressed_tensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor +from safetensors.torch import save_file @pytest.mark.parametrize( From 41168b9e03a71827875198faeec47babf8b063c2 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 18 Apr 2024 19:21:38 +0000 Subject: [PATCH 09/10] oops rounding --- src/compressed_tensors/quantization/observers/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index bc43bbec..d0fca813 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -47,6 +47,6 @@ def calculate_qparams( # scales from a 0 range should be set to 1 scales[observed_range == 0] = 1 - zero_points = ((0 - min_vals) / scales + bit_min).to(torch.int8) + zero_points = torch.round(((0.0 - min_vals) / scales + bit_min)).to(torch.int8) return scales, zero_points From 06200fc3f6221ebf16a4989bc017e4ef27afb02c Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Fri, 19 Apr 2024 12:48:53 +0200 Subject: [PATCH 10/10] Fix failing GHA (#29) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6ab8dc09..225d7b8d 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def _setup_packages() -> List: return find_packages( - "src", include=["compressed-tensors", "compressed-tensors.*"], exclude=["*.__pycache__.*"] + "src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"] ) def _setup_install_requires() -> List: