From daa0e697f57f24688689806f64ff57cf3eac6079 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Wed, 17 Jan 2024 16:04:16 -0800 Subject: [PATCH] add `restore_from_best` to BaseCheckpointer (#677) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/677 # Context We are adding best checkpoint feature in TNT checkpointers. This requires us to save, as well as load the best checkpoint. # This Diff Adds `restore_from_best` method to load from best checkpoint # Next Diff Implement save best checkpoint Reviewed By: galrotem Differential Revision: D52746033 fbshipit-source-id: 03bdeffa780363715a5da94703242c8f0190d53a --- .../callbacks/test_base_checkpointer.py | 69 +++++++++++++++++++ .../framework/callbacks/base_checkpointer.py | 64 ++++++++++++++++- 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index 7135ac003b..2c16a4af19 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -210,6 +210,75 @@ def test_restore_from_latest_empty_dir(self) -> None: ) self.assertFalse(restored) + def test_restore_from_best(self) -> None: + input_dim = 2 + state = get_dummy_train_state() + + with tempfile.TemporaryDirectory() as temp_dir: + bcs_cb = BaseCheckpointSaver(temp_dir) + + my_unit = DummyTrainUnit(input_dim=input_dim) + bcs_cb._generate_checkpoint_and_upkeep(state, my_unit, hook="foo") + os.rename( + os.path.join(temp_dir, "epoch_0_step_0"), + os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01"), + ) + + my_unit.train_progress._num_steps_completed = 10 + bcs_cb._generate_checkpoint_and_upkeep(state, my_unit, hook="foo") + os.rename( + os.path.join(temp_dir, "epoch_0_step_10"), + os.path.join(temp_dir, "epoch_0_step_10_val_loss=-0.1"), + ) + + my_unit.train_progress._num_steps_completed = 20 + bcs_cb._generate_checkpoint_and_upkeep(state, my_unit, hook="foo") + os.rename( + os.path.join(temp_dir, "epoch_0_step_20"), + os.path.join(temp_dir, "epoch_0_step_20_val_loss=0.1"), + ) + + my_unit = DummyTrainUnit(input_dim=input_dim) + with self.assertLogs(level="INFO") as log: + restored = bcs_cb.restore_from_best( + temp_dir, my_unit, "val_loss", "min" + ) + self.assertTrue(restored) + self.assertIn( + f"INFO:torchtnt.utils.rank_zero_log:Loading checkpoint from {os.path.join(temp_dir, 'epoch_0_step_10_val_loss=-0.1')}", + log.output, + ) + + my_unit = DummyTrainUnit(input_dim=input_dim) + with self.assertLogs(level="INFO") as log: + restored = bcs_cb.restore_from_best( + temp_dir, my_unit, "val_loss", "max" + ) + self.assertTrue(restored) + self.assertIn( + f"INFO:torchtnt.utils.rank_zero_log:Loading checkpoint from {os.path.join(temp_dir, 'epoch_0_step_20_val_loss=0.1')}", + log.output, + ) + + def test_restore_from_best_empty_dir(self) -> None: + input_dim = 2 + + my_unit = DummyTrainUnit(input_dim=input_dim) + with tempfile.TemporaryDirectory() as temp_dir: + bcs_cb = BaseCheckpointSaver( + temp_dir, + ) + + with self.assertLogs(level="WARNING") as log: + restored = bcs_cb.restore_from_best( + temp_dir, my_unit, "val_loss", "min" + ) + self.assertIn( + f"WARNING:torchtnt.framework.callbacks.base_checkpointer:No checkpoints with metric name val_loss were found in {temp_dir}. Not loading any checkpoint.", + log.output, + ) + self.assertFalse(restored) + def test_save_on_train_end(self) -> None: input_dim = 2 dataset_len = 10 diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 55ebc4f8b6..82a2426b8a 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -7,7 +7,7 @@ import abc import logging import os -from typing import Any, cast, Iterable, List, Optional +from typing import Any, cast, Iterable, List, Literal, Optional import torch.distributed as dist @@ -16,6 +16,7 @@ _delete_checkpoint, _metadata_exists, _sort_by_recency, + get_best_checkpoint_path, get_checkpoint_dirpaths, get_latest_checkpoint_path, rank_zero_read_and_broadcast, @@ -26,7 +27,7 @@ from torchtnt.framework.utils import get_timing_context from torchtnt.utils.distributed import PGWrapper from torchtnt.utils.fsspec import get_filesystem -from torchtnt.utils.rank_zero_log import rank_zero_warn +from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn logger: logging.Logger = logging.getLogger(__name__) @@ -308,6 +309,65 @@ def restore_from_latest( ) return True + @classmethod + def restore_from_best( + cls, + dirpath: str, + unit: AppStateMixin, + metric_name: str, + mode: Literal["min", "max"], + *, + train_dataloader: Optional[Iterable[TTrainData]] = None, + process_group: Optional[dist.ProcessGroup] = None, + restore_options: Optional[RestoreOptions] = None, + **kwargs: Any, + ) -> bool: + """ + Given a parent directory where checkpoints are saved, restore the checkpoint state from the best checkpoint in the directory. + + There are additional flags offered should the user want to skip loading the train and eval progress. + By default, the train and eval progress are restored, if applicable. + + Args: + dirpath: Parent directory from which to get the latest checkpoint. + unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. + metric_name: Name of the metric to use to find the best checkpoint. + mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. + train_dataloader: An optional train dataloader to restore. + process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) + restore_options: Controls what to filter when restoring the state. + + Returns: + True if the best checkpoint directory was found and successfully restored, otherwise False. + """ + best_checkpoint_path = get_best_checkpoint_path( + dirpath, + metric_name=metric_name, + mode=mode, + metadata_fname=cls.metadata_fname, + process_group=process_group, + ) + + if best_checkpoint_path is None: + rank_zero_warn( + f"No checkpoints with metric name {metric_name} were found in {dirpath}. Not loading any checkpoint.", + logger=logger, + ) + return False + + rank_zero_info(f"Loading checkpoint from {best_checkpoint_path}") + + cls.restore( + best_checkpoint_path, + unit, + train_dataloader=train_dataloader, + process_group=process_group, + restore_options=restore_options, + **kwargs, + ) + + return True + @rank_zero_read_and_broadcast def _does_checkpoint_exist( self, checkpoint_path: str, process_group: Optional[dist.ProcessGroup] = None