diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 73f6c7b55d82b5..f9e044ecadb914 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1725,7 +1725,12 @@ def _save_checkpoint(self, model, trial, metrics=None): torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint - if metrics is not None and self.args.metric_for_best_model is not None: + if ( + metrics is not None + and self.args.metric_for_best_model is not None + and self.args.best_model_after_epoch is not None + and self.state.epoch > self.args.best_model_after_epoch + ): metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" @@ -2661,7 +2666,9 @@ def prediction_step( logits = smp_nested_concat(logits_mb) else: if has_labels: - with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()): + with self.autocast_smart_context_manager( + enabled=hasattr(self, "scaler") and self.scaler.is_enabled() + ): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() @@ -2671,7 +2678,9 @@ def prediction_step( logits = outputs[1:] else: loss = None - with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()): + with self.autocast_smart_context_manager( + enabled=hasattr(self, "scaler") and self.scaler.is_enabled() + ): outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 62c91341f307e0..c15dfe01c91c88 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -640,6 +640,10 @@ class TrainingArguments: default=False, metadata={"help": "Whether or not to load the best model found during training at the end of training."}, ) + best_model_after_epoch: int = field( + default=None, + metadata={"help": "Epoch after which best model will be saved."}, + ) metric_for_best_model: Optional[str] = field( default=None, metadata={"help": "The metric to use to compare two different models."} ) @@ -748,12 +752,8 @@ class TrainingArguments: metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"}, ) modifier_log_frequency: float = field( - default = 0.1, - metadata={ - "help": ( - "How often to log SparseML modifier data, in number of epochs or fraction of epochs" - ) - } + default=0.1, + metadata={"help": ("How often to log SparseML modifier data, in number of epochs or fraction of epochs")}, ) def __post_init__(self):