From b8756a4dbba08be2b291c0d06a88a1fee1a5d915 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Mon, 24 Feb 2025 15:06:19 -0800 Subject: [PATCH] Add str method to ActivePhase (#975) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/975 Differential Revision: D70127457 --- tests/framework/test_state.py | 10 ++++++++++ torchtnt/framework/state.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/tests/framework/test_state.py b/tests/framework/test_state.py index ca1c9dabd5..9e2f658e03 100644 --- a/tests/framework/test_state.py +++ b/tests/framework/test_state.py @@ -53,6 +53,16 @@ def test_active_phase_into_phase(self) -> None: predict_phase = ActivePhase.PREDICT self.assertEqual(predict_phase.into_phase(), Phase.PREDICT) + def test_actiive_phase_str(self) -> None: + active_phase = ActivePhase.TRAIN + self.assertEqual(str(active_phase), "train") + + eval_phase = ActivePhase.EVALUATE + self.assertEqual(str(eval_phase), "eval") + + predict_phase = ActivePhase.PREDICT + self.assertEqual(str(predict_phase), "predict") + def test_set_evaluate_every_n_steps_or_epochs(self) -> None: state = PhaseState(dataloader=[], evaluate_every_n_steps=2) state.evaluate_every_n_steps = None diff --git a/torchtnt/framework/state.py b/torchtnt/framework/state.py index 99d4ae6f96..57eb29cd3e 100644 --- a/torchtnt/framework/state.py +++ b/torchtnt/framework/state.py @@ -74,6 +74,16 @@ def into_phase(self) -> Phase: else: raise AssertionError("Should match an ActivePhase") + def __str__(self) -> str: + if self == ActivePhase.TRAIN: + return "train" + elif self == ActivePhase.EVALUATE: + return "eval" + elif self == ActivePhase.PREDICT: + return "predict" + else: + raise AssertionError("Should match an ActivePhase") + class PhaseState(Generic[TData, TStepOutput]): """State for each phase (train, eval, predict).