diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 297fc3a6..8a583cfe 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -17,8 +17,9 @@ import operator import os import re +from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union import compressed_tensors import torch @@ -38,6 +39,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.lifecycle import expand_sparse_target_names from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, @@ -104,7 +106,6 @@ def from_pretrained( """ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) - return cls.from_compression_config(compression_config) @classmethod @@ -282,8 +283,14 @@ def compress( ) if self.sparsity_compressor is not None: + sparse_compression_targets: Set[str] = expand_sparse_target_names( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) compressed_state_dict = self.sparsity_compressor.compress( - compressed_state_dict + compressed_state_dict, + compression_targets=sparse_compression_targets, ) # HACK: Override the dtype_byte_size function in transformers to @@ -301,23 +308,41 @@ def decompress(self, model_path: str, model: Module): :param model: pytorch model to load decompressed weights into """ model_path = get_safetensors_folder(model_path) + sparse_decompressed = False + if self.sparsity_compressor is not None: + # Sparse decompression is applied on the model_path dense_gen = self.sparsity_compressor.decompress(model_path) self._replace_weights(dense_gen, model) setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) + sparse_decompressed = True if self.quantization_compressor is not None: - names_to_scheme = apply_quantization_config(model, self.quantization_config) - load_pretrained_quantization(model, model_path) + # Temporarily set quantization status to FROZEN to prevent + # quantization during apply_quantization_config. This ensures + # that the dtypes of the weights are not unintentionally updated. + # The status is restored after quantization params are loaded. + with override_quantization_status( + self.quantization_config, QuantizationStatus.FROZEN + ): + names_to_scheme = apply_quantization_config( + model, self.quantization_config + ) + load_pretrained_quantization(model, model_path) + + model_path_or_state_dict = ( + model.state_dict() if sparse_decompressed else model_path + ) + dense_gen = self.quantization_compressor.decompress( - model_path, names_to_scheme=names_to_scheme + model_path_or_state_dict, names_to_scheme=names_to_scheme ) self._replace_weights(dense_gen, model) - def update_status(module): + def freeze_quantization_status(module): module.quantization_status = QuantizationStatus.FROZEN - model.apply(update_status) + model.apply(freeze_quantization_status) setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config) def update_config(self, save_directory: str): @@ -402,3 +427,23 @@ def new_dtype_byte_size(dtype): raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) return bit_size // 8 + + +@contextmanager +def override_quantization_status( + config: QuantizationConfig, status: QuantizationStatus +): + """ + Within this context, the quantization status will be set to the + supplied status. After the context exits, the original status + will be restored. + + :param config: the quantization config to override + :param status: the status to temporarily set + """ + original_status = config.quantization_status + config.quantization_status = status + try: + yield + finally: + config.quantization_status = original_status diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 8ceb8afd..b49361d4 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -13,12 +13,17 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from pathlib import Path +from typing import Any, Dict, Generator, Tuple, Union import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.quantization import QuantizationArgs -from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from compressed_tensors.utils import ( + get_nested_mappings_from_state_dict, + get_nested_weight_mappings, + merge_names, +) from safetensors import safe_open from torch import Tensor from tqdm import tqdm @@ -113,7 +118,7 @@ def compress( def decompress( self, - path_to_model_or_tensors: str, + path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], names_to_scheme: Dict[str, QuantizationArgs], device: str = "cpu", ) -> Generator[Tuple[str, Tensor], None, None]: @@ -121,15 +126,25 @@ def decompress( Reads a compressed state dict located at path_to_model_or_tensors and returns a generator for sequentially decompressing back to a dense state dict - :param path_to_model_or_tensors: path to compressed safetensors model (directory with one or more safetensors files) or compressed tensors file :param names_to_scheme: quantization args for each quantized weight :param device: optional device to load intermediate weights into :return: compressed state dict """ + if isinstance(path_to_model_or_tensors, (str, Path)): + yield from self._decompress_from_path( + path_to_model_or_tensors, names_to_scheme, device + ) + + else: + yield from self._decompress_from_state_dict( + path_to_model_or_tensors, names_to_scheme + ) + + def _decompress_from_path(self, path_to_model, names_to_scheme, device): weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + path_to_model, self.COMPRESSION_PARAM_NAMES ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -137,6 +152,21 @@ def decompress( full_name = merge_names(weight_name, param_name) with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) + if "weight_scale" in weight_data: + quant_args = names_to_scheme[weight_name] + decompressed = self.decompress_weight( + compressed_data=weight_data, quantization_args=quant_args + ) + yield merge_names(weight_name, "weight"), decompressed + + def _decompress_from_state_dict(self, state_dict, names_to_scheme): + weight_mappings = get_nested_mappings_from_state_dict( + state_dict, self.COMPRESSION_PARAM_NAMES + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, param_value in weight_mappings[weight_name].items(): + weight_data[param_name] = param_value if "weight_scale" in weight_data: quant_args = names_to_scheme[weight_name] diff --git a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py index 85eebe00..eea50848 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -68,9 +68,9 @@ def compress_weight( self, weight: Tensor, scale: Tensor, + quantization_args: QuantizationArgs, zero_point: Optional[Tensor] = None, g_idx: Optional[torch.Tensor] = None, - quantization_args: Optional[QuantizationArgs] = None, device: Optional[torch.device] = None, ) -> Dict[str, torch.Tensor]: """ @@ -79,8 +79,8 @@ def compress_weight( :param weight: uncompressed weight tensor :param scale: quantization scale for weight :param zero_point: quantization zero point for weight - :param g_idx: optional mapping from column index to group index :param quantization_args: quantization parameters for weight + :param g_idx: optional mapping from column index to group index :param device: optional device to move compressed output to :return: dictionary of compressed weight data """ diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index c236f8c9..8f694c8b 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -68,9 +68,9 @@ def compress_weight( self, weight: Tensor, scale: Tensor, + quantization_args: QuantizationArgs, zero_point: Optional[Tensor] = None, g_idx: Optional[torch.Tensor] = None, - quantization_args: Optional[QuantizationArgs] = None, device: Optional[torch.device] = None, ) -> Dict[str, torch.Tensor]: """ diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 1b1a6825..72232801 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Set, Tuple from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -30,7 +30,8 @@ class BaseSparseCompressor(BaseCompressor): """ Base class representing a sparse compression algorithm. Each child class should - implement compression_param_info, compress_weight and decompress_weight. + implement compression_param_info, compress_weight and decompress_weight; child + classes should also define COMPRESSION_PARAM_NAMES. Compressors support compressing/decompressing a full module state dict or a single quantized PyTorch leaf module. @@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor): :param config: config specifying compression parameters """ - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress( + self, + model_state: Dict[str, Tensor], + compression_targets: Optional[Set[str]] = None, + ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression :param model_state: state dict of uncompressed model + :param compression_targets: optional set of layer prefixes to compress, if None + compress all layers (for backwards compatibility) :return: compressed state dict """ compressed_dict = {} @@ -71,7 +78,15 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): - compression_data = self.compress_weight(name, value) + ignored = not self.should_compress(name, compression_targets) + if ignored: + compressed_dict[name] = value + continue + prefix = name + if prefix.endswith(".weight"): + prefix = prefix[: -(len(".weight"))] + + compression_data = self.compress_weight(prefix, value) for key in compression_data.keys(): if key in compressed_dict: _LOGGER.warn( @@ -97,8 +112,10 @@ def decompress( :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ - weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + weight_mappings, ignored_params = get_nested_weight_mappings( + path_to_model_or_tensors, + self.COMPRESSION_PARAM_NAMES, + return_unmatched_params=True, ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -107,4 +124,26 @@ def decompress( with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) decompressed = self.decompress_weight(weight_data) - yield weight_name, decompressed + yield merge_names(weight_name, "weight"), decompressed + + for ignored_param_name, safe_path in ignored_params.items(): + with safe_open(safe_path, framework="pt", device=device) as f: + value = f.get_tensor(ignored_param_name) + yield ignored_param_name, value + + @staticmethod + def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: + """ + Check if a parameter should be compressed. + Currently, this only returns True for weight parameters. + + :param name: name of the parameter + :param expanded_targets: set of layer prefixes to compress + :return: whether or not the parameter should be compressed + """ + if expanded_targets is None: + return name.endswith(".weight") + + return ( + name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets + ) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index a950aa64..0434499d 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -19,6 +19,7 @@ from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import FP8_DTYPE from compressed_tensors.utils import merge_names from torch import Tensor @@ -134,9 +135,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: bytemasks = tensor != 0 row_counts = bytemasks.sum(dim=-1) row_offsets = torch.cumsum(row_counts, 0) - row_counts - values = tensor[bytemasks] + if tensor.dtype == FP8_DTYPE: + # acces raw bytes of the tensor + tensor_view = tensor.view(torch.int8) + values = tensor_view[bytemasks] + values = values.view(FP8_DTYPE) + else: + values = tensor[bytemasks] bitmasks_packed = pack_bitmasks(bytemasks) - return values, bitmasks_packed, row_offsets diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ed9a50f7..43279d71 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -18,7 +18,7 @@ from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Set, Union import torch from compressed_tensors.config import CompressionFormat @@ -52,6 +52,8 @@ "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", + "expand_sparse_target_names", + "is_target", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) +def expand_sparse_target_names( + model: Module, targets: Iterable[str], ignore: Iterable[str] +) -> Set[str]: + """ + Finds all unique module names in the model that match the given + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param model: model to search for targets in + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: set of all targets that match the given targets and should + not be ignored + """ + return { + name + for name, module in iter_named_leaf_modules(model) + if is_target(name, module, targets, ignore) + } + + +def is_target( + name: str, module: Module, targets: Iterable[str], ignore: Iterable[str] +) -> bool: + """ + Determines if a module should be included in the targets based on the + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param name: name of the module + :param module: the module itself + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: True if the module is a target and not ignored, False otherwise + """ + return bool( + find_name_or_class_matches(name, module, targets) + and not find_name_or_class_matches(name, module, ignore or []) + ) + + def find_name_or_class_matches( name: str, module: Module, targets: Iterable[str], check_contains: bool = False ) -> List[str]: diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index dcab122a..5426a226 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -83,7 +83,7 @@ def dequantize( x_q: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor = None, - args: QuantizationArgs = None, + args: Optional[QuantizationArgs] = None, dtype: Optional[torch.dtype] = None, g_idx: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 4fdb3007..f7569b98 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -16,7 +16,7 @@ import os import re import struct -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open from torch import Tensor @@ -30,10 +30,14 @@ "merge_names", "get_weight_mappings", "get_nested_weight_mappings", + "get_nested_mappings_from_state_dict", "get_quantization_state_dict", "is_quantization_param", ] +WeightMappingType = Dict[str, str] +NestedWeightMappingType = Dict[str, WeightMappingType] + def get_safetensors_folder( pretrained_model_name_or_path: str, cache_dir: Optional[str] = None @@ -92,7 +96,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]: return header -def match_param_name(full_name: str, param_name: str) -> str: +def match_param_name(full_name: str, param_name: str) -> Optional[str]: """ Helper function extracting the uncompressed parameterized layer name from a compressed name. Assumes the compressed name was merged using merge_names. @@ -176,38 +180,98 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: def get_nested_weight_mappings( - model_path: str, params_to_nest: List[str] -) -> Dict[str, Dict[str, str]]: + model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False +) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]: """ Takes a path to a state dict saved in safetensors format and returns a nested - mapping from uncompressed parameterized layer names to the file locations of each - of the layers compression parameters. + mapping from uncompressed parameterized layer names to the file locations of + each layer's compression parameters. - layer.weight: { + Example of the nested mapping: + layer: { bitmask: file_location, row_offsets: file_location, shape: file_location, compressed: file_location } - This generalizes to cases where the model is split into multiple safetensors files + If other parameters are found that do not match the nested parameters, they will + be returned in a separate dictionary only if return_unmatched_params is True. + This dictionary may be needed for cases where compressors are stacked (e.g., + quantization compression followed by sparse compression). + + Example of the unmatched params mapping: + { + layer.weight_scale: file_location, + layer.input_scale: file_location + } - :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index - :return: nested mapping of parameterized layer name to file location + This generalizes to cases where the model is split into multiple safetensors + files. + + :param model_path: Path to the safetensors state dict, must contain either a + single safetensors file or multiple files with an index. + :param params_to_nest: List of parameter names to nest. + :param return_unmatched_params: If True, return a second dictionary containing + the remaining parameters that were not matched to the params_to_nest. + :return: + - If return_unmatched_params is False: + NestedWeightMappingType: A nested mapping of parameterized layer names to + file locations of each layer's compression parameters. + - If return_unmatched_params is True: + Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing: + - NestedWeightMappingType: A nested mapping of parameterized layer + names to file locations of each layer's compression parameters. + - WeightMappingType: A mapping of the remaining parameter names to + their file locations that were not matched to the params_to_nest. """ weight_mappings = get_weight_mappings(model_path) - nested_weight_mappings = {} - for key in weight_mappings.keys(): + unmatched_params = {} + + for key, file_location in weight_mappings.items(): + matched = False for param_name in params_to_nest: - maybe_match = match_param_name(key, param_name) - if maybe_match is not None: - dense_param = maybe_match + dense_param = match_param_name(key, param_name) + if dense_param: if dense_param not in nested_weight_mappings: nested_weight_mappings[dense_param] = {} - nested_weight_mappings[dense_param][param_name] = weight_mappings[key] + nested_weight_mappings[dense_param][param_name] = file_location + matched = True + if return_unmatched_params and not matched: + unmatched_params[key] = file_location + + if return_unmatched_params: + return nested_weight_mappings, unmatched_params + return nested_weight_mappings + +def get_nested_mappings_from_state_dict(state_dict, params_to_nest): + """ + Takes a state dict and returns a nested mapping from uncompressed + parameterized layer names to the value of + each layer's compression parameters. + + Example of the nested mapping: + layer: { + weight_scale: ..., + weight: ..., + zero_point: ..., + } + + :param state_dict: state dict of the model + :param params_to_nest: List of parameter names to nest. + :return: Nested mapping of parameterized layer names to the value of + each layer's compression parameters. + """ + nested_weight_mappings = {} + for key in state_dict.keys(): + for param_name in params_to_nest: + dense_param = match_param_name(key, param_name) + if dense_param: + if dense_param not in nested_weight_mappings: + nested_weight_mappings[dense_param] = {} + nested_weight_mappings[dense_param][param_name] = state_dict[key] return nested_weight_mappings diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 4a8327ce..bbde3011 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from copy import deepcopy +from pathlib import Path import pytest +import torch from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config.base import SparsityCompressionConfig from compressed_tensors.quantization.quant_config import QuantizationConfig -from tests.testing_utils import requires_hf_quantizer +from safetensors.torch import save_file +from tests.testing_utils import induce_sparsity, requires_hf_quantizer def sparsity_config(): @@ -109,3 +113,144 @@ def test_hf_compressor_tensors_config(s_config, q_config, tmp_path): assert ( ModelCompressor.parse_quantization_config(compression_config) == q_config_dict ) + + +@pytest.fixture +def fake_model_class(): + import torch.nn as nn + + class CustomLinearModel(nn.Module): + def __init__(self, weights, weight_scale=None, weight_zero_point=None): + super(CustomLinearModel, self).__init__() + out_features, in_features = weights.shape + + # Define a linear layer without bias + self.linear = nn.Linear(in_features, out_features, bias=False) + + # Set the weights of the linear layer + self.linear.weight = nn.Parameter(weights, requires_grad=False) + + # Attach weight_scale and weight_zero_point as parameters + if weight_scale is not None: + self.linear.weight_scale = nn.Parameter( + torch.tensor(weight_scale), requires_grad=False + ) + if weight_zero_point is not None: + self.linear.weight_zero_point = nn.Parameter( + torch.tensor(weight_zero_point), requires_grad=False + ) + + def forward(self, x): + return self.linear(x) + + return CustomLinearModel + + +def get_bitmask_sparsity_config(): + from compressed_tensors import BitmaskConfig + + return BitmaskConfig( + format="sparse-bitmask", + global_sparsity=0.7, + targets=["Linear"], + sparsity_structure="unstructured", + ) + + +def create_quantization_config(bits=8, type="int", strategy="tensor"): + + config_dict = { + "format": "int-quantized", + "global_compression_ratio": 1.0, + "quant_method": "compressed-tensors", + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": bits, + "strategy": strategy, + "symmetric": True, + "type": type, + }, + } + }, + } + + return QuantizationConfig.model_validate(config_dict) + + +@pytest.mark.parametrize("sparsity_config", [get_bitmask_sparsity_config()]) +@pytest.mark.parametrize( + "quantization_config", + [ + create_quantization_config(bits=8, type="int", strategy="channel"), + create_quantization_config(bits=8, type="float", strategy="channel"), + ], +) +def test_composability( + tmp_path, fake_model_class, sparsity_config, quantization_config +): + from compressed_tensors.quantization.lifecycle.forward import quantize + + weights = torch.rand(10, 5) + sparse_weights = induce_sparsity(weights, sparsity_config.global_sparsity) + + quantization_args = quantization_config.config_groups["group_0"].weights + + if quantization_args.strategy == "channel": + scale = torch.tensor([1.0] * weights.shape[1]) + elif quantization_args.strategy == "tensor": + scale = torch.tensor([1.0]) + + zero_point = torch.zeros_like(scale) + + quantized_weights = quantize( + sparse_weights, + scale=scale, + zero_point=zero_point, + args=quantization_args, + ) + + fake_oneshot_model = fake_model_class(quantized_weights, scale, zero_point) + fake_oneshot_model.linear.quantization_scheme = quantization_config.config_groups[ + "group_0" + ] + model_compressor = ModelCompressor( + sparsity_config=sparsity_config, quantization_config=quantization_config + ) + # does both sparse and quantization compression + compressed_state_dict = model_compressor.compress(fake_oneshot_model) + + save_dir = tmp_path / "model" + save_dir = _create_dummy_checkpoint( + compressed_state_dict, save_dir, model_compressor + ) + + decompressed_model = fake_model_class(torch.zeros_like(weights)) + model_compressor.decompress(model=decompressed_model, model_path=save_dir) + + # check that the decompressed model is the same as the original model + _check_state_dicts(fake_oneshot_model.state_dict(), decompressed_model.state_dict()) + + +def _create_dummy_checkpoint(state_dict, save_dir, model_compressor): + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + save_file(state_dict, save_dir / "model.safetensors") + + config_file_path = save_dir / "config.json" + with open(config_file_path, "w") as config_file: + json.dump({}, config_file, indent=2, sort_keys=True) + + model_compressor.update_config(save_dir) + return save_dir + + +def _check_state_dicts(state_dict1, state_dict2): + for key in state_dict1.keys(): + assert key in state_dict2, f"Missing tensor: {key}" + if key.endswith("weight"): + original_tensor = state_dict1[key] + decompressed_tensor = state_dict2[key].to(original_tensor.dtype) + diff = torch.abs(original_tensor - decompressed_tensor) + assert not torch.any(diff > 0.01), f"Max diff: {torch.max(diff)}" diff --git a/tests/test_compressors/sparse_compressors/test_bitmask.py b/tests/test_compressors/sparse_compressors/test_bitmask.py index 248580bc..491ee2e8 100644 --- a/tests/test_compressors/sparse_compressors/test_bitmask.py +++ b/tests/test_compressors/sparse_compressors/test_bitmask.py @@ -44,16 +44,16 @@ def test_bitmask_sizes(shape, sparsity, dtype): assert len(dense_state_dict) * 4 == len(sparse_state_dict) # bitmask should be 1 bit per dense element, rounded up to nearest int8 - sparse_shape = sparse_state_dict["dummy.weight.shape"] + sparse_shape = sparse_state_dict["dummy.shape"] assert torch.all(torch.eq(sparse_shape, torch.tensor(shape))) - bitmask_shape = sparse_state_dict["dummy.weight.bitmask"].shape + bitmask_shape = sparse_state_dict["dummy.bitmask"].shape assert bitmask_shape[0] == sparse_shape[0] assert bitmask_shape[1] == int(math.ceil(sparse_shape[1] / 8.0)) # one value for each non-zero weight - values_shape = sparse_state_dict["dummy.weight.compressed"].shape + values_shape = sparse_state_dict["dummy.compressed"].shape assert values_shape[0] == torch.sum(test_tensor != 0) - row_offsets_shape = sparse_state_dict["dummy.weight.row_offsets"].shape + row_offsets_shape = sparse_state_dict["dummy.row_offsets"].shape assert row_offsets_shape[0] == test_tensor.shape[0] diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 5ad56b8e..8f3c93fc 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -14,6 +14,7 @@ import re from typing import Optional +from unittest.mock import MagicMock import pytest import torch @@ -26,12 +27,38 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, + expand_sparse_target_names, + is_target, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM +@pytest.fixture +def mock_model(): + model = MagicMock() + model.named_modules.return_value = [ + ("layer1", MagicMock()), + ("layer2", MagicMock()), + ("layer3", MagicMock()), + ] + return model + + +@pytest.fixture +def mock_module(): + return MagicMock() + + +@pytest.fixture +def llama_stories_model(): + return AutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", + torch_dtype="auto", + ) + + def test_target_prioritization(mock_frozen): # tests that the config_groups are applied in the correct order # of priority, where exact layer name > regex > module name @@ -266,3 +293,73 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 + + +@pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ([], [], set()), + (["layer1", "layer2"], [], {"layer1", "layer2"}), + ([], ["layer1"], set()), + (["layer1", "layer2"], ["layer2"], {"layer1"}), + (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), + ], +) +def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): + expanded_targets = expand_sparse_target_names(mock_model, targets, ignore) + assert expanded_targets == expected_targets + + +@pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ( + ["re:model.layers.[01].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[01].self_attn.q_proj"], + [], + set(["model.layers.0.self_attn.q_proj", "model.layers.1.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[0-2].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj", "model.layers.2.self_attn.q_proj"]), + ), + ( + ["model.layers.0.self_attn.q_proj"], + ["model.layers.0.self_attn.q_proj"], + set(), + ), + ( + ["re:model.layers.*.self_attn.q_proj"], + ["re:model.layers.[01].self_attn.q_proj"], + set( + f"model.layers.{layer_idx}.self_attn.q_proj" + for layer_idx in range(2, 6) + ), + ), + ], +) +def test_expand_targets_with_llama_stories( + llama_stories_model, targets, ignore, expected_targets +): + expanded_targets = expand_sparse_target_names(llama_stories_model, targets, ignore) + assert expanded_targets == expected_targets + + +@pytest.mark.parametrize( + "name, targets, ignore, expected", + [ + ("layer1", ["layer1"], [], True), + ("layer1", ["layer1"], ["layer1"], False), + ("layer1", ["layer2"], [], False), + ("layer1", ["re:layer.*"], [], True), + ("layer1", ["re:layer.*"], ["re:layer1"], False), + ], +) +def test_is_target_with_mock(mock_module, name, targets, ignore, expected): + result = is_target(name, mock_module, targets, ignore) + assert result == expected diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py new file mode 100644 index 00000000..932a8926 --- /dev/null +++ b/tests/test_utils/test_safetensors_load.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +from compressed_tensors.utils.safetensors_load import get_nested_weight_mappings + + +mock_weight_mappings = { + "layer1.weight": "file1", + "layer1.bias": "file2", + "layer2.weight": "file3", + "layer2.bias": "file4", + "layer3.weight": "file5", +} + + +@pytest.fixture +def mock_get_weight_mappings(): + with patch( + "compressed_tensors.utils.safetensors_load.get_weight_mappings", + return_value=mock_weight_mappings, + ): + yield + + +@pytest.mark.usefixtures("mock_get_weight_mappings") +class TestGetNestedWeightMappings: + """ + Tests for the get_nested_weight_mappings function + in different scenarios, such as single and multiple + parameters to nest, and returning other parameters + """ + + def test_single_param(self): + params_to_nest = ["weight"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_multiple_params(self): + params_to_nest = ["weight", "bias"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1", "bias": "file2"}, + "layer2": {"weight": "file3", "bias": "file4"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_return_other_params(self): + params_to_nest = ["weight"] + result, other_params = get_nested_weight_mappings( + "dummy_path", params_to_nest, return_unmatched_params=True + ) + expected_nested = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + expected_other = { + "layer1.bias": "file2", + "layer2.bias": "file4", + } + assert result == expected_nested + assert other_params == expected_other diff --git a/tests/testing_utils.py b/tests/testing_utils.py index e446cad3..fe11c8a9 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# flake8: noqa import pytest @@ -52,3 +52,75 @@ def requires_accelerate(): not _is_accelerate_available, reason="requires accelerate", ) + + +def get_random_mat(M, K, dtype) -> "torch.Tensor": + """ + :param M: number of rows + :param K: number of columns + :param dtype: data type of the matrix + :return: random matrix of shape (M, K) with non-zero values + """ + import torch + from compressed_tensors.quantization import FP8_DTYPE + + rand_tensor_dtype = dtype + if dtype in [torch.int8, FP8_DTYPE]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() + mat = mat.masked_fill_(mat == 0, 1) + return mat.to(dtype) + + +def generate_pruned_semi_structured_mat(M, K, dtype) -> "torch.Tensor": + """ + :param M: number of rows + :param K: number of columns + :param dtype: data type of the matrix + :return: random matrix of shape (M, K) with 2:4 sparsity pattern + """ + import torch + from compressed_tensors.quantization import FP8_DTYPE + + mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() + rand_tensor_dtype = dtype + if dtype in [torch.int8, FP8_DTYPE]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype) + mat = mat.masked_fill_(mat == 0, 1) + if dtype == FP8_DTYPE: + # some float8_e4m3fn operations are not supported on CPU + mat = mat.cuda() + mask = mask.cuda() + mat = mat * mask + return mat.to(dtype) + + +def induce_sparsity(tensor, sparsity_ratio) -> "torch.Tensor": + """ + Makes a tensor sparse by zeroing out a given fraction + of its smallest absolute values. + + :param: weight_tensor (torch.Tensor): The input weight tensor. + :param: sparsity_ratio (float): Fraction of weights to be zeroed + (0 <= sparsity_ratio <= 1). + :returns: torch.Tensor: Sparse version of the input tensor. + """ + import torch + + if not (0 <= sparsity_ratio <= 1): + raise ValueError("Sparsity ratio must be between 0 and 1.") + + # Flatten the tensor and compute the threshold for sparsity + flattened = tensor.view(-1) + k = int(sparsity_ratio * flattened.numel()) + + if k > 0: + threshold = torch.topk(flattened.abs(), k, largest=False).values.max() + sparse_tensor = torch.where( + tensor.abs() > threshold, tensor, torch.zeros_like(tensor) + ) + else: + sparse_tensor = tensor + + return sparse_tensor