From 1f6a056f4ef0254a7a20a3346061d4eae5668462 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 2 Oct 2024 16:57:59 -0400 Subject: [PATCH] Add: base sparsity/quantization compressors (#165) * Add: base sparsity/quantization compressors Update: tests Update: Usages of Compressor -> BaseCompressor * Review Comments from @mgoin --- .../compressors/__init__.py | 2 +- src/compressed_tensors/compressors/base.py | 134 +++++----------- .../base_quantization_compressor.py | 146 ++++++++++++++++++ .../compressors/base_sparsity_compressor.py | 110 +++++++++++++ src/compressed_tensors/compressors/dense.py | 6 +- src/compressed_tensors/compressors/helpers.py | 12 +- .../compressors/marlin_24.py | 6 +- .../compressors/model_compressor.py | 8 +- .../compressors/naive_quantized.py | 16 +- .../compressors/pack_quantized.py | 9 +- .../compressors/sparse_bitmask.py | 73 ++------- .../linear/compressed_linear.py | 4 +- tests/test_compressors/test_marlin_24.py | 4 +- .../test_compressors/test_model_compressor.py | 2 +- tests/test_quantization/test_quant_scheme.py | 6 +- tests/test_registry.py | 4 +- 16 files changed, 344 insertions(+), 198 deletions(-) create mode 100644 src/compressed_tensors/compressors/base_quantization_compressor.py create mode 100644 src/compressed_tensors/compressors/base_sparsity_compressor.py diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index 6cffc6d7..21b20589 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -14,7 +14,7 @@ # flake8: noqa -from .base import Compressor +from .base import BaseCompressor from .dense import DenseCompressor from .helpers import load_compressed, save_compressed, save_compressed_model from .marlin_24 import Marlin24Compressor diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 086d7a37..f63cab37 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -12,26 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +from abc import ABC, abstractmethod from typing import Dict, Generator, Optional, Tuple, Union import torch from compressed_tensors.config import SparsityCompressionConfig from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig from compressed_tensors.registry import RegistryMixin -from compressed_tensors.utils import get_nested_weight_mappings, merge_names -from safetensors import safe_open from torch import Tensor -from torch.nn.modules import Module -from tqdm import tqdm +from torch.nn import Module -_LOGGER: logging.Logger = logging.getLogger(__name__) +__all__ = ["BaseCompressor"] -__all__ = ["Compressor"] - -class Compressor(RegistryMixin): +class BaseCompressor(RegistryMixin, ABC): """ Base class representing a model compression algorithm. Each child class should implement compression_param_info, compress_weight and decompress_weight. @@ -43,12 +38,11 @@ class Compressor(RegistryMixin): - ModelCompressor.decompress() - apply_quantization_config() - Compressor.decompress() - - Compressor.decompress_weight() Model Save Lifecycle: - ModelCompressor.compress() - Compressor.compress() - - Compressor.compress_weight() + Module Lifecycle (run_compressed=True): - apply_quantization_config() @@ -83,61 +77,27 @@ def compression_param_info( """ raise NotImplementedError() + @abstractmethod def compress( self, model_state: Dict[str, Tensor], - names_to_scheme: Dict[str, QuantizationArgs], **kwargs, ) -> Dict[str, Tensor]: """ Compresses a dense state dict :param model_state: state dict of uncompressed model - :param names_to_scheme: quantization args for each quantized weight, needed for - quantize function to calculate bit depth + :param kwargs: additional arguments for compression :return: compressed state dict """ - compressed_dict = {} - weight_suffix = ".weight" - _LOGGER.debug( - f"Compressing model with {len(model_state)} parameterized layers..." - ) - - for name, value in tqdm(model_state.items(), desc="Compressing model"): - if name.endswith(weight_suffix): - prefix = name[: -(len(weight_suffix))] - scale = model_state.get(merge_names(prefix, "weight_scale"), None) - zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) - g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) - if scale is not None: - # weight is quantized, compress it - quant_args = names_to_scheme[prefix] - compressed_data = self.compress_weight( - weight=value, - scale=scale, - zero_point=zp, - g_idx=g_idx, - quantization_args=quant_args, - device="cpu", - ) - for key, value in compressed_data.items(): - compressed_dict[merge_names(prefix, key)] = value - else: - compressed_dict[name] = value.to("cpu") - elif name.endswith("zero_point") and torch.all(value == 0): - continue - elif name.endswith("g_idx") and torch.any(value <= -1): - continue - else: - compressed_dict[name] = value.to("cpu") - - return compressed_dict + raise NotImplementedError() + @abstractmethod def decompress( self, path_to_model_or_tensors: str, - names_to_scheme: Dict[str, QuantizationArgs], device: str = "cpu", + **kwargs, ) -> Generator[Tuple[str, Tensor], None, None]: """ Reads a compressed state dict located at path_to_model_or_tensors @@ -150,55 +110,6 @@ def decompress( :param device: optional device to load intermediate weights into :return: compressed state dict """ - weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES - ) - for weight_name in weight_mappings.keys(): - weight_data = {} - for param_name, safe_path in weight_mappings[weight_name].items(): - full_name = merge_names(weight_name, param_name) - with safe_open(safe_path, framework="pt", device=device) as f: - weight_data[param_name] = f.get_tensor(full_name) - - if "weight_scale" in weight_data: - quant_args = names_to_scheme[weight_name] - decompressed = self.decompress_weight( - compressed_data=weight_data, quantization_args=quant_args - ) - yield merge_names(weight_name, "weight"), decompressed - - def compress_weight( - self, - weight: Tensor, - scale: Tensor, - zero_point: Optional[Tensor] = None, - g_idx: Optional[torch.Tensor] = None, - quantization_args: Optional[QuantizationArgs] = None, - ) -> Dict[str, torch.Tensor]: - """ - Compresses a single uncompressed weight - - :param weight: uncompressed weight tensor - :param scale: quantization scale for weight - :param zero_point: quantization zero point for weight - :param g_idx: optional mapping from column index to group index - :param quantization_args: quantization parameters for weight - :return: dictionary of compressed weight data - """ - raise NotImplementedError() - - def decompress_weight( - self, - compressed_data: Dict[str, Tensor], - quantization_args: Optional[QuantizationArgs] = None, - ) -> torch.Tensor: - """ - Decompresses a single compressed weight - - :param compressed_data: dictionary of data needed for decompression - :param quantization_args: quantization parameters for the weight - :return: tensor of the decompressed weight - """ raise NotImplementedError() def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]: @@ -228,6 +139,19 @@ def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]: quantization_args=quantization_args, ) + def compress_weight( + self, + weight: Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """ + Compresses a single uncompressed weight + + :param weight: uncompressed weight tensor + :param kwargs: additional arguments for compression + """ + raise NotImplementedError() + def decompress_module(self, module: Module): """ Decompresses a single compressed leaf PyTorch module. If the module is not @@ -250,3 +174,15 @@ def decompress_module(self, module: Module): return self.decompress_weight( compressed_data=compressed_data, quantization_args=quantization_args ) + + def decompress_weight( + self, compressed_data: Dict[str, Tensor], **kwargs + ) -> torch.Tensor: + """ + Decompresses a single compressed weight + + :param compressed_data: dictionary of data needed for decompression + :param kwargs: additional arguments for decompression + :return: tensor of the decompressed weight + """ + raise NotImplementedError() diff --git a/src/compressed_tensors/compressors/base_quantization_compressor.py b/src/compressed_tensors/compressors/base_quantization_compressor.py new file mode 100644 index 00000000..563ecf56 --- /dev/null +++ b/src/compressed_tensors/compressors/base_quantization_compressor.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Generator, Tuple + +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open +from torch import Tensor +from tqdm import tqdm + + +__all__ = ["BaseQuantizationCompressor"] + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class BaseQuantizationCompressor(BaseCompressor): + """ + Base class representing a quant compression algorithm. Each child class should + implement compression_param_info, compress_weight and decompress_weight. + + Compressors support compressing/decompressing a full module state dict or a single + quantized PyTorch leaf module. + + Model Load Lifecycle (run_compressed=False): + - ModelCompressor.decompress() + - apply_quantization_config() + - Compressor.decompress() + - Compressor.decompress_weight() + + Model Save Lifecycle: + - ModelCompressor.compress() + - Compressor.compress() + - Compressor.compress_weight() + + Module Lifecycle (run_compressed=True): + - apply_quantization_config() + - compressed_module = CompressedLinear(module) + - initialize_module_for_quantization() + - Compressor.compression_param_info() + - register_parameters() + - compressed_module.forward() + - compressed_module.decompress() + + + :param config: config specifying compression parameters + """ + + def compress( + self, + model_state: Dict[str, Tensor], + names_to_scheme: Dict[str, QuantizationArgs], + **kwargs, + ) -> Dict[str, Tensor]: + """ + Compresses a dense state dict + + :param model_state: state dict of uncompressed model + :param names_to_scheme: quantization args for each quantized weight, needed for + quantize function to calculate bit depth + :return: compressed state dict + """ + compressed_dict = {} + weight_suffix = ".weight" + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + + for name, value in tqdm(model_state.items(), desc="Quantized Compression"): + if name.endswith(weight_suffix): + prefix = name[: -(len(weight_suffix))] + scale = model_state.get(merge_names(prefix, "weight_scale"), None) + zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) + g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) + if scale is not None: + # weight is quantized, compress it + quant_args = names_to_scheme[prefix] + compressed_data = self.compress_weight( + weight=value, + scale=scale, + zero_point=zp, + g_idx=g_idx, + quantization_args=quant_args, + device="cpu", + ) + for key, value in compressed_data.items(): + compressed_dict[merge_names(prefix, key)] = value + else: + compressed_dict[name] = value.to("cpu") + elif name.endswith("zero_point") and torch.all(value == 0): + continue + elif name.endswith("g_idx") and torch.any(value <= -1): + continue + else: + compressed_dict[name] = value.to("cpu") + + return compressed_dict + + def decompress( + self, + path_to_model_or_tensors: str, + names_to_scheme: Dict[str, QuantizationArgs], + device: str = "cpu", + ) -> Generator[Tuple[str, Tensor], None, None]: + """ + Reads a compressed state dict located at path_to_model_or_tensors + and returns a generator for sequentially decompressing back to a + dense state dict + + :param path_to_model_or_tensors: path to compressed safetensors model (directory + with one or more safetensors files) or compressed tensors file + :param names_to_scheme: quantization args for each quantized weight + :param device: optional device to load intermediate weights into + :return: compressed state dict + """ + weight_mappings = get_nested_weight_mappings( + path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, safe_path in weight_mappings[weight_name].items(): + full_name = merge_names(weight_name, param_name) + with safe_open(safe_path, framework="pt", device=device) as f: + weight_data[param_name] = f.get_tensor(full_name) + + if "weight_scale" in weight_data: + quant_args = names_to_scheme[weight_name] + decompressed = self.decompress_weight( + compressed_data=weight_data, quantization_args=quant_args + ) + yield merge_names(weight_name, "weight"), decompressed diff --git a/src/compressed_tensors/compressors/base_sparsity_compressor.py b/src/compressed_tensors/compressors/base_sparsity_compressor.py new file mode 100644 index 00000000..f7ebbc2f --- /dev/null +++ b/src/compressed_tensors/compressors/base_sparsity_compressor.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Generator, Tuple + +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open +from torch import Tensor +from tqdm import tqdm + + +__all__ = ["BaseSparseCompressor"] + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class BaseSparseCompressor(BaseCompressor): + """ + Base class representing a sparse compression algorithm. Each child class should + implement compression_param_info, compress_weight and decompress_weight. + + Compressors support compressing/decompressing a full module state dict or a single + quantized PyTorch leaf module. + + Model Load Lifecycle (run_compressed=False): + - ModelCompressor.decompress() + - apply_quantization_config() + - Compressor.decompress() + - Compressor.decompress_weight() + + Model Save Lifecycle: + - ModelCompressor.compress() + - Compressor.compress() + - Compressor.compress_weight() + + Module Lifecycle (run_compressed=True): + - apply_quantization_config() + - compressed_module = CompressedLinear(module) + - initialize_module_for_quantization() + - Compressor.compression_param_info() + - register_parameters() + - compressed_module.forward() + - compressed_module.decompress() + + + :param config: config specifying compression parameters + """ + + def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Compresses a dense state dict using bitmask compression + + :param model_state: state dict of uncompressed model + :return: compressed state dict + """ + compressed_dict = {} + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + for name, value in tqdm(model_state.items(), desc="Compressing model"): + compression_data = self.compress_weight(name, value) + for key in compression_data.keys(): + if key in compressed_dict: + _LOGGER.warn( + f"Expected all compressed state_dict keys to be unique, but " + f"found an existing entry for {key}. The existing entry will " + "be replaced." + ) + + compressed_dict.update(compression_data) + + return compressed_dict + + def decompress( + self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs + ) -> Generator[Tuple[str, Tensor], None, None]: + """ + Reads a bitmask compressed state dict located + at path_to_model_or_tensors and returns a generator + for sequentially decompressing back to a dense state dict + + :param model_path: path to compressed safetensors model (directory with + one or more safetensors files) or compressed tensors file + :param device: device to load decompressed weights onto + :return: iterator for generating decompressed weights + """ + weight_mappings = get_nested_weight_mappings( + path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, safe_path in weight_mappings[weight_name].items(): + full_name = merge_names(weight_name, param_name) + with safe_open(safe_path, framework="pt", device=device) as f: + weight_data[param_name] = f.get_tensor(full_name) + decompressed = self.decompress_weight(weight_data) + yield weight_name, decompressed diff --git a/src/compressed_tensors/compressors/dense.py b/src/compressed_tensors/compressors/dense.py index 16707acd..53aaf6c8 100644 --- a/src/compressed_tensors/compressors/dense.py +++ b/src/compressed_tensors/compressors/dense.py @@ -14,13 +14,13 @@ from typing import Dict, Generator, Tuple -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat from torch import Tensor -@Compressor.register(name=CompressionFormat.dense.value) -class DenseCompressor(Compressor): +@BaseCompressor.register(name=CompressionFormat.dense.value) +class DenseCompressor(BaseCompressor): """ Identity compressor for dense models, returns the original state_dict """ diff --git a/src/compressed_tensors/compressors/helpers.py b/src/compressed_tensors/compressors/helpers.py index fe4b361c..2753621b 100644 --- a/src/compressed_tensors/compressors/helpers.py +++ b/src/compressed_tensors/compressors/helpers.py @@ -16,7 +16,7 @@ from typing import Dict, Generator, Optional, Tuple, Union import torch -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.utils.safetensors_load import get_weight_mappings from safetensors import safe_open @@ -52,16 +52,16 @@ def save_compressed( compression_format = compression_format or CompressionFormat.dense.value if not ( - compression_format in Compressor.registered_names() - or compression_format in Compressor.registered_aliases() + compression_format in BaseCompressor.registered_names() + or compression_format in BaseCompressor.registered_aliases() ): raise ValueError( f"Unknown compression format: {compression_format}. " - f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501 + f"Must be one of {set(BaseCompressor.registered_names() + BaseCompressor.registered_aliases())}" # noqa E501 ) # compress - compressor = Compressor.load_from_registry(compression_format) + compressor = BaseCompressor.load_from_registry(compression_format) # save compressed tensors compressed_tensors = compressor.compress(tensors) save_file(compressed_tensors, save_path) @@ -102,7 +102,7 @@ def load_compressed( else: # decompress tensors compression_format = compression_config.format - compressor = Compressor.load_from_registry( + compressor = BaseCompressor.load_from_registry( compression_format, config=compression_config ) yield from compressor.decompress(compressed_tensors, device=device) diff --git a/src/compressed_tensors/compressors/marlin_24.py b/src/compressed_tensors/compressors/marlin_24.py index 187dd5b7..7954b6b8 100644 --- a/src/compressed_tensors/compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/marlin_24.py @@ -17,7 +17,7 @@ import numpy as np import torch -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import quantize @@ -35,8 +35,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -@Compressor.register(name=CompressionFormat.marlin_24.value) -class Marlin24Compressor(Compressor): +@BaseCompressor.register(name=CompressionFormat.marlin_24.value) +class Marlin24Compressor(BaseCompressor): """ Compresses a quantized model with 2:4 sparsity structure for inference with the Marlin24 kernel. Decompression is not implemented for this compressor. diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 7acc1030..ac15fdaa 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -30,7 +30,7 @@ QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, ) -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, @@ -247,11 +247,11 @@ def __init__( self.sparsity_config = None if sparsity_config is not None: - self.sparsity_compressor = Compressor.load_from_registry( + self.sparsity_compressor = BaseCompressor.load_from_registry( sparsity_config.format, config=sparsity_config ) if quantization_config is not None: - self.quantization_compressor = Compressor.load_from_registry( + self.quantization_compressor = BaseCompressor.load_from_registry( quantization_config.format, config=quantization_config ) @@ -262,7 +262,7 @@ def compress( Compresses a dense state dict or model with sparsity and/or quantization :param model: uncompressed model to compress - :param model_state: optional uncompressed state_dict to insert into model + :param state_dict: optional uncompressed state_dict to insert into model :return: compressed state dict """ if state_dict is None: diff --git a/src/compressed_tensors/compressors/naive_quantized.py b/src/compressed_tensors/compressors/naive_quantized.py index 81cc2fac..acc09932 100644 --- a/src/compressed_tensors/compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/naive_quantized.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import Dict, Optional, Tuple import torch -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.base_quantization_compressor import ( + BaseQuantizationCompressor, +) from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import QuantizationArgs from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize @@ -30,11 +32,9 @@ "FloatQuantizationCompressor", ] -_LOGGER: logging.Logger = logging.getLogger(__name__) - -@Compressor.register(name=CompressionFormat.naive_quantized.value) -class QuantizationCompressor(Compressor): +@BaseCompressor.register(name=CompressionFormat.naive_quantized.value) +class QuantizationCompressor(BaseQuantizationCompressor): """ Implements naive compression for quantized models. Weight of each quantized layer is converted from its original float type to the closest Pytorch @@ -122,7 +122,7 @@ def decompress_weight( return decompressed_weight -@Compressor.register(name=CompressionFormat.int_quantized.value) +@BaseCompressor.register(name=CompressionFormat.int_quantized.value) class IntQuantizationCompressor(QuantizationCompressor): """ Alias for integer quantized models @@ -131,7 +131,7 @@ class IntQuantizationCompressor(QuantizationCompressor): pass -@Compressor.register(name=CompressionFormat.float_quantized.value) +@BaseCompressor.register(name=CompressionFormat.float_quantized.value) class FloatQuantizationCompressor(QuantizationCompressor): """ Alias for fp quantized models diff --git a/src/compressed_tensors/compressors/pack_quantized.py b/src/compressed_tensors/compressors/pack_quantized.py index 78b0a848..9d63e264 100644 --- a/src/compressed_tensors/compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/pack_quantized.py @@ -16,7 +16,10 @@ import numpy as np import torch -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.base_quantization_compressor import ( + BaseQuantizationCompressor, +) from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import QuantizationArgs from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize @@ -27,8 +30,8 @@ __all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"] -@Compressor.register(name=CompressionFormat.pack_quantized.value) -class PackedQuantizationCompressor(Compressor): +@BaseCompressor.register(name=CompressionFormat.pack_quantized.value) +class PackedQuantizationCompressor(BaseQuantizationCompressor): """ Compresses a quantized model by packing every eight 4-bit weights into an int32 """ diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index 796d2643..63124163 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Dict, Generator, List, Tuple, Union +from typing import Dict, List, Tuple, Union import numpy import torch -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.base_sparsity_compressor import BaseSparseCompressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.utils import get_nested_weight_mappings, merge_names -from safetensors import safe_open +from compressed_tensors.utils import merge_names from torch import Tensor -from tqdm import tqdm __all__ = [ @@ -34,11 +32,9 @@ "unpack_bitmasks", ] -_LOGGER: logging.Logger = logging.getLogger(__name__) - -@Compressor.register(name=CompressionFormat.sparse_bitmask.value) -class BitmaskCompressor(Compressor): +@BaseCompressor.register(name=CompressionFormat.sparse_bitmask.value) +class BitmaskCompressor(BaseSparseCompressor): """ Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d values tensor, with their locations stored in a 2d bitmask @@ -46,56 +42,15 @@ class BitmaskCompressor(Compressor): COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"] - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: - """ - Compresses a dense state dict using bitmask compression + def compress_weight(self, name, value): + bitmask_tensor = BitmaskTensor.from_dense(value) + bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") + return bitmask_dict - :param model_state: state dict of uncompressed model - :return: compressed state dict - """ - compressed_dict = {} - _LOGGER.debug( - f"Compressing model with {len(model_state)} parameterized layers..." - ) - for name, value in tqdm(model_state.items(), desc="Compressing model"): - bitmask_tensor = BitmaskTensor.from_dense(value) - bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") - for key in bitmask_dict.keys(): - if key in compressed_dict: - _LOGGER.warn( - f"Expected all compressed state_dict keys to be unique, but " - f"found an existing entry for {key}. The existing entry will " - "be replaced." - ) - compressed_dict.update(bitmask_dict) - - return compressed_dict - - def decompress( - self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs - ) -> Generator[Tuple[str, Tensor], None, None]: - """ - Reads a bitmask compressed state dict located - at path_to_model_or_tensors and returns a generator - for sequentially decompressing back to a dense state dict - - :param model_path: path to compressed safetensors model (directory with - one or more safetensors files) or compressed tensors file - :param device: device to load decompressed weights onto - :return: iterator for generating decompressed weights - """ - weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES - ) - for weight_name in weight_mappings.keys(): - weight_data = {} - for param_name, safe_path in weight_mappings[weight_name].items(): - full_name = merge_names(weight_name, param_name) - with safe_open(safe_path, framework="pt", device=device) as f: - weight_data[param_name] = f.get_tensor(full_name) - data = BitmaskTensor(**weight_data) - decompressed = data.decompress() - yield weight_name, decompressed + def decompress_weight(self, weight_data): + data = BitmaskTensor(**weight_data) + decompressed = data.decompress() + return decompressed class BitmaskTensor: diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index 5f013b32..a4d5b532 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from compressed_tensors.compressors.base import Compressor +from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.quantization import ( QuantizationScheme, QuantizationStatus, @@ -44,7 +44,7 @@ def from_linear( quantization_format: str, ): module.__class__ = CompressedLinear - module.compressor = Compressor.load_from_registry(quantization_format) + module.compressor = BaseCompressor.load_from_registry(quantization_format) device = next(module.parameters()).device # this will initialize all the scales and zero points diff --git a/tests/test_compressors/test_marlin_24.py b/tests/test_compressors/test_marlin_24.py index baac1c22..f12accb8 100644 --- a/tests/test_compressors/test_marlin_24.py +++ b/tests/test_compressors/test_marlin_24.py @@ -17,7 +17,7 @@ import pytest import torch from compressed_tensors.compressors import ( - Compressor, + BaseCompressor, Marlin24Compressor, map_modules_to_quant_args, ) @@ -45,7 +45,7 @@ def get_2_4_quant_config(num_bits, strategy, ignore): def test_marlin_registered(): config_name = CompressionFormat.marlin_24.value - compressor = Compressor.load_from_registry(config_name) + compressor = BaseCompressor.load_from_registry(config_name) assert isinstance(compressor, Marlin24Compressor) diff --git a/tests/test_compressors/test_model_compressor.py b/tests/test_compressors/test_model_compressor.py index ab27ab0e..3f6940a9 100644 --- a/tests/test_compressors/test_model_compressor.py +++ b/tests/test_compressors/test_model_compressor.py @@ -15,7 +15,7 @@ from copy import deepcopy import pytest -from compressed_tensors.compressors.model_compressor import ModelCompressor +from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config.base import SparsityCompressionConfig from compressed_tensors.quantization.quant_config import QuantizationConfig from tests.testing_utils import requires_hf_quantizer diff --git a/tests/test_quantization/test_quant_scheme.py b/tests/test_quantization/test_quant_scheme.py index f4e708c7..14ee5b72 100644 --- a/tests/test_quantization/test_quant_scheme.py +++ b/tests/test_quantization/test_quant_scheme.py @@ -13,11 +13,7 @@ # limitations under the License. import pytest -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationConfig, - QuantizationScheme, -) +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError diff --git a/tests/test_registry.py b/tests/test_registry.py index 4726fcf2..9b69b1b4 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -14,10 +14,10 @@ import pytest from compressed_tensors import ( + BaseCompressor, BitmaskCompressor, BitmaskConfig, CompressionFormat, - Compressor, DenseCompressor, DenseSparsityConfig, SparsityCompressionConfig, @@ -45,7 +45,7 @@ def test_configs(name, type): ], ) def test_compressors(name, type): - compressor = Compressor.load_from_registry( + compressor = BaseCompressor.load_from_registry( name, config=SparsityCompressionConfig(format="none") ) assert isinstance(compressor, type)