Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 24 compressor #167

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -38,6 +38,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_targets
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
Expand Down Expand Up @@ -282,8 +283,14 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_targets(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict,
compression_targets=sparse_compression_targets,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
41 changes: 36 additions & 5 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand All @@ -30,7 +30,8 @@
class BaseSparseCompressor(BaseCompressor):
"""
Base class representing a sparse compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.
implement compression_param_info, compress_weight and decompress_weight; child
classes should also define COMPRESSION_PARAM_NAMES.

Compressors support compressing/decompressing a full module state dict or a single
quantized PyTorch leaf module.
Expand Down Expand Up @@ -59,18 +60,27 @@ class BaseSparseCompressor(BaseCompressor):
:param config: config specifying compression parameters
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression

:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress, if None
compress all layers (for backwards compatibility)
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
compression_data = self.compress_weight(name, value)
for key in compression_data.keys():
if key in compressed_dict:
Expand All @@ -97,8 +107,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
weight_mappings, other_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_other_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -108,3 +120,22 @@ def decompress(
weight_data[param_name] = f.get_tensor(full_name)
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed

for other_name, safe_path in other_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(other_name)
yield other_name, value

@staticmethod
def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed

:param name: name of the parameter
:param targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if targets is None:
return name.endswith(".weight")

return name.endswith(".weight") and name[: -(len(".weight"))] in targets
92 changes: 92 additions & 0 deletions src/compressed_tensors/compressors/sparse_compressors/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.utils import (
merge_names,
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
tensor_follows_mask_structure,
)
from torch import Tensor


@BaseCompressor.register(name=CompressionFormat.sparse_24.value)
class Sparse24Compressor(BaseSparseCompressor):
"""
Compresses a with 2:4 sparsity structure for inference
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
with sparse 2:4 kernels for float/float16/bfloat16.
https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py
"""

COMPRESSION_PARAM_NAMES = ["sparse_24_packed_weight", "meta"]

@staticmethod
def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
"""
Checks if a tensor fits the required 2:4 sparsity structure
:param name: name of the tensor to check
:param weight: tensor to check for sparsity structure
:return: True if all rows match the 2:4 sparsity structure, raises
ValueError otherwise
"""

if not tensor_follows_mask_structure(
weight, mask=SparsityStructure.TWO_FOUR.value
):
raise ValueError(
"Sparse24Compressor is only compatible with weights that have "
f"a 2:4 sparsity structure. Found segments in {name} "
"that do not match the expected structure."
)

return True

def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]:
"""
Compresses a given with 2:4 sparsity structure.
:param name: name of the tensor in state dict of uncompressed model
:param value: 2:4 sparse tensor to compress
:return: dictionary containing the compressed weight and associated
metadata
"""
weight_suffix = ".weight"
if not name.endswith(weight_suffix):
return {}

prefix = name[: -len(weight_suffix)]
self.validate_sparsity_structure(name=prefix, weight=value)
sparse_24_packed_weight, meta = sparse_semi_structured_from_dense_cutlass(
dense=value
)
return {
merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(),
merge_names(name, "meta"): meta.cpu(),
}

def decompress_weight(self, weight_data):
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
assert (
"sparse_24_packed_weight" in weight_data
), "sparse_24_packed_weight not found in weight_data"
assert "meta" in weight_data, "meta not found in weight_data"

return sparse_semi_structured_to_dense_cutlass(
sparse=weight_data["sparse_24_packed_weight"],
meta_reordered=weight_data["meta"],
)
1 change: 1 addition & 0 deletions src/compressed_tensors/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa
from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
sparse_24 = "sparse-24"
int_quantized = "int-quantized"
float_quantized = "float-quantized"
naive_quantized = "naive-quantized"
Expand Down
37 changes: 37 additions & 0 deletions src/compressed_tensors/config/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from compressed_tensors.config import (
CompressionFormat,
SparsityCompressionConfig,
SparsityStructure,
)


__all__ = ["Sparse24Config"]


@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value)
class Sparse24Config(SparsityCompressionConfig):
"""
Configuration for storing a sparse model using 2:4 compression
:param global_sparsity: average sparsity of the entire model
:param sparsity_structure: structure of the sparsity, "2:4"
"""

format: str = CompressionFormat.sparse_24.value
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
47 changes: 46 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from typing import Dict, Iterable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Union
from typing import Set, Union

import torch
from compressed_tensors.config import CompressionFormat
Expand Down Expand Up @@ -52,6 +52,8 @@
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
"expand_targets",
"is_target",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
Expand Down Expand Up @@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
model.apply(compress_quantized_weights)


def expand_targets(
model: Module, targets: Iterable[str], ignore: Iterable[str]
) -> Set[str]:
"""
Finds all the targets in the model that match the given
targets and ignore lists.

Note: Targets must be regexes, layer types, or full layer names.

:param model: model to search for targets in
:param targets: list of targets to search for
:param ignore: list of targets to ignore
:return: set of all targets that match the given targets and should
not be ignored
"""
return {
name
for name, module in iter_named_leaf_modules(model)
if is_target(name, module, targets, ignore)
}


def is_target(
name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
) -> bool:
"""
Determines if a module should be included in the targets based on the
targets and ignore lists.

Note: Targets must be regexes, layer types, or full layer names.

:param name: name of the module
:param module: the module itself
:param targets: list of targets to search for
:param ignore: list of targets to ignore
:return: True if the module is a target and not ignored, False otherwise
"""
return bool(
find_name_or_class_matches(name, module, targets)
and not find_name_or_class_matches(name, module, ignore)
)


def find_name_or_class_matches(
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
) -> List[str]:
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:

return model


"""
Pre-Set Quantization Scheme Args
"""
Expand Down
Loading