diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 7c77ee3e08c..d80ad837a30 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -533,7 +533,14 @@ def collate_fn(batch): # load params if checkpoint is not None: if "optimizer" in checkpoint and not args.test_only: - optimizer.load_state_dict(checkpoint["optimizer"]) + if args.resume: + optimizer.load_state_dict(checkpoint["optimizer"]) + else: + warnings.warn( + "Optimizer state dict not loaded from checkpoint. Unless run is " + "resumed with the --resume arg, the optimizer will start from a " + "fresh state" + ) if model_ema and "model_ema" in checkpoint: model_ema.load_state_dict(checkpoint["model_ema"]) if scaler and "scaler" in checkpoint: