Skip to content

Commit

Permalink
[BugFix][Torchvision] update optimizer state dict before transfer lea…
Browse files Browse the repository at this point in the history
…rning (#1358)

* Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict

* Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set

* Remove: un-needed imports

* Address review comments

* Style
  • Loading branch information
rahul-tuli authored and bfineran committed Feb 3, 2023
1 parent b1ec8a9 commit 1271955
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,14 @@ def collate_fn(batch):
# load params
if checkpoint is not None:
if "optimizer" in checkpoint and not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
if args.resume:
optimizer.load_state_dict(checkpoint["optimizer"])
else:
warnings.warn(
"Optimizer state dict not loaded from checkpoint. Unless run is "
"resumed with the --resume arg, the optimizer will start from a "
"fresh state"
)
if model_ema and "model_ema" in checkpoint:
model_ema.load_state_dict(checkpoint["model_ema"])
if scaler and "scaler" in checkpoint:
Expand Down

0 comments on commit 1271955

Please sign in to comment.