Skip to content

Commit

Permalink
Composability (#219)
Browse files Browse the repository at this point in the history
* Add: Support for targets and ignore in SparseCompressors
Enable: Operations on state_dict to allow composability
Add: Composability for compress/decompress pathways
Update: Typing for a few methods
Add: Composability Test
Add: Some testing utils

* Add: FP8 Test for composability

* Review Comments!

* More review comments from @dsikka

* Fix failing tests

* Rename is_target to is_sparse_target
Update _replace_weight to work with updates from `85b473e`
Add docstring to _replace_weights
Update failing test

* review comments from @kylesayrs
  • Loading branch information
rahul-tuli authored Jan 7, 2025
1 parent fe4a442 commit 7801f00
Show file tree
Hide file tree
Showing 14 changed files with 694 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import operator
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -38,6 +39,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
Expand Down Expand Up @@ -104,7 +106,6 @@ def from_pretrained(
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)

return cls.from_compression_config(compression_config)

@classmethod
Expand Down Expand Up @@ -282,8 +283,14 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_sparse_target_names(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict,
compression_targets=sparse_compression_targets,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand All @@ -301,23 +308,41 @@ def decompress(self, model_path: str, model: Module):
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
sparse_decompressed = False

if self.sparsity_compressor is not None:
# Sparse decompression is applied on the model_path
dense_gen = self.sparsity_compressor.decompress(model_path)
self._replace_weights(dense_gen, model)
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
sparse_decompressed = True

if self.quantization_compressor is not None:
names_to_scheme = apply_quantization_config(model, self.quantization_config)
load_pretrained_quantization(model, model_path)
# Temporarily set quantization status to FROZEN to prevent
# quantization during apply_quantization_config. This ensures
# that the dtypes of the weights are not unintentionally updated.
# The status is restored after quantization params are loaded.
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
load_pretrained_quantization(model, model_path)

model_path_or_state_dict = (
model.state_dict() if sparse_decompressed else model_path
)

dense_gen = self.quantization_compressor.decompress(
model_path, names_to_scheme=names_to_scheme
model_path_or_state_dict, names_to_scheme=names_to_scheme
)
self._replace_weights(dense_gen, model)

def update_status(module):
def freeze_quantization_status(module):
module.quantization_status = QuantizationStatus.FROZEN

model.apply(update_status)
model.apply(freeze_quantization_status)
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)

def update_config(self, save_directory: str):
Expand Down Expand Up @@ -367,12 +392,26 @@ def update_config(self, save_directory: str):
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)

def _replace_weights(self, dense_weight_generator, model):
def _replace_weights(self, dense_weight_generator, model: Module):
"""
Replace the weights of the model with the
provided dense weights.
This method iterates over the dense_weight_generator and
updates the corresponding weights in the model. If a parameter
name does not exist in the model, it will be skipped.
:param dense_weight_generator (generator): A generator that yields
tuples of (name, data), where 'name' is the parameter name and
'data' is the updated param data
:param model: The model whose weights are to be updated.
"""
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
split_name = name.split(".")
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
module = operator.attrgetter(prefix)(model)
update_parameter_data(module, data, param_name)
if hasattr(module, param_name):
update_parameter_data(module, data, param_name)


def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
Expand Down Expand Up @@ -402,3 +441,23 @@ def new_dtype_byte_size(dtype):
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


