Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Oct 23, 2024
1 parent 286c081 commit e5bfd8a
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import find_compression_targets
from compressed_tensors.quantization.lifecycle import expand_targets
from compressed_tensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
Expand Down Expand Up @@ -277,9 +277,13 @@ def compress(
)

if self.sparsity_compressor is not None:
compression_targets = self._find_sparse_compression_targets(model=model)
sparse_compression_targets: Set[str] = expand_targets(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict, compression_targets=compression_targets
compressed_state_dict, compression_targets=sparse_compression_targets
)

# HACK: Override the dtype_byte_size function in transformers to
Expand Down Expand Up @@ -370,13 +374,6 @@ def _replace_weights(self, dense_weight_generator, model):
module = operator.attrgetter(prefix)(model)
update_parameter_data(module, data, param_name)

def _find_sparse_compression_targets(self, model: Module) -> Set[str]:
return find_compression_targets(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)


def map_modules_to_quant_args(model: Module) -> Dict:
quantized_modules_to_args = {}
Expand Down
29 changes: 25 additions & 4 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def compress(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
prefix = name.rsplit(".", 1)[0]
if compression_targets and prefix not in compression_targets:
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
compression_data = self.compress_weight(name, value)
for key in compression_data.keys():
Expand Down Expand Up @@ -106,8 +106,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
weight_mappings, other_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_other_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -117,3 +119,22 @@ def decompress(
weight_data[param_name] = f.get_tensor(full_name)
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed

for other_name, safe_path in other_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(other_name)
yield other_name, value

@staticmethod
def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed
:param name: name of the parameter
:param targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if targets is None:
return name.endswith(".weight")

return name.endswith(".weight") and name[: -(len(".weight"))] in targets
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
"find_compression_targets",
"expand_targets",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
Expand Down Expand Up @@ -281,7 +281,7 @@ def find_name_or_class_matches(
return matches


def find_compression_targets(
def expand_targets(
model: Module, targets: Iterable[str], ignore: Iterable[str]
) -> Set[str]:
"""
Expand Down
31 changes: 25 additions & 6 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import re
import struct
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union

from safetensors import safe_open
from torch import Tensor
Expand All @@ -34,6 +34,9 @@
"is_quantization_param",
]

WEIGHT_MAPPING_TYPE = Dict[str, str]
NESTED_WEIGHT_MAPPING_TYPE = Dict[str, WEIGHT_MAPPING_TYPE]


def get_safetensors_folder(
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
Expand Down Expand Up @@ -176,8 +179,10 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:


def get_nested_weight_mappings(
model_path: str, params_to_nest: List[str]
) -> Dict[str, Dict[str, str]]:
model_path: str, params_to_nest: List[str], return_other_params: bool = False
) -> Union[
NESTED_WEIGHT_MAPPING_TYPE, Tuple[NESTED_WEIGHT_MAPPING_TYPE, WEIGHT_MAPPING_TYPE]
]:
"""
Takes a path to a state dict saved in safetensors format and returns a nested
mapping from uncompressed parameterized layer names to the file locations of each
Expand All @@ -193,22 +198,36 @@ def get_nested_weight_mappings(
This generalizes to cases where the model is split into multiple safetensors files
:param model_path: path to safetensors state dict, must contain either a single
safetensors file or multiple files with an index
:return: nested mapping of parameterized layer name to file location
safetensors file or multiple files with an index
:param return_other_params: if True, return a second dictionary containing the
remaining parameters that were not matched to the nested parameters
:return: nested mapping of parameterized layer name to file location if
return_other_params is False, else a tuple containing the nested mapping
and a mapping of the remaining parameters that were not matched to
the nested parameters
"""
weight_mappings = get_weight_mappings(model_path)
other_params = {}

nested_weight_mappings = {}
for key in weight_mappings.keys():
matched = False
for param_name in params_to_nest:
maybe_match = match_param_name(key, param_name)
if maybe_match is not None:
dense_param = maybe_match
if dense_param not in nested_weight_mappings:
nested_weight_mappings[dense_param] = {}
matched = True
nested_weight_mappings[dense_param][param_name] = weight_mappings[key]
if not matched:
other_params[key] = weight_mappings[key]

return nested_weight_mappings
return (
nested_weight_mappings
if not return_other_params
else (nested_weight_mappings, other_params)
)


def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
from transformers import AutoModelForCausalLM


@pytest.fixture
def model():
return AutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M",
torch_dtype="auto",
)


def test_target_prioritization():
# tests that the config_groups are applied in the correct order
# of priority, where exact layer name > regex > module name
Expand Down Expand Up @@ -275,14 +283,6 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):
assert len(caplog.text) == 0


@pytest.fixture
def model():
return AutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M",
torch_dtype="auto",
)


@pytest.mark.parametrize(
"targets, ignore, expected",
[
Expand Down

0 comments on commit e5bfd8a

Please sign in to comment.