Skip to content

Commit

Permalink
Fix for quantization modifier w/ DDP (#1594)
Browse files Browse the repository at this point in the history
* Removes "module." from submodule names inserted by DDP

* Renamed variables to make them consistent
  • Loading branch information
anmarques authored Jun 6, 2023
1 parent 7abe53e commit 22e63cd
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/sparseml/pytorch/sparsification/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,13 @@ def _match_submodule_name_or_type(
# 2. match the submodule prefix (longest first)
submodule_match = ""
for name_or_type in names_or_types:
name_to_compare = submodule_name[:]
if name_to_compare.startswith("module."):
name_to_compare = name_to_compare[7:]
if name_or_type == submodule.__class__.__name__:
# type match, return type name
return name_or_type
if submodule_name.startswith(name_or_type) and (
if name_to_compare.startswith(name_or_type) and (
len(name_or_type) > len(submodule_match)
):
# match to most specific submodule name
Expand Down Expand Up @@ -422,7 +425,10 @@ def _get_unmatched_types_or_names(types_or_names):
for type_or_name in types_or_names:
matched = False
for submodule_name, submodule in model.named_modules():
if submodule_name.startswith(type_or_name) or (
name_to_compare = submodule_name[:]
if name_to_compare.startswith("module."):
name_to_compare = name_to_compare[7:]
if name_to_compare.startswith(type_or_name) or (
submodule.__class__.__name__ == type_or_name
):
matched = True
Expand Down

0 comments on commit 22e63cd

Please sign in to comment.