Skip to content

Commit

Permalink
add restore_from_best to BaseCheckpointer (#677)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #677

# Context
We are adding best checkpoint feature in TNT checkpointers. This requires us to save, as well as load the best checkpoint.

# This Diff
Adds `restore_from_best` method to load from best checkpoint

# Next Diff
Implement save best checkpoint

Reviewed By: galrotem

Differential Revision: D52746033

fbshipit-source-id: 03bdeffa780363715a5da94703242c8f0190d53a
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 18, 2024
1 parent 4073700 commit daa0e69
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
69 changes: 69 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,75 @@ def test_restore_from_latest_empty_dir(self) -> None:
)
self.assertFalse(restored)

def test_restore_from_best(self) -> None:
input_dim = 2
state = get_dummy_train_state()

with tempfile.TemporaryDirectory() as temp_dir:
bcs_cb = BaseCheckpointSaver(temp_dir)

my_unit = DummyTrainUnit(input_dim=input_dim)
bcs_cb._generate_checkpoint_and_upkeep(state, my_unit, hook="foo")
os.rename(
os.path.join(temp_dir, "epoch_0_step_0"),
os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01"),
)

my_unit.train_progress._num_steps_completed = 10
bcs_cb._generate_checkpoint_and_upkeep(state, my_unit, hook="foo")
os.rename(
os.path.join(temp_dir, "epoch_0_step_10"),
os.path.join(temp_dir, "epoch_0_step_10_val_loss=-0.1"),
)

my_unit.train_progress._num_steps_completed = 20
bcs_cb._generate_checkpoint_and_upkeep(state, my_unit, hook="foo")
os.rename(
os.path.join(temp_dir, "epoch_0_step_20"),
os.path.join(temp_dir, "epoch_0_step_20_val_loss=0.1"),
)

my_unit = DummyTrainUnit(input_dim=input_dim)
with self.assertLogs(level="INFO") as log:
restored = bcs_cb.restore_from_best(
temp_dir, my_unit, "val_loss", "min"
)
self.assertTrue(restored)
self.assertIn(
f"INFO:torchtnt.utils.rank_zero_log:Loading checkpoint from {os.path.join(temp_dir, 'epoch_0_step_10_val_loss=-0.1')}",
log.output,
)

my_unit = DummyTrainUnit(input_dim=input_dim)
with self.assertLogs(level="INFO") as log:
restored = bcs_cb.restore_from_best(
temp_dir, my_unit, "val_loss", "max"
)
self.assertTrue(restored)
self.assertIn(
f"INFO:torchtnt.utils.rank_zero_log:Loading checkpoint from {os.path.join(temp_dir, 'epoch_0_step_20_val_loss=0.1')}",
log.output,
)

def test_restore_from_best_empty_dir(self) -> None:
input_dim = 2

my_unit = DummyTrainUnit(input_dim=input_dim)
with tempfile.TemporaryDirectory() as temp_dir:
bcs_cb = BaseCheckpointSaver(
temp_dir,
)

with self.assertLogs(level="WARNING") as log:
restored = bcs_cb.restore_from_best(
temp_dir, my_unit, "val_loss", "min"
)
self.assertIn(
f"WARNING:torchtnt.framework.callbacks.base_checkpointer:No checkpoints with metric name val_loss were found in {temp_dir}. Not loading any checkpoint.",
log.output,
)
self.assertFalse(restored)

def test_save_on_train_end(self) -> None:
input_dim = 2
dataset_len = 10
Expand Down
64 changes: 62 additions & 2 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import abc
import logging
import os
from typing import Any, cast, Iterable, List, Optional
from typing import Any, cast, Iterable, List, Literal, Optional

import torch.distributed as dist

Expand All @@ -16,6 +16,7 @@
_delete_checkpoint,
_metadata_exists,
_sort_by_recency,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
rank_zero_read_and_broadcast,
Expand All @@ -26,7 +27,7 @@
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.distributed import PGWrapper
from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.rank_zero_log import rank_zero_warn
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -308,6 +309,65 @@ def restore_from_latest(
)
return True

@classmethod
def restore_from_best(
cls,
dirpath: str,
unit: AppStateMixin,
metric_name: str,
mode: Literal["min", "max"],
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
**kwargs: Any,
) -> bool:
"""
Given a parent directory where checkpoints are saved, restore the checkpoint state from the best checkpoint in the directory.
There are additional flags offered should the user want to skip loading the train and eval progress.
By default, the train and eval progress are restored, if applicable.
Args:
dirpath: Parent directory from which to get the latest checkpoint.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
metric_name: Name of the metric to use to find the best checkpoint.
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
Returns:
True if the best checkpoint directory was found and successfully restored, otherwise False.
"""
best_checkpoint_path = get_best_checkpoint_path(
dirpath,
metric_name=metric_name,
mode=mode,
metadata_fname=cls.metadata_fname,
process_group=process_group,
)

if best_checkpoint_path is None:
rank_zero_warn(
f"No checkpoints with metric name {metric_name} were found in {dirpath}. Not loading any checkpoint.",
logger=logger,
)
return False

rank_zero_info(f"Loading checkpoint from {best_checkpoint_path}")

cls.restore(
best_checkpoint_path,
unit,
train_dataloader=train_dataloader,
process_group=process_group,
restore_options=restore_options,
**kwargs,
)

return True

@rank_zero_read_and_broadcast
def _does_checkpoint_exist(
self, checkpoint_path: str, process_group: Optional[dist.ProcessGroup] = None
Expand Down

0 comments on commit daa0e69

Please sign in to comment.