From a5cfaa1b621360aeda5a226bbfb8a6da6c5033ca Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Fri, 19 Apr 2024 15:54:16 +0000 Subject: [PATCH] simplify UX --- README.md | 27 ++++--- src/compressed_tensors/README.md | 2 +- src/compressed_tensors/compressors/base.py | 4 +- src/compressed_tensors/compressors/dense.py | 3 +- .../compressors/sparse_bitmask.py | 3 +- src/compressed_tensors/config/base.py | 8 +- src/compressed_tensors/config/dense.py | 6 +- .../config/sparse_bitmask.py | 6 +- src/compressed_tensors/utils/helpers.py | 45 +++++------ tests/test_registry.py | 10 ++- tests/test_utils/test_helpers.py | 78 +++++++++++-------- 11 files changed, 107 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index 294c6a2f..cf292014 100644 --- a/README.md +++ b/README.md @@ -36,23 +36,28 @@ pip install -e . ### Saving -The function `save_compressed` returns an optional `compression_config` (if compression has been applied). It can be used to inspect the applied compression. +The function `save_compressed` uses the `compression_format` argument to apply compression to tensors. +The function `load_compressed` reverses the process: converts the compressed weights on disk to decompressed weights in device memory. ```python -from compressed_tensors import save_compressed +from compressed_tensors import save_compressed, load_compressed, BitmaskConfig from torch import Tensor +from typing import Dict -tensors: Dict[str, Tensor] = ... -compression_config: Dict = save_compressed(tensors, "model.safetensors") -``` +# the example BitmaskConfig method efficiently compresses +# tensors with large number of zero entries +compression_config = BitmaskConfig() -### Loading -```python -from compressed_tensors import load_compressed -from torch import Tensor +tensors: Dict[str, Tensor] = {"tensor_1": Tensor( + [[0.0, 0.0, 0.0], + [1.0, 1.0, 1.0]] +)} +# compress tensors using BitmaskConfig compression format (save them efficiently on disk) +save_compressed(tensors, "model.safetensors", compression_format=compression_config.format) -tensors: Dict[str, Tensor] = load_compressed("model.safetensors", device="cpu") +# decompress tensors (load the uncompressed representation to device memory) +tensors = load_compressed("model.safetensors", device="cpu", compression_config = compression_config) ``` ## Benefits @@ -87,7 +92,7 @@ The library provides pathways to automatically add the config information to the ```json // config.json { - "sparsity_config": { + "compression_config": { "format": "sparse_bitmask", // "dense_sparsity" for the original tensor format // Informational diff --git a/src/compressed_tensors/README.md b/src/compressed_tensors/README.md index 5b1c8ece..107eca65 100644 --- a/src/compressed_tensors/README.md +++ b/src/compressed_tensors/README.md @@ -34,7 +34,7 @@ Config information gets stored in the HF config file ```json // config.json { - "sparsity_config": { + "compression_config": { "format": "sparse_bitmask", // "dense_sparsity" for original tensor format // informational diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 9d7033d3..50c34da8 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import operator -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Tuple from compressed_tensors.base import CONFIG_NAME from compressed_tensors.config import CompressionConfig @@ -33,7 +33,7 @@ class ModelCompressor(RegistryMixin): :param config: config specifying compression parameters """ - def __init__(self, config: CompressionConfig): + def __init__(self, config: Optional[CompressionConfig] = None): self.config = config def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: diff --git a/src/compressed_tensors/compressors/dense.py b/src/compressed_tensors/compressors/dense.py index 2fa1603c..97a62a81 100644 --- a/src/compressed_tensors/compressors/dense.py +++ b/src/compressed_tensors/compressors/dense.py @@ -15,10 +15,11 @@ from typing import Dict, Generator, Tuple from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.config import CompressionFormat from torch import Tensor -@ModelCompressor.register(name="dense_sparsity") +@ModelCompressor.register(name=CompressionFormat.dense_sparsity.value) class DenseCompressor(ModelCompressor): """ Identity compressor for dense models, returns the original state_dict diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index cb19b633..4a946fb9 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -18,6 +18,7 @@ import numpy import torch from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.config import CompressionFormat from compressed_tensors.utils import get_nested_weight_mappings, merge_names from safetensors import safe_open from torch import Tensor @@ -36,7 +37,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -@ModelCompressor.register(name="sparse_bitmask") +@ModelCompressor.register(name=CompressionFormat.sparse_bitmask.value) class BitmaskCompressor(ModelCompressor): """ Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index f58b11f8..96778995 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Optional from compressed_tensors.registry import RegistryMixin from pydantic import BaseModel -__all__ = ["CompressionConfig"] +__all__ = ["CompressionConfig", "CompressionFormat"] + + +class CompressionFormat(Enum): + dense_sparsity = "dense-sparsity" + sparse_bitmask = "sparse-bitmask" class CompressionConfig(RegistryMixin, BaseModel): diff --git a/src/compressed_tensors/config/dense.py b/src/compressed_tensors/config/dense.py index aa23220c..0a18309e 100644 --- a/src/compressed_tensors/config/dense.py +++ b/src/compressed_tensors/config/dense.py @@ -14,13 +14,13 @@ from typing import Optional -from compressed_tensors.config import CompressionConfig +from compressed_tensors.config import CompressionConfig, CompressionFormat __all__ = ["DenseSparsityConfig"] -@CompressionConfig.register(name="dense_sparsity") +@CompressionConfig.register(name=CompressionFormat.dense_sparsity.value) class DenseSparsityConfig(CompressionConfig): """ Identity configuration for storing a sparse model in @@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig): "unstructured", "2:4", "8:16" etc """ - format: str = "dense_sparsity" + format: str = CompressionFormat.dense_sparsity.value global_sparsity: Optional[float] = 0.0 sparsity_structure: Optional[str] = "unstructured" diff --git a/src/compressed_tensors/config/sparse_bitmask.py b/src/compressed_tensors/config/sparse_bitmask.py index 9b9cf211..9d2015f3 100644 --- a/src/compressed_tensors/config/sparse_bitmask.py +++ b/src/compressed_tensors/config/sparse_bitmask.py @@ -14,13 +14,13 @@ from typing import Optional -from compressed_tensors.config.base import CompressionConfig +from compressed_tensors.config import CompressionConfig, CompressionFormat __all__ = ["BitmaskConfig"] -@CompressionConfig.register(name="sparse_bitmask") +@CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value) class BitmaskConfig(CompressionConfig): """ Configuration for storing a sparse model using @@ -31,6 +31,6 @@ class BitmaskConfig(CompressionConfig): "unstructured", "2:4", "8:16" etc """ - format: str = "sparse_bitmask" + format: str = CompressionFormat.sparse_bitmask.value global_sparsity: Optional[float] = 0.0 sparsity_structure: Optional[str] = "unstructured" diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 8de3c233..1c8dd29f 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -13,11 +13,11 @@ # limitations under the License. from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Literal, Optional, Union from compressed_tensors.base import CONFIG_NAME from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.config import CompressionConfig +from compressed_tensors.config import CompressionConfig, CompressionFormat from safetensors import safe_open from safetensors.torch import save_file from torch import Tensor @@ -51,46 +51,46 @@ def infer_compressor_from_model_config( def save_compressed( tensors: Dict[str, Tensor], save_path: Union[str, Path], - compression_config: Optional[CompressionConfig] = None, -) -> Optional[CompressionConfig]: + compression_format: Optional[ + Literal[CompressionFormat.sparse_bitmask, CompressionFormat.dense_sparsity] + ] = None, +): """ Save compressed tensors to disk. If tensors are not compressed, save them as is. :param tensors: dictionary of tensors to compress :param save_path: path to save compressed tensors - :param compression_config: compression config to use for compressing tensors. - Can be either inferred from tensors or provided explicitly + :param compression_format: compression format used for the tensors :return: compression config, if tensors were compressed - None otherwise """ if tensors is None or len(tensors) == 0: raise ValueError("No tensors or empty tensors provided to compress") - # create compression config if not provided - # TODO: Not implemented, need to get this in ASAP - # compression_config = compression_config or infer_compression_config(tensors) - - if compression_config is None: + if compression_format is None: # no compression applied save_file(tensors, save_path) - return None + return + + if not ( + compression_format in ModelCompressor.registered_names() + or compression_format in ModelCompressor.registered_aliases() + ): + raise ValueError( + f"Unknown compression format: {compression_format}. " + f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501 + ) # compress - compression_format = compression_config.format - compressor = ModelCompressor.load_from_registry( - compression_format, config=compression_config - ) + compressor = ModelCompressor.load_from_registry(compression_format) # save compressed tensors compressed_tensors = compressor.compress(tensors) save_file(compressed_tensors, save_path) - # return compression_config as dict - return {CONFIG_NAME: compression_config.model_dump(exclude_unset=True)} - def load_compressed( compressed_tensors: Union[str, Path], - compression_config: Optional[CompressionConfig] = None, + compression_config: CompressionConfig = None, device: Optional[str] = "cpu", ) -> Dict[str, Tensor]: """ @@ -99,7 +99,6 @@ def load_compressed( :param compressed_tensors: path to compressed tensors :param compression_config: compression config to use for decompressing tensors. - Can be either inferred from tensors or provided explicitly. :param device: device to move tensors to. If None, tensors are loaded on CPU. :return decompressed tensors """ @@ -107,10 +106,6 @@ def load_compressed( if compressed_tensors is None or not Path(compressed_tensors).exists(): raise ValueError("No compressed tensors provided to load") - # create compression config if not provided - # TODO: Not implemented, need to get this in ASAP - # compression_config = compression_config or infer_compression_config(tensors) - if compression_config is None: # no compression applied tensors = {} diff --git a/tests/test_registry.py b/tests/test_registry.py index a183d77d..ffe66b85 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -17,6 +17,7 @@ BitmaskCompressor, BitmaskConfig, CompressionConfig, + CompressionFormat, DenseCompressor, DenseSparsityConfig, ModelCompressor, @@ -26,8 +27,8 @@ @pytest.mark.parametrize( "name,type", [ - ["sparse_bitmask", BitmaskConfig], - ["dense_sparsity", DenseSparsityConfig], + [CompressionFormat.sparse_bitmask.value, BitmaskConfig], + [CompressionFormat.dense_sparsity.value, DenseSparsityConfig], ], ) def test_configs(name, type): @@ -38,7 +39,10 @@ def test_configs(name, type): @pytest.mark.parametrize( "name,type", - [["sparse_bitmask", BitmaskCompressor], ["dense_sparsity", DenseCompressor]], + [ + [CompressionFormat.sparse_bitmask.value, BitmaskCompressor], + [CompressionFormat.dense_sparsity.value, DenseCompressor], + ], ) def test_compressors(name, type): compressor = ModelCompressor.load_from_registry( diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index 54b6b1f9..f643233c 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -19,47 +19,44 @@ @pytest.fixture -def tensors_and_config_sparse(): +def tensors(): tensors = {"tensor_1": torch.Tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])} - expected_config_json = { - "compression_config": { - "format": "sparse_bitmask", - "global_sparsity": ( - tensors["tensor_1"].sum() / tensors["tensor_1"].numel() - ).item(), - "sparsity_structure": "unstructured", - } - } - return tensors, expected_config_json - - -@pytest.fixture -def tensors_dense(): - tensors = {"tensor_1": torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])} return tensors -def test_save_compressed_sparse(tmp_path, tensors_and_config_sparse): - tensors, expected_config_json = tensors_and_config_sparse - - config_json = save_compressed( +def test_save_compressed_sparse_bitmask(tmp_path, tensors): + save_compressed( tensors, - compression_config=BitmaskConfig(**expected_config_json["compression_config"]), + compression_format="sparse-bitmask", save_path=tmp_path / "model.safetensors", ) assert (tmp_path / "model.safetensors").exists() - assert config_json == expected_config_json -def test_save_compressed_dense(tmp_path, tensors_dense): - tensors = tensors_dense +def test_save_compressed_dense_sparsity(tmp_path, tensors): + save_compressed( + tensors, + compression_format="dense-sparsity", + save_path=tmp_path / "model.safetensors", + ) + assert (tmp_path / "model.safetensors").exists() - config_json = save_compressed( + +def test_save_compressed_no_compression(tmp_path, tensors): + save_compressed( tensors, save_path=tmp_path / "model.safetensors", ) assert (tmp_path / "model.safetensors").exists() - assert config_json is None + + +def test_save_compressed_rubbish_compression_format(tmp_path, tensors): + with pytest.raises(Exception): + save_compressed( + tensors, + compression_format="this_is_not_a_valid_format", + save_path=tmp_path / "model.safetensors", + ) def test_save_compressed_empty(): @@ -71,24 +68,37 @@ def test_save_compressed_empty(): save_compressed(None, "") -def test_load_compressed_sparse(tmp_path, tensors_and_config_sparse): - tensors, expected_config_json = tensors_and_config_sparse - compression_config = BitmaskConfig(**expected_config_json["compression_config"]) +def test_load_compressed_sparse_bitmask(tmp_path, tensors): save_compressed( tensors, - compression_config=compression_config, + compression_format="sparse-bitmask", save_path=tmp_path / "model.safetensors", ) + compression_config = BitmaskConfig( + format="sparse-bitmask", + ) loaded_tensors = load_compressed(tmp_path / "model.safetensors", compression_config) for key in tensors: assert torch.allclose(tensors[key], loaded_tensors[key]) -def test_load_compressed_dense(tmp_path, tensors_dense): +def test_load_compressed_dense_sparsity(tmp_path, tensors): save_compressed( - tensors_dense, + tensors, + compression_format="dense-sparsity", + save_path=tmp_path / "model.safetensors", + ) + compression_config = BitmaskConfig(format="dense-sparsity") + loaded_tensors = load_compressed(tmp_path / "model.safetensors", compression_config) + # loaded_tensors is empty -> decompression returns empty dict + assert not loaded_tensors + + +def test_load_compressed_no_compression(tmp_path, tensors): + save_compressed( + tensors, save_path=tmp_path / "model.safetensors", ) loaded_tensors = load_compressed(tmp_path / "model.safetensors") - for key in tensors_dense: - assert torch.allclose(tensors_dense[key], loaded_tensors[key]) + for key in tensors: + assert torch.allclose(tensors[key], loaded_tensors[key])