From dd13810c760c2ffbe9f5aa4e722ad45ad32967e4 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 07:58:21 +0000 Subject: [PATCH 01/10] Add expand_targets and is_target functions for target matching in apply.py Signed-off-by: Rahul Tuli --- .../quantization/lifecycle/apply.py | 47 ++++++++- .../quantization/quant_scheme.py | 1 + .../test_quantization/lifecycle/test_apply.py | 97 +++++++++++++++++++ 3 files changed, 144 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ed9a50f7..800be585 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -18,7 +18,7 @@ from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Set, Union import torch from compressed_tensors.config import CompressionFormat @@ -52,6 +52,8 @@ "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", + "expand_targets", + "is_target", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) +def expand_targets( + model: Module, targets: Iterable[str], ignore: Iterable[str] +) -> Set[str]: + """ + Finds all the targets in the model that match the given + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param model: model to search for targets in + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: set of all targets that match the given targets and should + not be ignored + """ + return { + name + for name, module in iter_named_leaf_modules(model) + if is_target(name, module, targets, ignore) + } + + +def is_target( + name: str, module: Module, targets: Iterable[str], ignore: Iterable[str] +) -> bool: + """ + Determines if a module should be included in the targets based on the + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param name: name of the module + :param module: the module itself + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: True if the module is a target and not ignored, False otherwise + """ + return bool( + find_name_or_class_matches(name, module, targets) + and not find_name_or_class_matches(name, module, ignore) + ) + + def find_name_or_class_matches( name: str, module: Module, targets: Iterable[str], check_contains: bool = False ) -> List[str]: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 3a8152da..36b88604 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -62,6 +62,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: return model + """ Pre-Set Quantization Scheme Args """ diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 7268ca27..7474795b 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -14,6 +14,7 @@ import re from typing import Optional +from unittest.mock import MagicMock import pytest import torch @@ -26,12 +27,38 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, + expand_targets, + is_target, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM +@pytest.fixture +def mock_model(): + model = MagicMock() + model.named_modules.return_value = [ + ("layer1", MagicMock()), + ("layer2", MagicMock()), + ("layer3", MagicMock()), + ] + return model + + +@pytest.fixture +def mock_module(): + return MagicMock() + + +@pytest.fixture +def llama_stories_model(): + return AutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", + torch_dtype="auto", + ) + + def test_target_prioritization(mock_frozen): # tests that the config_groups are applied in the correct order # of priority, where exact layer name > regex > module name @@ -266,3 +293,73 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 + + +@pytest.mark.parametrize( + "targets, ignore, expected", + [ + ([], [], set()), + (["layer1", "layer2"], [], {"layer1", "layer2"}), + ([], ["layer1"], set()), + (["layer1", "layer2"], ["layer2"], {"layer1"}), + (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), + ], +) +def test_expand_targets_with_mock(mock_model, targets, ignore, expected): + result = expand_targets(mock_model, targets, ignore) + assert result == expected + + +@pytest.mark.parametrize( + "targets, ignore, expected", + [ + ( + ["re:model.layers.[01].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[01].self_attn.q_proj"], + [], + set(["model.layers.0.self_attn.q_proj", "model.layers.1.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[0-2].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj", "model.layers.2.self_attn.q_proj"]), + ), + ( + ["model.layers.0.self_attn.q_proj"], + ["model.layers.0.self_attn.q_proj"], + set(), + ), + ( + ["re:model.layers.*.self_attn.q_proj"], + ["re:model.layers.[01].self_attn.q_proj"], + set( + f"model.layers.{layer_idx}.self_attn.q_proj" + for layer_idx in range(2, 6) + ), + ), + ], +) +def test_expand_targets_with_llama_stories( + llama_stories_model, targets, ignore, expected +): + actual_targets = expand_targets(llama_stories_model, targets, ignore) + assert actual_targets == expected + + +@pytest.mark.parametrize( + "name, targets, ignore, expected", + [ + ("layer1", ["layer1"], [], True), + ("layer1", ["layer1"], ["layer1"], False), + ("layer1", ["layer2"], [], False), + ("layer1", ["re:layer.*"], [], True), + ("layer1", ["re:layer.*"], ["re:layer1"], False), + ], +) +def test_is_target_with_mock(mock_module, name, targets, ignore, expected): + result = is_target(name, mock_module, targets, ignore) + assert result == expected From 4c21a95d6293bd3083982a14a8f585d17a0d4038 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 09:06:45 +0000 Subject: [PATCH 02/10] Update: get_nested_weight_mappings to optionally return other params Add: tests for get_nested_weight_mappings Signed-off-by: Rahul Tuli --- .../utils/safetensors_load.py | 58 ++++++++++---- tests/test_utils/test_safetensors_load.py | 76 +++++++++++++++++++ 2 files changed, 121 insertions(+), 13 deletions(-) create mode 100644 tests/test_utils/test_safetensors_load.py diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 4fdb3007..e4f8d7a7 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -16,7 +16,7 @@ import os import re import struct -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open from torch import Tensor @@ -32,8 +32,12 @@ "get_nested_weight_mappings", "get_quantization_state_dict", "is_quantization_param", + "get_nested_mappings_from_state_dict", ] +WEIGHT_MAPPING_TYPE = Dict[str, str] +NESTED_WEIGHT_MAPPING_TYPE = Dict[str, WEIGHT_MAPPING_TYPE] + def get_safetensors_folder( pretrained_model_name_or_path: str, cache_dir: Optional[str] = None @@ -92,7 +96,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]: return header -def match_param_name(full_name: str, param_name: str) -> str: +def match_param_name(full_name: str, param_name: str) -> Optional[str]: """ Helper function extracting the uncompressed parameterized layer name from a compressed name. Assumes the compressed name was merged using merge_names. @@ -176,8 +180,10 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: def get_nested_weight_mappings( - model_path: str, params_to_nest: List[str] -) -> Dict[str, Dict[str, str]]: + model_path: str, params_to_nest: List[str], return_other_params: bool = False +) -> Union[ + NESTED_WEIGHT_MAPPING_TYPE, Tuple[NESTED_WEIGHT_MAPPING_TYPE, WEIGHT_MAPPING_TYPE] +]: """ Takes a path to a state dict saved in safetensors format and returns a nested mapping from uncompressed parameterized layer names to the file locations of each @@ -190,24 +196,36 @@ def get_nested_weight_mappings( compressed: file_location } - This generalizes to cases where the model is split into multiple safetensors files + This generalizes to cases where the model is split into multiple safetensors files. :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index - :return: nested mapping of parameterized layer name to file location + safetensors file or multiple files with an index. + :param params_to_nest: list of parameter names to nest. + :param return_other_params: if True, return a second dictionary containing the + remaining parameters that were not matched to the nested parameters. + :return: nested mapping of parameterized layer name to file location if + return_other_params is False, else a tuple containing the nested mapping + and a mapping of the remaining parameters that were not matched to + the nested parameters. """ weight_mappings = get_weight_mappings(model_path) - nested_weight_mappings = {} - for key in weight_mappings.keys(): + other_params = {} + + for key, file_location in weight_mappings.items(): + matched = False for param_name in params_to_nest: - maybe_match = match_param_name(key, param_name) - if maybe_match is not None: - dense_param = maybe_match + dense_param = match_param_name(key, param_name) + if dense_param: if dense_param not in nested_weight_mappings: nested_weight_mappings[dense_param] = {} - nested_weight_mappings[dense_param][param_name] = weight_mappings[key] + nested_weight_mappings[dense_param][param_name] = file_location + matched = True + if not matched: + other_params[key] = file_location + if return_other_params: + return nested_weight_mappings, other_params return nested_weight_mappings @@ -238,3 +256,17 @@ def is_quantization_param(name: str) -> bool: return True return False + + +def get_nested_mappings_from_state_dict(state_dict, params_to_nest): + nested_weight_mappings = {} + for key in state_dict.keys(): + for param_name in params_to_nest: + maybe_match = match_param_name(key, param_name) + if maybe_match is not None: + dense_param = maybe_match + if dense_param not in nested_weight_mappings: + nested_weight_mappings[dense_param] = {} + nested_weight_mappings[dense_param][param_name] = state_dict[key] + + return nested_weight_mappings diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py new file mode 100644 index 00000000..3af2342d --- /dev/null +++ b/tests/test_utils/test_safetensors_load.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +from compressed_tensors.utils.safetensors_load import get_nested_weight_mappings + + +mock_weight_mappings = { + "layer1.weight": "file1", + "layer1.bias": "file2", + "layer2.weight": "file3", + "layer2.bias": "file4", + "layer3.weight": "file5", +} + + +@pytest.fixture +def mock_get_weight_mappings(): + with patch( + "compressed_tensors.utils.safetensors_load.get_weight_mappings", + return_value=mock_weight_mappings, + ): + yield + + +@pytest.mark.usefixtures("mock_get_weight_mappings") +class TestGetNestedWeightMappings: + def test_single_param(self): + params_to_nest = ["weight"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_multiple_params(self): + params_to_nest = ["weight", "bias"] + result = get_nested_weight_mappings("dummy_path", params_to_nest) + expected = { + "layer1": {"weight": "file1", "bias": "file2"}, + "layer2": {"weight": "file3", "bias": "file4"}, + "layer3": {"weight": "file5"}, + } + assert result == expected + + def test_return_other_params(self): + params_to_nest = ["weight"] + result, other_params = get_nested_weight_mappings( + "dummy_path", params_to_nest, return_other_params=True + ) + expected_nested = { + "layer1": {"weight": "file1"}, + "layer2": {"weight": "file3"}, + "layer3": {"weight": "file5"}, + } + expected_other = { + "layer1.bias": "file2", + "layer2.bias": "file4", + } + assert result == expected_nested + assert other_params == expected_other From a528334096a22c6f20ad21b6877eb2c732097f89 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 09:48:35 +0000 Subject: [PATCH 03/10] Enable: Sparse Compression with targets and ignores Signed-off-by: Rahul Tuli --- .../model_compressors/model_compressor.py | 11 ++++- .../compressors/sparse_compressors/base.py | 44 ++++++++++++++++--- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 68bd52ec..bc4633d9 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -18,7 +18,7 @@ import os import re from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union import compressed_tensors import torch @@ -38,6 +38,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.lifecycle import expand_targets from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, @@ -282,8 +283,14 @@ def compress( ) if self.sparsity_compressor is not None: + sparse_compression_targets: Set[str] = expand_targets( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) compressed_state_dict = self.sparsity_compressor.compress( - compressed_state_dict + compressed_state_dict, + compression_targets=sparse_compression_targets, ) # HACK: Override the dtype_byte_size function in transformers to diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 1b1a6825..67e2727a 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Set, Tuple from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -30,7 +30,8 @@ class BaseSparseCompressor(BaseCompressor): """ Base class representing a sparse compression algorithm. Each child class should - implement compression_param_info, compress_weight and decompress_weight. + implement compression_param_info, compress_weight and decompress_weight; child + classes should also define COMPRESSION_PARAM_NAMES. Compressors support compressing/decompressing a full module state dict or a single quantized PyTorch leaf module. @@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor): :param config: config specifying compression parameters """ - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress( + self, + model_state: Dict[str, Tensor], + compression_targets: Optional[Set[str]] = None, + ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression :param model_state: state dict of uncompressed model + :param compression_targets: optional set of layer prefixes to compress, if None + compress all layers (for backwards compatibility) :return: compressed state dict """ compressed_dict = {} @@ -71,6 +78,9 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): + if not self.should_compress(name, compression_targets): + compressed_dict[name] = value + continue compression_data = self.compress_weight(name, value) for key in compression_data.keys(): if key in compressed_dict: @@ -97,8 +107,10 @@ def decompress( :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ - weight_mappings = get_nested_weight_mappings( - path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + weight_mappings, other_params = get_nested_weight_mappings( + path_to_model_or_tensors, + self.COMPRESSION_PARAM_NAMES, + return_other_params=True, ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -107,4 +119,24 @@ def decompress( with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) decompressed = self.decompress_weight(weight_data) - yield weight_name, decompressed + full_name = merge_names(weight_name, "weight") + yield full_name, decompressed + + for other_name, safe_path in other_params.items(): + with safe_open(safe_path, framework="pt", device=device) as f: + value = f.get_tensor(other_name) + yield other_name, value + + @staticmethod + def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool: + """ + Check if a parameter should be compressed + + :param name: name of the parameter + :param targets: set of layer prefixes to compress + :return: whether or not the parameter should be compressed + """ + if targets is None: + return name.endswith(".weight") + + return name.endswith(".weight") and name[: -(len(".weight"))] in targets From 305904cd8781b66efd8fa3bbddd9ce08eb9d6d55 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 11:26:07 +0000 Subject: [PATCH 04/10] Bugfix a typo --- src/compressed_tensors/compressors/sparse_compressors/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 67e2727a..d15057ce 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -119,8 +119,7 @@ def decompress( with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) decompressed = self.decompress_weight(weight_data) - full_name = merge_names(weight_name, "weight") - yield full_name, decompressed + yield weight_name, decompressed for other_name, safe_path in other_params.items(): with safe_open(safe_path, framework="pt", device=device) as f: From 3a6ccc8839b5cf8336ab334299ca19b1431e3081 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 13:57:29 +0000 Subject: [PATCH 05/10] Add: Sparse24_compressor + tests --- .../sparse_compressors/__init__.py | 1 + .../sparse_compressors/sparse_24.py | 92 +++++++++++++++++++ src/compressed_tensors/config/__init__.py | 1 + src/compressed_tensors/config/base.py | 1 + src/compressed_tensors/config/sparse_24.py | 37 ++++++++ .../utils/semi_structured_conversions.py | 19 +++- .../test_semi_structured_conversions.py | 66 +++++++++++++ 7 files changed, 213 insertions(+), 4 deletions(-) create mode 100644 src/compressed_tensors/compressors/sparse_compressors/sparse_24.py create mode 100644 src/compressed_tensors/config/sparse_24.py create mode 100644 tests/test_utils/test_semi_structured_conversions.py diff --git a/src/compressed_tensors/compressors/sparse_compressors/__init__.py b/src/compressed_tensors/compressors/sparse_compressors/__init__.py index de4fd887..f1b59ad3 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/__init__.py +++ b/src/compressed_tensors/compressors/sparse_compressors/__init__.py @@ -15,4 +15,5 @@ from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py new file mode 100644 index 00000000..70974e68 --- /dev/null +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict + +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor +from compressed_tensors.config import CompressionFormat, SparsityStructure +from compressed_tensors.utils import ( + merge_names, + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, + tensor_follows_mask_structure, +) +from torch import Tensor + + +@BaseCompressor.register(name=CompressionFormat.sparse_24.value) +class Sparse24Compressor(BaseSparseCompressor): + """ + Compresses a with 2:4 sparsity structure for inference + with sparse 2:4 kernels for float/float16/bfloat16. + https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py + """ + + COMPRESSION_PARAM_NAMES = ["sparse_24_packed_weight", "meta"] + + @staticmethod + def validate_sparsity_structure(name: str, weight: Tensor) -> bool: + """ + Checks if a tensor fits the required 2:4 sparsity structure + :param name: name of the tensor to check + :param weight: tensor to check for sparsity structure + :return: True if all rows match the 2:4 sparsity structure, raises + ValueError otherwise + """ + + if not tensor_follows_mask_structure( + weight, mask=SparsityStructure.TWO_FOUR.value + ): + raise ValueError( + "Sparse24Compressor is only compatible with weights that have " + f"a 2:4 sparsity structure. Found segments in {name} " + "that do not match the expected structure." + ) + + return True + + def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: + """ + Compresses a given with 2:4 sparsity structure. + :param name: name of the tensor in state dict of uncompressed model + :param value: 2:4 sparse tensor to compress + :return: dictionary containing the compressed weight and associated + metadata + """ + weight_suffix = ".weight" + if not name.endswith(weight_suffix): + return {} + + prefix = name[: -len(weight_suffix)] + self.validate_sparsity_structure(name=prefix, weight=value) + sparse_24_packed_weight, meta = sparse_semi_structured_from_dense_cutlass( + dense=value + ) + return { + merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(), + merge_names(name, "meta"): meta.cpu(), + } + + def decompress_weight(self, weight_data): + assert ( + "sparse_24_packed_weight" in weight_data + ), "sparse_24_packed_weight not found in weight_data" + assert "meta" in weight_data, "meta not found in weight_data" + + return sparse_semi_structured_to_dense_cutlass( + sparse=weight_data["sparse_24_packed_weight"], + meta_reordered=weight_data["meta"], + ) diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index ff83f5af..f021f284 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,4 +15,5 @@ # flake8: noqa from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 79a4fcdd..2d280330 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -26,6 +26,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + sparse_24 = "sparse-24" int_quantized = "int-quantized" float_quantized = "float-quantized" naive_quantized = "naive-quantized" diff --git a/src/compressed_tensors/config/sparse_24.py b/src/compressed_tensors/config/sparse_24.py new file mode 100644 index 00000000..2a5ed384 --- /dev/null +++ b/src/compressed_tensors/config/sparse_24.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) + + +__all__ = ["Sparse24Config"] + + +@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value) +class Sparse24Config(SparsityCompressionConfig): + """ + Configuration for storing a sparse model using 2:4 compression + :param global_sparsity: average sparsity of the entire model + :param sparsity_structure: structure of the sparsity, "2:4" + """ + + format: str = CompressionFormat.sparse_24.value + global_sparsity: Optional[float] = 0.0 + sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value diff --git a/src/compressed_tensors/utils/semi_structured_conversions.py b/src/compressed_tensors/utils/semi_structured_conversions.py index ef318a48..480d1b48 100644 --- a/src/compressed_tensors/utils/semi_structured_conversions.py +++ b/src/compressed_tensors/utils/semi_structured_conversions.py @@ -75,6 +75,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device # This function converts dense matrix into sparse semi-structured # representation, producing "compressed" matrix, in the layout used by # CUTLASS backend, and corresponding metadata matrix. +# Modified from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L47 def sparse_semi_structured_from_dense_cutlass(dense): if dense.dim() != 2: raise RuntimeError( @@ -85,7 +86,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): device = dense.device meta_dtype = torch.int8 - if dense.dtype == torch.int8: + if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn: meta_dtype = torch.int32 elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: meta_dtype = torch.int16 @@ -165,11 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: + if dense.dtype == torch.float8_e4m3fn: + dense_4 = dense_4.view(torch.int8) sparse0 = dense_4.gather( -1, idxs0.unsqueeze(-1) ) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + if dense.dtype == torch.float8_e4m3fn: + sparse = sparse.view(torch.float8_e4m3fn) else: sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( m, k // 2 @@ -213,6 +218,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): # reconstructs dense matrix from a pair of "compressed" matrix, given # in the layout used by CUTLASS backend, and accompanying metadata # matrix. +# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L180 def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if sparse.dim() != 2: raise RuntimeError( @@ -298,16 +304,21 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): torch.arange(0, 2 * m * k // ksparse, device=device) * 4 ).view(-1, 1).repeat(1, 2).view(-1) - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + sparse_dtype = sparse.dtype if sparse.dtype != torch.float8_e4m3fn else torch.int8 + dense = torch.zeros((m * 2 * k,), dtype=sparse_dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + if sparse.dtype == torch.float8_e4m3fn: + dense.scatter_(0, dense_offsets, sparse.view(torch.int8).view(-1)) + else: + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: dense.view(torch.half).scatter_( 0, dense_offsets, sparse.view(torch.half).view(-1) ) - return dense.view(m, 2 * k) + result = dense.view(m, 2 * k) + return result.view(sparse.dtype) def mask_creator(tensor): diff --git a/tests/test_utils/test_semi_structured_conversions.py b/tests/test_utils/test_semi_structured_conversions.py new file mode 100644 index 00000000..e74722fb --- /dev/null +++ b/tests/test_utils/test_semi_structured_conversions.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.utils.semi_structured_conversions import ( + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, +) + + +def supported_dtypes(): + return [torch.int8, torch.float16, torch.bfloat16, torch.float8_e4m3fn] + + +def get_random_mat(M, K, dtype): + rand_tensor_dtype = dtype + if dtype in [torch.int8, torch.float8_e4m3fn]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() + mat = mat.masked_fill_(mat == 0, 1) + return mat.to(dtype) + + +def generate_pruned_semi_structured_mat(M, K, dtype): + mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() + rand_tensor_dtype = dtype + if dtype in [torch.int8, torch.float8_e4m3fn]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype) + mat = mat.masked_fill_(mat == 0, 1) + if dtype == torch.float8_e4m3fn: + # some float8_e4m3fn operations are not supported on CPU + mat = mat.cuda() + mask = mask.cuda() + mat = mat * mask + return mat.to(dtype) + + +@pytest.mark.parametrize("dtype", supported_dtypes()) +def test_inverse_property_from_dense_then_to_dense(dtype): + M, K = 1024, 1024 + dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype) + compressed_matrix, meta = sparse_semi_structured_from_dense_cutlass(dense_matrix) + result = sparse_semi_structured_to_dense_cutlass(compressed_matrix, meta) + + assert ( + dense_matrix.dtype == result.dtype + ), f"Dtype Mis-match: {dense_matrix.dtype} and {result.dtype}" + assert ( + dense_matrix.shape == result.shape + ), f"Shape Mis-match: {dense_matrix.shape} and {result.shape}" + assert torch.equal( + dense_matrix, result + ), f"Failed for dtype: {dense_matrix.dtype} and input: {dense_matrix}" From 8fd469f0f317877e636600eb5b01eee1d7bfef43 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 14:13:47 +0000 Subject: [PATCH 06/10] Run float8 test only if cuda is available and device capability is greater than 90 --- tests/test_utils/test_semi_structured_conversions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_utils/test_semi_structured_conversions.py b/tests/test_utils/test_semi_structured_conversions.py index e74722fb..eb25b34a 100644 --- a/tests/test_utils/test_semi_structured_conversions.py +++ b/tests/test_utils/test_semi_structured_conversions.py @@ -21,7 +21,12 @@ def supported_dtypes(): - return [torch.int8, torch.float16, torch.bfloat16, torch.float8_e4m3fn] + dtypes = [torch.int8, torch.float16, torch.bfloat16] + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + if major > 9 or (major == 9 and minor >= 0): + dtypes += [torch.float8_e4m3fn] + return dtypes def get_random_mat(M, K, dtype): From c54699aad3eb9c5cf4506eac0a46b6566e91ce58 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 18:09:02 +0000 Subject: [PATCH 07/10] Address: Review comments from @kylesayrs --- .../compressors/sparse_compressors/base.py | 10 ++++---- .../utils/safetensors_load.py | 23 +++---------------- .../test_quantization/lifecycle/test_apply.py | 16 ++++++------- 3 files changed, 17 insertions(+), 32 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index d15057ce..12b43287 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -127,15 +127,17 @@ def decompress( yield other_name, value @staticmethod - def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool: + def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: """ Check if a parameter should be compressed :param name: name of the parameter - :param targets: set of layer prefixes to compress + :param expanded_targets: set of layer prefixes to compress :return: whether or not the parameter should be compressed """ - if targets is None: + if expanded_targets is None: return name.endswith(".weight") - return name.endswith(".weight") and name[: -(len(".weight"))] in targets + return ( + name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets + ) diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index e4f8d7a7..11f0f326 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -32,11 +32,10 @@ "get_nested_weight_mappings", "get_quantization_state_dict", "is_quantization_param", - "get_nested_mappings_from_state_dict", ] -WEIGHT_MAPPING_TYPE = Dict[str, str] -NESTED_WEIGHT_MAPPING_TYPE = Dict[str, WEIGHT_MAPPING_TYPE] +WeightMappingType = Dict[str, str] +NestedWeightMappingType = Dict[str, WeightMappingType] def get_safetensors_folder( @@ -181,9 +180,7 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: def get_nested_weight_mappings( model_path: str, params_to_nest: List[str], return_other_params: bool = False -) -> Union[ - NESTED_WEIGHT_MAPPING_TYPE, Tuple[NESTED_WEIGHT_MAPPING_TYPE, WEIGHT_MAPPING_TYPE] -]: +) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]: """ Takes a path to a state dict saved in safetensors format and returns a nested mapping from uncompressed parameterized layer names to the file locations of each @@ -256,17 +253,3 @@ def is_quantization_param(name: str) -> bool: return True return False - - -def get_nested_mappings_from_state_dict(state_dict, params_to_nest): - nested_weight_mappings = {} - for key in state_dict.keys(): - for param_name in params_to_nest: - maybe_match = match_param_name(key, param_name) - if maybe_match is not None: - dense_param = maybe_match - if dense_param not in nested_weight_mappings: - nested_weight_mappings[dense_param] = {} - nested_weight_mappings[dense_param][param_name] = state_dict[key] - - return nested_weight_mappings diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 7474795b..d1799a39 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -296,7 +296,7 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): @pytest.mark.parametrize( - "targets, ignore, expected", + "targets, ignore, expected_targets", [ ([], [], set()), (["layer1", "layer2"], [], {"layer1", "layer2"}), @@ -305,13 +305,13 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), ], ) -def test_expand_targets_with_mock(mock_model, targets, ignore, expected): - result = expand_targets(mock_model, targets, ignore) - assert result == expected +def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): + expanded_targets = expand_targets(mock_model, targets, ignore) + assert expanded_targets == expected_targets @pytest.mark.parametrize( - "targets, ignore, expected", + "targets, ignore, expected_targets", [ ( ["re:model.layers.[01].self_attn.q_proj"], @@ -344,10 +344,10 @@ def test_expand_targets_with_mock(mock_model, targets, ignore, expected): ], ) def test_expand_targets_with_llama_stories( - llama_stories_model, targets, ignore, expected + llama_stories_model, targets, ignore, expected_targets ): - actual_targets = expand_targets(llama_stories_model, targets, ignore) - assert actual_targets == expected + expanded_targets = expand_targets(llama_stories_model, targets, ignore) + assert expanded_targets == expected_targets @pytest.mark.parametrize( From 69361210519daad9864ebe662a695065c77e0818 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 3 Dec 2024 07:43:46 +0000 Subject: [PATCH 08/10] review suggestions from @dsikka --- .../model_compressors/model_compressor.py | 10 ++-- .../compressors/sparse_compressors/base.py | 10 ++-- .../quantization/lifecycle/apply.py | 6 +- .../utils/safetensors_load.py | 57 ++++++++++++------- .../test_quantization/lifecycle/test_apply.py | 6 +- tests/test_utils/test_safetensors_load.py | 2 +- 6 files changed, 55 insertions(+), 36 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index bc4633d9..f9197bf3 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -38,7 +38,7 @@ apply_quantization_config, load_pretrained_quantization, ) -from compressed_tensors.quantization.lifecycle import expand_targets +from compressed_tensors.quantization.lifecycle import expand_sparse_target_names from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, @@ -269,9 +269,9 @@ def compress( compressed_state_dict = state_dict - quantized_modules_to_args: Dict[ - str, QuantizationArgs - ] = map_modules_to_quant_args(model) + quantized_modules_to_args: Dict[str, QuantizationArgs] = ( + map_modules_to_quant_args(model) + ) if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( @@ -283,7 +283,7 @@ def compress( ) if self.sparsity_compressor is not None: - sparse_compression_targets: Set[str] = expand_targets( + sparse_compression_targets: Set[str] = expand_sparse_target_names( model=model, targets=self.sparsity_config.targets, ignore=self.sparsity_config.ignore, diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 12b43287..c6b715de 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -107,10 +107,10 @@ def decompress( :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ - weight_mappings, other_params = get_nested_weight_mappings( + weight_mappings, uncompressed_params = get_nested_weight_mappings( path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES, - return_other_params=True, + return_unmatched_params=True, ) for weight_name in weight_mappings.keys(): weight_data = {} @@ -121,10 +121,10 @@ def decompress( decompressed = self.decompress_weight(weight_data) yield weight_name, decompressed - for other_name, safe_path in other_params.items(): + for uncompressed_param_name, safe_path in uncompressed_params.items(): with safe_open(safe_path, framework="pt", device=device) as f: - value = f.get_tensor(other_name) - yield other_name, value + value = f.get_tensor(uncompressed_param_name) + yield uncompressed_param_name, value @staticmethod def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 800be585..c297f8c7 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -52,7 +52,7 @@ "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", - "expand_targets", + "expand_sparse_target_names", "is_target", ] @@ -247,11 +247,11 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) -def expand_targets( +def expand_sparse_target_names( model: Module, targets: Iterable[str], ignore: Iterable[str] ) -> Set[str]: """ - Finds all the targets in the model that match the given + Finds all unique module names in the model that match the given targets and ignore lists. Note: Targets must be regexes, layer types, or full layer names. diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 11f0f326..bc58f351 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -179,13 +179,14 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: def get_nested_weight_mappings( - model_path: str, params_to_nest: List[str], return_other_params: bool = False + model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False ) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]: """ Takes a path to a state dict saved in safetensors format and returns a nested - mapping from uncompressed parameterized layer names to the file locations of each - of the layers compression parameters. + mapping from uncompressed parameterized layer names to the file locations of + each layer's compression parameters. + Example of the nested mapping: layer.weight: { bitmask: file_location, row_offsets: file_location, @@ -193,21 +194,39 @@ def get_nested_weight_mappings( compressed: file_location } - This generalizes to cases where the model is split into multiple safetensors files. - - :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index. - :param params_to_nest: list of parameter names to nest. - :param return_other_params: if True, return a second dictionary containing the - remaining parameters that were not matched to the nested parameters. - :return: nested mapping of parameterized layer name to file location if - return_other_params is False, else a tuple containing the nested mapping - and a mapping of the remaining parameters that were not matched to - the nested parameters. + If other parameters are found that do not match the nested parameters, they will + be returned in a separate dictionary only if return_unmatched_params is True. + This dictionary may be needed for cases where compressors are stacked (e.g., + quantization compression followed by sparse compression). + + Example of the unmatched params mapping: + { + layer.weight_scale: file_location, + layer.input_scale: file_location + } + + This generalizes to cases where the model is split into multiple safetensors + files. + + :param model_path: Path to the safetensors state dict, must contain either a + single safetensors file or multiple files with an index. + :param params_to_nest: List of parameter names to nest. + :param return_unmatched_params: If True, return a second dictionary containing + the remaining parameters that were not matched to the params_to_nest. + :return: + - If return_unmatched_params is False: + NestedWeightMappingType: A nested mapping of parameterized layer names to + file locations of each layer's compression parameters. + - If return_unmatched_params is True: + Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing: + - NestedWeightMappingType: A nested mapping of parameterized layer + names to file locations of each layer's compression parameters. + - WeightMappingType: A mapping of the remaining parameter names to + their file locations that were not matched to the params_to_nest. """ weight_mappings = get_weight_mappings(model_path) nested_weight_mappings = {} - other_params = {} + unmatched_params = {} for key, file_location in weight_mappings.items(): matched = False @@ -218,11 +237,11 @@ def get_nested_weight_mappings( nested_weight_mappings[dense_param] = {} nested_weight_mappings[dense_param][param_name] = file_location matched = True - if not matched: - other_params[key] = file_location + if return_unmatched_params and not matched: + unmatched_params[key] = file_location - if return_other_params: - return nested_weight_mappings, other_params + if return_unmatched_params: + return nested_weight_mappings, unmatched_params return nested_weight_mappings diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index d1799a39..e2dc2ad1 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -27,7 +27,7 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, - expand_targets, + expand_sparse_target_names, is_target, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules @@ -306,7 +306,7 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): ], ) def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): - expanded_targets = expand_targets(mock_model, targets, ignore) + expanded_targets = expand_sparse_target_names(mock_model, targets, ignore) assert expanded_targets == expected_targets @@ -346,7 +346,7 @@ def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets) def test_expand_targets_with_llama_stories( llama_stories_model, targets, ignore, expected_targets ): - expanded_targets = expand_targets(llama_stories_model, targets, ignore) + expanded_targets = expand_sparse_target_names(llama_stories_model, targets, ignore) assert expanded_targets == expected_targets diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py index 3af2342d..95cd6347 100644 --- a/tests/test_utils/test_safetensors_load.py +++ b/tests/test_utils/test_safetensors_load.py @@ -61,7 +61,7 @@ def test_multiple_params(self): def test_return_other_params(self): params_to_nest = ["weight"] result, other_params = get_nested_weight_mappings( - "dummy_path", params_to_nest, return_other_params=True + "dummy_path", params_to_nest, return_unmatched_params=True ) expected_nested = { "layer1": {"weight": "file1"}, From b07961c1e97a84a1dc52e15ad6eea1984a9ade62 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 3 Dec 2024 08:12:57 +0000 Subject: [PATCH 09/10] Review comments from @dsikka and @kylesayrs --- .../compressors/sparse_compressors/sparse_24.py | 16 +++++++++++++++- .../utils/semi_structured_conversions.py | 16 ++++++++-------- .../test_semi_structured_conversions.py | 9 +++++---- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py index 70974e68..e2219a8d 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -30,7 +30,7 @@ @BaseCompressor.register(name=CompressionFormat.sparse_24.value) class Sparse24Compressor(BaseSparseCompressor): """ - Compresses a with 2:4 sparsity structure for inference + Compresses a model with 2:4 sparsity structure for inference with sparse 2:4 kernels for float/float16/bfloat16. https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py """ @@ -81,6 +81,20 @@ def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: } def decompress_weight(self, weight_data): + """ + Decompresses the given weight data from its compressed representation to its + dense form. + + The weight_data dictionary must contain the keys 'sparse_24_packed_weight' and + 'meta', which represent the sparse-compressed weight and its associated meta + tensor. + + :param weight_data: A dictionary containing: + - sparse_24_packed_weight: The sparse-compressed representation of + the weight. + - meta: The meta tesnor associated with the compressed weight. + :return: The dense representation of the weight. + """ assert ( "sparse_24_packed_weight" in weight_data ), "sparse_24_packed_weight not found in weight_data" diff --git a/src/compressed_tensors/utils/semi_structured_conversions.py b/src/compressed_tensors/utils/semi_structured_conversions.py index 480d1b48..17ea6ef2 100644 --- a/src/compressed_tensors/utils/semi_structured_conversions.py +++ b/src/compressed_tensors/utils/semi_structured_conversions.py @@ -20,7 +20,7 @@ # limitations under the License. import torch - +from compressed_tensors.quantization import FP8_DTYPE __all__ = [ "sparse_semi_structured_from_dense_cutlass", @@ -85,8 +85,8 @@ def sparse_semi_structured_from_dense_cutlass(dense): m, k = dense.shape device = dense.device - meta_dtype = torch.int8 - if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn: + meta_dtype = None + if dense.dtype == torch.int8 or dense.dtype == FP8_DTYPE: meta_dtype = torch.int32 elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: meta_dtype = torch.int16 @@ -166,15 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: - if dense.dtype == torch.float8_e4m3fn: + if dense.dtype == FP8_DTYPE: dense_4 = dense_4.view(torch.int8) sparse0 = dense_4.gather( -1, idxs0.unsqueeze(-1) ) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - if dense.dtype == torch.float8_e4m3fn: - sparse = sparse.view(torch.float8_e4m3fn) + if dense.dtype == FP8_DTYPE: + sparse = sparse.view(FP8_DTYPE) else: sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( m, k // 2 @@ -304,11 +304,11 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): torch.arange(0, 2 * m * k // ksparse, device=device) * 4 ).view(-1, 1).repeat(1, 2).view(-1) - sparse_dtype = sparse.dtype if sparse.dtype != torch.float8_e4m3fn else torch.int8 + sparse_dtype = sparse.dtype if sparse.dtype != FP8_DTYPE else torch.int8 dense = torch.zeros((m * 2 * k,), dtype=sparse_dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) - if sparse.dtype == torch.float8_e4m3fn: + if sparse.dtype == FP8_DTYPE: dense.scatter_(0, dense_offsets, sparse.view(torch.int8).view(-1)) else: dense.scatter_(0, dense_offsets, sparse.reshape(-1)) diff --git a/tests/test_utils/test_semi_structured_conversions.py b/tests/test_utils/test_semi_structured_conversions.py index eb25b34a..c2c198c6 100644 --- a/tests/test_utils/test_semi_structured_conversions.py +++ b/tests/test_utils/test_semi_structured_conversions.py @@ -14,6 +14,7 @@ import pytest import torch +from compressed_tensors.quantization import FP8_DTYPE from compressed_tensors.utils.semi_structured_conversions import ( sparse_semi_structured_from_dense_cutlass, sparse_semi_structured_to_dense_cutlass, @@ -25,13 +26,13 @@ def supported_dtypes(): if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() if major > 9 or (major == 9 and minor >= 0): - dtypes += [torch.float8_e4m3fn] + dtypes += [FP8_DTYPE] return dtypes def get_random_mat(M, K, dtype): rand_tensor_dtype = dtype - if dtype in [torch.int8, torch.float8_e4m3fn]: + if dtype in [torch.int8, FP8_DTYPE]: rand_tensor_dtype = torch.float16 mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() mat = mat.masked_fill_(mat == 0, 1) @@ -41,11 +42,11 @@ def get_random_mat(M, K, dtype): def generate_pruned_semi_structured_mat(M, K, dtype): mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() rand_tensor_dtype = dtype - if dtype in [torch.int8, torch.float8_e4m3fn]: + if dtype in [torch.int8, FP8_DTYPE]: rand_tensor_dtype = torch.float16 mat = torch.rand(M, K, dtype=rand_tensor_dtype) mat = mat.masked_fill_(mat == 0, 1) - if dtype == torch.float8_e4m3fn: + if dtype == FP8_DTYPE: # some float8_e4m3fn operations are not supported on CPU mat = mat.cuda() mask = mask.cuda() From c6ef4f96d0db7ba4cb5c2a7a17a008d4dc25bbe9 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 3 Dec 2024 16:54:01 +0000 Subject: [PATCH 10/10] remove extra .weight --- .../compressors/sparse_compressors/sparse_24.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py index e2219a8d..040daa50 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -76,8 +76,10 @@ def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: dense=value ) return { - merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(), - merge_names(name, "meta"): meta.cpu(), + merge_names( + prefix, "sparse_24_packed_weight" + ): sparse_24_packed_weight.cpu(), + merge_names(prefix, "meta"): meta.cpu(), } def decompress_weight(self, weight_data):