Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] #232

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import operator
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -38,6 +39,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
Expand Down Expand Up @@ -104,7 +106,6 @@ def from_pretrained(
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)

return cls.from_compression_config(compression_config)

@classmethod
Expand Down Expand Up @@ -282,8 +283,14 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_sparse_target_names(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict,
compression_targets=sparse_compression_targets,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand All @@ -301,23 +308,41 @@ def decompress(self, model_path: str, model: Module):
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
sparse_decompressed = False

if self.sparsity_compressor is not None:
# Sparse decompression is applied on the model_path
dense_gen = self.sparsity_compressor.decompress(model_path)
self._replace_weights(dense_gen, model)
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
sparse_decompressed = True

if self.quantization_compressor is not None:
names_to_scheme = apply_quantization_config(model, self.quantization_config)
load_pretrained_quantization(model, model_path)
# Temporarily set quantization status to FROZEN to prevent
# quantization during apply_quantization_config. This ensures
# that the dtypes of the weights are not unintentionally updated.
# The status is restored after quantization params are loaded.
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
load_pretrained_quantization(model, model_path)

model_path_or_state_dict = (
model.state_dict() if sparse_decompressed else model_path
)

dense_gen = self.quantization_compressor.decompress(
model_path, names_to_scheme=names_to_scheme
model_path_or_state_dict, names_to_scheme=names_to_scheme
)
self._replace_weights(dense_gen, model)

def update_status(module):
def freeze_quantization_status(module):
module.quantization_status = QuantizationStatus.FROZEN

model.apply(update_status)
model.apply(freeze_quantization_status)
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)

def update_config(self, save_directory: str):
Expand Down Expand Up @@ -402,3 +427,23 @@ def new_dtype_byte_size(dtype):
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


@contextmanager
def override_quantization_status(
config: QuantizationConfig, status: QuantizationStatus
):
"""
Within this context, the quantization status will be set to the
supplied status. After the context exits, the original status
will be restored.

:param config: the quantization config to override
:param status: the status to temporarily set
"""
original_status = config.quantization_status
config.quantization_status = status
try:
yield
finally:
config.quantization_status = original_status
40 changes: 35 additions & 5 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from pathlib import Path
from typing import Any, Dict, Generator, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
)
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm
Expand Down Expand Up @@ -113,30 +118,55 @@ def compress(

def decompress(
self,
path_to_model_or_tensors: str,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
and returns a generator for sequentially decompressing back to a
dense state dict

:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
if isinstance(path_to_model_or_tensors, (str, Path)):
yield from self._decompress_from_path(
path_to_model_or_tensors, names_to_scheme, device
)

else:
yield from self._decompress_from_state_dict(
path_to_model_or_tensors, names_to_scheme
)

def _decompress_from_path(self, path_to_model, names_to_scheme, device):
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
path_to_model, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
yield merge_names(weight_name, "weight"), decompressed

def _decompress_from_state_dict(self, state_dict, names_to_scheme):
weight_mappings = get_nested_mappings_from_state_dict(
state_dict, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, param_value in weight_mappings[weight_name].items():
weight_data[param_name] = param_value

if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def compress_weight(
self,
weight: Tensor,
scale: Tensor,
quantization_args: QuantizationArgs,
zero_point: Optional[Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
quantization_args: Optional[QuantizationArgs] = None,
device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
"""
Expand All @@ -79,8 +79,8 @@ def compress_weight(
:param weight: uncompressed weight tensor
:param scale: quantization scale for weight
:param zero_point: quantization zero point for weight
:param g_idx: optional mapping from column index to group index
:param quantization_args: quantization parameters for weight
:param g_idx: optional mapping from column index to group index
:param device: optional device to move compressed output to
:return: dictionary of compressed weight data
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def compress_weight(
self,
weight: Tensor,
scale: Tensor,
quantization_args: QuantizationArgs,
zero_point: Optional[Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
quantization_args: Optional[QuantizationArgs] = None,
device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
"""
Expand Down
53 changes: 46 additions & 7 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand All @@ -30,7 +30,8 @@
class BaseSparseCompressor(BaseCompressor):
"""
Base class representing a sparse compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.
implement compression_param_info, compress_weight and decompress_weight; child
classes should also define COMPRESSION_PARAM_NAMES.

Compressors support compressing/decompressing a full module state dict or a single
quantized PyTorch leaf module.
Expand Down Expand Up @@ -59,19 +60,33 @@ class BaseSparseCompressor(BaseCompressor):
:param config: config specifying compression parameters
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression

:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress, if None
compress all layers (for backwards compatibility)
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
compression_data = self.compress_weight(name, value)
ignored = not self.should_compress(name, compression_targets)
if ignored:
compressed_dict[name] = value
continue
prefix = name
if prefix.endswith(".weight"):
prefix = prefix[: -(len(".weight"))]

compression_data = self.compress_weight(prefix, value)
for key in compression_data.keys():
if key in compressed_dict:
_LOGGER.warn(
Expand All @@ -97,8 +112,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
weight_mappings, ignored_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_unmatched_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -107,4 +124,26 @@ def decompress(
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed
yield merge_names(weight_name, "weight"), decompressed

for ignored_param_name, safe_path in ignored_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(ignored_param_name)
yield ignored_param_name, value

@staticmethod
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed.
Currently, this only returns True for weight parameters.

:param name: name of the parameter
:param expanded_targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if expanded_targets is None:
return name.endswith(".weight")

return (
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.utils import merge_names
from torch import Tensor

Expand Down Expand Up @@ -134,9 +135,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
bytemasks = tensor != 0
row_counts = bytemasks.sum(dim=-1)
row_offsets = torch.cumsum(row_counts, 0) - row_counts
values = tensor[bytemasks]
if tensor.dtype == FP8_DTYPE:
# acces raw bytes of the tensor
tensor_view = tensor.view(torch.int8)
values = tensor_view[bytemasks]
values = values.view(FP8_DTYPE)
else:
values = tensor[bytemasks]
bitmasks_packed = pack_bitmasks(bytemasks)

return values, bitmasks_packed, row_offsets


Expand Down
Loading
Loading