@contextmanager
def override_quantization_status(
config: QuantizationConfig, status: QuantizationStatus
):
"""
Within this context, the quantization status will be set to the
supplied status. After the context exits, the original status
will be restored.
:param config: the quantization config to override
:param status: the status to temporarily set
"""
original_status = config.quantization_status
config.quantization_status = status
try:
yield
finally:
config.quantization_status = original_status
40 changes: 35 additions & 5 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from pathlib import Path
from typing import Any, Dict, Generator, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
)
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm
Expand Down Expand Up @@ -113,30 +118,55 @@ def compress(

def decompress(
self,
path_to_model_or_tensors: str,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
and returns a generator for sequentially decompressing back to a
dense state dict
:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
if isinstance(path_to_model_or_tensors, (str, Path)):
yield from self._decompress_from_path(
path_to_model_or_tensors, names_to_scheme, device
)

else:
yield from self._decompress_from_state_dict(
path_to_model_or_tensors, names_to_scheme
)

def _decompress_from_path(self, path_to_model, names_to_scheme, device):
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
path_to_model, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
yield merge_names(weight_name, "weight"), decompressed

def _decompress_from_state_dict(self, state_dict, names_to_scheme):
weight_mappings = get_nested_mappings_from_state_dict(
state_dict, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, param_value in weight_mappings[weight_name].items():
weight_data[param_name] = param_value

if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ def compress_weight(
self,
weight: Tensor,
scale: Tensor,
quantization_args: QuantizationArgs,
zero_point: Optional[Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
quantization_args: Optional[QuantizationArgs] = None,
device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
"""
Compresses a single uncompressed weight
:param weight: uncompressed weight tensor
:param scale: quantization scale for weight
:param quantization_args: quantization parameters for weight
:param zero_point: quantization zero point for weight
:param g_idx: optional mapping from column index to group index
:param quantization_args: quantization parameters for weight
:param device: optional device to move compressed output to
:return: dictionary of compressed weight data
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ def compress_weight(
self,
weight: Tensor,
scale: Tensor,
quantization_args: QuantizationArgs,
zero_point: Optional[Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
quantization_args: Optional[QuantizationArgs] = None,
device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
"""
Compresses a single uncompressed weight
:param weight: uncompressed weight tensor
:param scale: quantization scale for weight
:param quantization_args: quantization parameters for weight
:param zero_point: quantization zero point for weight
:param g_idx: optional mapping from column index to group index
:param quantization_args: quantization parameters for weight
:param device: optional device to move compressed output to
:return: dictionary of compressed weight data
"""
Expand Down
52 changes: 45 additions & 7 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand All @@ -30,7 +30,8 @@
class BaseSparseCompressor(BaseCompressor):
"""
Base class representing a sparse compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.
implement compression_param_info, compress_weight and decompress_weight; child
classes should also define COMPRESSION_PARAM_NAMES.
Compressors support compressing/decompressing a full module state dict or a single
quantized PyTorch leaf module.
Expand Down Expand Up @@ -59,19 +60,32 @@ class BaseSparseCompressor(BaseCompressor):
:param config: config specifying compression parameters
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression
:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress,
otherwise compress all layers (for backwards compatibility)
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
compression_data = self.compress_weight(name, value)
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
prefix = name
if prefix.endswith(".weight"):
prefix = prefix[: -(len(".weight"))]

compression_data = self.compress_weight(prefix, value)
for key in compression_data.keys():
if key in compressed_dict:
_LOGGER.warn(
Expand All @@ -97,8 +111,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
weight_mappings, ignored_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_unmatched_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -107,4 +123,26 @@ def decompress(
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed
yield merge_names(weight_name, "weight"), decompressed

for ignored_param_name, safe_path in ignored_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(ignored_param_name)
yield ignored_param_name, value

@staticmethod
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed.
Currently, this only returns True for weight parameters.
:param name: name of the parameter
:param expanded_targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if expanded_targets is None:
return name.endswith(".weight")

return (
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.utils import merge_names
from torch import Tensor

Expand Down Expand Up @@ -134,9 +135,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
bytemasks = tensor != 0
row_counts = bytemasks.sum(dim=-1)
row_offsets = torch.cumsum(row_counts, 0) - row_counts
values = tensor[bytemasks]
if tensor.dtype == FP8_DTYPE:
# acces raw bytes of the tensor
tensor_view = tensor.view(torch.int8)
values = tensor_view[bytemasks]
values = values.view(FP8_DTYPE)
else:
values = tensor[bytemasks]
bitmasks_packed = pack_bitmasks(bytemasks)

return values, bitmasks_packed, row_offsets


Expand Down
Loading

0 comments on commit 7801f00

Please sign in to comment.