Skip to content

Commit

Permalink
Composability for sparse+quant compression decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Dec 2, 2024
1 parent 8fd469f commit ef4b4a0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def from_pretrained(
:return: compressor for the configs, or None if model is not compressed
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)

return cls.from_compression_config(compression_config)

Expand Down Expand Up @@ -269,9 +268,9 @@ def compress(

compressed_state_dict = state_dict

quantized_modules_to_args: Dict[
str, QuantizationArgs
] = map_modules_to_quant_args(model)
quantized_modules_to_args: Dict[str, QuantizationArgs] = (
map_modules_to_quant_args(model)
)

if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
Expand Down Expand Up @@ -308,16 +307,28 @@ def decompress(self, model_path: str, model: Module):
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
sparse_decompressed = False
if self.quantization_compressor is not None:
# update model structure
names_to_scheme = apply_quantization_config(model, self.quantization_config)
load_pretrained_quantization(model, model_path)

if self.sparsity_compressor is not None:
# sparse decompression
dense_gen = self.sparsity_compressor.decompress(model_path)
self._replace_weights(dense_gen, model)
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
sparse_decompressed = True

if self.quantization_compressor is not None:
# quantized decompression
names_to_scheme = apply_quantization_config(model, self.quantization_config)
load_pretrained_quantization(model, model_path)
model_path_or_state_dict = (
model.state_dict() if sparse_decompressed else model_path
)

dense_gen = self.quantization_compressor.decompress(
model_path, names_to_scheme=names_to_scheme
model_path_or_state_dict, names_to_scheme=names_to_scheme
)
self._replace_weights(dense_gen, model)

Expand Down
37 changes: 33 additions & 4 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Any, Dict, Generator, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
)
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm
Expand Down Expand Up @@ -113,21 +117,31 @@ def compress(

def decompress(
self,
path_to_model_or_tensors: str,
path_to_model_or_tensors: Union[str, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
and returns a generator for sequentially decompressing back to a
dense state dict
:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
if isinstance(path_to_model_or_tensors, str):
yield from self._decompress_from_path(
path_to_model_or_tensors, names_to_scheme, device
)

else:
yield from self._decompress_from_state_dict(
path_to_model_or_tensors, names_to_scheme
)

def _decompress_from_path(self, path_to_model_or_tensors, names_to_scheme, device):
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
)
Expand All @@ -137,6 +151,21 @@ def decompress(
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
yield merge_names(weight_name, "weight"), decompressed

def _decompress_from_state_dict(self, state_dict, names_to_scheme):
weight_mappings = get_nested_mappings_from_state_dict(
state_dict, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, param_value in weight_mappings[weight_name].items():
weight_data[param_name] = param_value

if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
Expand Down

0 comments on commit ef4b4a0

Please sign in to comment.