From 127195590aa39805cac541b8c70e6bb10cbbca58 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 2 Feb 2023 11:15:54 -0500 Subject: [PATCH] [BugFix][Torchvision] update optimizer state dict before transfer learning (#1358) * Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict * Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set * Remove: un-needed imports * Address review comments * Style --- src/sparseml/pytorch/torchvision/train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 7c77ee3e08c..d80ad837a30 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -533,7 +533,14 @@ def collate_fn(batch): # load params if checkpoint is not None: if "optimizer" in checkpoint and not args.test_only: - optimizer.load_state_dict(checkpoint["optimizer"]) + if args.resume: + optimizer.load_state_dict(checkpoint["optimizer"]) + else: + warnings.warn( + "Optimizer state dict not loaded from checkpoint. Unless run is " + "resumed with the --resume arg, the optimizer will start from a " + "fresh state" + ) if model_ema and "model_ema" in checkpoint: model_ema.load_state_dict(checkpoint["model_ema"]) if scaler and "scaler" in checkpoint: