diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 68bd52ec..f9197bf3 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -18,7 +18,7 @@ import os import re from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union import compressed_tensors import torch @@ -38,6 +38,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.lifecycle import expand_sparse_target_names from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, @@ -268,9 +269,9 @@ def compress( compressed_state_dict = state_dict - quantized_modules_to_args: Dict[ - str, QuantizationArgs - ] = map_modules_to_quant_args(model) + quantized_modules_to_args: Dict[str, QuantizationArgs] = ( + map_modules_to_quant_args(model) + ) if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( @@ -282,8 +283,14 @@ def compress( ) if self.sparsity_compressor is not None: + sparse_compression_targets: Set[str] = expand_sparse_target_names( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) compressed_state_dict = self.sparsity_compressor.compress( - compressed_state_dict + compressed_state_dict, + compression_targets=sparse_compression_targets, ) # HACK: Override the dtype_byte_size function in transformers to diff --git a/src/compressed_tensors/compressors/sparse_compressors/__init__.py b/src/compressed_tensors/compressors/sparse_compressors/__init__.py index de4fd887..f1b59ad3 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/__init__.py +++ b/src/compressed_tensors/compressors/sparse_compressors/__init__.py @@ -15,4 +15,5 @@ from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 1b1a6825..c6b715de 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Set, Tuple from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -30,7 +30,8 @@ class BaseSparseCompressor(BaseCompressor): """ Base class representing a sparse compression algorithm. Each child class should - implement compression_param_info, compress_weight and decompress_weight. + implement compression_param_info, compress_weight and decompress_weight; child + classes should also define COMPRESSION_PARAM_NAMES. Compressors support compressing/decompressing a full module state dict or a single quantized PyTorch leaf module. @@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor): :param config: config specifying compression parameters """ - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress( + self, + model_state: Dict[str, Tensor], + compression_targets: Optional[Set[str]] = None, + ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression :param model_state: state dict of uncompressed model + :param compression_targets: optional set of layer prefixes to compress, if None + compress all layers (for backwards compatibility) :return: compressed state dict """ compressed_dict = {} @@ -71,6 +78,9 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): + if not self.should_compress(name, compression_targets): + compressed_dict[name] = value + continue compression_data = self.compress_weight(name, value) for key in compression_data.keys(): if key in compressed_dict: @@ -97,8 +107,10 @@ def decompress( :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ - weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + weight_mappings, uncompressed_params = get_nested_weight_mappings( + path_to_model_or_tensors, + self.COMPRESSION_PARAM_NAMES, + return_unmatched_params=True, ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -108,3 +120,24 @@ def decompress( weight_data[param_name] = f.get_tensor(full_name) decompressed = self.decompress_weight(weight_data) yield weight_name, decompressed + + for uncompressed_param_name, safe_path in uncompressed_params.items(): + with safe_open(safe_path, framework="pt", device=device) as f: + value = f.get_tensor(uncompressed_param_name) + yield uncompressed_param_name, value + + @staticmethod + def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: + """ + Check if a parameter should be compressed + + :param name: name of the parameter + :param expanded_targets: set of layer prefixes to compress + :return: whether or not the parameter should be compressed + """ + if expanded_targets is None: + return name.endswith(".weight") + + return ( + name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets + ) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py new file mode 100644 index 00000000..040daa50 --- /dev/null +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict + +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor +from compressed_tensors.config import CompressionFormat, SparsityStructure +from compressed_tensors.utils import ( + merge_names, + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, + tensor_follows_mask_structure, +) +from torch import Tensor + + +@BaseCompressor.register(name=CompressionFormat.sparse_24.value) +class Sparse24Compressor(BaseSparseCompressor): + """ + Compresses a model with 2:4 sparsity structure for inference + with sparse 2:4 kernels for float/float16/bfloat16. + https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py + """ + + COMPRESSION_PARAM_NAMES = ["sparse_24_packed_weight", "meta"] + + @staticmethod + def validate_sparsity_structure(name: str, weight: Tensor) -> bool: + """ + Checks if a tensor fits the required 2:4 sparsity structure + :param name: name of the tensor to check + :param weight: tensor to check for sparsity structure + :return: True if all rows match the 2:4 sparsity structure, raises + ValueError otherwise + """ + + if not tensor_follows_mask_structure( + weight, mask=SparsityStructure.TWO_FOUR.value + ): + raise ValueError( + "Sparse24Compressor is only compatible with weights that have " + f"a 2:4 sparsity structure. Found segments in {name} " + "that do not match the expected structure." + ) + + return True + + def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: + """ + Compresses a given with 2:4 sparsity structure. + :param name: name of the tensor in state dict of uncompressed model + :param value: 2:4 sparse tensor to compress + :return: dictionary containing the compressed weight and associated + metadata + """ + weight_suffix = ".weight" + if not name.endswith(weight_suffix): + return {} + + prefix = name[: -len(weight_suffix)] + self.validate_sparsity_structure(name=prefix, weight=value) + sparse_24_packed_weight, meta = sparse_semi_structured_from_dense_cutlass( + dense=value + ) + return { + merge_names( + prefix, "sparse_24_packed_weight" + ): sparse_24_packed_weight.cpu(), + merge_names(prefix, "meta"): meta.cpu(), + } + + def decompress_weight(self, weight_data): + """ + Decompresses the given weight data from its compressed representation to its + dense form. + + The weight_data dictionary must contain the keys 'sparse_24_packed_weight' and + 'meta', which represent the sparse-compressed weight and its associated meta + tensor. + + :param weight_data: A dictionary containing: + - sparse_24_packed_weight: The sparse-compressed representation of + the weight. + - meta: The meta tesnor associated with the compressed weight. + :return: The dense representation of the weight. + """ + assert ( + "sparse_24_packed_weight" in weight_data + ), "sparse_24_packed_weight not found in weight_data" + assert "meta" in weight_data, "meta not found in weight_data" + + return sparse_semi_structured_to_dense_cutlass( + sparse=weight_data["sparse_24_packed_weight"], + meta_reordered=weight_data["meta"], + ) diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index ff83f5af..f021f284 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,4 +15,5 @@ # flake8: noqa from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 79a4fcdd..2d280330 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -26,6 +26,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + sparse_24 = "sparse-24" int_quantized = "int-quantized" float_quantized = "float-quantized" naive_quantized = "naive-quantized" diff --git a/src/compressed_tensors/config/sparse_24.py b/src/compressed_tensors/config/sparse_24.py new file mode 100644 index 00000000..2a5ed384 --- /dev/null +++ b/src/compressed_tensors/config/sparse_24.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) + + +__all__ = ["Sparse24Config"] + + +@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value) +class Sparse24Config(SparsityCompressionConfig): + """ + Configuration for storing a sparse model using 2:4 compression + :param global_sparsity: average sparsity of the entire model + :param sparsity_structure: structure of the sparsity, "2:4" + """ + + format: str = CompressionFormat.sparse_24.value + global_sparsity: Optional[float] = 0.0 + sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ed9a50f7..c297f8c7 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -18,7 +18,7 @@ from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Set, Union import torch from compressed_tensors.config import CompressionFormat @@ -52,6 +52,8 @@ "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", + "expand_sparse_target_names", + "is_target", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) +def expand_sparse_target_names( + model: Module, targets: Iterable[str], ignore: Iterable[str] +) -> Set[str]: + """ + Finds all unique module names in the model that match the given + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param model: model to search for targets in + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: set of all targets that match the given targets and should + not be ignored + """ + return { + name + for name, module in iter_named_leaf_modules(model) + if is_target(name, module, targets, ignore) + } + + +def is_target( + name: str, module: Module, targets: Iterable[str], ignore: Iterable[str] +) -> bool: + """ + Determines if a module should be included in the targets based on the + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param name: name of the module + :param module: the module itself + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: True if the module is a target and not ignored, False otherwise + """ + return bool( + find_name_or_class_matches(name, module, targets) + and not find_name_or_class_matches(name, module, ignore) + ) + + def find_name_or_class_matches( name: str, module: Module, targets: Iterable[str], check_contains: bool = False ) -> List[str]: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 3a8152da..36b88604 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -62,6 +62,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: return model + """ Pre-Set Quantization Scheme Args """ diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 4fdb3007..bc58f351 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -16,7 +16,7 @@ import os import re import struct -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open from torch import Tensor @@ -34,6 +34,9 @@ "is_quantization_param", ] +WeightMappingType = Dict[str, str] +NestedWeightMappingType = Dict[str, WeightMappingType] + def get_safetensors_folder( pretrained_model_name_or_path: str, cache_dir: Optional[str] = None @@ -92,7 +95,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]: return header -def match_param_name(full_name: str, param_name: str) -> str: +def match_param_name(full_name: str, param_name: str) -> Optional[str]: """ Helper function extracting the uncompressed parameterized layer name from a compressed name. Assumes the compressed name was merged using merge_names. @@ -176,13 +179,14 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: def get_nested_weight_mappings( - model_path: str, params_to_nest: List[str] -) -> Dict[str, Dict[str, str]]: + model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False +) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]: """ Takes a path to a state dict saved in safetensors format and returns a nested - mapping from uncompressed parameterized layer names to the file locations of each - of the layers compression parameters. + mapping from uncompressed parameterized layer names to the file locations of + each layer's compression parameters. + Example of the nested mapping: layer.weight: { bitmask: file_location, row_offsets: file_location, @@ -190,24 +194,54 @@ def get_nested_weight_mappings( compressed: file_location } - This generalizes to cases where the model is split into multiple safetensors files + If other parameters are found that do not match the nested parameters, they will + be returned in a separate dictionary only if return_unmatched_params is True. + This dictionary may be needed for cases where compressors are stacked (e.g., + quantization compression followed by sparse compression). - :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index - :return: nested mapping of parameterized layer name to file location + Example of the unmatched params mapping: + { + layer.weight_scale: file_location, + layer.input_scale: file_location + } + + This generalizes to cases where the model is split into multiple safetensors + files. + + :param model_path: Path to the safetensors state dict, must contain either a + single safetensors file or multiple files with an index. + :param params_to_nest: List of parameter names to nest. + :param return_unmatched_params: If True, return a second dictionary containing + the remaining parameters that were not matched to the params_to_nest. + :return: + - If return_unmatched_params is False: + NestedWeightMappingType: A nested mapping of parameterized layer names to + file locations of each layer's compression parameters. + - If return_unmatched_params is True: + Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing: + - NestedWeightMappingType: A nested mapping of parameterized layer + names to file locations of each layer's compression parameters. + - WeightMappingType: A mapping of the remaining parameter names to + their file locations that were not matched to the params_to_nest. """ weight_mappings = get_weight_mappings(model_path) - nested_weight_mappings = {} - for key in weight_mappings.keys(): + unmatched_params = {} + + for key, file_location in weight_mappings.items(): + matched = False for param_name in params_to_nest: - maybe_match = match_param_name(key, param_name) - if maybe_match is not None: - dense_param = maybe_match + dense_param = match_param_name(key, param_name) + if dense_param: if dense_param not in nested_weight_mappings: nested_weight_mappings[dense_param] = {} - nested_weight_mappings[dense_param][param_name] = weight_mappings[key] + nested_weight_mappings[dense_param][param_name] = file_location + matched = True + if return_unmatched_params and not matched: + unmatched_params[key] = file_location + if return_unmatched_params: + return nested_weight_mappings, unmatched_params return nested_weight_mappings diff --git a/src/compressed_tensors/utils/semi_structured_conversions.py b/src/compressed_tensors/utils/semi_structured_conversions.py index ef318a48..17ea6ef2 100644 --- a/src/compressed_tensors/utils/semi_structured_conversions.py +++ b/src/compressed_tensors/utils/semi_structured_conversions.py @@ -20,7 +20,7 @@ # limitations under the License. import torch - +from compressed_tensors.quantization import FP8_DTYPE __all__ = [ "sparse_semi_structured_from_dense_cutlass", @@ -75,6 +75,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device # This function converts dense matrix into sparse semi-structured # representation, producing "compressed" matrix, in the layout used by # CUTLASS backend, and corresponding metadata matrix. +# Modified from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L47 def sparse_semi_structured_from_dense_cutlass(dense): if dense.dim() != 2: raise RuntimeError( @@ -84,8 +85,8 @@ def sparse_semi_structured_from_dense_cutlass(dense): m, k = dense.shape device = dense.device - meta_dtype = torch.int8 - if dense.dtype == torch.int8: + meta_dtype = None + if dense.dtype == torch.int8 or dense.dtype == FP8_DTYPE: meta_dtype = torch.int32 elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: meta_dtype = torch.int16 @@ -165,11 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: + if dense.dtype == FP8_DTYPE: + dense_4 = dense_4.view(torch.int8) sparse0 = dense_4.gather( -1, idxs0.unsqueeze(-1) ) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + if dense.dtype == FP8_DTYPE: + sparse = sparse.view(FP8_DTYPE) else: sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( m, k // 2 @@ -213,6 +218,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): # reconstructs dense matrix from a pair of "compressed" matrix, given # in the layout used by CUTLASS backend, and accompanying metadata # matrix. +# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L180 def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if sparse.dim() != 2: raise RuntimeError( @@ -298,16 +304,21 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): torch.arange(0, 2 * m * k // ksparse, device=device) * 4 ).view(-1, 1).repeat(1, 2).view(-1) - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + sparse_dtype = sparse.dtype if sparse.dtype != FP8_DTYPE else torch.int8 + dense = torch.zeros((m * 2 * k,), dtype=sparse_dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + if sparse.dtype == FP8_DTYPE: + dense.scatter_(0, dense_offsets, sparse.view(torch.int8).view(-1)) + else: + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: dense.view(torch.half).scatter_( 0, dense_offsets, sparse.view(torch.half).view(-1) ) - return dense.view(m, 2 * k) + result = dense.view(m, 2 * k) + return result.view(sparse.dtype) def mask_creator(tensor): diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 7268ca27..e2dc2ad1 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -14,6 +14,7 @@ import re from typing import Optional +from unittest.mock import MagicMock import pytest import torch @@ -26,12 +27,38 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, + expand_sparse_target_names, + is_target, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM +@pytest.fixture +def mock_model(): + model = MagicMock() + model.named_modules.return_value = [ + ("layer1", MagicMock()), + ("layer2", MagicMock()), + ("layer3", MagicMock()), + ] + return model + + +@pytest.fixture +def mock_module(): + return MagicMock() + + +@pytest.fixture +def llama_stories_model(): + return AutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", + torch_dtype="auto", + ) + + def test_target_prioritization(mock_frozen): # tests that the config_groups are applied in the correct order # of priority, where exact layer name > regex > module name @@ -266,3 +293,73 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 + + +@pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ([], [], set()), + (["layer1", "layer2"], [], {"layer1", "layer2"}), + ([], ["layer1"], set()), + (["layer1", "layer2"], ["layer2"], {"layer1"}), + (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), + ], +) +def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): + expanded_targets = expand_sparse_target_names(mock_model, targets, ignore) + assert expanded_targets == expected_targets + + +@pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ( + ["re:model.layers.[01].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[01].self_attn.q_proj"], + [], + set(["model.layers.0.self_attn.q_proj", "model.layers.1.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[0-2].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj", "model.layers.2.self_attn.q_proj"]), + ), + ( + ["model.layers.0.self_attn.q_proj"], + ["model.layers.0.self_attn.q_proj"], + set(), + ), + ( + ["re:model.layers.*.self_attn.q_proj"], + ["re:model.layers.[01].self_attn.q_proj"], + set( + f"model.layers.{layer_idx}.self_attn.q_proj" + for layer_idx in range(2, 6) + ), + ), + ], +) +def test_expand_targets_with_llama_stories( + llama_stories_model, targets, ignore, expected_targets +): + expanded_targets = expand_sparse_target_names(llama_stories_model, targets, ignore) + assert expanded_targets == expected_targets + + +@pytest.mark.parametrize( + "name, targets, ignore, expected", + [ + ("layer1", ["layer1"], [], True), + ("layer1", ["layer1"], ["layer1"], False), + ("layer1", ["layer2"], [], False), + ("layer1", ["re:layer.*"], [], True), + ("layer1", ["re:layer.*"], ["re:layer1"], False), + ], +) +def test_is_target_with_mock(mock_module, name, targets, ignore, expected): + result = is_target(name, mock_module, targets, ignore) + assert result == expected diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py new file mode 100644 index 00000000..95cd6347 --- /dev/null +++ b/tests/test_utils/test_safetensors_load.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +from compressed_tensors.utils.safetensors_load import get_nested_weight_mappings + + +mock_weight_mappings = { + "layer1.weight": "file1", + "layer1.bias": "file2", + "layer2.weight": "file3", + "layer2.bias": "file4", + "layer3.weight": "file5", +} + + +@pytest.fixture +def mock_get_weight_mappings(): + with patch( + "compressed_tensors.utils.safetensors_load.get_weight_mappings", + return_value=mock_weight_mappings, + ): + yield + + +@pytest.mark.usefixtures("mock_get_weight_mappings") +class TestGetNestedWeightMappings: + def test_single_param(self): + params_to_nest = ["weight"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_multiple_params(self): + params_to_nest = ["weight", "bias"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1", "bias": "file2"}, + "layer2": {"weight": "file3", "bias": "file4"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_return_other_params(self): + params_to_nest = ["weight"] + result, other_params = get_nested_weight_mappings( + "dummy_path", params_to_nest, return_unmatched_params=True + ) + expected_nested = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + expected_other = { + "layer1.bias": "file2", + "layer2.bias": "file4", + } + assert result == expected_nested + assert other_params == expected_other diff --git a/tests/test_utils/test_semi_structured_conversions.py b/tests/test_utils/test_semi_structured_conversions.py new file mode 100644 index 00000000..c2c198c6 --- /dev/null +++ b/tests/test_utils/test_semi_structured_conversions.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils.semi_structured_conversions import ( + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, +) + + +def supported_dtypes(): + dtypes = [torch.int8, torch.float16, torch.bfloat16] + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + if major > 9 or (major == 9 and minor >= 0): + dtypes += [FP8_DTYPE] + return dtypes + + +def get_random_mat(M, K, dtype): + rand_tensor_dtype = dtype + if dtype in [torch.int8, FP8_DTYPE]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() + mat = mat.masked_fill_(mat == 0, 1) + return mat.to(dtype) + + +def generate_pruned_semi_structured_mat(M, K, dtype): + mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() + rand_tensor_dtype = dtype + if dtype in [torch.int8, FP8_DTYPE]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype) + mat = mat.masked_fill_(mat == 0, 1) + if dtype == FP8_DTYPE: + # some float8_e4m3fn operations are not supported on CPU + mat = mat.cuda() + mask = mask.cuda() + mat = mat * mask + return mat.to(dtype) + + +@pytest.mark.parametrize("dtype", supported_dtypes()) +def test_inverse_property_from_dense_then_to_dense(dtype): + M, K = 1024, 1024 + dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype) + compressed_matrix, meta = sparse_semi_structured_from_dense_cutlass(dense_matrix) + result = sparse_semi_structured_to_dense_cutlass(compressed_matrix, meta) + + assert ( + dense_matrix.dtype == result.dtype + ), f"Dtype Mis-match: {dense_matrix.dtype} and {result.dtype}" + assert ( + dense_matrix.shape == result.shape + ), f"Shape Mis-match: {dense_matrix.shape} and {result.shape}" + assert torch.equal( + dense_matrix, result + ), f"Failed for dtype: {dense_matrix.dtype} and input: {dense_matrix}"