Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 12, 2024
1 parent d99440d commit 4a15d6a
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions python/text_utils/modules/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@


def _select_params_and_modules(
modules: Iterator[Tuple[str, nn.Module]],
prefix: str
modules: Iterator[Tuple[str, nn.Module]], prefix: str
) -> Iterator[Tuple[str, nn.Module, nn.Parameter]]:
for name, mod in modules:
for p_name, param in mod.named_parameters(prefix=name, recurse=False):
Expand All @@ -17,10 +16,9 @@ def _select_params_and_modules(
def optimizer_from_config(
model: nn.Module,
cfg: Dict[str, Any],
additional_optimizer_fn: Optional[Callable[
[nn.Module, Dict[str, Any]],
optim.Optimizer
]] = None
additional_optimizer_fn: Optional[
Callable[[nn.Module, Dict[str, Any]], optim.Optimizer]
] = None,
) -> optim.Optimizer:
cfg = copy.deepcopy(cfg)
opt_type = cfg.pop("type")
Expand All @@ -40,8 +38,7 @@ def optimizer_from_config(
assert len(param_groups) > 0, "param_groups must be non-empty"

weight_decay_modules: dict[str, list[str]] | str = cfg.pop(
"weight_decay_modules",
"all"
"weight_decay_modules", "all"
)
all = set(name for name, p in model.named_parameters() if p.requires_grad)
params = []
Expand All @@ -54,28 +51,25 @@ def optimizer_from_config(
decay = set()
param_dict = {}
for name, mod, param in _select_params_and_modules(
model.named_modules(),
prefix
model.named_modules(), prefix
):
if name not in all:
# this should only happen for shared
# or non-trainable parameters
continue

if fix:
param.requires_grad = False
continue

names.add(name)
param_dict[name] = param
mod_name = mod.__class__.__name__
if (
weight_decay_modules == "all"
or (
isinstance(weight_decay_modules, dict)
and mod_name in weight_decay_modules
and any(
name.endswith(suffix)
for suffix in weight_decay_modules[mod_name]
)
if weight_decay_modules == "all" or (
isinstance(weight_decay_modules, dict)
and mod_name in weight_decay_modules
and any(
name.endswith(suffix) for suffix in weight_decay_modules[mod_name]
)
):
decay.add(name)
Expand All @@ -86,18 +80,23 @@ def optimizer_from_config(
assert len(param_dict.keys() - (decay | no_decay)) == 0

if len(decay) > 0:
params.append({
"params": [param_dict[name] for name in sorted(list(decay))],
**(cfg | group)
})
params.append(
{
"params": [param_dict[name] for name in sorted(list(decay))],
**(cfg | group),
}
)
if len(no_decay) > 0:
params.append({
"params": [param_dict[name] for name in sorted(list(no_decay))],
**(cfg | group | {"weight_decay": 0.0})
})
params.append(
{
"params": [param_dict[name] for name in sorted(list(no_decay))],
**(cfg | group | {"weight_decay": 0.0}),
}
)

unused = all - names
assert len(unused) == 0, \
f"parameter groups dont match trainable model parameters: {unused}"
assert (
len(unused) == 0
), f"parameter groups dont match trainable model parameters: {unused}"

return optim_cls(params, **cfg)

0 comments on commit 4a15d6a

Please sign in to comment.