From 774da35f273fa415b976c5f2b157e58b9411f744 Mon Sep 17 00:00:00 2001 From: George Date: Wed, 1 May 2024 10:53:00 -0400 Subject: [PATCH] [Lifecycle][Tests] Feature Branch (#38) * test forward (#16) * test frozen (#17) * test frozen * rename * lifecycle conftest (#21) * test initalize (#18) * test initalize * newline * parametrize weights and inp_act * remove dup * test lifecycle (#19) * test lifecycle * comments * comments * add quantization test * Lifecycle/min max obs (#20) * min max test * add minmax obs * test scale range and min_max update * rebase * rebase * fix * fix --- README.md | 3 +- src/compressed_tensors/README.md | 162 ++++++++++++++++++ .../quantization/lifecycle/forward.py | 11 +- src/compressed_tensors/utils/helpers.py | 45 +++++ .../observers}/quantization/__init__.py | 0 .../quantization/lifecycle/__init__.py | 0 .../quantization/lifecycle/conftest.py | 37 ++++ .../quantization/lifecycle/test_apply.py | 0 .../lifecycle/test_dynamic_lifecycle.py | 0 .../quantization/lifecycle/test_forward.py | 82 +++++++++ .../quantization/lifecycle/test_frozen.py | 47 +++++ .../quantization/lifecycle/test_initialize.py | 79 +++++++++ .../quantization/lifecycle/test_lifecycle.py | 122 +++++++++++++ .../quantization/test_quant_args.py | 0 .../quantization/test_quant_config.py | 0 .../quantization/test_quant_scheme.py | 0 .../quantization/observers/test_min_max.py | 91 ++++++++++ 17 files changed, 672 insertions(+), 7 deletions(-) create mode 100644 src/compressed_tensors/README.md create mode 100644 src/compressed_tensors/utils/helpers.py rename tests/{ => compressed_tensors/quantization/observers}/quantization/__init__.py (100%) rename tests/{ => compressed_tensors/quantization/observers}/quantization/lifecycle/__init__.py (100%) create mode 100644 tests/compressed_tensors/quantization/observers/quantization/lifecycle/conftest.py rename tests/{ => compressed_tensors/quantization/observers}/quantization/lifecycle/test_apply.py (100%) rename tests/{ => compressed_tensors/quantization/observers}/quantization/lifecycle/test_dynamic_lifecycle.py (100%) create mode 100644 tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_forward.py create mode 100644 tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_frozen.py create mode 100644 tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_initialize.py create mode 100644 tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_lifecycle.py rename tests/{ => compressed_tensors/quantization/observers}/quantization/test_quant_args.py (100%) rename tests/{ => compressed_tensors/quantization/observers}/quantization/test_quant_config.py (100%) rename tests/{ => compressed_tensors/quantization/observers}/quantization/test_quant_scheme.py (100%) create mode 100644 tests/compressed_tensors/quantization/observers/test_min_max.py diff --git a/README.md b/README.md index dc2a2b04..c3381e28 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# compressed-tensors +# compressed_tensors This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation. @@ -82,4 +82,3 @@ state_dict = dict(load_compressed("compressed_model.safetensors", compression_co ``` For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb). - diff --git a/src/compressed_tensors/README.md b/src/compressed_tensors/README.md new file mode 100644 index 00000000..5b1c8ece --- /dev/null +++ b/src/compressed_tensors/README.md @@ -0,0 +1,162 @@ +# Save/Load Compressed SafeTensors + +## Motivation + +* Reduce disk space by saving in a compressed format for sparse models. Models in this compressed format will be loaded by vLLM for more efficient inference +* Set up the save/load architecture such that we can easily expand to additional compression formats in the future. The config should be human readable so users can understand the compression format at a quick glance + +## SafeTensors File Format + +For each parameter in the uncompressed state_dict, we store the following attributes +needed for decompression in the compressed state_dict: + +* compressed tensor +* bitmask +* uncompressed shape +* row offsets + +```python +# dense +{ + PARAM_NAME: uncompressed_tensor +} + +# compressed +{ + PARAM_NAME.compressed: compressed_tensor # 1d tensor + PARAM_NAME.bitmask: value # 2d bitmask tensor (nrows x (ncols / 8)) + PARAM_NAME.shape: value # uncompressed shape tensor + PARAM_NAME.row_offsets: value # 1d offsets tensor +} +``` + +Config information gets stored in the HF config file +```json +// config.json +{ + "sparsity_config": { + "format": "sparse_bitmask", // "dense_sparsity" for original tensor format + + // informational + "sparsity_structure": "unstructured", // or 2:4, 8:16 etc... + "global_sparsity": "0.5" + } +} +``` + +## Saving/Loading Interface + +Loading in a compressed model requires no interface changes + +```python +from sparseml.transformers.utils import SparseAutoModelForCausalLM + +# should contain model.safetensors or model.safetensors.index.json +model_path = "/PATH/TO/COMPRESSED_MODEL" + +model = SparseAutoModelForCausalLM.from_pretrained( + model_name_or_path=model_path, + **model_kwargs, +) +``` + +Saving a compressed model with an explicitly provided compression config. The config +is saved to the model's `config.json` file. **Note:** the model must have been +initialized with SparseAutoModelForCausalLM.from_pretrained() + +```python +from compressed_tensors import BitmaskConfig + +output_dir = "/PATH/TO/SAVE/COMPRESSED_MODEL" +sparsity_config = BitmaskConfig() + +model.save_pretrained( + save_directory=output_dir, + sparsity_config=sparsity_config, +) +``` + +Saving a compressed model, inferring the config from the model attributes + +```python +model.save_pretrained( + save_directory=output_dir, + save_compressed=True +) +``` + +Saving a model in the dense format. If the model has at least 5% global sparsity a +sparsity config will still be included in `config.json` with format `dense_sparsity` + +```python +model.save_pretrained( + save_directory=output_dir +) +``` + +Saving a model in the dense format, bypassing the sparsity config calculation. When the +`skip_compression_stats` flag is set, no sparsity config will be written to +`config.json` + +```python +model.save_pretrained( + save_directory=output_dir + skip_compression_stats=True +) +``` + +## Enable Compression During One-Shot and Sparse Finetunining +Models that are saved in a supported compressed format on disk will automatically be +decompressed when loaded as input to `sparseml.transformers.oneshot` or +`sparseml.transformers.train` + +To enable compression on save after oneshot or finetuning simply add the +`save_compressed=True` argument to `sparseml.transformers.oneshot` or +`sparseml.transformers.train` + +```python +from sparseml.transformers import train + +train( + save_compressed=True, + model="neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4", + recipe=RECIPE, + dataset=DATASET +) +``` + + +## Example Code + +Loads a 60% sparse model, compresses it using the inferred bitmask compression, then +reloads the compressed model. + +```python +from sparseml.transformers import SparseAutoModelForCausalLM +from sparseml.utils.pytorch.utils import measure_cuda_memory +import torch + +MODEL_PATH = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" +OUTPUT_PATH = "./test_compress_output" +RECIPE = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" + +torch.cuda.set_device(0) +with measure_cuda_memory() as m: + model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0") +print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") + +sparsity_config = getattr(model,"sparsity_config", None) +print(f"Sparsity config before compression: {sparsity_config}") +with measure_cuda_memory() as m: + model.save_pretrained(OUTPUT_PATH, save_compressed=True) +print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") + +torch.cuda.set_device(1) +with measure_cuda_memory() as m: + model_again = SparseAutoModelForCausalLM.from_pretrained( + OUTPUT_PATH, device_map="cuda:1" + ) +print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") +sparsity_config = getattr(model_again,"sparsity_config", None) +print(f"Sparsity config after compression: {sparsity_config}") +``` diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 47dca276..b3165f26 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -21,7 +21,7 @@ from torch.nn import Module -__all__ = ["wrap_module_forward_quantized"] +__all__ = ["wrap_module_forward_quantized", "maybe_calibrate_or_quantize"] @torch.no_grad() @@ -76,14 +76,14 @@ def wrapped_forward(self, *args, **kwargs): if scheme.input_activations is not None: # calibrate and (fake) quantize input activations when applicable - input_ = _maybe_calibrate_or_quantize( + input_ = maybe_calibrate_or_quantize( module, input_, "input", scheme.input_activations ) if scheme.weights is not None: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() - self.weight.data = _maybe_calibrate_or_quantize( + self.weight.data = maybe_calibrate_or_quantize( module, self.weight, "weight", scheme.weights ) @@ -94,7 +94,7 @@ def wrapped_forward(self, *args, **kwargs): if scheme.output_activations is not None: # calibrate and (fake) quantize output activations when applicable - output = _maybe_calibrate_or_quantize( + output = maybe_calibrate_or_quantize( module, output, "output", scheme.output_activations ) @@ -110,7 +110,7 @@ def wrapped_forward(self, *args, **kwargs): setattr(module, "forward", bound_wrapped_forward) -def _maybe_calibrate_or_quantize( +def maybe_calibrate_or_quantize( module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" ) -> torch.Tensor: # only run quantized for the included stages @@ -132,6 +132,7 @@ def _maybe_calibrate_or_quantize( if module.quantization_status == QuantizationStatus.CALIBRATION: # calibration mode - get new quant params from observer observer = getattr(module, f"{base_name}_observer") + updated_scale, updated_zero_point = observer(value) # update scale and zero point diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py new file mode 100644 index 00000000..ac9ed229 --- /dev/null +++ b/src/compressed_tensors/utils/helpers.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +from compressed_tensors.base import SPARSITY_CONFIG_NAME +from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.config import CompressionConfig +from transformers import AutoConfig + + +__all__ = ["infer_compressor_from_model_config"] + + +def infer_compressor_from_model_config( + pretrained_model_name_or_path: str, +) -> Optional[ModelCompressor]: + """ + Given a path to a model config, extract a sparsity config if it exists and return + the associated ModelCompressor + + :param pretrained_model_name_or_path: path to model config on disk or HF hub + :return: matching compressor if config contains a sparsity config + """ + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) + if sparsity_config is None: + return None + + format = sparsity_config.get("format") + sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config) + compressor = ModelCompressor.load_from_registry(format, config=sparsity_config) + return compressor diff --git a/tests/quantization/__init__.py b/tests/compressed_tensors/quantization/observers/quantization/__init__.py similarity index 100% rename from tests/quantization/__init__.py rename to tests/compressed_tensors/quantization/observers/quantization/__init__.py diff --git a/tests/quantization/lifecycle/__init__.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/__init__.py similarity index 100% rename from tests/quantization/lifecycle/__init__.py rename to tests/compressed_tensors/quantization/observers/quantization/lifecycle/__init__.py diff --git a/tests/compressed_tensors/quantization/observers/quantization/lifecycle/conftest.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/conftest.py new file mode 100644 index 00000000..97bf8b0c --- /dev/null +++ b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/conftest.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import pytest +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_scheme import QuantizationScheme + + +@pytest.fixture +def create_quantization_scheme(): + def quantization_scheme( + targets: List[str], + weights: Optional[QuantizationArgs] = None, + input_activations: Optional[QuantizationArgs] = None, + output_activations: Optional[QuantizationArgs] = None, + ): + return QuantizationScheme( + targets=targets, + weights=weights, + input_activations=input_activations, + output_activations=output_activations, + ) + + return quantization_scheme diff --git a/tests/quantization/lifecycle/test_apply.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_apply.py similarity index 100% rename from tests/quantization/lifecycle/test_apply.py rename to tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_apply.py diff --git a/tests/quantization/lifecycle/test_dynamic_lifecycle.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_dynamic_lifecycle.py similarity index 100% rename from tests/quantization/lifecycle/test_dynamic_lifecycle.py rename to tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_dynamic_lifecycle.py diff --git a/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_forward.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_forward.py new file mode 100644 index 00000000..00c95d16 --- /dev/null +++ b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_forward.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from compressed_tensors.quantization.lifecycle.forward import ( + maybe_calibrate_or_quantize, + wrap_module_forward_quantized, +) +from compressed_tensors.quantization.lifecycle.initialize import ( + initialize_module_for_quantization, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus +from torch.nn import Linear + + +def test_wrap_module_forward_quantized(create_quantization_scheme): + num_bits = 8 + quantization_scheme = create_quantization_scheme( + targets=["*"], + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), + ) + layer = Linear(4, 4) + + func_forward = layer.forward.__func__ + + # check that the forward call is overwritten + wrap_module_forward_quantized(layer, quantization_scheme) + + assert not func_forward == layer.forward.__func__ + + +@pytest.mark.parametrize( + "quantization_status", ["initialized", "calibration", "frozen"] +) +def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_status): + num_bits = 8 + quantization_scheme = create_quantization_scheme( + targets=["*"], + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=True), + ) + quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True) + layer = Linear(4, 4) + layer.weight.data *= 100 + + initialize_module_for_quantization(layer, quantization_scheme) + layer.quantization_status = QuantizationStatus(quantization_status) + + # only calibration updates the scale and zero-point + if layer.quantization_status == QuantizationStatus.INITIALIZED: + out = maybe_calibrate_or_quantize( + layer, layer.weight.data, "input", quantization_args + ) + assert torch.allclose(out, layer.weight.data) + elif layer.quantization_status == QuantizationStatus.CALIBRATION: + + out = maybe_calibrate_or_quantize( + layer, layer.weight.data, "input", quantization_args + ) + assert torch.allclose(out, layer.weight.data, atol=0.2) + + elif layer.quantization_status == QuantizationStatus.FROZEN: + # scale and zero points are empty -- cannot quantize + with pytest.raises(Exception): + out = maybe_calibrate_or_quantize( + layer, layer.weight.data, "input", quantization_args + ) diff --git a/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_frozen.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_frozen.py new file mode 100644 index 00000000..056c6089 --- /dev/null +++ b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_frozen.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization +from compressed_tensors.quantization.lifecycle.initialize import ( + initialize_module_for_quantization, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus +from torch.nn import Linear + + +def test_set_module_for_calibration(create_quantization_scheme): + num_bits = 8 + quantization_scheme = create_quantization_scheme( + targets=["*"], + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), + ) + + layer = Linear(4, 4) + + initialize_module_for_quantization(layer, quantization_scheme) + layer.quantization_status = QuantizationStatus("calibration") + + # should have both input and weight observer after initalizing + assert hasattr(layer, "input_observer") + assert hasattr(layer, "weight_observer") + + # observers should get deleted after freezing + freeze_module_quantization(layer) + assert not hasattr(layer, "input_observer") + assert not hasattr(layer, "weight_observer") + + assert layer.quantization_status == QuantizationStatus("frozen") diff --git a/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_initialize.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_initialize.py new file mode 100644 index 00000000..987b2ae2 --- /dev/null +++ b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_initialize.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +from compressed_tensors.quantization.lifecycle.initialize import ( + initialize_module_for_quantization, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus +from torch.nn import Linear + + +NUM_BITS = 8 + + +@pytest.mark.parametrize( + "weights,input_activations", + [ + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + None, + ), + ( + None, + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ], +) +def test_initialize_module_for_quantization( + create_quantization_scheme, weights, input_activations +): + quantization_scheme = create_quantization_scheme( + targets=["*"], + weights=weights, + input_activations=input_activations, + ) + layer = Linear(4, 4) + + assert not hasattr(layer, "quantization_scheme") + assert not hasattr(layer, "quantization_status") + + # add attributes, zero_points and scale + initialize_module_for_quantization(layer, quantization_scheme) + + registered_params = {"weight", "bias"} + if weights is not None: + registered_params.add("weight_scale") + registered_params.add("weight_zero_point") + + if input_activations is not None: + registered_params.add("input_scale") + registered_params.add("input_zero_point") + + for key in layer.state_dict().keys(): + assert key in registered_params + registered_params.remove(key) + + assert len(registered_params) == 0 + + assert hasattr(layer, "quantization_scheme") + assert hasattr(layer, "quantization_status") + + assert layer.quantization_status == QuantizationStatus.INITIALIZED diff --git a/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_lifecycle.py b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_lifecycle.py new file mode 100644 index 00000000..352fcb4d --- /dev/null +++ b/tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_lifecycle.py @@ -0,0 +1,122 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy + +import torch +from compressed_tensors.quantization.lifecycle.calibration import ( + set_module_for_calibration, +) +from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization +from compressed_tensors.quantization.lifecycle.initialize import ( + initialize_module_for_quantization, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus +from torch.nn import Linear + + +def test_lifecyle(create_quantization_scheme): + num_bits = 8 + + quantization_scheme = create_quantization_scheme( + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + targets=["*"], + ) + + layer = Linear(4, 4) + layer.weight.data *= 100 + + # updated layer keys check + expected_layer_keys = {"weight", "bias"} + for key in layer.state_dict().keys(): + expected_layer_keys.remove(key) + assert len(expected_layer_keys) == 0 + + # over write forward pass and register zero_point and scale + initialize_module_for_quantization(layer, quantization_scheme) + expected_layer_keys = { + "input_scale", + "input_zero_point", + "weight_scale", + "weight_zero_point", + "weight", + "bias", + } + for key in layer.state_dict().keys(): + expected_layer_keys.remove(key) + assert len(expected_layer_keys) == 0 + + # should have both input and weight observer after initalizing + assert hasattr(layer, "input_observer") + assert hasattr(layer, "weight_observer") + + assert hasattr(layer, "quantization_scheme") + assert hasattr(layer, "quantization_status") + assert layer.quantization_status == QuantizationStatus.INITIALIZED + + set_module_for_calibration(layer) + assert layer.quantization_status == QuantizationStatus.CALIBRATION + + # do a calibration step + assert torch.numel(layer.input_zero_point.data) == 0 + assert torch.numel(layer.input_scale) == 0 + assert torch.numel(layer.weight_scale) == 0 + assert torch.numel(layer.weight_zero_point) == 0 + + layer(torch.randn(4, 4)) + + # zero-points and scale should be updated after forward pass + assert torch.numel(layer.input_zero_point.data) > 0 + assert torch.numel(layer.input_scale) > 0 + assert torch.numel(layer.weight_scale) > 0 + assert torch.numel(layer.weight_zero_point) > 0 + + # symmetric zero points should center at 0 + assert layer.weight_zero_point.data == 0 + + # check high and low bound of the weights + assert torch.all(layer.weight.data >= -128) and torch.all(layer.weight.data <= 127) + + initialized_layer_input_zero_point = deepcopy(layer.input_zero_point) + initialized_layer_input_scale = deepcopy(layer.input_scale) + initialized_layer_weight_scale = deepcopy(layer.weight_scale) + # calibrate the layers with each iteration + for _ in range(10): + layer(torch.randn(4, 4)) + + assert initialized_layer_input_zero_point != layer.input_zero_point + assert initialized_layer_input_scale != layer.input_scale + assert initialized_layer_weight_scale == layer.weight_scale + + # check quantization f_q(x) is applied after frozen without update + input_check_for_quant = torch.randn(4, 4) + out_calibration = layer(input_check_for_quant) + + layer_before_freeze_input_zero_point = deepcopy(layer.input_zero_point) + layer_before_freeze_input_scale = deepcopy(layer.input_scale) + layer_before_freeze_weight_scale = deepcopy(layer.weight_scale) + + # Freeze, no update after any forward pass + freeze_module_quantization(layer) + + for _ in range(10): + layer(torch.randn(4, 4)) + assert layer_before_freeze_input_zero_point == layer.input_zero_point + assert layer_before_freeze_input_scale == layer.input_scale + assert layer_before_freeze_weight_scale == layer.weight_scale + + # check that the same quantization is applied as calibration to frozen + assert torch.all(out_calibration == layer(input_check_for_quant)) diff --git a/tests/quantization/test_quant_args.py b/tests/compressed_tensors/quantization/observers/quantization/test_quant_args.py similarity index 100% rename from tests/quantization/test_quant_args.py rename to tests/compressed_tensors/quantization/observers/quantization/test_quant_args.py diff --git a/tests/quantization/test_quant_config.py b/tests/compressed_tensors/quantization/observers/quantization/test_quant_config.py similarity index 100% rename from tests/quantization/test_quant_config.py rename to tests/compressed_tensors/quantization/observers/quantization/test_quant_config.py diff --git a/tests/quantization/test_quant_scheme.py b/tests/compressed_tensors/quantization/observers/quantization/test_quant_scheme.py similarity index 100% rename from tests/quantization/test_quant_scheme.py rename to tests/compressed_tensors/quantization/observers/quantization/test_quant_scheme.py diff --git a/tests/compressed_tensors/quantization/observers/test_min_max.py b/tests/compressed_tensors/quantization/observers/test_min_max.py new file mode 100644 index 00000000..ee5a63bb --- /dev/null +++ b/tests/compressed_tensors/quantization/observers/test_min_max.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from compressed_tensors.quantization.quant_args import QuantizationArgs + + +@pytest.mark.parametrize( + "symmetric,expected_scale,expected_zero_point", + [ + (True, 0.0078, 0), + (False, 0.0039, -128), + ], +) +def test_min_max_observer(symmetric, expected_scale, expected_zero_point): + tensor = torch.tensor([1, 1, 1, 1, 1]) + num_bits = 8 + weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric) + + observer = weights.get_observer() + scale, zero_point = observer(tensor) + + assert round(scale.item(), 4) == expected_scale + assert round(zero_point.item(), 4) == expected_zero_point + + +def test_min_max_observer_symmetric_scale_range(): + tensor = torch.rand(4, 4) + tensor *= 127 + + num_bits = 8 + weights = QuantizationArgs(num_bits=num_bits, symmetric=True) + + observer = weights.get_observer() + scale, zero_point = observer(tensor) + + # if symmetric, max symmetric_range = abs(-128) / 255 + assert round(scale.item(), 4) <= 1.0039 + assert round(zero_point.item(), 4) == 0 + + +def test_min_max_observer_value_update(): + inp = torch.tensor([1, 1, 1, 1, 1]) + inp_update_max = torch.tensor([127, 1, 1, 1, 1]) + inp_update_min = torch.tensor([-128, 1, 1, 1, 1]) + + delta = 1e-6 + + # udpate the min, max twice total + tensors = [ + inp, + inp, + inp_update_max, # update max + inp, + inp_update_min, # update min + ] + + tensor = inp + num_bits = 8 + weights = QuantizationArgs(num_bits=num_bits, symmetric=True) + + observer = weights.get_observer() + curr_max = 1 + curr_min = 1 + for i, tensor in enumerate(tensors): + observer(tensor) + curr_max = max(observer.max_val, curr_max) + curr_min = min(observer.min_val, curr_max) + + if i < 2: + assert curr_max == 1 + assert curr_min == 1 + elif i < 4: + assert abs(curr_max - 2.2600) < delta + assert curr_min == 1 + else: + assert abs(curr_max - 2.2600) < delta + assert abs(curr_min - (-0.2900)) < delta