From a528334096a22c6f20ad21b6877eb2c732097f89 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 09:48:35 +0000 Subject: [PATCH] Enable: Sparse Compression with targets and ignores Signed-off-by: Rahul Tuli --- .../model_compressors/model_compressor.py | 11 ++++- .../compressors/sparse_compressors/base.py | 44 ++++++++++++++++--- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 68bd52ec..bc4633d9 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -18,7 +18,7 @@ import os import re from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union import compressed_tensors import torch @@ -38,6 +38,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.lifecycle import expand_targets from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, @@ -282,8 +283,14 @@ def compress( ) if self.sparsity_compressor is not None: + sparse_compression_targets: Set[str] = expand_targets( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) compressed_state_dict = self.sparsity_compressor.compress( - compressed_state_dict + compressed_state_dict, + compression_targets=sparse_compression_targets, ) # HACK: Override the dtype_byte_size function in transformers to diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 1b1a6825..67e2727a 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Set, Tuple from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -30,7 +30,8 @@ class BaseSparseCompressor(BaseCompressor): """ Base class representing a sparse compression algorithm. Each child class should - implement compression_param_info, compress_weight and decompress_weight. + implement compression_param_info, compress_weight and decompress_weight; child + classes should also define COMPRESSION_PARAM_NAMES. Compressors support compressing/decompressing a full module state dict or a single quantized PyTorch leaf module. @@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor): :param config: config specifying compression parameters """ - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress( + self, + model_state: Dict[str, Tensor], + compression_targets: Optional[Set[str]] = None, + ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression :param model_state: state dict of uncompressed model + :param compression_targets: optional set of layer prefixes to compress, if None + compress all layers (for backwards compatibility) :return: compressed state dict """ compressed_dict = {} @@ -71,6 +78,9 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): + if not self.should_compress(name, compression_targets): + compressed_dict[name] = value + continue compression_data = self.compress_weight(name, value) for key in compression_data.keys(): if key in compressed_dict: @@ -97,8 +107,10 @@ def decompress( :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ - weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + weight_mappings, other_params = get_nested_weight_mappings( + path_to_model_or_tensors, + self.COMPRESSION_PARAM_NAMES, + return_other_params=True, ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -107,4 +119,24 @@ def decompress( with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) decompressed = self.decompress_weight(weight_data) - yield weight_name, decompressed + full_name = merge_names(weight_name, "weight") + yield full_name, decompressed + + for other_name, safe_path in other_params.items(): + with safe_open(safe_path, framework="pt", device=device) as f: + value = f.get_tensor(other_name) + yield other_name, value + + @staticmethod + def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool: + """ + Check if a parameter should be compressed + + :param name: name of the parameter + :param targets: set of layer prefixes to compress + :return: whether or not the parameter should be compressed + """ + if targets is None: + return name.endswith(".weight") + + return name.endswith(".weight") and name[: -(len(".weight"))] in targets