Skip to content

Commit

Permalink
simplify UX
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 19, 2024
1 parent 26192e9 commit a5cfaa1
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 85 deletions.
27 changes: 16 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,28 @@ pip install -e .

### Saving

The function `save_compressed` returns an optional `compression_config` (if compression has been applied). It can be used to inspect the applied compression.
The function `save_compressed` uses the `compression_format` argument to apply compression to tensors.
The function `load_compressed` reverses the process: converts the compressed weights on disk to decompressed weights in device memory.

```python
from compressed_tensors import save_compressed
from compressed_tensors import save_compressed, load_compressed, BitmaskConfig
from torch import Tensor
from typing import Dict

tensors: Dict[str, Tensor] = ...
compression_config: Dict = save_compressed(tensors, "model.safetensors")
```
# the example BitmaskConfig method efficiently compresses
# tensors with large number of zero entries
compression_config = BitmaskConfig()

### Loading

```python
from compressed_tensors import load_compressed
from torch import Tensor
tensors: Dict[str, Tensor] = {"tensor_1": Tensor(
[[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0]]
)}
# compress tensors using BitmaskConfig compression format (save them efficiently on disk)
save_compressed(tensors, "model.safetensors", compression_format=compression_config.format)

tensors: Dict[str, Tensor] = load_compressed("model.safetensors", device="cpu")
# decompress tensors (load the uncompressed representation to device memory)
tensors = load_compressed("model.safetensors", device="cpu", compression_config = compression_config)
```

