Skip to content

Commit

Permalink
Set evaluate_every_epoch value based on number of train epochs comple…
Browse files Browse the repository at this point in the history
…ted (#968)

Summary:
Pull Request resolved: #968

- Use `evaluate_every_n_epoch` until having trained for a certain amount of epochs, after which we evaluate every epoch forcefully
- Need to be able to set the `evaluate_every_n_epoch` field of the `PhaseState` class
- Save total time by cutting down evaluations, 8.3 hours -> 6.3 hours

Reviewed By: galrotem

Differential Revision: D68933995

fbshipit-source-id: a16bbcc81482681cd297cbcc6dc01764db8bb097
  • Loading branch information
clarkdykang authored and facebook-github-bot committed Jan 31, 2025
1 parent d71a41b commit 93347d9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/framework/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,20 @@ def test_active_phase_into_phase(self) -> None:

predict_phase = ActivePhase.PREDICT
self.assertEqual(predict_phase.into_phase(), Phase.PREDICT)

def test_set_evaluate_every_n_steps_or_epochs(self) -> None:
state = PhaseState(dataloader=[], evaluate_every_n_steps=2)
state.evaluate_every_n_steps = None
state.evaluate_every_n_steps = 100
with self.assertRaisesRegex(
ValueError, "Invalid value provided for evaluate_every_n_steps"
):
state.evaluate_every_n_steps = -2

state = PhaseState(dataloader=[], evaluate_every_n_epochs=2)
state.evaluate_every_n_epochs = None
state.evaluate_every_n_epochs = 100
with self.assertRaisesRegex(
ValueError, "Invalid value provided for evaluate_every_n_epochs"
):
state.evaluate_every_n_epochs = -2
10 changes: 10 additions & 0 deletions torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,21 @@ def evaluate_every_n_steps(self) -> Optional[int]:
"""Frequency with which to evaluate in terms of training steps, when running :func:`~torchtnt.framework.fit`. Defined by the user."""
return self._evaluate_every_n_steps

@evaluate_every_n_steps.setter
def evaluate_every_n_steps(self, value: Optional[int]) -> None:
_check_loop_condition("evaluate_every_n_steps", value)
self._evaluate_every_n_steps = value

@property
def evaluate_every_n_epochs(self) -> Optional[int]:
"""Frequency with which to evaluate in terms of training epochs, when running :func:`~torchtnt.framework.fit`. Defined by the user."""
return self._evaluate_every_n_epochs

@evaluate_every_n_epochs.setter
def evaluate_every_n_epochs(self, value: Optional[int]) -> None:
_check_loop_condition("evaluate_every_n_epochs", value)
self._evaluate_every_n_epochs = value

@property
def step_output(self) -> Optional[TStepOutput]:
"""Output of the last step."""
Expand Down

0 comments on commit 93347d9

Please sign in to comment.