diff --git a/train.py b/train.py index 90c512629d31..5d31fdf04482 100644 --- a/train.py +++ b/train.py @@ -215,7 +215,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary LOGGER.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % (weights, start_epoch-1, epochs)) epochs += start_epoch # finetune additional epochs - if sparseml_wrapper.qat_active(start_epoch): + if sparseml_wrapper.qat_active(start_epoch) and ema: ema.enabled = False # Optimizer @@ -353,7 +353,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary LOGGER.info('Disabling half precision and EMA, QAT scheduled to run') half_precision = False scaler._enabled = False - ema.enabled = False + if ema: + ema.enabled = False model.train() # Update image weights (optional, single-GPU only)