Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for targets and ignore in Sparsity Compressors #182

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -38,6 +38,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
Expand Down Expand Up @@ -268,9 +269,9 @@ def compress(

compressed_state_dict = state_dict

quantized_modules_to_args: Dict[
str, QuantizationArgs
] = map_modules_to_quant_args(model)
quantized_modules_to_args: Dict[str, QuantizationArgs] = (
map_modules_to_quant_args(model)
)

if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
Expand All @@ -282,8 +283,14 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_sparse_target_names(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict,
compression_targets=sparse_compression_targets,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand Down
44 changes: 39 additions & 5 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand All @@ -30,7 +30,8 @@
class BaseSparseCompressor(BaseCompressor):
"""
Base class representing a sparse compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.
implement compression_param_info, compress_weight and decompress_weight; child
classes should also define COMPRESSION_PARAM_NAMES.

Compressors support compressing/decompressing a full module state dict or a single
quantized PyTorch leaf module.
Expand Down Expand Up @@ -59,18 +60,27 @@ class BaseSparseCompressor(BaseCompressor):
:param config: config specifying compression parameters
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression

:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress, if None
compress all layers (for backwards compatibility)
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
compression_data = self.compress_weight(name, value)
for key in compression_data.keys():
if key in compressed_dict:
Expand All @@ -97,8 +107,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
weight_mappings, uncompressed_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
return_unmatched_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -108,3 +120,25 @@ def decompress(
weight_data[param_name] = f.get_tensor(full_name)
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed

for uncompressed_param_name, safe_path in uncompressed_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(uncompressed_param_name)
yield uncompressed_param_name, value

@staticmethod
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed.
Currently, this only returns True for weight parameters.

:param name: name of the parameter
:param expanded_targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if expanded_targets is None:
return name.endswith(".weight")

return (
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
)
47 changes: 46 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from typing import Dict, Iterable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Union
from typing import Set, Union

import torch
from compressed_tensors.config import CompressionFormat
Expand Down Expand Up @@ -52,6 +52,8 @@
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
"expand_sparse_target_names",
"is_target",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
Expand Down Expand Up @@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
model.apply(compress_quantized_weights)


def expand_sparse_target_names(
model: Module, targets: Iterable[str], ignore: Iterable[str]
) -> Set[str]:
"""
Finds all unique module names in the model that match the given
targets and ignore lists.

Note: Targets must be regexes, layer types, or full layer names.

:param model: model to search for targets in
:param targets: list of targets to search for
:param ignore: list of targets to ignore
:return: set of all targets that match the given targets and should
not be ignored
"""
return {
name
for name, module in iter_named_leaf_modules(model)
if is_target(name, module, targets, ignore)
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
}


def is_target(
name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
) -> bool:
"""
Determines if a module should be included in the targets based on the
targets and ignore lists.

Note: Targets must be regexes, layer types, or full layer names.

:param name: name of the module
:param module: the module itself
:param targets: list of targets to search for
:param ignore: list of targets to ignore
:return: True if the module is a target and not ignored, False otherwise
"""
return bool(
find_name_or_class_matches(name, module, targets)
and not find_name_or_class_matches(name, module, ignore)
)


def find_name_or_class_matches(
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
) -> List[str]:
Expand Down
66 changes: 50 additions & 16 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import re
import struct
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union

from safetensors import safe_open
from torch import Tensor
Expand All @@ -34,6 +34,9 @@
"is_quantization_param",
]

WeightMappingType = Dict[str, str]
NestedWeightMappingType = Dict[str, WeightMappingType]


def get_safetensors_folder(
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
Expand Down Expand Up @@ -92,7 +95,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
return header


def match_param_name(full_name: str, param_name: str) -> str:
def match_param_name(full_name: str, param_name: str) -> Optional[str]:
"""
Helper function extracting the uncompressed parameterized layer name from a
compressed name. Assumes the compressed name was merged using merge_names.
Expand Down Expand Up @@ -176,38 +179,69 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:


def get_nested_weight_mappings(
model_path: str, params_to_nest: List[str]
) -> Dict[str, Dict[str, str]]:
model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
"""
Takes a path to a state dict saved in safetensors format and returns a nested
mapping from uncompressed parameterized layer names to the file locations of each
of the layers compression parameters.
mapping from uncompressed parameterized layer names to the file locations of
each layer's compression parameters.

Example of the nested mapping:
layer.weight: {
bitmask: file_location,
row_offsets: file_location,
shape: file_location,
compressed: file_location
}

This generalizes to cases where the model is split into multiple safetensors files
If other parameters are found that do not match the nested parameters, they will
be returned in a separate dictionary only if return_unmatched_params is True.
This dictionary may be needed for cases where compressors are stacked (e.g.,
quantization compression followed by sparse compression).

:param model_path: path to safetensors state dict, must contain either a single
safetensors file or multiple files with an index
:return: nested mapping of parameterized layer name to file location
Example of the unmatched params mapping:
{
layer.weight_scale: file_location,
layer.input_scale: file_location
}

This generalizes to cases where the model is split into multiple safetensors
files.

:param model_path: Path to the safetensors state dict, must contain either a
single safetensors file or multiple files with an index.
:param params_to_nest: List of parameter names to nest.
:param return_unmatched_params: If True, return a second dictionary containing
the remaining parameters that were not matched to the params_to_nest.
:return:
- If return_unmatched_params is False:
NestedWeightMappingType: A nested mapping of parameterized layer names to
file locations of each layer's compression parameters.
- If return_unmatched_params is True:
Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
- NestedWeightMappingType: A nested mapping of parameterized layer
names to file locations of each layer's compression parameters.
- WeightMappingType: A mapping of the remaining parameter names to
their file locations that were not matched to the params_to_nest.
"""
weight_mappings = get_weight_mappings(model_path)

nested_weight_mappings = {}
for key in weight_mappings.keys():
unmatched_params = {}

for key, file_location in weight_mappings.items():
matched = False
for param_name in params_to_nest:
maybe_match = match_param_name(key, param_name)
if maybe_match is not None:
dense_param = maybe_match
dense_param = match_param_name(key, param_name)
if dense_param:
if dense_param not in nested_weight_mappings:
nested_weight_mappings[dense_param] = {}
nested_weight_mappings[dense_param][param_name] = weight_mappings[key]
nested_weight_mappings[dense_param][param_name] = file_location
matched = True
if return_unmatched_params and not matched:
unmatched_params[key] = file_location

if return_unmatched_params:
return nested_weight_mappings, unmatched_params
return nested_weight_mappings


Expand Down
Loading