Skip to content

Commit

Permalink
More review comments from @dsikka
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Dec 20, 2024
1 parent 4ef03bd commit 75664ce
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def compress(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
if not self.should_compress(name, compression_targets):
ignored = not self.should_compress(name, compression_targets)
if ignored:
compressed_dict[name] = value
continue
prefix = name
Expand Down Expand Up @@ -111,7 +112,7 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings, uncompressed_params = get_nested_weight_mappings(
weight_mappings, ignored_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_unmatched_params=True,
Expand All @@ -125,10 +126,10 @@ def decompress(
decompressed = self.decompress_weight(weight_data)
yield merge_names(weight_name, "weight"), decompressed

for uncompressed_param_name, safe_path in uncompressed_params.items():
for ignored_param_name, safe_path in ignored_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(uncompressed_param_name)
yield uncompressed_param_name, value
value = f.get_tensor(ignored_param_name)
yield ignored_param_name, value

@staticmethod
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
Expand Down

0 comments on commit 75664ce

Please sign in to comment.