diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 73f87050..72232801 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -78,7 +78,8 @@ def compress( f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): - if not self.should_compress(name, compression_targets): + ignored = not self.should_compress(name, compression_targets) + if ignored: compressed_dict[name] = value continue prefix = name @@ -111,7 +112,7 @@ def decompress( :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ - weight_mappings, uncompressed_params = get_nested_weight_mappings( + weight_mappings, ignored_params = get_nested_weight_mappings( path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES, return_unmatched_params=True, @@ -125,10 +126,10 @@ def decompress( decompressed = self.decompress_weight(weight_data) yield merge_names(weight_name, "weight"), decompressed - for uncompressed_param_name, safe_path in uncompressed_params.items(): + for ignored_param_name, safe_path in ignored_params.items(): with safe_open(safe_path, framework="pt", device=device) as f: - value = f.get_tensor(uncompressed_param_name) - yield uncompressed_param_name, value + value = f.get_tensor(ignored_param_name) + yield ignored_param_name, value @staticmethod def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: