diff --git a/examples/llama_1.1b/ex_config_quantization.py b/examples/llama_1.1b/ex_config_quantization.py new file mode 100644 index 00000000..4dfbc951 --- /dev/null +++ b/examples/llama_1.1b/ex_config_quantization.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import tqdm +from torch.utils.data import RandomSampler +from compressed_tensors.quantization import ( + apply_quantization_config, + freeze_module_quantization, + QuantizationConfig, + QuantizationStatus, +) +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from sparseml.transformers.finetune.data.base import TextGenerationDataset +from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator +from torch.utils.data import DataLoader +from sparseml.pytorch.utils import tensors_to_device +import torch + +config_file = "example_quant_config.json" +model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset_name = "open_platypus" +split = "train" +num_calibration_samples = 512 +max_seq_length = 1024 +pad_to_max_length = False +output_dir = "./llama1.1b_new_quant_out" +device = "cuda:0" if torch.cuda_is_available() else "cpu" + +model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device) +model.eval() # no grad or updates needed for base model +config = QuantizationConfig.parse_file(config_file) + +# set status to calibration +config.quantization_status = QuantizationStatus.CALIBRATION + +# initialize quantization +apply_quantization_config(model, config) + +# create dataset +tokenizer = AutoTokenizer.from_pretrained(model_name) +data_args = DataTrainingArguments( + dataset=dataset_name, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, +) +dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split=split, + tokenizer=tokenizer, +) +calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() +) +data_loader = DataLoader( + calib_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(calib_dataset) +) + +# run calibration +for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"): + sample = tensors_to_device(sample, "cuda:0") + _ = model(**sample) + + if idx >= num_calibration_samples: + break + +# freeze params after calibration +model.apply(freeze_module_quantization) + +# this functionality will move but for now we need to get the save override from +# SparseML in order to save the config +from sparseml.transformers.compression import modify_save_pretrained +modify_save_pretrained(model) +model.save_pretrained(output_dir) diff --git a/examples/llama_1.1b/ex_sparseml_quantization.py b/examples/llama_1.1b/ex_sparseml_quantization.py index 3c66f5e8..5ef492b1 100644 --- a/examples/llama_1.1b/ex_sparseml_quantization.py +++ b/examples/llama_1.1b/ex_sparseml_quantization.py @@ -13,28 +13,45 @@ # limitations under the License. from sparseml.transformers import oneshot, SparseAutoModelForCausalLM +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from sparseml.transformers.finetune.data.base import TextGenerationDataset +from transformers import AutoTokenizer +import torch -dataset_name = "open_platypus" -overwrite_output_dir = True -splits = {"calibration": "train"} -seed = 42 -output_dir = "./llama_1.1b_quant_mod_only" -num_calibration_samples = 1024 recipe = "example_quant_recipe.yaml" model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset_name = "open_platypus" +split = "train" +num_calibration_samples = 512 max_seq_length = 1024 pad_to_max_length = False +output_dir = "./llama1.1b_old_quant_out" +device = "cuda:0" if torch.cuda_is_available() else "cpu" -model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0") +model = SparseAutoModelForCausalLM.from_pretrained(model_name, device_map=device) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +data_args = DataTrainingArguments( + dataset=dataset_name, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, +) +dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split=split, + tokenizer=tokenizer, +) +calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() +) oneshot( model=model_name, dataset=dataset_name, output_dir=output_dir, - overwrite_output_dir=overwrite_output_dir, - splits = splits, + overwrite_output_dir=True, max_seq_length = max_seq_length, - seed=seed, num_calibration_samples=num_calibration_samples, recipe=recipe, pad_to_max_length=pad_to_max_length diff --git a/examples/llama_1.1b/example_quant_config.json b/examples/llama_1.1b/example_quant_config.json index 65d15740..969cea76 100644 --- a/examples/llama_1.1b/example_quant_config.json +++ b/examples/llama_1.1b/example_quant_config.json @@ -1,7 +1,6 @@ { "quant_method": "sparseml", "format": "fakequant", - "quantization_status": "frozen", "global_compression_ratio": null, "config_groups": { "group_1": { @@ -14,7 +13,7 @@ "input_activations": { "num_bits": 8, "type": "int", - "symmetric": true, + "symmetric": false, "strategy": "tensor" }, "targets": ["Linear"] @@ -23,17 +22,11 @@ "weights": { "num_bits": 8, "type": "int", - "symmetric": false, + "symmetric": true, "strategy": "tensor" }, - "input_activations": null, "targets": ["Embedding"] } }, - "ignore": [ - "LlamaRotaryEmbedding", "LlamaRMSNorm", "SiLUActivation", - "model.layers.1.mlp.down_proj", "MatMulLeftInput_QK", "MatMulRightInput_QK", - "MatMulOutput_QK", "MatMulLeftInput_PV", "MatMulRightInput_PV", - "MatMulOutput_PV" - ] + "ignore": ["model.layers.0.mlp.down_proj"] } \ No newline at end of file diff --git a/examples/llama_1.1b/example_quant_recipe.yaml b/examples/llama_1.1b/example_quant_recipe.yaml index e69de29b..c94e9285 100644 --- a/examples/llama_1.1b/example_quant_recipe.yaml +++ b/examples/llama_1.1b/example_quant_recipe.yaml @@ -0,0 +1,32 @@ +test_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - model.layers.0.mlp.down_proj + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - MatMulLeftInput_QK + - MatMulRightInput_QK + - MatMulOutput_QK + - MatMulLeftInput_PV + - MatMulRightInput_PV + - MatMulOutput_PV + scheme_overrides: + Linear: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: + num_bits: 8 + symmetric: false + strategy: "tensor" + output_activations: null + Embedding: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null \ No newline at end of file diff --git a/setup.py b/setup.py index 6ab8dc09..225d7b8d 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def _setup_packages() -> List: return find_packages( - "src", include=["compressed-tensors", "compressed-tensors.*"], exclude=["*.__pycache__.*"] + "src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"] ) def _setup_install_requires() -> List: diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index 9edf0b69..f6f03f0b 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -17,9 +17,9 @@ import numpy import torch -from safetensors import safe_open from compressed_tensors.compressors import ModelCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open from torch import Tensor from tqdm import tqdm diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 552c10a1..f58b11f8 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -14,8 +14,8 @@ from typing import Optional -from pydantic import BaseModel from compressed_tensors.registry import RegistryMixin +from pydantic import BaseModel __all__ = ["CompressionConfig"] diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 77c5245b..08cb42f9 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -16,7 +16,9 @@ from collections import OrderedDict from typing import Iterable, Optional -from compressed_tensors.quantization.lifecycle.calibration import set_module_for_calibration +from compressed_tensors.quantization.lifecycle.calibration import ( + set_module_for_calibration, +) from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e9db69ae..48b93e02 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -29,13 +29,14 @@ def quantize( x: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + q_min: torch.Tensor, q_max: torch.Tensor, ) -> torch.Tensor: return torch.clamp( torch.round( x / scale + zero_point, ), - 0, + q_min, q_max, ) @@ -56,9 +57,11 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, ) -> torch.Tensor: - max_q = torch.tensor(2**args.num_bits - 1, device=x.device) + bit_range = 2**args.num_bits + max_q = torch.tensor(bit_range / 2 - 1, device=x.device) + min_q = torch.tensor(-bit_range / 2, device=x.device) Q = torch.zeros_like(x) - Q = quantize(x, scale, zero_point, max_q) + Q = quantize(x, scale, zero_point, min_q, max_q) return dequantize(Q, scale, zero_point) @@ -114,7 +117,6 @@ def _maybe_calibrate_or_quantize( device = next(module.parameters()).device scale = getattr(module, f"{base_name}_scale") - # zero_point = getattr(module, f"{base_name}_zero_point").data zero_point = getattr(module, f"{base_name}_zero_point") if module.quantization_status == QuantizationStatus.CALIBRATION: diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d78997c1..4ef6379b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,7 +17,9 @@ from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import wrap_module_forward_quantized +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 3fd2f4b6..d0fca813 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -34,6 +34,7 @@ def calculate_qparams( :return: tuple of the calculated scale(s) and zero point(s) """ bit_range = 2**quantization_args.num_bits - 1 + bit_min = -(bit_range + 1) / 2 if quantization_args.symmetric: symmetric_range = 2 * max(min_vals.abs(), max_vals.abs()) scales = symmetric_range / bit_range @@ -46,6 +47,6 @@ def calculate_qparams( # scales from a 0 range should be set to 1 scales[observed_range == 0] = 1 - zero_points = ((0 - min_vals) / scales).to(torch.int8) + zero_points = torch.round(((0.0 - min_vals) / scales + bit_min)).to(torch.int8) return scales, zero_points diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 163392f9..fe16956f 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -28,7 +28,7 @@ class MinMaxObserver(Observer): """ Implements a dynamic quantization observer that sets the scale and - zero point based on the latest observed value + zero point based on the overall min and max value """ def __init__(self, quantization_args: QuantizationArgs): @@ -56,12 +56,14 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: # update running average if self.counter > 0: - self.min_vals = (self.min_vals * self.counter + min_vals) / ( - self.counter + 1 - ) - self.max_vals = (self.max_vals * self.counter + max_vals) / ( - self.counter + 1 - ) + # self.min_vals = (self.min_vals * self.counter + min_vals) / ( + # self.counter + 1 + # ) + # self.max_vals = (self.max_vals * self.counter + max_vals) / ( + # self.counter + 1 + # ) + self.min_vals = torch.min(min_vals, self.min_vals) + self.max_vals = torch.max(max_val, self.max_vals) else: self.min_vals = min_vals self.max_vals = max_vals @@ -76,10 +78,10 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: min_val = torch.tensor([observed.min()]) max_val = torch.tensor([observed.max()]) - # update running average + # update global min and max if self.counter > 0: - self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) - self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + self.min_val = torch.min(min_val, self.min_val) + self.max_val = torch.max(max_val, self.max_val) else: self.min_val = min_val self.max_val = max_val diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 7214bc83..a62a79bd 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -15,7 +15,6 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, Field from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( calculate_compression_ratio, @@ -23,6 +22,7 @@ iter_named_leaf_modules, module_type, ) +from pydantic import BaseModel, Field from torch.nn import Module diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index c083fc55..ed0f8245 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,8 +14,8 @@ from typing import List, Optional -from pydantic import BaseModel from compressed_tensors.quantization.quant_args import QuantizationArgs +from pydantic import BaseModel __all__ = ["QuantizationScheme"] diff --git a/tests/quantization/test_quant_args.py b/tests/quantization/test_quant_args.py index 46a4cc49..c1c84be6 100644 --- a/tests/quantization/test_quant_args.py +++ b/tests/quantization/test_quant_args.py @@ -13,12 +13,12 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, QuantizationType, ) +from pydantic import ValidationError def test_defaults(): diff --git a/tests/quantization/test_quant_config.py b/tests/quantization/test_quant_config.py index 68688c36..091be723 100644 --- a/tests/quantization/test_quant_config.py +++ b/tests/quantization/test_quant_config.py @@ -14,12 +14,12 @@ import pytest -from pydantic import ValidationError from compressed_tensors.quantization import ( QuantizationConfig, QuantizationScheme, QuantizationStatus, ) +from pydantic import ValidationError def test_basic_config(): diff --git a/tests/quantization/test_quant_scheme.py b/tests/quantization/test_quant_scheme.py index 1c198812..14ba9f7e 100644 --- a/tests/quantization/test_quant_scheme.py +++ b/tests/quantization/test_quant_scheme.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest -from pydantic import ValidationError from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from pydantic import ValidationError def test_basic_scheme(): diff --git a/tests/test_bitmask.py b/tests/test_bitmask.py index 28d29ed0..248580bc 100644 --- a/tests/test_bitmask.py +++ b/tests/test_bitmask.py @@ -17,8 +17,8 @@ import pytest import torch -from safetensors.torch import save_file from compressed_tensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor +from safetensors.torch import save_file @pytest.mark.parametrize(