## Benefits
Expand Down Expand Up @@ -87,7 +92,7 @@ The library provides pathways to automatically add the config information to the
```json
// config.json
{
"sparsity_config": {
"compression_config": {
"format": "sparse_bitmask", // "dense_sparsity" for the original tensor format

// Informational
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Config information gets stored in the HF config file
```json
// config.json
{
"sparsity_config": {
"compression_config": {
"format": "sparse_bitmask", // "dense_sparsity" for original tensor format

// informational
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import operator
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Tuple

from compressed_tensors.base import CONFIG_NAME
from compressed_tensors.config import CompressionConfig
Expand All @@ -33,7 +33,7 @@ class ModelCompressor(RegistryMixin):
:param config: config specifying compression parameters
"""

def __init__(self, config: CompressionConfig):
def __init__(self, config: Optional[CompressionConfig] = None):
self.config = config

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/compressors/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from typing import Dict, Generator, Tuple

from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionFormat
from torch import Tensor


@ModelCompressor.register(name="dense_sparsity")
@ModelCompressor.register(name=CompressionFormat.dense_sparsity.value)
class DenseCompressor(ModelCompressor):
"""
Identity compressor for dense models, returns the original state_dict
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/compressors/sparse_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy
import torch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
Expand All @@ -36,7 +37,7 @@
_LOGGER: logging.Logger = logging.getLogger(__name__)


@ModelCompressor.register(name="sparse_bitmask")
@ModelCompressor.register(name=CompressionFormat.sparse_bitmask.value)
class BitmaskCompressor(ModelCompressor):
"""
Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
Expand Down
8 changes: 7 additions & 1 deletion src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Optional

from compressed_tensors.registry import RegistryMixin
from pydantic import BaseModel


__all__ = ["CompressionConfig"]
__all__ = ["CompressionConfig", "CompressionFormat"]


class CompressionFormat(Enum):
dense_sparsity = "dense-sparsity"
sparse_bitmask = "sparse-bitmask"


class CompressionConfig(RegistryMixin, BaseModel):
Expand Down
6 changes: 3 additions & 3 deletions src/compressed_tensors/config/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from typing import Optional

from compressed_tensors.config import CompressionConfig
from compressed_tensors.config import CompressionConfig, CompressionFormat


__all__ = ["DenseSparsityConfig"]


@CompressionConfig.register(name="dense_sparsity")
@CompressionConfig.register(name=CompressionFormat.dense_sparsity.value)
class DenseSparsityConfig(CompressionConfig):
"""
Identity configuration for storing a sparse model in
Expand All @@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig):
"unstructured", "2:4", "8:16" etc
"""

format: str = "dense_sparsity"
format: str = CompressionFormat.dense_sparsity.value
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = "unstructured"
6 changes: 3 additions & 3 deletions src/compressed_tensors/config/sparse_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from typing import Optional

from compressed_tensors.config.base import CompressionConfig
from compressed_tensors.config import CompressionConfig, CompressionFormat


__all__ = ["BitmaskConfig"]


@CompressionConfig.register(name="sparse_bitmask")
@CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
class BitmaskConfig(CompressionConfig):
"""
Configuration for storing a sparse model using
Expand All @@ -31,6 +31,6 @@ class BitmaskConfig(CompressionConfig):
"unstructured", "2:4", "8:16" etc
"""

format: str = "sparse_bitmask"
format: str = CompressionFormat.sparse_bitmask.value
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = "unstructured"
45 changes: 20 additions & 25 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Literal, Optional, Union

from compressed_tensors.base import CONFIG_NAME
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig
from compressed_tensors.config import CompressionConfig, CompressionFormat
from safetensors import safe_open
from safetensors.torch import save_file
from torch import Tensor
Expand Down Expand Up @@ -51,46 +51,46 @@ def infer_compressor_from_model_config(
def save_compressed(
tensors: Dict[str, Tensor],
save_path: Union[str, Path],
compression_config: Optional[CompressionConfig] = None,
) -> Optional[CompressionConfig]:
compression_format: Optional[
Literal[CompressionFormat.sparse_bitmask, CompressionFormat.dense_sparsity]
] = None,
):
"""
Save compressed tensors to disk. If tensors are not compressed,
save them as is.
:param tensors: dictionary of tensors to compress
:param save_path: path to save compressed tensors
:param compression_config: compression config to use for compressing tensors.
Can be either inferred from tensors or provided explicitly
:param compression_format: compression format used for the tensors
:return: compression config, if tensors were compressed - None otherwise
"""
if tensors is None or len(tensors) == 0:
raise ValueError("No tensors or empty tensors provided to compress")

# create compression config if not provided
# TODO: Not implemented, need to get this in ASAP
# compression_config = compression_config or infer_compression_config(tensors)

if compression_config is None:
if compression_format is None:
# no compression applied
save_file(tensors, save_path)
return None
return

if not (
compression_format in ModelCompressor.registered_names()
or compression_format in ModelCompressor.registered_aliases()
):
raise ValueError(
f"Unknown compression format: {compression_format}. "
f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
)

# compress
compression_format = compression_config.format
compressor = ModelCompressor.load_from_registry(
compression_format, config=compression_config
)
compressor = ModelCompressor.load_from_registry(compression_format)
# save compressed tensors
compressed_tensors = compressor.compress(tensors)
save_file(compressed_tensors, save_path)

# return compression_config as dict
return {CONFIG_NAME: compression_config.model_dump(exclude_unset=True)}


def load_compressed(
compressed_tensors: Union[str, Path],
compression_config: Optional[CompressionConfig] = None,
compression_config: CompressionConfig = None,
device: Optional[str] = "cpu",
) -> Dict[str, Tensor]:
"""
Expand All @@ -99,18 +99,13 @@ def load_compressed(
:param compressed_tensors: path to compressed tensors
:param compression_config: compression config to use for decompressing tensors.
Can be either inferred from tensors or provided explicitly.
:param device: device to move tensors to. If None, tensors are loaded on CPU.
:return decompressed tensors
"""

if compressed_tensors is None or not Path(compressed_tensors).exists():
raise ValueError("No compressed tensors provided to load")

# create compression config if not provided
# TODO: Not implemented, need to get this in ASAP
# compression_config = compression_config or infer_compression_config(tensors)

if compression_config is None:
# no compression applied
tensors = {}
Expand Down
10 changes: 7 additions & 3 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BitmaskCompressor,
BitmaskConfig,
CompressionConfig,
CompressionFormat,
DenseCompressor,
DenseSparsityConfig,
ModelCompressor,
Expand All @@ -26,8 +27,8 @@
@pytest.mark.parametrize(
"name,type",
[
["sparse_bitmask", BitmaskConfig],
["dense_sparsity", DenseSparsityConfig],
[CompressionFormat.sparse_bitmask.value, BitmaskConfig],
[CompressionFormat.dense_sparsity.value, DenseSparsityConfig],
],
)
def test_configs(name, type):
Expand All @@ -38,7 +39,10 @@ def test_configs(name, type):

@pytest.mark.parametrize(
"name,type",
[["sparse_bitmask", BitmaskCompressor], ["dense_sparsity", DenseCompressor]],
[
[CompressionFormat.sparse_bitmask.value, BitmaskCompressor],
[CompressionFormat.dense_sparsity.value, DenseCompressor],
],
)
def test_compressors(name, type):
compressor = ModelCompressor.load_from_registry(
Expand Down
Loading

0 comments on commit a5cfaa1

Please sign in to comment.