From 10f976b7362062f8a86762a9565c71c67d8d522c Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Thu, 25 Apr 2024 12:02:16 +0000 Subject: [PATCH] finalize the PR --- README.md | 11 ++-- examples/bitmask_compression.ipynb | 18 +++--- makefile | 6 +- .../compressors/__init__.py | 7 ++- src/compressed_tensors/compressors/helpers.py | 52 +++++++++------- tests/test_utils/test_helpers.py | 59 ++++++++++++------- 6 files changed, 95 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 29bb0b11..361a68f9 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,6 @@ from typing import Dict # tensors with large number of zero entries compression_config = BitmaskConfig() - tensors: Dict[str, Tensor] = {"tensor_1": Tensor( [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]] @@ -55,8 +54,10 @@ tensors: Dict[str, Tensor] = {"tensor_1": Tensor( # compress tensors using BitmaskConfig compression format (save them efficiently on disk) save_compressed(tensors, "model.safetensors", compression_format=compression_config.format) -# decompress tensors (load the uncompressed representation to device memory) -tensors = load_compressed("model.safetensors", compression_config = compression_config) +# decompress tensors (load_compressed returns a generator for memory efficiency) +decompressed_tensors = {} +for tensor_name, tensor in load_compressed("model.safetensors", compression_config = compression_config): + decompressed_tensors[tensor_name] = tensor ``` ## Saving/Loading Compressed Models (Bitmask Compression) @@ -76,6 +77,6 @@ compression_config = BitmaskConfig() # save compressed model weights save_compressed_model(model, "compressed_model.safetensors", compression_format=compression_config.format) -# load compressed model weights -state_dict = load_compressed("compressed_model.safetensors", compression_config) +# load compressed model weights (`dict` turns generator into a dictionary) +state_dict = dict(load_compressed("compressed_model.safetensors", compression_config)) ``` diff --git a/examples/bitmask_compression.ipynb b/examples/bitmask_compression.ipynb index 7658a67a..995629c4 100644 --- a/examples/bitmask_compression.ipynb +++ b/examples/bitmask_compression.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -63,7 +63,7 @@ ")" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -121,14 +121,14 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.73it/s]\n" + "Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.92it/s]\n" ] }, { @@ -168,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -190,7 +190,7 @@ "\n", "## load the compressed-tensors to memory ##\n", "config = BitmaskConfig() # we need to specify the method for decompression\n", - "state_dict_2 = load_compressed(\"compressed_model.safetensors\", config)\n", + "state_dict_2 = dict(load_compressed(\"compressed_model.safetensors\", config)) # load_compressed returns a generator, we convert it to a dict\n", "\n", "tensors_equal = all(torch.equal(state_dict_1[key], state_dict_2[key]) for key in state_dict_1)\n", "\n", diff --git a/makefile b/makefile index 435a37b9..255514f9 100644 --- a/makefile +++ b/makefile @@ -1,4 +1,4 @@ -BUILDDIR := $(PWD) + PYCHECKDIRS := src tests PYCHECKGLOBS := 'src/**/*.py' 'tests/**/*.py' 'utils/**/*.py' 'examples/**/*.py' setup.py # run checks on all files for the repo @@ -23,6 +23,10 @@ test: @echo "Running python tests"; pytest tests; +# creates wheel file +build: + python3 setup.py sdist bdist_wheel $(BUILD_ARGS) + # clean package clean: @echo "Cleaning up"; diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index 50d569e4..c93f1346 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -16,5 +16,10 @@ from .base import ModelCompressor from .dense import DenseCompressor -from .helpers import infer_compressor_from_model_config +from .helpers import ( + infer_compressor_from_model_config, + load_compressed, + save_compressed, + save_compressed_model, +) from .sparse_bitmask import BitmaskCompressor, BitmaskTensor diff --git a/src/compressed_tensors/compressors/helpers.py b/src/compressed_tensors/compressors/helpers.py index 8d6c26cf..1ba75636 100644 --- a/src/compressed_tensors/compressors/helpers.py +++ b/src/compressed_tensors/compressors/helpers.py @@ -13,16 +13,14 @@ # limitations under the License. from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Generator, Optional, Tuple, Union import torch from compressed_tensors.base import SPARSITY_CONFIG_NAME from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.config import ( - CompressionConfig, - CompressionFormat, - DenseSparsityConfig, -) +from compressed_tensors.config import CompressionConfig, CompressionFormat +from compressed_tensors.utils.safetensors_load import get_weight_mappings +from safetensors import safe_open from safetensors.torch import save_file from torch import Tensor from transformers import AutoConfig @@ -97,29 +95,41 @@ def load_compressed( compressed_tensors: Union[str, Path], compression_config: CompressionConfig = None, device: Optional[str] = "cpu", -) -> Dict[str, Tensor]: +) -> Generator[Tuple[str, Tensor], None, None]: """ - Load compressed tensors from disk. If tensors are not compressed, - load them as is. + Load compressed tensors from disk. + If tensors are not compressed, load them as is. - :param compressed_tensors: path to compressed tensors + :param compressed_tensors: path to compressed tensors. + This can be a path to a file or a directory containing + one or multiple safetensor files (if multiple - in the format + assumed by huggingface) :param compression_config: compression config to use for decompressing tensors. :param device: device to move tensors to. If None, tensors are loaded on CPU. - :return decompressed tensors + :param return_dict: if True, return a dictionary of decompressed tensors + :return a generator that yields the name and tensor of the decompressed tensor """ - if compressed_tensors is None or not Path(compressed_tensors).exists(): raise ValueError("No compressed tensors provided to load") - # if no compression_config specified, default to `dense_sparsity` - compression_config = compression_config or DenseSparsityConfig() - - # decompress - compression_format = compression_config.format - compressor = ModelCompressor.load_from_registry( - compression_format, config=compression_config - ) - return dict(compressor.decompress(compressed_tensors, device=device)) + if ( + compression_config is None + or compression_config.format == CompressionFormat.dense_sparsity.value + ): + # if no compression_config specified, or `dense_sparsity` format specified, + # assume tensors are not compressed on disk + weight_mappings = get_weight_mappings(compressed_tensors) + for weight_name, file_with_weight_name in weight_mappings.items(): + with safe_open(file_with_weight_name, framework="pt", device=device) as f: + weight = f.get_tensor(weight_name) + yield weight_name, weight + else: + # decompress tensors + compression_format = compression_config.format + compressor = ModelCompressor.load_from_registry( + compression_format, config=compression_config + ) + yield from compressor.decompress(compressed_tensors, device=device) def save_compressed_model( diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index eeff70de..7ae0799d 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from compressed_tensors import load_compressed, save_compressed, save_compressed_model from compressed_tensors.config import BitmaskConfig -from safetensors import safe_open from safetensors.torch import save_model from transformers import AutoModelForCausalLM @@ -60,7 +61,13 @@ def test_save_compressed_no_compression(tmp_path, tensors): assert (tmp_path / "model.safetensors").exists() -def test_save_compressed_rubbish_compression_format(tmp_path, tensors): +def test_save_compressed_error(tmp_path): + with pytest.raises(Exception): + save_compressed({}, "") + + with pytest.raises(Exception): + save_compressed(None, "") + with pytest.raises(Exception): save_compressed( tensors, @@ -69,15 +76,6 @@ def test_save_compressed_rubbish_compression_format(tmp_path, tensors): ) -def test_save_compressed_empty(): - # make sure function raises error - with pytest.raises(Exception): - save_compressed({}, "") - - with pytest.raises(Exception): - save_compressed(None, "") - - def test_load_compressed_sparse_bitmask(tmp_path, tensors): save_compressed( tensors, @@ -87,7 +85,9 @@ def test_load_compressed_sparse_bitmask(tmp_path, tensors): compression_config = BitmaskConfig( format="sparse-bitmask", ) - loaded_tensors = load_compressed(tmp_path / "model.safetensors", compression_config) + loaded_tensors = dict( + load_compressed(tmp_path / "model.safetensors", compression_config) + ) for key in tensors: assert torch.allclose(tensors[key], loaded_tensors[key]) @@ -98,10 +98,30 @@ def test_load_compressed_dense_sparsity(tmp_path, tensors): compression_format="dense-sparsity", save_path=tmp_path / "model.safetensors", ) + save_compressed( + tensors, + save_path=tmp_path / "model_.safetensors", + ) + + loaded_tensors = dict(load_compressed(tmp_path / "model.safetensors")) + loaded_tensors_ = dict(load_compressed(tmp_path / "model_.safetensors")) + # loaded_tensors should be equal to loaded_tensors_ + for key in tensors: + assert torch.allclose(loaded_tensors[key], loaded_tensors_[key]) + - loaded_tensors = load_compressed(tmp_path / "model.safetensors") - # loaded_tensors is empty -> decompression returns empty dict - assert not loaded_tensors +def test_load_compressed_sharded(tmp_path, llama_model): + sharded_model_path = tmp_path / "shared_model" + llama_model.save_pretrained(sharded_model_path, max_shard_size="2MB") + # make sure that model is shared on disk + assert len(os.listdir(sharded_model_path)) > 1 + loaded_state_dict = dict(load_compressed(sharded_model_path)) + for key, value in llama_model.state_dict().items(): + if key == "lm_head.weight": + # lm_head doesn't have separate weights. + # It shares its weight tensor with the token embedding layer. + continue + assert torch.allclose(value, loaded_state_dict[key]) def test_save_compressed_model(tmp_path, llama_model): @@ -119,12 +139,9 @@ def test_save_compressed_model(tmp_path, llama_model): size_compressed_kb = path_to_compressed.stat().st_size / 1024 # compare that the are the same after loading - state_dict_1 = {} - with safe_open(path_to_uncompressed, framework="pt") as f: - for key in f.keys(): - state_dict_1[key] = f.get_tensor(key) - state_dict_2 = load_compressed( - path_to_compressed, BitmaskConfig(format="sparse-bitmask") + state_dict_1 = dict(load_compressed(path_to_uncompressed)) + state_dict_2 = dict( + load_compressed(path_to_compressed, BitmaskConfig(format="sparse-bitmask")) ) assert all( torch.allclose(state_dict_1[key], state_dict_2[key]) for key in state_dict_1