diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4f201782..2d96ef50 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -76,13 +76,9 @@ jobs: with: ref: ${{ inputs.gitref }} - - name: install testmo - uses: neuralmagic/nm-actions/actions/install-testmo@v1.0.0 - - name: create testmo run id: create_testmo_run - uses: neuralmagic/nm-actions/actions/testmo-run-create@v1.2.0 - if: success() + uses: neuralmagic/nm-actions/actions/testmo-run-create@v1.11.0 with: testmo_url: https://neuralmagic.testmo.net testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }} @@ -142,8 +138,8 @@ jobs: - name: report build status to testmo id: report_build - uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.2.0 - if: (success() || failure()) && ${{ inputs.testmo_run_id != '' }} + uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.11.0 + if: success() || failure() with: testmo_url: https://neuralmagic.testmo.net testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 07654cbc..ef7c3ca2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -90,9 +90,6 @@ jobs: with: venv: TEST - - name: install testmo - uses: neuralmagic/nm-actions/actions/install-testmo@v1.0.0 - - name: download whl id: download uses: actions/download-artifact@v4 @@ -108,8 +105,8 @@ jobs: - name: report test results id: report_test - uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.2.0 - if: (success() || failure()) && ${{ inputs.testmo_run_id != '' }} + uses: neuralmagic/nm-actions/actions/testmo-run-submit-thread@v1.11.0 + if: (success() || failure()) && inputs.testmo_run_id != '' with: testmo_url: https://neuralmagic.testmo.net testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }} diff --git a/.github/workflows/upload.yml b/.github/workflows/upload.yml index 2f1ba4f3..4505442a 100644 --- a/.github/workflows/upload.yml +++ b/.github/workflows/upload.yml @@ -69,12 +69,9 @@ jobs: with: python-version: 3.10.12 - - name: install testmo - uses: neuralmagic/nm-actions/actions/install-testmo@v1.0.0 - - name: complete testmo run - uses: neuralmagic/nm-actions/actions/testmo-run-complete@v1.2.0 - if: (success() || failure()) && ${{ inputs.testmo_run_id != '' }} + uses: neuralmagic/nm-actions/actions/testmo-run-complete@v1.11.0 + if: (success() || failure()) && inputs.testmo_run_id != '' with: testmo_url: https://neuralmagic.testmo.net testmo_token: ${{ secrets.TESTMO_TEST_TOKEN }} diff --git a/setup.py b/setup.py index eed23404..5206d75b 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,7 +15,33 @@ import os from setuptools import setup, find_packages from typing import List, Dict, Tuple -from utils.artifacts import get_release_and_version + + +def get_release_and_version(package_path: str) -> Tuple[bool, bool, str, str, str, str]: + """ + Load version and release info from compressed-tensors package + """ + # compressed-tensors/src/compressed_tensors/version.py always exists, default source of truth + version_path = os.path.join(package_path, "version.py") + + # exec() cannot set local variables so need to manually + locals_dict = {} + exec(open(version_path).read(), globals(), locals_dict) + is_release = locals_dict.get("is_release", False) + version = locals_dict.get("version", "unknown") + version_major = locals_dict.get("version_major", "unknown") + version_minor = locals_dict.get("version_minor", "unknown") + version_bug = locals_dict.get("version_bug", "unknown") + + print(f"Loaded version {version} from {version_path}") + + return ( + is_release, + version, + version_major, + version_minor, + version_bug, + ) package_path = os.path.join( @@ -35,7 +61,7 @@ _PACKAGE_NAME = "compressed-tensors" else: _PACKAGE_NAME = "compressed-tensors-nightly" - + def _setup_long_description() -> Tuple[str, str]: return open("README.md", "r", encoding="utf-8").read(), "text/markdown" @@ -44,7 +70,7 @@ def _setup_packages() -> List: return find_packages( "src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"] ) - + def _setup_install_requires() -> List: return ["torch>=1.7.0", "transformers", "pydantic>=2.0"] diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 68bd52ec..951eef1f 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -17,8 +17,9 @@ import operator import os import re +from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union import compressed_tensors import torch @@ -38,6 +39,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.lifecycle import expand_sparse_target_names from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, @@ -104,7 +106,6 @@ def from_pretrained( """ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) - return cls.from_compression_config(compression_config) @classmethod @@ -137,7 +138,7 @@ def from_compression_config( format, **sparsity_config ) if quantization_config is not None: - quantization_config = QuantizationConfig.parse_obj(quantization_config) + quantization_config = QuantizationConfig.model_validate(quantization_config) return cls( sparsity_config=sparsity_config, quantization_config=quantization_config @@ -193,7 +194,7 @@ def parse_sparsity_config( if is_compressed_tensors_config(compression_config): s_config = compression_config.sparsity_config - return s_config.dict() if s_config is not None else None + return s_config.model_dump() if s_config is not None else None return compression_config.get(SPARSITY_CONFIG_NAME, None) @@ -214,7 +215,7 @@ def parse_quantization_config( if is_compressed_tensors_config(compression_config): q_config = compression_config.quantization_config - return q_config.dict() if q_config is not None else None + return q_config.model_dump() if q_config is not None else None quantization_config = deepcopy(compression_config) quantization_config.pop(SPARSITY_CONFIG_NAME, None) @@ -282,8 +283,14 @@ def compress( ) if self.sparsity_compressor is not None: + sparse_compression_targets: Set[str] = expand_sparse_target_names( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) compressed_state_dict = self.sparsity_compressor.compress( - compressed_state_dict + compressed_state_dict, + compression_targets=sparse_compression_targets, ) # HACK: Override the dtype_byte_size function in transformers to @@ -301,23 +308,44 @@ def decompress(self, model_path: str, model: Module): :param model: pytorch model to load decompressed weights into """ model_path = get_safetensors_folder(model_path) - if self.sparsity_compressor is not None: + sparse_decompressed = False + + if ( + self.sparsity_compressor is not None + and self.sparsity_config.format != CompressionFormat.dense.value + ): + # Sparse decompression is applied on the model_path dense_gen = self.sparsity_compressor.decompress(model_path) self._replace_weights(dense_gen, model) setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) + sparse_decompressed = True if self.quantization_compressor is not None: - names_to_scheme = apply_quantization_config(model, self.quantization_config) - load_pretrained_quantization(model, model_path) + # Temporarily set quantization status to FROZEN to prevent + # quantization during apply_quantization_config. This ensures + # that the dtypes of the weights are not unintentionally updated. + # The status is restored after quantization params are loaded. + with override_quantization_status( + self.quantization_config, QuantizationStatus.FROZEN + ): + names_to_scheme = apply_quantization_config( + model, self.quantization_config + ) + load_pretrained_quantization(model, model_path) + + model_path_or_state_dict = ( + model.state_dict() if sparse_decompressed else model_path + ) + dense_gen = self.quantization_compressor.decompress( - model_path, names_to_scheme=names_to_scheme + model_path_or_state_dict, names_to_scheme=names_to_scheme ) self._replace_weights(dense_gen, model) - def update_status(module): + def freeze_quantization_status(module): module.quantization_status = QuantizationStatus.FROZEN - model.apply(update_status) + model.apply(freeze_quantization_status) setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config) def update_config(self, save_directory: str): @@ -367,12 +395,26 @@ def update_config(self, save_directory: str): with open(config_file_path, "w") as config_file: json.dump(config_data, config_file, indent=2, sort_keys=True) - def _replace_weights(self, dense_weight_generator, model): + def _replace_weights(self, dense_weight_generator, model: Module): + """ + Replace the weights of the model with the + provided dense weights. + + This method iterates over the dense_weight_generator and + updates the corresponding weights in the model. If a parameter + name does not exist in the model, it will be skipped. + + :param dense_weight_generator (generator): A generator that yields + tuples of (name, data), where 'name' is the parameter name and + 'data' is the updated param data + :param model: The model whose weights are to be updated. + """ for name, data in tqdm(dense_weight_generator, desc="Decompressing model"): split_name = name.split(".") prefix, param_name = ".".join(split_name[:-1]), split_name[-1] module = operator.attrgetter(prefix)(model) - update_parameter_data(module, data, param_name) + if hasattr(module, param_name): + update_parameter_data(module, data, param_name) def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]: @@ -402,3 +444,23 @@ def new_dtype_byte_size(dtype): raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) return bit_size // 8 + + +@contextmanager +def override_quantization_status( + config: QuantizationConfig, status: QuantizationStatus +): + """ + Within this context, the quantization status will be set to the + supplied status. After the context exits, the original status + will be restored. + + :param config: the quantization config to override + :param status: the status to temporarily set + """ + original_status = config.quantization_status + config.quantization_status = status + try: + yield + finally: + config.quantization_status = original_status diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 8ceb8afd..b49361d4 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -13,12 +13,17 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from pathlib import Path +from typing import Any, Dict, Generator, Tuple, Union import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.quantization import QuantizationArgs -from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from compressed_tensors.utils import ( + get_nested_mappings_from_state_dict, + get_nested_weight_mappings, + merge_names, +) from safetensors import safe_open from torch import Tensor from tqdm import tqdm @@ -113,7 +118,7 @@ def compress( def decompress( self, - path_to_model_or_tensors: str, + path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], names_to_scheme: Dict[str, QuantizationArgs], device: str = "cpu", ) -> Generator[Tuple[str, Tensor], None, None]: @@ -121,15 +126,25 @@ def decompress( Reads a compressed state dict located at path_to_model_or_tensors and returns a generator for sequentially decompressing back to a dense state dict - :param path_to_model_or_tensors: path to compressed safetensors model (directory with one or more safetensors files) or compressed tensors file :param names_to_scheme: quantization args for each quantized weight :param device: optional device to load intermediate weights into :return: compressed state dict """ + if isinstance(path_to_model_or_tensors, (str, Path)): + yield from self._decompress_from_path( + path_to_model_or_tensors, names_to_scheme, device + ) + + else: + yield from self._decompress_from_state_dict( + path_to_model_or_tensors, names_to_scheme + ) + + def _decompress_from_path(self, path_to_model, names_to_scheme, device): weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + path_to_model, self.COMPRESSION_PARAM_NAMES ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -137,6 +152,21 @@ def decompress( full_name = merge_names(weight_name, param_name) with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) + if "weight_scale" in weight_data: + quant_args = names_to_scheme[weight_name] + decompressed = self.decompress_weight( + compressed_data=weight_data, quantization_args=quant_args + ) + yield merge_names(weight_name, "weight"), decompressed + + def _decompress_from_state_dict(self, state_dict, names_to_scheme): + weight_mappings = get_nested_mappings_from_state_dict( + state_dict, self.COMPRESSION_PARAM_NAMES + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, param_value in weight_mappings[weight_name].items(): + weight_data[param_name] = param_value if "weight_scale" in weight_data: quant_args = names_to_scheme[weight_name] diff --git a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py index 85eebe00..69d9d596 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -68,9 +68,9 @@ def compress_weight( self, weight: Tensor, scale: Tensor, + quantization_args: QuantizationArgs, zero_point: Optional[Tensor] = None, g_idx: Optional[torch.Tensor] = None, - quantization_args: Optional[QuantizationArgs] = None, device: Optional[torch.device] = None, ) -> Dict[str, torch.Tensor]: """ @@ -78,9 +78,9 @@ def compress_weight( :param weight: uncompressed weight tensor :param scale: quantization scale for weight + :param quantization_args: quantization parameters for weight :param zero_point: quantization zero point for weight :param g_idx: optional mapping from column index to group index - :param quantization_args: quantization parameters for weight :param device: optional device to move compressed output to :return: dictionary of compressed weight data """ diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index c236f8c9..629ef37e 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -68,9 +68,9 @@ def compress_weight( self, weight: Tensor, scale: Tensor, + quantization_args: QuantizationArgs, zero_point: Optional[Tensor] = None, g_idx: Optional[torch.Tensor] = None, - quantization_args: Optional[QuantizationArgs] = None, device: Optional[torch.device] = None, ) -> Dict[str, torch.Tensor]: """ @@ -78,9 +78,9 @@ def compress_weight( :param weight: uncompressed weight tensor :param scale: quantization scale for weight + :param quantization_args: quantization parameters for weight :param zero_point: quantization zero point for weight :param g_idx: optional mapping from column index to group index - :param quantization_args: quantization parameters for weight :param device: optional device to move compressed output to :return: dictionary of compressed weight data """ diff --git a/src/compressed_tensors/compressors/sparse_compressors/__init__.py b/src/compressed_tensors/compressors/sparse_compressors/__init__.py index de4fd887..871079ac 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/__init__.py +++ b/src/compressed_tensors/compressors/sparse_compressors/__init__.py @@ -15,4 +15,5 @@ from .base import * from .dense import * +from .sparse_24_bitmask import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 1b1a6825..7cd6e8e8 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Set, Tuple from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -30,7 +30,8 @@ class BaseSparseCompressor(BaseCompressor): """ Base class representing a sparse compression algorithm. Each child class should - implement compression_param_info, compress_weight and decompress_weight. + implement compression_param_info, compress_weight and decompress_weight; child + classes should also define COMPRESSION_PARAM_NAMES. Compressors support compressing/decompressing a full module state dict or a single quantized PyTorch leaf module. @@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor): :param config: config specifying compression parameters """ - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress( + self, + model_state: Dict[str, Tensor], + compression_targets: Optional[Set[str]] = None, + ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression :param model_state: state dict of uncompressed model + :param compression_targets: optional set of layer prefixes to compress, + otherwise compress all layers (for backwards compatibility) :return: compressed state dict """ compressed_dict = {} @@ -71,7 +78,14 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): - compression_data = self.compress_weight(name, value) + if not self.should_compress(name, compression_targets): + compressed_dict[name] = value + continue + prefix = name + if prefix.endswith(".weight"): + prefix = prefix[: -(len(".weight"))] + + compression_data = self.compress_weight(prefix, value) for key in compression_data.keys(): if key in compressed_dict: _LOGGER.warn( @@ -97,8 +111,10 @@ def decompress( :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ - weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + weight_mappings, ignored_params = get_nested_weight_mappings( + path_to_model_or_tensors, + self.COMPRESSION_PARAM_NAMES, + return_unmatched_params=True, ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -107,4 +123,26 @@ def decompress( with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) decompressed = self.decompress_weight(weight_data) - yield weight_name, decompressed + yield merge_names(weight_name, "weight"), decompressed + + for ignored_param_name, safe_path in ignored_params.items(): + with safe_open(safe_path, framework="pt", device=device) as f: + value = f.get_tensor(ignored_param_name) + yield ignored_param_name, value + + @staticmethod + def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: + """ + Check if a parameter should be compressed. + Currently, this only returns True for weight parameters. + + :param name: name of the parameter + :param expanded_targets: set of layer prefixes to compress + :return: whether or not the parameter should be compressed + """ + if expanded_targets is None: + return name.endswith(".weight") + + return ( + name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets + ) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py new file mode 100644 index 00000000..e51433c2 --- /dev/null +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -0,0 +1,238 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor +from compressed_tensors.config import CompressionFormat, SparsityStructure +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks +from torch import Tensor + + +__all__ = [ + "Sparse24BitMaskCompressor", + "Sparse24BitMaskTensor", + "sparse24_bitmask_compress", + "sparse24_bitmask_decompress", + "get_24_bytemasks", +] + + +@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value) +class Sparse24BitMaskCompressor(BaseSparseCompressor): + """ + Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d + values tensor, with their locations stored in a 2d bitmask + """ + + COMPRESSION_PARAM_NAMES = [ + "shape", + "compressed", + "bitmask", + ] + + def compress_weight(self, name, value): + bitmask_tensor = Sparse24BitMaskTensor.from_dense( + value, self.config.sparsity_structure + ) + bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") + return bitmask_dict + + def decompress_weight(self, weight_data): + data = Sparse24BitMaskTensor.from_compressed_data(**weight_data) + decompressed = data.decompress() + return decompressed + + +@dataclass +class Sparse24BitMaskTensor: + """ + Owns compressions and decompression for a single 2:4 sparse + bitmask compressed tensor. + + :param shape: shape of dense tensor + :param compressed: 2d tensor of non-zero values + :param bitmask: 2d bitmask of non-zero values + """ + + shape: List[int] + compressed: Tensor + bitmask: Tensor + + @staticmethod + def from_dense( + tensor: Tensor, + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, + ) -> "Sparse24BitMaskTensor": + """ + :param tensor: dense tensor to compress + :return: instantiated compressed tensor + """ + shape = list(tensor.shape) + compressed, bitmask = sparse24_bitmask_compress( + tensor.cpu(), sparsity_structure=sparsity_structure + ) + return Sparse24BitMaskTensor( + shape=shape, + compressed=compressed, + bitmask=bitmask, + ) + + @staticmethod + def from_compressed_data( + shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor + ) -> "Sparse24BitMaskTensor": + """ + :param shape: shape of the dense tensor (can be a list or a tensor) + :param compressed: 2d tensor of non-zero values + :param bitmask: 2d bitmask of non-zero values + :return: instantiated Sparse24BitMaskTensor + """ + if isinstance(shape, Tensor): + shape = shape.tolist() + return Sparse24BitMaskTensor( + shape=shape, compressed=compressed, bitmask=bitmask + ) + + def decompress(self) -> Tensor: + """ + :return: reconstructed dense tensor + """ + return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape) + + def curr_memory_size_bytes(self) -> int: + """ + :return: size in bytes required to store compressed tensor on disk + """ + + def sizeof_tensor(a: Tensor) -> int: + return a.element_size() * a.nelement() + + return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask) + + def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: + """ + :param name_prefix: name of original tensor to store compressed weight as + :return: dict of compressed data for the stored weight + """ + if name_prefix.endswith(".weight"): + name_prefix = name_prefix[: -len(".weight")] + return { + merge_names(name_prefix, "shape"): torch.tensor( + self.shape, device=device + ).reshape(-1, 1), + merge_names(name_prefix, "compressed"): self.compressed.to(device), + merge_names(name_prefix, "bitmask"): self.bitmask.to(device), + } + + def __repr__(self) -> str: + return f"BitMaskTensor(shape={self.shape}, compressed=True)" + + +def sparse24_bitmask_compress( + tensor: Tensor, + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compresses a dense tensor using bitmask compression + + :param tensor: dense 2D tensor to compress + :param sparsity_structure: structure of sparsity in the tensor, defaults + to unstructured, can also be set to `2:4` + :return: tuple of compressed data representing tensor + """ + assert len(tensor.shape) == 2, "Only 2D tensors are supported" + assert ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ), "Only 2:4 sparsity is supported" + + bytemasks = get_24_bytemasks(tensor=tensor) + + if tensor.dtype == FP8_DTYPE: + # acces raw bytes of the tensor + tensor_view = tensor.view(torch.int8) + values = tensor_view[bytemasks] + values = values.view(FP8_DTYPE) + else: + values = tensor[bytemasks] + + num_rows, num_cols = tensor.shape + compressed_values = values.reshape(num_rows, num_cols // 2) + bitmasks_packed = pack_bitmasks(bytemasks) + return compressed_values, bitmasks_packed + + +def sparse24_bitmask_decompress( + values: Tensor, bitmasks: Tensor, original_shape: torch.Size +) -> Tensor: + """ + Reconstructs a dense tensor from a compressed one + + :param values: 1d tensor of non-zero values + :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the + tensors original shape + :param original_shape: shape of the dense tensor + :return: decompressed dense tensor + """ + bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape) + + decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype) + decompressed_tensor = decompressed_tensor.to(values.device) + values = values.flatten() + if decompressed_tensor.dtype == FP8_DTYPE: + decompressed_tensor[bytemasks_unpacked] = values + decompressed_tensor = decompressed_tensor.cuda() + else: + decompressed_tensor[bytemasks_unpacked] = values + return decompressed_tensor + + +def get_24_bytemasks(tensor): + """ + Generate a 2:4 sparsity mask for the given tensor. + + This function creates a mask where exactly 2 out of every 4 elements are + preserved based on their magnitudes. The preserved elements are the ones + with the highest absolute values in each group of 4 elements. + + :param tensor: The input tensor for which the 2:4 sparsity mask is to be created. + The tensor can be of any shape but its total number of elements + must be a multiple of 4. + :return: A boolean tensor of the same shape as the input tensor, where `True` + indicates the preserved elements and `False` indicates the pruned elements. + :raises ValueError: If the total number of elements in the tensor is not a + multiple of 4. + """ + original_dtype = tensor.dtype + if tensor.dtype == FP8_DTYPE: + tensor = tensor.view(torch.int8) + original_shape = tensor.shape + num_elements = tensor.numel() + + if num_elements % 4 != 0: + raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity") + + reshaped_tensor = tensor.view(-1, 4) + abs_tensor = reshaped_tensor.abs() + topk_indices = abs_tensor.topk(2, dim=1).indices + mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool) + mask.scatter_(1, topk_indices, True) + mask = mask.view(original_shape) + tensor = tensor.view(original_dtype) + + return mask diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index a950aa64..7c2023cf 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -14,12 +14,12 @@ from typing import Dict, List, Tuple, Union -import numpy import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.utils import merge_names +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks from torch import Tensor @@ -28,8 +28,6 @@ "BitmaskTensor", "bitmask_compress", "bitmask_decompress", - "pack_bitmasks", - "unpack_bitmasks", ] @@ -134,9 +132,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: bytemasks = tensor != 0 row_counts = bytemasks.sum(dim=-1) row_offsets = torch.cumsum(row_counts, 0) - row_counts - values = tensor[bytemasks] + if tensor.dtype == FP8_DTYPE: + # acces raw bytes of the tensor + tensor_view = tensor.view(torch.int8) + values = tensor_view[bytemasks] + values = values.view(FP8_DTYPE) + else: + values = tensor[bytemasks] bitmasks_packed = pack_bitmasks(bytemasks) - return values, bitmasks_packed, row_offsets @@ -158,37 +161,3 @@ def bitmask_decompress( decompressed_tensor[bytemasks_unpacked] = values return decompressed_tensor - - -def pack_bitmasks(bytemasks: Tensor) -> Tensor: - """ - Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be - compressed to R x ceil(C/8) - :param bytemasks: mask tensor where each byte corresponds to a weight - :return: mask tensor where each bit corresounds to a weight - """ - packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little") - packed_bits_torch = torch.from_numpy(packed_bits_numpy) - - return packed_bits_torch - - -def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor: - """ - Converts a bitmask tensor back to a bytemask tensor for use during decompression - - :param packed_bitmasks: mask tensor where each bit corresponds to a weight - :param original_shape: dense shape to decompress to - :return: boolean mask of weights in the original dense shape - """ - # Unpack the bits - unpacked_bits = numpy.unpackbits( - packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little" - ) - - # Reshape to match the original shape - unpacked_bitmasks_torch = torch.from_numpy( - unpacked_bits.reshape(original_shape).astype(bool) - ) - - return unpacked_bitmasks_torch diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index ff83f5af..582b8a9e 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,4 +15,5 @@ # flake8: noqa from .base import * from .dense import * +from .sparse_24_bitmask import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 79a4fcdd..9ca6f2cf 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -26,6 +26,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + sparse_24_bitmask = "sparse-24-bitmask" int_quantized = "int-quantized" float_quantized = "float-quantized" naive_quantized = "naive-quantized" diff --git a/src/compressed_tensors/config/sparse_24_bitmask.py b/src/compressed_tensors/config/sparse_24_bitmask.py new file mode 100644 index 00000000..7aae2dbe --- /dev/null +++ b/src/compressed_tensors/config/sparse_24_bitmask.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) + + +__all__ = ["Sparse24BitMaskConfig"] + + +@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value) +class Sparse24BitMaskConfig(SparsityCompressionConfig): + """ + Configuration for storing a 24 sparse model using + bytemask compression + + :param global_sparsity: average sparsity of the entire model + :param sparsity_structure: structure of the sparsity, should always be + "2:4" for this compression format + """ + + format: str = CompressionFormat.sparse_24_bitmask.value + global_sparsity: Optional[float] = 0.0 + sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ed9a50f7..31f14df0 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -18,7 +18,7 @@ from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Set, Union import torch from compressed_tensors.config import CompressionFormat @@ -52,6 +52,8 @@ "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", + "expand_sparse_target_names", + "is_sparse_target", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) +def expand_sparse_target_names( + model: Module, targets: Iterable[str], ignore: Iterable[str] +) -> Set[str]: + """ + Finds all unique module names in the model that match the given + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param model: model to search for targets in + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: set of all targets that match the given targets and should + not be ignored + """ + return { + name + for name, module in iter_named_leaf_modules(model) + if is_sparse_target(name, module, targets, ignore) + } + + +def is_sparse_target( + name: str, module: Module, targets: Iterable[str], ignore: Iterable[str] +) -> bool: + """ + Determines if a module should be included in the targets based on the + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param name: name of the module + :param module: the module itself + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: True if the module is a target and not ignored, False otherwise + """ + return bool( + find_name_or_class_matches(name, module, targets) + and not find_name_or_class_matches(name, module, ignore or []) + ) + + def find_name_or_class_matches( name: str, module: Module, targets: Iterable[str], check_contains: bool = False ) -> List[str]: diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index dcab122a..f4f93f27 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -82,8 +82,8 @@ def quantize( def dequantize( x_q: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor = None, - args: QuantizationArgs = None, + zero_point: Optional[torch.Tensor] = None, + args: Optional[QuantizationArgs] = None, dtype: Optional[torch.dtype] = None, g_idx: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index eb4d6b18..8dd8fc51 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme -from compressed_tensors.utils import get_execution_device, is_module_offloaded +from compressed_tensors.utils import ( + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, +) from torch.nn import Module, Parameter @@ -112,43 +116,10 @@ def initialize_module_for_quantization( module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED - offloaded = False - # What is this doing/why isn't this in the attn case? - if is_module_offloaded(module): - try: - from accelerate.hooks import add_hook_to_module, remove_hook_from_module - from accelerate.utils import PrefixedDataset - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Offloaded model detected. To use CPU offloading with " - "compressed-tensors the `accelerate` package must be installed, " - "run `pip install compressed-tensors[accelerate]`" - ) - - offloaded = True - hook = module._hf_hook - prefix_dict = module._hf_hook.weights_map - new_prefix = {} - - # recreate the prefix dict (since it is immutable) - # and add quantization parameters - for key, data in module.named_parameters(): - if key not in prefix_dict: - new_prefix[f"{prefix_dict.prefix}{key}"] = data - else: - new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] - new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) - remove_hook_from_module(module) - - # wrap forward call of module to perform - # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) - - if offloaded: - # we need to re-add the hook for offloading now that we've wrapped forward - add_hook_to_module(module, hook) - if prefix_dict is not None: - module._hf_hook.weights_map = new_prefix_dict + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): @@ -169,12 +140,17 @@ def _initialize_scale_zero_point( if quantization_args.dynamic: return - device = next(module.parameters()).device - if is_module_offloaded(module): - device = get_execution_device(module) + # begin on the same device as other parameters or cpu if offloaded. + # in the offloaded case, there's no point moving tensors to the execution device + # if they're going to be immediately offloaded by `register_offload_parameter` + params_device = next(module.parameters()).device + device = "cpu" if has_offloaded_params(module) else params_device # infer expected scale/zero point shape - expected_shape = 1 # per tensor + if quantization_args.strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + else: + expected_shape = 1 if base_name == "weight" and weight_shape is not None: if quantization_args.strategy == QuantizationStrategy.CHANNEL: @@ -193,7 +169,7 @@ def _initialize_scale_zero_point( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) - module.register_parameter(f"{base_name}_scale", init_scale) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: zp_dtype = quantization_args.pytorch_dtype() @@ -201,7 +177,7 @@ def _initialize_scale_zero_point( torch.zeros(expected_shape, device=device, dtype=zp_dtype), requires_grad=False, ) - module.register_parameter(f"{base_name}_zero_point", init_zero_point) + register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) # only grouped activation ordering has g_idx if quantization_args.actorder == ActivationOrdering.GROUP: @@ -211,7 +187,7 @@ def _initialize_scale_zero_point( torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), requires_grad=False, ) - module.register_parameter(f"{base_name}_g_idx", init_g_idx) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) def _initialize_attn_scales(module: Module) -> None: diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 1d95aee8..3a80f0cb 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -160,7 +160,7 @@ def model_post_init(self, __context): def to_dict(self): # for compatibility with HFQuantizer - return self.dict() + return self.model_dump() @staticmethod def from_pretrained( diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 9ec522e0..48cd3bd7 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Optional +import warnings +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +import numpy import torch from transformers import AutoConfig @@ -28,7 +31,13 @@ "tensor_follows_mask_structure", "replace_module", "is_compressed_tensors_config", + "getattr_chain", + "deprecated", "Aliasable", + "combine_shards", + "shard_tensor", + "pack_bitmasks", + "unpack_bitmasks", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -126,6 +135,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool: return False +def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: + """ + Chain multiple getattr calls, separated by `.` + + :param obj: base object whose attributes are being retrieved + :param chain_str: attribute names separated by `.` + :param default: default value, throw error otherwise + """ + if len(args) >= 1: + has_default = True + default = args[0] + elif "default" in kwargs: + has_default = True + default = kwargs["default"] + else: + has_default = False + + attr_names = chain_str.split(".") + + res = obj + for attr_name in attr_names: + if not hasattr(res, attr_name): + if has_default: + return default + else: + raise AttributeError(f"{res} object has no attribute {attr_name}") + res = getattr(res, attr_name) + + return res + + +def deprecated(future_name: Optional[str] = None, message: Optional[str] = None): + """ + Decorator to mark functions as deprecated + + :param new_function: Function called in place of depreciated function + :param message: Depreciation message, replaces default depreciation message + """ + + def decorator(func: Callable[[Any], Any]): + nonlocal message + + if message is None: + message = ( + f"{func.__name__} is deprecated and will be removed in a future release" + ) + if future_name is not None: + message += f". Please use {future_name} instead." + + @wraps(func) + def wrapped(*args, **kwargs): + warnings.warn(message, DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapped + + return decorator + + class Aliasable: """ A mixin for enums to allow aliasing of enum members @@ -155,3 +223,108 @@ def __eq__(self, other): def __hash__(self): canonical_value = self.aliases.get(self.value, self.value) return hash(canonical_value) + + +def shard_tensor( + tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0 +) -> List[torch.Tensor]: + """ + Shards a tensor into a list of tensors along a given dimension. + + raises: ValueError: If the sum of shard_sizes does not match the + size of the tensor along the given dimension. + + :param tensor: The input tensor to shard. + :param shard_sizes : List of sizes for each shard along the specified dimension. + :param dim : The dimension along which to shard the tensor. + :returns: A list of tensors sharded along the specified dimension. + """ + if sum(shard_sizes) != tensor.size(dim): + raise ValueError( + "Sum of shard_sizes must equal the size of the tensor " + "along the specified dimension." + ) + + shards = [] + start_idx = 0 + + for size in shard_sizes: + end_idx = start_idx + size + shard = tensor.narrow(dim, start_idx, size) + shards.append(shard) + start_idx = end_idx + + return shards + + +def combine_shards(shards, dim=0): + """ + Combine decompressed shards along a given dimension using `narrow`. + + :param shards: List of decompressed shard tensors. + :param dim: Dimension to combine along (default: 0). + :return: Combined decompressed tensor. + """ + if not shards: + raise ValueError("The list of shards is empty.") + + # Assert that all shards have the same dtype + shard_dtypes = {shard.dtype for shard in shards} + if len(shard_dtypes) > 1: + raise ValueError("All shards must have the same dtype.") + + # Determine the total shape of the combined tensor + total_shape = list(shards[0].shape) + total_shape[dim] = sum(shard.shape[dim] for shard in shards) + + # Create the combined tensor + combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device) + + # Fill the combined tensor using narrow + shard_offset = 0 + for shard in shards: + shard_size = shard.shape[dim] + combined.narrow(dim, shard_offset, shard_size).copy_(shard) + shard_offset += shard_size + + return combined + + +def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor: + """ + Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be + compressed to R x ceil(C/8) + + :param bytemasks: mask tensor where each byte corresponds to a weight + :return: mask tensor where each bit corresounds to a weight + """ + packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little") + packed_bits_torch = torch.from_numpy(packed_bits_numpy) + + return packed_bits_torch + + +def unpack_bitmasks( + packed_bitmasks: torch.Tensor, original_shape: torch.Size +) -> torch.Tensor: + """ + Converts a bitmask tensor back to a bytemask tensor for use during decompression + + :param packed_bitmasks: mask tensor where each bit corresponds to a weight + :param original_shape: dense shape to decompress to + :return: boolean mask of weights in the original dense shape + """ + # Unpack the bits + unpacked_bits = numpy.unpackbits( + packed_bitmasks.cpu().numpy(), + axis=-1, + count=original_shape[-1], + bitorder="little", + ) + + # Reshape to match the original shape + unpacked_bitmasks_torch = torch.from_numpy( + unpacked_bits.reshape(original_shape).astype(bool) + ) + + return unpacked_bitmasks_torch diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 9dd7b22d..b3c77c58 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -11,9 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Utilities associated with offloading functionality provided by `accelerate`. + +| ----------------------------------------------------------------------------------------------------- | # noqa: E501 +| Operation | Without offloading support | With offloading support | # noqa: E501 +| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501 +| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501 +| Check | N/A | has_offloaded_params(module) | # noqa: E501 +| Onload | N/A | with align_module_device(module) | # noqa: E501 +| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501 +| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501 +| ----------------------------------------------------------------------------------------------------- | # noqa: E501 +""" + +import contextlib +from functools import wraps +from typing import Any, Callable, Dict, Literal, Optional, Union import torch -from torch.nn import Module + + +try: + from accelerate.hooks import ( + AlignDevicesHook, + add_hook_to_module, + remove_hook_from_module, + ) + from accelerate.utils import ( + OffloadedWeightsLoader, + PrefixedDataset, + set_module_tensor_to_device, + ) + + _has_accelerate = True +except ImportError: + _has_accelerate = False + AlignDevicesHook = None + add_hook_to_module = None + remove_hook_from_module = None + OffloadedWeightsLoader = None + PrefixedDataset = None + set_module_tensor_to_device = None __all__ = [ @@ -22,23 +61,44 @@ "get_offloaded_device", "update_prefix_dict", "update_parameter_data", + "register_offload_parameter", + "update_offload_parameter", + "delete_offload_parameter", + "has_offloaded_params", + "disable_hf_hook", + "align_module_device", ] -def is_module_offloaded(module: Module) -> bool: - """ - :param module: layer to check - :return: True if layer is offloaded from GPU, False otherwise - """ - return hasattr(module, "_hf_hook") and module._hf_hook.offload +def check_accelerate(fallback: Any): + def decorator(func: Callable[[Any], Any]): + if not _has_accelerate: + + @wraps(func) + def fallback_fn(*args, **kwargs): + return fallback + + return fallback_fn + + return func + return decorator -def get_execution_device(module: Module) -> torch.device: + +""" Candidates for Depreciation """ + + +@check_accelerate(fallback=False) +def is_module_offloaded(module: torch.nn.Module) -> bool: + return has_offloaded_params(module) + + +def get_execution_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is loaded onto during forward pass + :param module: module to check + :return: device module is loaded onto during forward pass """ - if is_module_offloaded(module): + if has_offloaded_params(module): return module._hf_hook.execution_device device = next(module.parameters()).device @@ -49,68 +109,296 @@ def get_execution_device(module: Module) -> torch.device: return device -def get_offloaded_device(module: Module) -> torch.device: +def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is offloaded to onto after forward pass + :param module: module to check + :return: device module is offloaded to onto after forward pass """ - if is_module_offloaded(module): + if has_offloaded_params(module): first_key = list(module._hf_hook.weights_map.keys())[0] prefix_dataset = module._hf_hook.weights_map.dataset return prefix_dataset[first_key].device return next(module.parameters()).device -def update_prefix_dict(module: Module, key: str, data: torch.Tensor): +@check_accelerate(fallback=None) +def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): """ Updates the offloaded state dict for a given module. Parameter named key is replaced by data. This is neccesary because parameter updates for offloaded modules do not persist automatically between loads. This function only affects the offloaded state dict and not the current state of the loaded module. - :param module: layer containing the parameter to update + :param module: module containing the parameter to update :param key: name of parameter to update :param data: tensor to update parameter with in the offloaded state dict """ - if not is_module_offloaded(module): + if not has_offloaded_params(module): raise ValueError("Prefix dict is only applicable to offloaded modules") - prefix_dict = module._hf_hook.weights_map - prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data + + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, key, data) def update_parameter_data( - module: Module, new_param_data: torch.Tensor, param_name: str + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str ): """ - Updates the paramter value named param_name for a given module. This function - updates both the current loaded module state and the offloaded state dict if - the module is offloaded. This is neccesary because parameter updates for offloaded - modules do not persist automatically between loads. + Update the data of an existing parameter and its offload dict. Supports both + parameters of offloaded modules and non-offloaded modules - :param module: layer containing the parameter to update + :param module: module containing the parameter to update :param new_param_data: tensor to update parameter with - :param param_name: name of layer parameter to update + :param param_name: name of module parameter to update """ - if not hasattr(module, param_name): - return + update_offload_parameter(module, param_name, new_param_data) + + +""" Candidates for Upstreaming """ + + +def register_offload_parameter( + module: torch.nn.Module, + name: str, + parameter: torch.nn.Parameter, + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, +): + """ + Register a parameter to the given module which may be offloaded + + :param module: maybe offloaded module + :param name: name of newly registered parameter + :param parameter: parameter being registered + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module + """ + has_onload = any(p.device != torch.device("meta") for p in module.parameters()) + module.register_parameter(name, parameter) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, name, parameter.data, offload_device) + if not has_onload: + set_module_tensor_to_device(module, name, "meta") + + +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: Optional[torch.Tensor], + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, +): + """ + Update the data of an existing parameter and its offload dict. Supports both + parameters of offloaded modules and non-offloaded modules + + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module + """ + param = getattr(module, name) + data = data.to(param.dtype) + + # copy data into onloaded parameter if applicable + if param.device != "meta": + param.data.copy_(data) + + # update offload dict + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, name, data, offload_device) + + +def delete_offload_parameter(module: torch.nn.Module, name: str): + """ + Delete a parameter from a module which may be offloaded + + :param module: maybe offloaded module + :param name: name of parameter being deleted + """ + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + delete_from_weights_map(weights_map, name) - device = next(module.parameters()).device - offloaded = False - if is_module_offloaded(module): - offload_device = get_offloaded_device(module) - offloaded = True +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def disable_hf_hook(module: torch.nn.Module): + hooks = {} - parameter = getattr(module, param_name, None) - if parameter is None: - raise ValueError("Attempted to update uninitialized parameter") + def collect_hooks(module): + nonlocal hooks + if hasattr(module, "_hf_hook"): + hooks[module] = module._hf_hook + remove_hook_from_module(module) - dtype = parameter.dtype - parameter.data = new_param_data.to(device).to(dtype) + module.apply(collect_hooks) - if offloaded: - prefix_dict = module._hf_hook.weights_map.dataset - prefix = module._hf_hook.weights_map.prefix - prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to( - dtype + yield + + for submodule, hook in hooks.items(): + add_hook_to_module(submodule, hook) + + +@check_accelerate(fallback=None) +def offload_to_weights_map( + weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], + key: str, + value: torch.Tensor, + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, +): + """ + Helper function which implements offloaded item assignment for PrefixedDataset, + OffloadedWeightsLoader, and Dict types. + + :param weights_map: weight map to be updated with offload information + :param key: key used to identify weight location + :param value: weight being offloaded + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters in weights_map + """ + if isinstance(weights_map, PrefixedDataset): + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + + dataset = weights_map.dataset + key = f"{weights_map.prefix}{key}" + offload_to_weights_map(dataset, key, value, offload_device) + + elif isinstance(weights_map, OffloadedWeightsLoader): + if key not in weights_map.all_keys: + weights_map.all_keys.append(key) + + if len(weights_map.index) <= 0 and offload_device != "disk": + offload_to_weights_map(weights_map.state_dict, key, value, offload_device) + + else: + raise NotImplementedError( + "Updating weights_map with disk offloading is not implemented yet" + ) + + elif isinstance(weights_map, dict): + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + + # infer offload device + if offload_device is None: + if key in weights_map: + offload_device = weights_map[key].device + else: + tens = next(iter(weights_map.values()), None) + if tens is None: + raise ValueError( + "Cannot infer offload device from empty weights_map" + ) + offload_device = tens.device + + weights_map[key] = value.to(device=offload_device) + + else: + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" + ) + + +@check_accelerate(fallback=None) +def delete_from_weights_map( + weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], + key: str, +): + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + key = f"{weights_map.prefix}{key}" + delete_from_weights_map(dataset, key) + + elif isinstance(weights_map, OffloadedWeightsLoader): + if len(weights_map.index) <= 0: + delete_from_weights_map(weights_map.state_dict, key) + + else: + raise NotImplementedError( + "Delete from weights_map with disk offloading is not implemented yet" + ) + + elif isinstance(weights_map, dict): + del weights_map[key] + + else: + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" ) + + +""" Upstreamed Functions """ + + +# introduced in accelerate v1.1.0 +@check_accelerate(fallback=False) +def has_offloaded_params(module: torch.nn.Module) -> bool: + """ + Checks if a module has offloaded parameters by checking if the given module has a + AlignDevicesHook attached with offloading enabled + + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. + """ + return ( + hasattr(module, "_hf_hook") + and isinstance(module._hf_hook, AlignDevicesHook) + and module._hf_hook.offload + ) + + +# introduced in accelerate v1.1.0 +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def align_module_device( + module: torch.nn.Module, execution_device: Optional[torch.device] = None +): + """ + Context manager that moves a module's parameters to the specified execution device. + + Args: + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. + Otherwise, use hook execution device or pass + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device + + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = { + name: param.device for name, param in module.named_parameters(recurse=False) + } + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) + + else: + yield diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 4fdb3007..ab4d04bf 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -16,7 +16,7 @@ import os import re import struct -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open from torch import Tensor @@ -30,10 +30,14 @@ "merge_names", "get_weight_mappings", "get_nested_weight_mappings", + "get_nested_mappings_from_state_dict", "get_quantization_state_dict", "is_quantization_param", ] +WeightMappingType = Dict[str, str] +NestedWeightMappingType = Dict[str, WeightMappingType] + def get_safetensors_folder( pretrained_model_name_or_path: str, cache_dir: Optional[str] = None @@ -92,7 +96,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]: return header -def match_param_name(full_name: str, param_name: str) -> str: +def match_param_name(full_name: str, param_name: str) -> Optional[str]: """ Helper function extracting the uncompressed parameterized layer name from a compressed name. Assumes the compressed name was merged using merge_names. @@ -176,38 +180,100 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: def get_nested_weight_mappings( - model_path: str, params_to_nest: List[str] -) -> Dict[str, Dict[str, str]]: + model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False +) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]: """ Takes a path to a state dict saved in safetensors format and returns a nested - mapping from uncompressed parameterized layer names to the file locations of each - of the layers compression parameters. + mapping from uncompressed parameterized layer names to the file locations of + each layer's compression parameters. - layer.weight: { + Example of the nested mapping: + layer: { bitmask: file_location, row_offsets: file_location, shape: file_location, compressed: file_location } - This generalizes to cases where the model is split into multiple safetensors files + If other parameters are found that do not match the nested parameters, they will + be returned in a separate dictionary only if return_unmatched_params is True. + This dictionary may be needed for cases where compressors are stacked (e.g., + quantization compression followed by sparse compression). + + Example of the unmatched params mapping: + { + layer.weight_scale: file_location, + layer.input_scale: file_location + } - :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index - :return: nested mapping of parameterized layer name to file location + This generalizes to cases where the model is split into multiple safetensors + files. + + :param model_path: Path to the safetensors state dict, must contain either a + single safetensors file or multiple files with an index. + :param params_to_nest: List of parameter names to nest. + :param return_unmatched_params: If True, return a second dictionary containing + the remaining parameters that were not matched to the params_to_nest. + :return: + - If return_unmatched_params is False: + NestedWeightMappingType: A nested mapping of parameterized layer names to + file locations of each layer's compression parameters. + - If return_unmatched_params is True: + Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing: + - NestedWeightMappingType: A nested mapping of parameterized layer + names to file locations of each layer's compression parameters. + - WeightMappingType: A mapping of the remaining parameter names to + their file locations that were not matched to the params_to_nest. """ weight_mappings = get_weight_mappings(model_path) - nested_weight_mappings = {} - for key in weight_mappings.keys(): + unmatched_params = {} + + for key, file_location in weight_mappings.items(): + matched = False for param_name in params_to_nest: - maybe_match = match_param_name(key, param_name) - if maybe_match is not None: - dense_param = maybe_match + dense_param = match_param_name(key, param_name) + if dense_param: if dense_param not in nested_weight_mappings: nested_weight_mappings[dense_param] = {} - nested_weight_mappings[dense_param][param_name] = weight_mappings[key] + nested_weight_mappings[dense_param][param_name] = file_location + matched = True + if return_unmatched_params and not matched: + unmatched_params[key] = file_location + + if return_unmatched_params: + return nested_weight_mappings, unmatched_params + return nested_weight_mappings + +def get_nested_mappings_from_state_dict( + state_dict, params_to_nest +) -> NestedWeightMappingType: + """ + Takes a state dict and returns a nested mapping from uncompressed + parameterized layer names to the value of + each layer's compression parameters. + + Example of the nested mapping: + layer: { + weight_scale: ..., + weight: ..., + zero_point: ..., + } + + :param state_dict: state dict of the model + :param params_to_nest: List of parameter names to nest. + :return: Nested mapping of parameterized layer names to the value of + each layer's compression parameters. + """ + nested_weight_mappings = {} + for key in state_dict.keys(): + for param_name in params_to_nest: + dense_param = match_param_name(key, param_name) + if dense_param: + if dense_param not in nested_weight_mappings: + nested_weight_mappings[dense_param] = {} + nested_weight_mappings[dense_param][param_name] = state_dict[key] return nested_weight_mappings diff --git a/src/compressed_tensors/version.py b/src/compressed_tensors/version.py index 8b4ea8c7..73356205 100644 --- a/src/compressed_tensors/version.py +++ b/src/compressed_tensors/version.py @@ -17,7 +17,7 @@ """ -version_base = "0.8.0" +version_base = "0.8.1" is_release = True # change to True to set the generated version as a release version diff --git a/tests/conftest.py b/tests/conftest.py index a1c1d861..492f7af0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,8 +44,6 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor min_val = torch.amin(value, dim=dim, keepdims=True) max_val = torch.amax(value, dim=dim, keepdims=True) scale, zp = calculate_qparams(min_val, max_val, args) - scale = scale.reshape((1, 1)) - zp = zp.reshape((1, 1)) update_parameter_data(module, scale, f"{base_name}_scale") update_parameter_data(module, zp, f"{base_name}_zero_point") diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 3f6940a9..d91ab262 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from copy import deepcopy +from pathlib import Path import pytest +import torch from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config.base import SparsityCompressionConfig from compressed_tensors.quantization.quant_config import QuantizationConfig -from tests.testing_utils import requires_hf_quantizer +from safetensors.torch import save_file +from tests.testing_utils import induce_sparsity, requires_hf_quantizer def sparsity_config(): @@ -99,8 +103,8 @@ def test_hf_compressor_tensors_config(s_config, q_config, tmp_path): ) q_config = QuantizationConfig(**q_config) if q_config is not None else None - s_config_dict = s_config.dict() if s_config is not None else None - q_config_dict = q_config.dict() if q_config is not None else None + s_config_dict = s_config.model_dump() if s_config is not None else None + q_config_dict = q_config.model_dump() if q_config is not None else None assert compressor.sparsity_config == s_config assert compressor.quantization_config == q_config @@ -109,3 +113,144 @@ def test_hf_compressor_tensors_config(s_config, q_config, tmp_path): assert ( ModelCompressor.parse_quantization_config(compression_config) == q_config_dict ) + + +@pytest.fixture +def fake_model_class(): + import torch.nn as nn + + class CustomLinearModel(nn.Module): + def __init__(self, weights, weight_scale=None, weight_zero_point=None): + super(CustomLinearModel, self).__init__() + out_features, in_features = weights.shape + + # Define a linear layer without bias + self.linear = nn.Linear(in_features, out_features, bias=False) + + # Set the weights of the linear layer + self.linear.weight = nn.Parameter(weights, requires_grad=False) + + # Attach weight_scale and weight_zero_point as parameters + if weight_scale is not None: + self.linear.weight_scale = nn.Parameter( + torch.tensor(weight_scale), requires_grad=False + ) + if weight_zero_point is not None: + self.linear.weight_zero_point = nn.Parameter( + torch.tensor(weight_zero_point), requires_grad=False + ) + + def forward(self, x): + return self.linear(x) + + return CustomLinearModel + + +def get_bitmask_sparsity_config(): + from compressed_tensors import BitmaskConfig + + return BitmaskConfig( + format="sparse-bitmask", + global_sparsity=0.7, + targets=["Linear"], + sparsity_structure="unstructured", + ) + + +def create_quantization_config(bits=8, type="int", strategy="tensor"): + + config_dict = { + "format": "int-quantized", + "global_compression_ratio": 1.0, + "quant_method": "compressed-tensors", + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": bits, + "strategy": strategy, + "symmetric": True, + "type": type, + }, + } + }, + } + + return QuantizationConfig.model_validate(config_dict) + + +@pytest.mark.parametrize("sparsity_config", [get_bitmask_sparsity_config()]) +@pytest.mark.parametrize( + "quantization_config", + [ + create_quantization_config(bits=8, type="int", strategy="channel"), + create_quantization_config(bits=8, type="float", strategy="channel"), + ], +) +def test_composability( + tmp_path, fake_model_class, sparsity_config, quantization_config +): + from compressed_tensors.quantization.lifecycle.forward import quantize + + weights = torch.rand(10, 5) + sparse_weights = induce_sparsity(weights, sparsity_config.global_sparsity) + + quantization_args = quantization_config.config_groups["group_0"].weights + + if quantization_args.strategy == "channel": + scale = torch.ones((weights.shape[0], 1)) + elif quantization_args.strategy == "tensor": + scale = torch.tensor([1.0]) + + zero_point = torch.zeros_like(scale) + + quantized_weights = quantize( + sparse_weights, + scale=scale, + zero_point=zero_point, + args=quantization_args, + ) + + fake_oneshot_model = fake_model_class(quantized_weights, scale, zero_point) + fake_oneshot_model.linear.quantization_scheme = quantization_config.config_groups[ + "group_0" + ] + model_compressor = ModelCompressor( + sparsity_config=sparsity_config, quantization_config=quantization_config + ) + # does both sparse and quantization compression + compressed_state_dict = model_compressor.compress(fake_oneshot_model) + + save_dir = tmp_path / "model" + save_dir = _create_dummy_checkpoint( + compressed_state_dict, save_dir, model_compressor + ) + + decompressed_model = fake_model_class(torch.zeros_like(weights)) + model_compressor.decompress(model=decompressed_model, model_path=save_dir) + + # check that the decompressed model is the same as the original model + _check_state_dicts(fake_oneshot_model.state_dict(), decompressed_model.state_dict()) + + +def _create_dummy_checkpoint(state_dict, save_dir, model_compressor): + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + save_file(state_dict, save_dir / "model.safetensors") + + config_file_path = save_dir / "config.json" + with open(config_file_path, "w") as config_file: + json.dump({}, config_file, indent=2, sort_keys=True) + + model_compressor.update_config(save_dir) + return save_dir + + +def _check_state_dicts(state_dict1, state_dict2): + for key in state_dict1.keys(): + assert key in state_dict2, f"Missing tensor: {key}" + if key.endswith("weight"): + original_tensor = state_dict1[key] + decompressed_tensor = state_dict2[key].to(original_tensor.dtype) + diff = torch.abs(original_tensor - decompressed_tensor) + assert not torch.any(diff > 0.01), f"Max diff: {torch.max(diff)}" diff --git a/tests/test_compressors/sparse_compressors/test_bitmask.py b/tests/test_compressors/sparse_compressors/test_bitmask.py index 248580bc..491ee2e8 100644 --- a/tests/test_compressors/sparse_compressors/test_bitmask.py +++ b/tests/test_compressors/sparse_compressors/test_bitmask.py @@ -44,16 +44,16 @@ def test_bitmask_sizes(shape, sparsity, dtype): assert len(dense_state_dict) * 4 == len(sparse_state_dict) # bitmask should be 1 bit per dense element, rounded up to nearest int8 - sparse_shape = sparse_state_dict["dummy.weight.shape"] + sparse_shape = sparse_state_dict["dummy.shape"] assert torch.all(torch.eq(sparse_shape, torch.tensor(shape))) - bitmask_shape = sparse_state_dict["dummy.weight.bitmask"].shape + bitmask_shape = sparse_state_dict["dummy.bitmask"].shape assert bitmask_shape[0] == sparse_shape[0] assert bitmask_shape[1] == int(math.ceil(sparse_shape[1] / 8.0)) # one value for each non-zero weight - values_shape = sparse_state_dict["dummy.weight.compressed"].shape + values_shape = sparse_state_dict["dummy.compressed"].shape assert values_shape[0] == torch.sum(test_tensor != 0) - row_offsets_shape = sparse_state_dict["dummy.weight.row_offsets"].shape + row_offsets_shape = sparse_state_dict["dummy.row_offsets"].shape assert row_offsets_shape[0] == test_tensor.shape[0] diff --git a/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py new file mode 100644 index 00000000..0e28f004 --- /dev/null +++ b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py @@ -0,0 +1,201 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from compressed_tensors import Sparse24BitMaskTensor +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils import combine_shards, shard_tensor +from tests.testing_utils import generate_pruned_semi_structured_mat, requires_gpu + + +@pytest.fixture +def dense_matrix_fixture(): + def _generate_dense_matrix(M, K, dtype): + return generate_pruned_semi_structured_mat(M, K, dtype) + + return _generate_dense_matrix + + +@pytest.fixture +def shard_validation(): + def _validate_shard_shapes(sharded_values, sharded_bitmask, expected_shapes): + for shard_values, shard_bitmask, expected_shape in zip( + sharded_values, sharded_bitmask, expected_shapes + ): + assert ( + shard_values.shape == expected_shape["compressed"] + ), f"Shape mismatch: {shard_values.shape} != {expected_shape['compressed']}" + assert ( + shard_bitmask.shape == expected_shape["bitmask"] + ), f"Shape mismatch: {shard_bitmask.shape} != {expected_shape['bitmask']}" + + return _validate_shard_shapes + + +def validate_compression(dense_matrix, decompressed_tensor): + """Validate that the decompressed tensor matches the original dense matrix.""" + dense_matrix = dense_matrix.to(decompressed_tensor.device) + assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch" + assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch" + assert torch.equal(dense_matrix, decompressed_tensor), "Decompression failed" + + +@pytest.mark.parametrize("dtype", [torch.int8]) +def test_bitmask_compress_decompress(dense_matrix_fixture, dtype): + M, K = 1024, 1024 + dense_matrix = dense_matrix_fixture(M, K, dtype) + + bitmask_tensor = Sparse24BitMaskTensor.from_dense( + dense_matrix, sparsity_structure="2:4" + ) + decompressed_tensor = bitmask_tensor.decompress() + + validate_compression(dense_matrix, decompressed_tensor) + + +@pytest.mark.parametrize( + "dtype, M, K, shard_sizes, shard_dim, expected_shapes", + [ + ( + torch.int8, + 2560, + 2048, + [2048, 256, 256], + 0, + [ + {"compressed": (2048, 1024), "bitmask": (2048, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + ], + ), + ( + torch.int8, + 2048, + 2048, + [1024, 1024], + 1, + [ + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + ], + ), + ], +) +def test_bitmask_compress_decompress_sharded( + dense_matrix_fixture, + shard_validation, + dtype, + M, + K, + shard_sizes, + shard_dim, + expected_shapes, +): + dense_matrix = dense_matrix_fixture(M, K, dtype) + + bitmask_tensor = Sparse24BitMaskTensor.from_dense(dense_matrix) + compressed_values = bitmask_tensor.compressed + compressed_bitmask = bitmask_tensor.bitmask + + if shard_dim == 1: + compressed_shard_sizes = [size // 2 for size in shard_sizes] + bitmask_shard_sizes = [size // 8 for size in shard_sizes] + else: + compressed_shard_sizes = shard_sizes + bitmask_shard_sizes = shard_sizes + + sharded_compressed_values = shard_tensor( + compressed_values, compressed_shard_sizes, dim=shard_dim + ) + sharded_compressed_bitmask = shard_tensor( + compressed_bitmask, bitmask_shard_sizes, dim=shard_dim + ) + + shard_validation( + sharded_compressed_values, sharded_compressed_bitmask, expected_shapes + ) + + decompressed_shards = [ + Sparse24BitMaskTensor( + shape=(expected_shape["bitmask"][0], expected_shape["bitmask"][1] * 8), + compressed=shard_values, + bitmask=shard_bitmask, + ).decompress() + for shard_values, shard_bitmask, expected_shape in zip( + sharded_compressed_values, sharded_compressed_bitmask, expected_shapes + ) + ] + + decompressed_combined = combine_shards(decompressed_shards, dim=shard_dim) + validate_compression(dense_matrix, decompressed_combined) + + +# GPU-Specific Tests for FP8_DTYPE +@pytest.mark.parametrize("dtype", [FP8_DTYPE]) +@requires_gpu +def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype): + test_bitmask_compress_decompress(dense_matrix_fixture, dtype) + + +@pytest.mark.parametrize( + "dtype, M, K, shard_sizes, shard_dim, expected_shapes", + [ + ( + FP8_DTYPE, + 2560, + 2048, + [2048, 256, 256], + 0, + [ + {"compressed": (2048, 1024), "bitmask": (2048, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + ], + ), + ( + FP8_DTYPE, + 2048, + 2048, + [1024, 1024], + 1, + [ + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + ], + ), + ], +) +@requires_gpu +def test_bitmask_compress_decompress_sharded_fp8( + dense_matrix_fixture, + shard_validation, + dtype, + M, + K, + shard_sizes, + shard_dim, + expected_shapes, +): + test_bitmask_compress_decompress_sharded( + dense_matrix_fixture, + shard_validation, + dtype, + M, + K, + shard_sizes, + shard_dim, + expected_shapes, + ) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 7268ca27..958ec3a5 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -13,7 +13,9 @@ # limitations under the License. import re +from collections import defaultdict from typing import Optional +from unittest.mock import MagicMock import pytest import torch @@ -26,12 +28,38 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, + expand_sparse_target_names, + is_sparse_target, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM +@pytest.fixture +def mock_model(): + model = MagicMock() + model.named_modules.return_value = [ + ("layer1", MagicMock()), + ("layer2", MagicMock()), + ("layer3", MagicMock()), + ] + return model + + +@pytest.fixture +def mock_module(): + return MagicMock() + + +@pytest.fixture +def llama_stories_model(): + return AutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", + torch_dtype="auto", + ) + + def test_target_prioritization(mock_frozen): # tests that the config_groups are applied in the correct order # of priority, where exact layer name > regex > module name @@ -87,31 +115,36 @@ def test_apply_quantization_config_tinyllama(): for module in model.modules(): _test_layer_quantization_status(module, inputs=False, weights=False) + count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding") + count_layer_num = defaultdict(int) + + for name, module in model.named_modules(): + if name in quant_config.ignore: + continue + module_type = module.__class__.__name__ + if module_type in count_layer_names: + count_layer_num[module_type] += 1 + + assert len(count_layer_num) > 0, f"None of {count_layer_names} found in model" + assert all(value > 0 for value in count_layer_num.values()) + # apply quant config to model apply_quantization_config(model, quant_config) # check for correct application of quant config - num_linears = 0 - num_embeddings = 0 - num_rotary_embeddings = 0 for name, module in model.named_modules(): if name in quant_config.ignore: continue module_type = module.__class__.__name__ - if module_type == "Linear": - num_linears += 1 - _test_layer_quantization_status(module, inputs=True, weights=True) - elif module_type == "Embedding": - num_embeddings += 1 - _test_layer_quantization_status(module, inputs=False, weights=True) - elif module_type == "LlamaRotaryEmbedding": - num_rotary_embeddings += 1 - _test_layer_quantization_status(module, inputs=False, weights=False) - - # sanity check correct number of layers targeted - assert num_linears == 154 # 155 Linear layers - 1 that gets ignored - assert num_embeddings == 1 - assert num_rotary_embeddings == 23 # model updated, now has model.rotary_embedding + if module_type in count_layer_names: + count_layer_num[module_type] -= 1 + _inputs = module_type == "Linear" + _weights = not module_type == "LlamaRotaryEmbedding" + _test_layer_quantization_status(module, inputs=_inputs, weights=_weights) + + assert all( + value == 0 for value in count_layer_num.values() + ), "Not all values are zero" # test quantization compression # sample forward pass to fill scales, zps @@ -222,7 +255,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): }, "ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"], } - return QuantizationConfig.parse_obj(config_dict) + return QuantizationConfig.model_validate(config_dict) @requires_accelerate() @@ -266,3 +299,73 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 + + +@pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ([], [], set()), + (["layer1", "layer2"], [], {"layer1", "layer2"}), + ([], ["layer1"], set()), + (["layer1", "layer2"], ["layer2"], {"layer1"}), + (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), + ], +) +def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): + expanded_targets = expand_sparse_target_names(mock_model, targets, ignore) + assert expanded_targets == expected_targets + + +@pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ( + ["re:model.layers.[01].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[01].self_attn.q_proj"], + [], + set(["model.layers.0.self_attn.q_proj", "model.layers.1.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[0-2].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj", "model.layers.2.self_attn.q_proj"]), + ), + ( + ["model.layers.0.self_attn.q_proj"], + ["model.layers.0.self_attn.q_proj"], + set(), + ), + ( + ["re:model.layers.*.self_attn.q_proj"], + ["re:model.layers.[01].self_attn.q_proj"], + set( + f"model.layers.{layer_idx}.self_attn.q_proj" + for layer_idx in range(2, 6) + ), + ), + ], +) +def test_expand_targets_with_llama_stories( + llama_stories_model, targets, ignore, expected_targets +): + expanded_targets = expand_sparse_target_names(llama_stories_model, targets, ignore) + assert expanded_targets == expected_targets + + +@pytest.mark.parametrize( + "name, targets, ignore, expected", + [ + ("layer1", ["layer1"], [], True), + ("layer1", ["layer1"], ["layer1"], False), + ("layer1", ["layer2"], [], False), + ("layer1", ["re:layer.*"], [], True), + ("layer1", ["re:layer.*"], ["re:layer1"], False), + ], +) +def test_is_target_with_mock(mock_module, name, targets, ignore, expected): + result = is_sparse_target(name, mock_module, targets, ignore) + assert result == expected diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index dd700637..3ac91e85 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -110,4 +110,4 @@ def get_sample_dynamic_tinyllama_quant_config(): }, "ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"], } - return QuantizationConfig.parse_obj(config_dict) + return QuantizationConfig.model_validate(config_dict) diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index 987b2ae2..80a1629d 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -14,15 +14,31 @@ import pytest +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationScheme, + QuantizationStatus, + QuantizationStrategy, +) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.quant_config import QuantizationStatus +from tests.testing_utils import requires_accelerate from torch.nn import Linear NUM_BITS = 8 +Q_PARAM_NAMES = { + "input_activations": "input", + "weights": "weight", + "output_activations": "output", +} + + +@pytest.fixture +def layer(): + return Linear(4, 4) @pytest.mark.parametrize( @@ -43,14 +59,13 @@ ], ) def test_initialize_module_for_quantization( - create_quantization_scheme, weights, input_activations + create_quantization_scheme, weights, input_activations, layer ): quantization_scheme = create_quantization_scheme( targets=["*"], weights=weights, input_activations=input_activations, ) - layer = Linear(4, 4) assert not hasattr(layer, "quantization_scheme") assert not hasattr(layer, "quantization_status") @@ -77,3 +92,111 @@ def test_initialize_module_for_quantization( assert hasattr(layer, "quantization_status") assert layer.quantization_status == QuantizationStatus.INITIALIZED + + +@requires_accelerate() +@pytest.mark.parametrize( + "weights,input_activations", + [ + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + None, + ), + ( + None, + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ], +) +def test_initialize_module_for_quantization_offloaded( + create_quantization_scheme, weights, input_activations, layer +): + from accelerate.hooks import attach_align_device_hook + + attach_align_device_hook(layer, offload=True) + + test_initialize_module_for_quantization( + create_quantization_scheme, + weights, + input_activations, + layer, + ) + + +@pytest.mark.parametrize( + "weights,input_activations", + [ + ( + QuantizationArgs(strategy="tensor"), + QuantizationArgs(strategy="tensor"), + ), + ( + QuantizationArgs(strategy="channel"), + None, + ), + ( + QuantizationArgs(strategy="group", group_size=2), + None, + ), + ( + QuantizationArgs(strategy="group", group_size=2, actorder="group"), + None, + ), + ( + QuantizationArgs(strategy="group", group_size=2, actorder="weight"), + None, + ), + ( + QuantizationArgs(strategy="block"), + QuantizationArgs(strategy="block"), + ), + ( + QuantizationArgs(strategy="token"), + QuantizationArgs(strategy="token"), + ), + ], +) +def test_initialize_quantization_parameters(weights, input_activations): + quantization_scheme = QuantizationScheme( + targets=["*"], + weights=weights, + input_activations=input_activations, + ) + layer = Linear(7, 8) + initialize_module_for_quantization(layer, quantization_scheme) + + for q_type in ("input_activations", "weights"): + args = getattr(quantization_scheme, q_type) + if args is None: + continue + q_param_name = Q_PARAM_NAMES[q_type] + + # scale and zero point + if args.strategy == QuantizationStrategy.TENSOR: + expected_shape = (1,) + + elif args.strategy == QuantizationStrategy.CHANNEL: # only weight + expected_shape = (layer.weight.shape[0], 1) + + elif args.strategy == QuantizationStrategy.GROUP: # only weight + num_groups = layer.weight.shape[1] // args.group_size + expected_shape = (layer.weight.shape[0], max(num_groups, 1)) + + elif args.strategy == QuantizationStrategy.BLOCK: + expected_shape = (1,) + + elif args.strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + + assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape + assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape + + # g_idx + if args.actorder == ActivationOrdering.GROUP: + assert getattr(layer, f"{q_param_name}_g_idx").shape == ( + layer.weight.shape[1], + ) diff --git a/tests/test_quantization/test_configs/test_strategies.py b/tests/test_quantization/test_configs/test_strategies.py index 94201463..6605daf0 100644 --- a/tests/test_quantization/test_configs/test_strategies.py +++ b/tests/test_quantization/test_configs/test_strategies.py @@ -67,8 +67,8 @@ def test_channelwise( if input_symmetry is not None: mock_per_channel_calibration(model, base_name="input", value=inputs) - assert list(model.weight_scale.shape) == [model_shape[1], 1] - assert list(model.weight_zero_point.shape) == [model_shape[1], 1] + assert model.weight_scale.shape == (model_shape[1], 1) + assert model.weight_zero_point.shape == (model_shape[1], 1) @torch.no_grad @@ -97,14 +97,14 @@ def test_group( model, base_name="input", value=inputs, group_size=group_size ) - assert list(model.weight_scale.shape) == [ + assert model.weight_scale.shape == ( model_shape[1], int(model_shape[0] / group_size), - ] - assert list(model.weight_zero_point.shape) == [ + ) + assert model.weight_zero_point.shape == ( model_shape[1], int(model_shape[0] / group_size), - ] + ) @torch.no_grad @@ -131,8 +131,8 @@ def test_token( mock_per_channel_calibration(model, base_name="weight", value=model.weight) mock_per_token_calibration(model, base_name="input", value=inputs) - assert list(model.input_scale.shape) == [1, 1] - assert list(model.input_zero_point.shape) == [1, 1] + assert model.input_scale.shape == (1, 1) + assert model.input_zero_point.shape == (1, 1) - assert list(model.weight_scale.shape) == [256, 1] - assert list(model.weight_zero_point.shape) == [256, 1] + assert model.weight_scale.shape == (256, 1) + assert model.weight_zero_point.shape == (256, 1) diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index 460db82b..c3830a02 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -72,3 +72,10 @@ def test_load_scheme_from_preset(scheme_name: str): assert scheme_name in config.config_groups assert isinstance(config.config_groups[scheme_name], QuantizationScheme) assert config.config_groups[scheme_name].targets == targets + + +def test_to_dict(): + config_groups = {"group_1": QuantizationScheme(targets=[])} + config = QuantizationConfig(config_groups=config_groups) + reloaded = QuantizationConfig.model_validate(config.to_dict()) + assert config == reloaded diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py new file mode 100644 index 00000000..b106ee2d --- /dev/null +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.utils import calculate_qparams + + +@pytest.mark.parametrize( + "keepdims,strategy,exp_shape", + [ + ( + False, + QuantizationStrategy.TENSOR, + torch.Size( + [ + 1, + ] + ), + ), + (True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), + (True, QuantizationStrategy.GROUP, torch.Size([1, 1])), + ( + False, + QuantizationStrategy.BLOCK, + torch.Size( + [ + 1, + ] + ), + ), + (True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), + ], +) +def test_calculate_qparams(keepdims, strategy, exp_shape): + value = torch.randn(14, 5) + min_val = torch.amin(value, dim=tuple(), keepdims=keepdims) + max_val = torch.amax(value, dim=tuple(), keepdims=keepdims) + + if strategy == QuantizationStrategy.GROUP: + args = QuantizationArgs(strategy=strategy, group_size=2) + else: + args = QuantizationArgs(strategy=strategy) + scale, zp = calculate_qparams(min_val, max_val, args) + assert scale.shape == exp_shape + assert zp.shape == exp_shape diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py new file mode 100644 index 00000000..1002a4f5 --- /dev/null +++ b/tests/test_utils/test_offload.py @@ -0,0 +1,255 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from compressed_tensors.utils import ( + align_module_device, + delete_offload_parameter, + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, + update_offload_parameter, +) +from compressed_tensors.utils.offload import offload_to_weights_map +from tests.testing_utils import requires_accelerate + + +class ExampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(0).float()) + self.b = torch.nn.Parameter(torch.tensor(0).float()) + + def forward(self, x): + return x * self.a + self.b + + +@requires_accelerate() +def test_has_offloaded_params(): + from accelerate.big_modeling import cpu_offload_with_hook + from accelerate.hooks import attach_align_device_hook, remove_hook_from_module + + module = ExampleModule() + assert not has_offloaded_params(module) + + attach_align_device_hook(module, offload=False) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + module, _ = cpu_offload_with_hook(module) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert has_offloaded_params(module) + + +@requires_accelerate() +def test_register_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + parameter = torch.nn.Parameter(torch.tensor(1.0)) + + # register a param prior to offloading + register_offload_parameter(module, "c", parameter) + assert hasattr(module, "c") and module.c == parameter + + # offloading, check that added param was offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert "c" in module._hf_hook.weights_map + + # register a param after offloading, check that added param was offloaded + register_offload_parameter(module, "d", parameter) + assert hasattr(module, "d") and module.d.device == torch.device("meta") + assert module._hf_hook.weights_map["d"].device == torch.device("cpu") + + # added parameters can be onloaded and offloaded + with align_module_device(module, execution_device="cpu"): + assert module.c.device == torch.device("cpu") + assert module.d.device == torch.device("cpu") + assert module.c.device == torch.device("meta") + assert module.d.device == torch.device("meta") + + # parameters can be added during onload + with align_module_device(module, execution_device="cpu"): + register_offload_parameter(module, "e", parameter) + assert module.e.device == torch.device("cpu") + + # parameters can be added before onload and with explicit offload + register_offload_parameter(module, "f", parameter, offload_device="cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + with align_module_device(module, execution_device="cpu"): + assert module.f.device == torch.device("cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + + +@requires_accelerate() +def test_update_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_a = torch.nn.Parameter(torch.tensor(1.0)) + param_b = torch.nn.Parameter(torch.tensor(2.0)) + + # can update modules which are not offloaded + update_offload_parameter(module, "a", param_a) + assert module.a == param_a + + # can update modules which are offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + update_offload_parameter(module, "b", param_b) + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across onloading + with align_module_device(module, execution_device="cpu"): + assert module.a == param_a + assert module.b == param_b + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across offloading + assert module.a.device == torch.device("meta") + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + +@requires_accelerate() +def test_delete_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_c = torch.nn.Parameter(torch.tensor(1.0)) + param_d = torch.nn.Parameter(torch.tensor(2.0)) + register_offload_parameter(module, "c", param_c) + register_offload_parameter(module, "d", param_d) + + # parameters are deleted + delete_offload_parameter(module, "a") + delete_offload_parameter(module, "c") + assert not hasattr(module, "a") + assert hasattr(module, "b") + assert not hasattr(module, "c") + assert hasattr(module, "d") + + # parameters and their offload are deleted + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + delete_offload_parameter(module, "b") + delete_offload_parameter(module, "d") + assert not hasattr(module, "a") + assert not hasattr(module, "b") + assert not hasattr(module, "c") + assert not hasattr(module, "d") + assert "a" not in module._hf_hook.weights_map + assert "b" not in module._hf_hook.weights_map + assert "c" not in module._hf_hook.weights_map + assert "d" not in module._hf_hook.weights_map + + +@requires_accelerate() +def test_disable_hf_hook(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + + def custom_forward(): + pass + + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + with disable_hf_hook(module): + assert not hasattr(module, "_hf_hook") + module.forward = custom_forward + + assert hasattr(module, "_hf_hook") + assert module._old_forward == custom_forward + + +@requires_accelerate() +def test_disable_hf_hook_model_recurse(): + from accelerate.hooks import attach_align_device_hook + + module0 = ExampleModule() + module1 = ExampleModule() + module2 = ExampleModule() + model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2)) + attach_align_device_hook(model, offload=True, weights_map=model.state_dict()) + + with disable_hf_hook(model): + assert not hasattr(module0, "_hf_hook") + assert not hasattr(module1, "_hf_hook") + assert not hasattr(module2, "_hf_hook") + + assert hasattr(module0, "_hf_hook") + assert hasattr(module1, "_hf_hook") + assert hasattr(module2, "_hf_hook") + + +@requires_accelerate() +def test_offload_to_weights_map(): + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset + + name = "name" + old_value = torch.tensor(0.0) + new_value = torch.tensor(1.0) + prefix = "prefix" + + # Dict empty + weights_map = {} + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # Dict populated + weights_map = {name: old_value} + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # OffloadedWeightsLoader[Dict] empty + weights_map = OffloadedWeightsLoader({}) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # OffloadedWeightsLoader[Dict] populated + weights_map = OffloadedWeightsLoader({name: old_value}) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[Dict] empty + weights_map = PrefixedDataset({}, prefix) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # PrefixedDataset[Dict] populated + weights_map = PrefixedDataset({name: old_value}, prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[OffloadedWeightsLoader[Dict]] empty + weights_map = PrefixedDataset(OffloadedWeightsLoader({}), prefix) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # PrefixedDataset[OffloadedWeightsLoader[Dict]] populated + weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py new file mode 100644 index 00000000..932a8926 --- /dev/null +++ b/tests/test_utils/test_safetensors_load.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +from compressed_tensors.utils.safetensors_load import get_nested_weight_mappings + + +mock_weight_mappings = { + "layer1.weight": "file1", + "layer1.bias": "file2", + "layer2.weight": "file3", + "layer2.bias": "file4", + "layer3.weight": "file5", +} + + +@pytest.fixture +def mock_get_weight_mappings(): + with patch( + "compressed_tensors.utils.safetensors_load.get_weight_mappings", + return_value=mock_weight_mappings, + ): + yield + + +@pytest.mark.usefixtures("mock_get_weight_mappings") +class TestGetNestedWeightMappings: + """ + Tests for the get_nested_weight_mappings function + in different scenarios, such as single and multiple + parameters to nest, and returning other parameters + """ + + def test_single_param(self): + params_to_nest = ["weight"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_multiple_params(self): + params_to_nest = ["weight", "bias"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1", "bias": "file2"}, + "layer2": {"weight": "file3", "bias": "file4"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_return_other_params(self): + params_to_nest = ["weight"] + result, other_params = get_nested_weight_mappings( + "dummy_path", params_to_nest, return_unmatched_params=True + ) + expected_nested = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + expected_other = { + "layer1.bias": "file2", + "layer2.bias": "file4", + } + assert result == expected_nested + assert other_params == expected_other diff --git a/tests/testing_utils.py b/tests/testing_utils.py index e446cad3..ebe7a0c6 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa +import unittest import pytest @@ -52,3 +54,91 @@ def requires_accelerate(): not _is_accelerate_available, reason="requires accelerate", ) + + +def get_random_mat(M, K, dtype) -> "torch.Tensor": + """ + :param M: number of rows + :param K: number of columns + :param dtype: data type of the matrix + :return: random matrix of shape (M, K) with non-zero values + """ + import torch + from compressed_tensors.quantization import FP8_DTYPE + + rand_tensor_dtype = dtype + if dtype in [torch.int8, FP8_DTYPE]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() + mat = mat.masked_fill_(mat == 0, 1) + return mat.to(dtype) + + +def generate_pruned_semi_structured_mat(M, K, dtype) -> "torch.Tensor": + """ + :param M: number of rows + :param K: number of columns + :param dtype: data type of the matrix + :return: random matrix of shape (M, K) with 2:4 sparsity pattern + """ + import torch + from compressed_tensors.quantization import FP8_DTYPE + + mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() + rand_tensor_dtype = dtype + if dtype in [torch.int8, FP8_DTYPE]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype) + mat = mat.masked_fill_(mat == 0, 1) + if dtype == FP8_DTYPE: + # some float8_e4m3fn operations are not supported on CPU + mat = mat.cuda() + mask = mask.cuda() + mat = mat * mask + return mat.to(dtype) + + +def induce_sparsity(tensor, sparsity_ratio) -> "torch.Tensor": + """ + Makes a tensor sparse by zeroing out a given fraction + of its smallest absolute values. + + :param: weight_tensor (torch.Tensor): The input weight tensor. + :param: sparsity_ratio (float): Fraction of weights to be zeroed + (0 <= sparsity_ratio <= 1). + :returns: torch.Tensor: Sparse version of the input tensor. + """ + import torch + + if not (0 <= sparsity_ratio <= 1): + raise ValueError("Sparsity ratio must be between 0 and 1.") + + # Flatten the tensor and compute the threshold for sparsity + flattened = tensor.view(-1) + k = int(sparsity_ratio * flattened.numel()) + + if k > 0: + threshold = torch.topk(flattened.abs(), k, largest=False).values.max() + sparse_tensor = torch.where( + tensor.abs() > threshold, tensor, torch.zeros_like(tensor) + ) + else: + sparse_tensor = tensor + + return sparse_tensor + + +def is_gpu_available(): + """ + :return: True if a GPU is available, False otherwise + """ + try: + import torch # noqa: F401 + + return torch.cuda.device_count() > 0 + except ImportError: + return False + + +def requires_gpu(test_case): + return unittest.skipUnless(is_gpu_available(), "test requires GPU")(test_case) diff --git a/utils/artifacts.py b/utils/artifacts.py deleted file mode 100644 index d16e0f44..00000000 --- a/utils/artifacts.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -from typing import Tuple - - -def get_release_and_version(package_path: str) -> Tuple[bool, bool, str, str, str, str]: - """ - Load version and release info from deepsparse package - """ - # deepsparse/src/deepsparse/version.py always exists, default source of truth - version_path = os.path.join(package_path, "version.py") - - # exec() cannot set local variables so need to manually - locals_dict = {} - exec(open(version_path).read(), globals(), locals_dict) - is_release = locals_dict.get("is_release", False) - version = locals_dict.get("version", "unknown") - version_major = locals_dict.get("version_major", "unknown") - version_minor = locals_dict.get("version_minor", "unknown") - version_bug = locals_dict.get("version_bug", "unknown") - - print(f"Loaded version {version} from {version_path}") - - return ( - is_release, - version, - version_major, - version_minor, - version_bug, - )