Skip to content

Commit

Permalink
support reading checkpoints w/ metric values (#675)
Browse files Browse the repository at this point in the history
Summary:

# Context

We want to support checkpointing best model in TorchTNT. This requires adding to the existing utils that handle reading checkpoints

# This Diff
1. Adds `metric_name` arg to `_retrieve_checkpoint_dirpaths()` (and `get_checkpoint_dirpaths()`, its distributed equivalent), which is responsible for reading the appropriate checkpoint paths. It will append to its regex and only consider checkpoints that contain the metric name in their name
2. Adds `get_best_checkpoint_path()` method, adding a sibling to the `get_latest_checkpoint_path()` method when restoring checkpoints. It now supports sorting by metric value, alongside sorting by latest (which it could already do)

# Next Diff
Implement the best checkpoint feature using these utils

Differential Revision: D52714747
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 12, 2024
1 parent bb8ed78 commit cec0ed8
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 8 deletions.
131 changes: 130 additions & 1 deletion tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import tempfile
import unittest
from functools import partial

import torch
import torch.distributed as dist
Expand All @@ -22,6 +23,7 @@
_metadata_exists,
_prepare_app_state_for_checkpoint,
_retrieve_checkpoint_dirpaths,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
rank_zero_read_and_broadcast,
Expand Down Expand Up @@ -68,7 +70,7 @@ def test_latest_checkpoint_path(self) -> None:
path_1 = os.path.join(temp_dir, "epoch_0_step_0")
os.mkdir(path_1)
self._create_snapshot_metadata(path_1)
path_2 = os.path.join(temp_dir, "epoch_0_step_100")
path_2 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.002")
os.mkdir(path_2)
self._create_snapshot_metadata(path_2)

Expand Down Expand Up @@ -136,6 +138,66 @@ def _latest_checkpoint_path_distributed() -> None:
if is_rank0:
shutil.rmtree(temp_dir) # delete temp directory

def test_best_checkpoint_path(self) -> None:
get_best_checkpoint_fn = partial(
get_best_checkpoint_path, metric_name="val_loss", mode="min"
)

with tempfile.TemporaryDirectory() as temp_dir:
self.assertIsNone(get_best_checkpoint_fn(temp_dir))

# no checkpoint w/ metric value
path = os.path.join(temp_dir, "epoch_0_step_0")
os.mkdir(path)
self.assertIsNone(get_best_checkpoint_fn(temp_dir))

with tempfile.TemporaryDirectory() as temp_dir:
best_path = os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01")
os.mkdir(best_path)
self.assertEqual(
get_best_checkpoint_fn(temp_dir),
best_path,
)
self.assertEqual(
get_best_checkpoint_fn(temp_dir, metadata_fname=METADATA_FNAME),
None,
)
self._create_snapshot_metadata(best_path)
self.assertEqual(
get_best_checkpoint_fn(temp_dir, metadata_fname=METADATA_FNAME),
best_path,
)

# handle negative values
best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01")
os.mkdir(best_path_2)
self.assertEqual(
get_best_checkpoint_fn(temp_dir),
best_path_2,
)

# handle "max" mode correctly
best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1")
os.mkdir(best_path_3)
self.assertEqual(
get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"),
best_path_3,
)

# handle different metric correctly
best_path_4 = os.path.join(temp_dir, "epoch_0_step_100_train_loss=0.2")
os.mkdir(best_path_4)
self.assertEqual(
get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"),
best_path_3,
)
self.assertEqual(
get_best_checkpoint_path(
temp_dir, metric_name="train_loss", mode="max"
),
best_path_4,
)

def test_retrieve_checkpoint_dirpaths(self) -> None:
"""
Tests retrieving checkpoint directories from a given root directory
Expand Down Expand Up @@ -178,6 +240,60 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
{os.path.join(temp_dir, paths[2])},
)

def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
"""
Tests retrieving checkpoint (w/ metrics) directories from a given root directory
"""
with tempfile.TemporaryDirectory() as temp_dir:
paths = [
"epoch_0_step_10_val_loss=10",
"epoch_1_step_10_val_loss=5",
"epoch_2_step_10",
"epoch_0_step_5",
"epoch_0_step_6_train_loss=13",
"epoch_0_step_3_val_loss=3",
]
for path in paths[:-1]:
os.mkdir(os.path.join(temp_dir, path))
# make last path a file instead of a directory
with open(os.path.join(temp_dir, paths[-1]), "w"):
pass

# compares set equality since order of returned dirpaths is not guaranteed
# in _retrieve_checkpoint_dirpaths
self.assertEqual(
set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
{os.path.join(temp_dir, path) for path in paths[:-1]},
)
self.assertEqual(
set(
_retrieve_checkpoint_dirpaths(
temp_dir, metadata_fname=None, metric_name="val_loss"
)
),
{
os.path.join(temp_dir, path) for path in paths[:2]
}, # since last path is a file
)
self.assertEqual(
_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"),
[],
)

# check metadata file is correct filtered for
# by creating metadata for 3rd path in list
with open(os.path.join(temp_dir, paths[1], ".metadata"), "w"):
pass

self.assertEqual(
set(
_retrieve_checkpoint_dirpaths(
temp_dir, metadata_fname=".metadata", metric_name="val_loss"
)
),
{os.path.join(temp_dir, paths[1])},
)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
Expand Down Expand Up @@ -235,6 +351,19 @@ def test_get_checkpoint_dirpaths(self) -> None:
[path3, path1, path2], # sorted by epoch and step
)

with tempfile.TemporaryDirectory() as temp_dir:
path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01")
path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2")
path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12")
os.mkdir(path1)
os.mkdir(path2)
os.mkdir(path3)

self.assertEqual(
get_checkpoint_dirpaths(temp_dir, metric_name="val_loss", mode="min"),
[path2, path1, path3], # sorted by val_loss, ascending
)

with tempfile.TemporaryDirectory() as temp_dir:
self.assertEqual(
get_checkpoint_dirpaths(temp_dir),
Expand Down
85 changes: 78 additions & 7 deletions torchtnt/framework/callbacks/_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
import os
import re

from typing import Any, Callable, cast, Dict, List, Optional, Pattern, Tuple, TypeVar
from typing import (
Any,
Callable,
cast,
Dict,
List,
Literal,
Optional,
Pattern,
Tuple,
TypeVar,
)

import fsspec

Expand Down Expand Up @@ -97,9 +108,9 @@ def _latest_checkpoint_path(
# dirname will be of the format epoch_N_step_M
# where N is the epoch number and M is the step number as integers
split = dirname.split("_")
if len(split) != 4:
if len(split) < 4:
raise AssertionError(
f"Expected exactly 4 elements for pattern of epoch_N_step_M, but received {split})"
f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})"
)

epoch_num, step_num = int(split[1]), int(split[3])
Expand All @@ -119,40 +130,97 @@ def _latest_checkpoint_path(
return os.path.join(dirpath, none_throws(largest_subdirectory))


@rank_zero_read_and_broadcast
def get_best_checkpoint_path(
dirpath: str,
metric_name: str,
mode: Literal["min", "max"],
metadata_fname: Optional[str] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Optional[str]:
"""
Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory.
Args:
dirpath: parent directory where checkpoints are saved.
metric_name: Name of the metric to use to find the best checkpoint.
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
"""

dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
if len(dirpaths) == 0:
# no checkpoints found
return None

best_checkpoint_path = None
best_metric_value = float("inf") if mode == "min" else float("-inf")
for dirpath in dirpaths:
dirname = os.path.basename(dirpath)
metric_value = float(dirname.split("=")[-1])

if mode == "min":
if metric_value < best_metric_value:
best_metric_value = metric_value
best_checkpoint_path = dirpath
else:
if metric_value > best_metric_value:
best_metric_value = metric_value
best_checkpoint_path = dirpath

return best_checkpoint_path


@rank_zero_read_and_broadcast
def get_checkpoint_dirpaths(
dirpath: str,
metadata_fname: Optional[str] = None,
metric_name: Optional[str] = None,
mode: Literal["min", "max"] = "min",
process_group: Optional[dist.ProcessGroup] = None,
) -> List[str]:
"""
Given a parent directory where checkpoints are saved, return the sorted checkpoint subdirectories
from oldest to newest.
Given a parent directory where checkpoints are saved, returns the sorted checkpoint subdirectories
from oldest to newest (if no metric specified), else in the order of their metric values.
Args:
dirpath: parent directory where checkpoints are saved.
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
metric_name: Name of the metric that must exist in checkpoint name.
mode: Either 'min' or 'max'. If 'min', sorts from lowest to highest metric value. If 'max', sorts from highest to lowest.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
"""

def sort_fn(path: str) -> Tuple[int, int]:
x = os.path.basename(path)
return (int(x.split("_")[1]), int(x.split("_")[3]))

dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname)
return sorted(dirpaths, key=sort_fn)
def sort_metric_fn(path: str) -> float:
x = os.path.basename(path)
metric_val = float(x.split("=")[-1])
return metric_val

dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
return sorted(
dirpaths,
key=sort_fn if not metric_name else sort_metric_fn,
reverse=(mode == "max"), # sort ascending if min is best metric
)


def _retrieve_checkpoint_dirpaths(
dirpath: str,
metadata_fname: Optional[str],
metric_name: Optional[str] = None,
) -> List[str]:
"""
Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories
Args:
dirpath: parent directory where checkpoints are saved.
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
metric_name: Name of the metric that must exist in checkpoint name.
"""

if dirpath[-1] == "/":
Expand All @@ -174,6 +242,9 @@ def _retrieve_checkpoint_dirpaths(

# Define the regex pattern to match the directory names
pattern = rf"^{dirpath}/epoch_\d+_step_\d+"
if metric_name:
# inject metric name in regex search
pattern += rf"_{metric_name}="
snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern)
candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents))

Expand Down

0 comments on commit cec0ed8

Please sign in to comment.