-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Lifecycle][Tests] Feature Branch (#38)
* test forward (#16) * test frozen (#17) * test frozen * rename * lifecycle conftest (#21) * test initalize (#18) * test initalize * newline * parametrize weights and inp_act * remove dup * test lifecycle (#19) * test lifecycle * comments * comments * add quantization test * Lifecycle/min max obs (#20) * min max test * add minmax obs * test scale range and min_max update * rebase * rebase * fix * fix
- Loading branch information
Showing
17 changed files
with
672 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Save/Load Compressed SafeTensors | ||
|
||
## Motivation | ||
|
||
* Reduce disk space by saving in a compressed format for sparse models. Models in this compressed format will be loaded by vLLM for more efficient inference | ||
* Set up the save/load architecture such that we can easily expand to additional compression formats in the future. The config should be human readable so users can understand the compression format at a quick glance | ||
|
||
## SafeTensors File Format | ||
|
||
For each parameter in the uncompressed state_dict, we store the following attributes | ||
needed for decompression in the compressed state_dict: | ||
|
||
* compressed tensor | ||
* bitmask | ||
* uncompressed shape | ||
* row offsets | ||
|
||
```python | ||
# dense | ||
{ | ||
PARAM_NAME: uncompressed_tensor | ||
} | ||
|
||
# compressed | ||
{ | ||
PARAM_NAME.compressed: compressed_tensor # 1d tensor | ||
PARAM_NAME.bitmask: value # 2d bitmask tensor (nrows x (ncols / 8)) | ||
PARAM_NAME.shape: value # uncompressed shape tensor | ||
PARAM_NAME.row_offsets: value # 1d offsets tensor | ||
} | ||
``` | ||
|
||
Config information gets stored in the HF config file | ||
```json | ||
// config.json | ||
{ | ||
"sparsity_config": { | ||
"format": "sparse_bitmask", // "dense_sparsity" for original tensor format | ||
|
||
// informational | ||
"sparsity_structure": "unstructured", // or 2:4, 8:16 etc... | ||
"global_sparsity": "0.5" | ||
} | ||
} | ||
``` | ||
|
||
## Saving/Loading Interface | ||
|
||
Loading in a compressed model requires no interface changes | ||
|
||
```python | ||
from sparseml.transformers.utils import SparseAutoModelForCausalLM | ||
|
||
# should contain model.safetensors or model.safetensors.index.json | ||
model_path = "/PATH/TO/COMPRESSED_MODEL" | ||
|
||
model = SparseAutoModelForCausalLM.from_pretrained( | ||
model_name_or_path=model_path, | ||
**model_kwargs, | ||
) | ||
``` | ||
|
||
Saving a compressed model with an explicitly provided compression config. The config | ||
is saved to the model's `config.json` file. **Note:** the model must have been | ||
initialized with SparseAutoModelForCausalLM.from_pretrained() | ||
|
||
```python | ||
from compressed_tensors import BitmaskConfig | ||
|
||
output_dir = "/PATH/TO/SAVE/COMPRESSED_MODEL" | ||
sparsity_config = BitmaskConfig() | ||
|
||
model.save_pretrained( | ||
save_directory=output_dir, | ||
sparsity_config=sparsity_config, | ||
) | ||
``` | ||
|
||
Saving a compressed model, inferring the config from the model attributes | ||
|
||
```python | ||
model.save_pretrained( | ||
save_directory=output_dir, | ||
save_compressed=True | ||
) | ||
``` | ||
|
||
Saving a model in the dense format. If the model has at least 5% global sparsity a | ||
sparsity config will still be included in `config.json` with format `dense_sparsity` | ||
|
||
```python | ||
model.save_pretrained( | ||
save_directory=output_dir | ||
) | ||
``` | ||
|
||
Saving a model in the dense format, bypassing the sparsity config calculation. When the | ||
`skip_compression_stats` flag is set, no sparsity config will be written to | ||
`config.json` | ||
|
||
```python | ||
model.save_pretrained( | ||
save_directory=output_dir | ||
skip_compression_stats=True | ||
) | ||
``` | ||
|
||
## Enable Compression During One-Shot and Sparse Finetunining | ||
Models that are saved in a supported compressed format on disk will automatically be | ||
decompressed when loaded as input to `sparseml.transformers.oneshot` or | ||
`sparseml.transformers.train` | ||
|
||
To enable compression on save after oneshot or finetuning simply add the | ||
`save_compressed=True` argument to `sparseml.transformers.oneshot` or | ||
`sparseml.transformers.train` | ||
|
||
```python | ||
from sparseml.transformers import train | ||
|
||
train( | ||
save_compressed=True, | ||
model="neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4", | ||
recipe=RECIPE, | ||
dataset=DATASET | ||
) | ||
``` | ||
|
||
|
||
## Example Code | ||
|
||
Loads a 60% sparse model, compresses it using the inferred bitmask compression, then | ||
reloads the compressed model. | ||
|
||
```python | ||
from sparseml.transformers import SparseAutoModelForCausalLM | ||
from sparseml.utils.pytorch.utils import measure_cuda_memory | ||
import torch | ||
|
||
MODEL_PATH = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" | ||
OUTPUT_PATH = "./test_compress_output" | ||
RECIPE = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" | ||
|
||
torch.cuda.set_device(0) | ||
with measure_cuda_memory() as m: | ||
model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0") | ||
print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") | ||
|
||
sparsity_config = getattr(model,"sparsity_config", None) | ||
print(f"Sparsity config before compression: {sparsity_config}") | ||
with measure_cuda_memory() as m: | ||
model.save_pretrained(OUTPUT_PATH, save_compressed=True) | ||
print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") | ||
|
||
torch.cuda.set_device(1) | ||
with measure_cuda_memory() as m: | ||
model_again = SparseAutoModelForCausalLM.from_pretrained( | ||
OUTPUT_PATH, device_map="cuda:1" | ||
) | ||
print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") | ||
sparsity_config = getattr(model_again,"sparsity_config", None) | ||
print(f"Sparsity config after compression: {sparsity_config}") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Optional | ||
|
||
from compressed_tensors.base import SPARSITY_CONFIG_NAME | ||
from compressed_tensors.compressors import ModelCompressor | ||
from compressed_tensors.config import CompressionConfig | ||
from transformers import AutoConfig | ||
|
||
|
||
__all__ = ["infer_compressor_from_model_config"] | ||
|
||
|
||
def infer_compressor_from_model_config( | ||
pretrained_model_name_or_path: str, | ||
) -> Optional[ModelCompressor]: | ||
""" | ||
Given a path to a model config, extract a sparsity config if it exists and return | ||
the associated ModelCompressor | ||
:param pretrained_model_name_or_path: path to model config on disk or HF hub | ||
:return: matching compressor if config contains a sparsity config | ||
""" | ||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) | ||
sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) | ||
if sparsity_config is None: | ||
return None | ||
|
||
format = sparsity_config.get("format") | ||
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config) | ||
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config) | ||
return compressor |
File renamed without changes.
File renamed without changes.
37 changes: 37 additions & 0 deletions
37
tests/compressed_tensors/quantization/observers/quantization/lifecycle/conftest.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Optional | ||
|
||
import pytest | ||
from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
from compressed_tensors.quantization.quant_scheme import QuantizationScheme | ||
|
||
|
||
@pytest.fixture | ||
def create_quantization_scheme(): | ||
def quantization_scheme( | ||
targets: List[str], | ||
weights: Optional[QuantizationArgs] = None, | ||
input_activations: Optional[QuantizationArgs] = None, | ||
output_activations: Optional[QuantizationArgs] = None, | ||
): | ||
return QuantizationScheme( | ||
targets=targets, | ||
weights=weights, | ||
input_activations=input_activations, | ||
output_activations=output_activations, | ||
) | ||
|
||
return quantization_scheme |
File renamed without changes.
File renamed without changes.
82 changes: 82 additions & 0 deletions
82
tests/compressed_tensors/quantization/observers/quantization/lifecycle/test_forward.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import pytest | ||
import torch | ||
from compressed_tensors.quantization.lifecycle.forward import ( | ||
maybe_calibrate_or_quantize, | ||
wrap_module_forward_quantized, | ||
) | ||
from compressed_tensors.quantization.lifecycle.initialize import ( | ||
initialize_module_for_quantization, | ||
) | ||
from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
from compressed_tensors.quantization.quant_config import QuantizationStatus | ||
from torch.nn import Linear | ||
|
||
|
||
def test_wrap_module_forward_quantized(create_quantization_scheme): | ||
num_bits = 8 | ||
quantization_scheme = create_quantization_scheme( | ||
targets=["*"], | ||
weights=QuantizationArgs(num_bits=num_bits, symmetric=True), | ||
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), | ||
) | ||
layer = Linear(4, 4) | ||
|
||
func_forward = layer.forward.__func__ | ||
|
||
# check that the forward call is overwritten | ||
wrap_module_forward_quantized(layer, quantization_scheme) | ||
|
||
assert not func_forward == layer.forward.__func__ | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"quantization_status", ["initialized", "calibration", "frozen"] | ||
) | ||
def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_status): | ||
num_bits = 8 | ||
quantization_scheme = create_quantization_scheme( | ||
targets=["*"], | ||
weights=QuantizationArgs(num_bits=num_bits, symmetric=True), | ||
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=True), | ||
) | ||
quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True) | ||
layer = Linear(4, 4) | ||
layer.weight.data *= 100 | ||
|
||
initialize_module_for_quantization(layer, quantization_scheme) | ||
layer.quantization_status = QuantizationStatus(quantization_status) | ||
|
||
# only calibration updates the scale and zero-point | ||
if layer.quantization_status == QuantizationStatus.INITIALIZED: | ||
out = maybe_calibrate_or_quantize( | ||
layer, layer.weight.data, "input", quantization_args | ||
) | ||
assert torch.allclose(out, layer.weight.data) | ||
elif layer.quantization_status == QuantizationStatus.CALIBRATION: | ||
|
||
out = maybe_calibrate_or_quantize( | ||
layer, layer.weight.data, "input", quantization_args | ||
) | ||
assert torch.allclose(out, layer.weight.data, atol=0.2) | ||
|
||
elif layer.quantization_status == QuantizationStatus.FROZEN: | ||
# scale and zero points are empty -- cannot quantize | ||
with pytest.raises(Exception): | ||
out = maybe_calibrate_or_quantize( | ||
layer, layer.weight.data, "input", quantization_args | ||
) |
Oops, something went wrong.