From 22e63cdab500af45f657e2f1d36cee5df8bd65a6 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Tue, 6 Jun 2023 00:45:50 -0400 Subject: [PATCH] Fix for quantization modifier w/ DDP (#1594) * Removes "module." from submodule names inserted by DDP * Renamed variables to make them consistent --- .../pytorch/sparsification/quantization/quantize.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize.py b/src/sparseml/pytorch/sparsification/quantization/quantize.py index 20415f54dee..ef109b9f05d 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize.py @@ -361,10 +361,13 @@ def _match_submodule_name_or_type( # 2. match the submodule prefix (longest first) submodule_match = "" for name_or_type in names_or_types: + name_to_compare = submodule_name[:] + if name_to_compare.startswith("module."): + name_to_compare = name_to_compare[7:] if name_or_type == submodule.__class__.__name__: # type match, return type name return name_or_type - if submodule_name.startswith(name_or_type) and ( + if name_to_compare.startswith(name_or_type) and ( len(name_or_type) > len(submodule_match) ): # match to most specific submodule name @@ -422,7 +425,10 @@ def _get_unmatched_types_or_names(types_or_names): for type_or_name in types_or_names: matched = False for submodule_name, submodule in model.named_modules(): - if submodule_name.startswith(type_or_name) or ( + name_to_compare = submodule_name[:] + if name_to_compare.startswith("module."): + name_to_compare = name_to_compare[7:] + if name_to_compare.startswith(type_or_name) or ( submodule.__class__.__name__ == type_or_name ): matched = True