Skip to content

Commit

Permalink
Enable: Sparse Compression with targets and ignores
Browse files Browse the repository at this point in the history
Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
rahul-tuli committed Nov 27, 2024
1 parent 4c21a95 commit a528334
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -38,6 +38,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_targets
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
Expand Down Expand Up @@ -282,8 +283,14 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_targets(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict,
compression_targets=sparse_compression_targets,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand Down
44 changes: 38 additions & 6 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand All @@ -30,7 +30,8 @@
class BaseSparseCompressor(BaseCompressor):
"""
Base class representing a sparse compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.
implement compression_param_info, compress_weight and decompress_weight; child
classes should also define COMPRESSION_PARAM_NAMES.
Compressors support compressing/decompressing a full module state dict or a single
quantized PyTorch leaf module.
Expand Down Expand Up @@ -59,18 +60,27 @@ class BaseSparseCompressor(BaseCompressor):
:param config: config specifying compression parameters
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression
:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress, if None
compress all layers (for backwards compatibility)
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
compression_data = self.compress_weight(name, value)
for key in compression_data.keys():
if key in compressed_dict:
Expand All @@ -97,8 +107,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
weight_mappings, other_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_other_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -107,4 +119,24 @@ def decompress(
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed
full_name = merge_names(weight_name, "weight")
yield full_name, decompressed

for other_name, safe_path in other_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(other_name)
yield other_name, value

@staticmethod
def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed
:param name: name of the parameter
:param targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if targets is None:
return name.endswith(".weight")

return name.endswith(".weight") and name[: -(len(".weight"))] in targets

0 comments on commit a528334

Please sign in to comment.