Skip to content

Commit

Permalink
Best model after epoch (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
natuan authored May 27, 2022
1 parent 053646a commit 15ba9b7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
15 changes: 12 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
if (
metrics is not None
and self.args.metric_for_best_model is not None
and self.args.best_model_after_epoch is not None
and self.state.epoch > self.args.best_model_after_epoch
):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
Expand Down Expand Up @@ -2661,7 +2666,9 @@ def prediction_step(
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
with self.autocast_smart_context_manager(
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()

Expand All @@ -2671,7 +2678,9 @@ def prediction_step(
logits = outputs[1:]
else:
loss = None
with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
with self.autocast_smart_context_manager(
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
):
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether or not to load the best model found during training at the end of training."},
)
best_model_after_epoch: int = field(
default=None,
metadata={"help": "Epoch after which best model will be saved."},
)
metric_for_best_model: Optional[str] = field(
default=None, metadata={"help": "The metric to use to compare two different models."}
)
Expand Down Expand Up @@ -748,12 +752,8 @@ class TrainingArguments:
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
)
modifier_log_frequency: float = field(
default = 0.1,
metadata={
"help": (
"How often to log SparseML modifier data, in number of epochs or fraction of epochs"
)
}
default=0.1,
metadata={"help": ("How often to log SparseML modifier data, in number of epochs or fraction of epochs")},
)

def __post_init__(self):
Expand Down

0 comments on commit 15ba9b7

Please sign in to comment.