Skip to content

Commit

Permalink
finalize the PR
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 25, 2024
1 parent 1f0738d commit 10f976b
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 58 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ from typing import Dict
# tensors with large number of zero entries
compression_config = BitmaskConfig()


tensors: Dict[str, Tensor] = {"tensor_1": Tensor(
[[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0]]
)}
# compress tensors using BitmaskConfig compression format (save them efficiently on disk)
save_compressed(tensors, "model.safetensors", compression_format=compression_config.format)

# decompress tensors (load the uncompressed representation to device memory)
tensors = load_compressed("model.safetensors", compression_config = compression_config)
# decompress tensors (load_compressed returns a generator for memory efficiency)
decompressed_tensors = {}
for tensor_name, tensor in load_compressed("model.safetensors", compression_config = compression_config):
decompressed_tensors[tensor_name] = tensor
```

## Saving/Loading Compressed Models (Bitmask Compression)
Expand All @@ -76,6 +77,6 @@ compression_config = BitmaskConfig()
# save compressed model weights
save_compressed_model(model, "compressed_model.safetensors", compression_format=compression_config.format)

# load compressed model weights
state_dict = load_compressed("compressed_model.safetensors", compression_config)
# load compressed model weights (`dict` turns generator into a dictionary)
state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
```
18 changes: 9 additions & 9 deletions examples/bitmask_compression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -63,7 +63,7 @@
")"
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -77,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -98,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -121,14 +121,14 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.73it/s]\n"
"Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.92it/s]\n"
]
},
{
Expand Down Expand Up @@ -168,7 +168,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand All @@ -190,7 +190,7 @@
"\n",
"## load the compressed-tensors to memory ##\n",
"config = BitmaskConfig() # we need to specify the method for decompression\n",
"state_dict_2 = load_compressed(\"compressed_model.safetensors\", config)\n",
"state_dict_2 = dict(load_compressed(\"compressed_model.safetensors\", config)) # load_compressed returns a generator, we convert it to a dict\n",
"\n",
"tensors_equal = all(torch.equal(state_dict_1[key], state_dict_2[key]) for key in state_dict_1)\n",
"\n",
Expand Down
6 changes: 5 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
BUILDDIR := $(PWD)

PYCHECKDIRS := src tests
PYCHECKGLOBS := 'src/**/*.py' 'tests/**/*.py' 'utils/**/*.py' 'examples/**/*.py' setup.py
# run checks on all files for the repo
Expand All @@ -23,6 +23,10 @@ test:
@echo "Running python tests";
pytest tests;

# creates wheel file
build:
python3 setup.py sdist bdist_wheel $(BUILD_ARGS)

# clean package
clean:
@echo "Cleaning up";
Expand Down
7 changes: 6 additions & 1 deletion src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,10 @@

from .base import ModelCompressor
from .dense import DenseCompressor
from .helpers import infer_compressor_from_model_config
from .helpers import (
infer_compressor_from_model_config,
load_compressed,
save_compressed,
save_compressed_model,
)
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
52 changes: 31 additions & 21 deletions src/compressed_tensors/compressors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
# limitations under the License.

from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import (
CompressionConfig,
CompressionFormat,
DenseSparsityConfig,
)
from compressed_tensors.config import CompressionConfig, CompressionFormat
from compressed_tensors.utils.safetensors_load import get_weight_mappings
from safetensors import safe_open
from safetensors.torch import save_file
from torch import Tensor
from transformers import AutoConfig
Expand Down Expand Up @@ -97,29 +95,41 @@ def load_compressed(
compressed_tensors: Union[str, Path],
compression_config: CompressionConfig = None,
device: Optional[str] = "cpu",
) -> Dict[str, Tensor]:
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Load compressed tensors from disk. If tensors are not compressed,
load them as is.
Load compressed tensors from disk.
If tensors are not compressed, load them as is.
:param compressed_tensors: path to compressed tensors
:param compressed_tensors: path to compressed tensors.
This can be a path to a file or a directory containing
one or multiple safetensor files (if multiple - in the format
assumed by huggingface)
:param compression_config: compression config to use for decompressing tensors.
:param device: device to move tensors to. If None, tensors are loaded on CPU.
:return decompressed tensors
:param return_dict: if True, return a dictionary of decompressed tensors
:return a generator that yields the name and tensor of the decompressed tensor
"""

if compressed_tensors is None or not Path(compressed_tensors).exists():
raise ValueError("No compressed tensors provided to load")

# if no compression_config specified, default to `dense_sparsity`
compression_config = compression_config or DenseSparsityConfig()

# decompress
compression_format = compression_config.format
compressor = ModelCompressor.load_from_registry(
compression_format, config=compression_config
)
return dict(compressor.decompress(compressed_tensors, device=device))
if (
compression_config is None
or compression_config.format == CompressionFormat.dense_sparsity.value
):
# if no compression_config specified, or `dense_sparsity` format specified,
# assume tensors are not compressed on disk
weight_mappings = get_weight_mappings(compressed_tensors)
for weight_name, file_with_weight_name in weight_mappings.items():
with safe_open(file_with_weight_name, framework="pt", device=device) as f:
weight = f.get_tensor(weight_name)
yield weight_name, weight
else:
# decompress tensors
compression_format = compression_config.format
compressor = ModelCompressor.load_from_registry(
compression_format, config=compression_config
)
yield from compressor.decompress(compressed_tensors, device=device)


def save_compressed_model(
Expand Down
59 changes: 38 additions & 21 deletions tests/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import torch
from compressed_tensors import load_compressed, save_compressed, save_compressed_model
from compressed_tensors.config import BitmaskConfig
from safetensors import safe_open
from safetensors.torch import save_model
from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -60,7 +61,13 @@ def test_save_compressed_no_compression(tmp_path, tensors):
assert (tmp_path / "model.safetensors").exists()


def test_save_compressed_rubbish_compression_format(tmp_path, tensors):
def test_save_compressed_error(tmp_path):
with pytest.raises(Exception):
save_compressed({}, "")

with pytest.raises(Exception):
save_compressed(None, "")

with pytest.raises(Exception):
save_compressed(
tensors,
Expand All @@ -69,15 +76,6 @@ def test_save_compressed_rubbish_compression_format(tmp_path, tensors):
)


def test_save_compressed_empty():
# make sure function raises error
with pytest.raises(Exception):
save_compressed({}, "")

with pytest.raises(Exception):
save_compressed(None, "")


def test_load_compressed_sparse_bitmask(tmp_path, tensors):
save_compressed(
tensors,
Expand All @@ -87,7 +85,9 @@ def test_load_compressed_sparse_bitmask(tmp_path, tensors):
compression_config = BitmaskConfig(
format="sparse-bitmask",
)
loaded_tensors = load_compressed(tmp_path / "model.safetensors", compression_config)
loaded_tensors = dict(
load_compressed(tmp_path / "model.safetensors", compression_config)
)
for key in tensors:
assert torch.allclose(tensors[key], loaded_tensors[key])

Expand All @@ -98,10 +98,30 @@ def test_load_compressed_dense_sparsity(tmp_path, tensors):
compression_format="dense-sparsity",
save_path=tmp_path / "model.safetensors",
)
save_compressed(
tensors,
save_path=tmp_path / "model_.safetensors",
)

loaded_tensors = dict(load_compressed(tmp_path / "model.safetensors"))
loaded_tensors_ = dict(load_compressed(tmp_path / "model_.safetensors"))
# loaded_tensors should be equal to loaded_tensors_
for key in tensors:
assert torch.allclose(loaded_tensors[key], loaded_tensors_[key])


loaded_tensors = load_compressed(tmp_path / "model.safetensors")
# loaded_tensors is empty -> decompression returns empty dict
assert not loaded_tensors
def test_load_compressed_sharded(tmp_path, llama_model):
sharded_model_path = tmp_path / "shared_model"
llama_model.save_pretrained(sharded_model_path, max_shard_size="2MB")
# make sure that model is shared on disk
assert len(os.listdir(sharded_model_path)) > 1
loaded_state_dict = dict(load_compressed(sharded_model_path))
for key, value in llama_model.state_dict().items():
if key == "lm_head.weight":
# lm_head doesn't have separate weights.
# It shares its weight tensor with the token embedding layer.
continue
assert torch.allclose(value, loaded_state_dict[key])


def test_save_compressed_model(tmp_path, llama_model):
Expand All @@ -119,12 +139,9 @@ def test_save_compressed_model(tmp_path, llama_model):
size_compressed_kb = path_to_compressed.stat().st_size / 1024

# compare that the are the same after loading
state_dict_1 = {}
with safe_open(path_to_uncompressed, framework="pt") as f:
for key in f.keys():
state_dict_1[key] = f.get_tensor(key)
state_dict_2 = load_compressed(
path_to_compressed, BitmaskConfig(format="sparse-bitmask")
state_dict_1 = dict(load_compressed(path_to_uncompressed))
state_dict_2 = dict(
load_compressed(path_to_compressed, BitmaskConfig(format="sparse-bitmask"))
)
assert all(
torch.allclose(state_dict_1[key], state_dict_2[key]) for key in state_dict_1
Expand Down

0 comments on commit 10f976b

Please sign in to comment.