Skip to content

Commit

Permalink
review suggestions from @dsikka
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Dec 3, 2024
1 parent c54699a commit 6936121
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_targets
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
Expand Down Expand Up @@ -269,9 +269,9 @@ def compress(

compressed_state_dict = state_dict

quantized_modules_to_args: Dict[
str, QuantizationArgs
] = map_modules_to_quant_args(model)
quantized_modules_to_args: Dict[str, QuantizationArgs] = (
map_modules_to_quant_args(model)
)

if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
Expand All @@ -283,7 +283,7 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_targets(
sparse_compression_targets: Set[str] = expand_sparse_target_names(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
Expand Down
10 changes: 5 additions & 5 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings, other_params = get_nested_weight_mappings(
weight_mappings, uncompressed_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_other_params=True,
return_unmatched_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -121,10 +121,10 @@ def decompress(
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed

for other_name, safe_path in other_params.items():
for uncompressed_param_name, safe_path in uncompressed_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(other_name)
yield other_name, value
value = f.get_tensor(uncompressed_param_name)
yield uncompressed_param_name, value

@staticmethod
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
"expand_targets",
"expand_sparse_target_names",
"is_target",
]

Expand Down Expand Up @@ -247,11 +247,11 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
model.apply(compress_quantized_weights)


def expand_targets(
def expand_sparse_target_names(
model: Module, targets: Iterable[str], ignore: Iterable[str]
) -> Set[str]:
"""
Finds all the targets in the model that match the given
Finds all unique module names in the model that match the given
targets and ignore lists.
Note: Targets must be regexes, layer types, or full layer names.
Expand Down
57 changes: 38 additions & 19 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,35 +179,54 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:


def get_nested_weight_mappings(
model_path: str, params_to_nest: List[str], return_other_params: bool = False
model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
"""
Takes a path to a state dict saved in safetensors format and returns a nested
mapping from uncompressed parameterized layer names to the file locations of each
of the layers compression parameters.
mapping from uncompressed parameterized layer names to the file locations of
each layer's compression parameters.
Example of the nested mapping:
layer.weight: {
bitmask: file_location,
row_offsets: file_location,
shape: file_location,
compressed: file_location
}
This generalizes to cases where the model is split into multiple safetensors files.
:param model_path: path to safetensors state dict, must contain either a single
safetensors file or multiple files with an index.
:param params_to_nest: list of parameter names to nest.
:param return_other_params: if True, return a second dictionary containing the
remaining parameters that were not matched to the nested parameters.
:return: nested mapping of parameterized layer name to file location if
return_other_params is False, else a tuple containing the nested mapping
and a mapping of the remaining parameters that were not matched to
the nested parameters.
If other parameters are found that do not match the nested parameters, they will
be returned in a separate dictionary only if return_unmatched_params is True.
This dictionary may be needed for cases where compressors are stacked (e.g.,
quantization compression followed by sparse compression).
Example of the unmatched params mapping:
{
layer.weight_scale: file_location,
layer.input_scale: file_location
}
This generalizes to cases where the model is split into multiple safetensors
files.
:param model_path: Path to the safetensors state dict, must contain either a
single safetensors file or multiple files with an index.
:param params_to_nest: List of parameter names to nest.
:param return_unmatched_params: If True, return a second dictionary containing
the remaining parameters that were not matched to the params_to_nest.
:return:
- If return_unmatched_params is False:
NestedWeightMappingType: A nested mapping of parameterized layer names to
file locations of each layer's compression parameters.
- If return_unmatched_params is True:
Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
- NestedWeightMappingType: A nested mapping of parameterized layer
names to file locations of each layer's compression parameters.
- WeightMappingType: A mapping of the remaining parameter names to
their file locations that were not matched to the params_to_nest.
"""
weight_mappings = get_weight_mappings(model_path)
nested_weight_mappings = {}
other_params = {}
unmatched_params = {}

for key, file_location in weight_mappings.items():
matched = False
Expand All @@ -218,11 +237,11 @@ def get_nested_weight_mappings(
nested_weight_mappings[dense_param] = {}
nested_weight_mappings[dense_param][param_name] = file_location
matched = True
if not matched:
other_params[key] = file_location
if return_unmatched_params and not matched:
unmatched_params[key] = file_location

if return_other_params:
return nested_weight_mappings, other_params
if return_unmatched_params:
return nested_weight_mappings, unmatched_params
return nested_weight_mappings


Expand Down
6 changes: 3 additions & 3 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from compressed_tensors.quantization.lifecycle import (
apply_quantization_config,
apply_quantization_status,
expand_targets,
expand_sparse_target_names,
is_target,
)
from compressed_tensors.quantization.utils import iter_named_leaf_modules
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):
],
)
def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets):
expanded_targets = expand_targets(mock_model, targets, ignore)
expanded_targets = expand_sparse_target_names(mock_model, targets, ignore)
assert expanded_targets == expected_targets


Expand Down Expand Up @@ -346,7 +346,7 @@ def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets)
def test_expand_targets_with_llama_stories(
llama_stories_model, targets, ignore, expected_targets
):
expanded_targets = expand_targets(llama_stories_model, targets, ignore)
expanded_targets = expand_sparse_target_names(llama_stories_model, targets, ignore)
assert expanded_targets == expected_targets


Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/test_safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_multiple_params(self):
def test_return_other_params(self):
params_to_nest = ["weight"]
result, other_params = get_nested_weight_mappings(
"dummy_path", params_to_nest, return_other_params=True
"dummy_path", params_to_nest, return_unmatched_params=True
)
expected_nested = {
"layer1": {"weight": "file1"},
Expand Down

0 comments on commit 6936121

Please sign in to comment.