Skip to content

Commit

Permalink
forward 'save_every_n_eval_epochs', 'best_checkpoint_config' to check…
Browse files Browse the repository at this point in the history
…pointers (#680)

Summary:
Pull Request resolved: #680

Forwards the arguments of best checkpointing to DCP and TSS checkpointers
1. Adds `save_every_n_eval_epochs` arg
2. Augments `keep_last_n_checkpoints` docstring to reference it's compatibility w/ best checkpoint feature
3. Adds `best_checkpoint_config` arg
4. Adds note on how the metric value is retrieved for best checkpoint feature in both checkpointer docstrings

Reviewed By: galrotem

Differential Revision: D52845663

fbshipit-source-id: dfb470c9757b92e0619961d4eb193886228eb12d
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 18, 2024
1 parent 49825e8 commit d236579
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
Expand Down
17 changes: 15 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
RestoreOptions,
)
from torchtnt.framework.state import State
from torchtnt.framework.unit import AppStateMixin, TTrainData
from torchtnt.framework.utils import get_timing_context
Expand All @@ -45,7 +48,9 @@ class DistributedCheckpointSaver(BaseCheckpointer):
dirpath: Parent directory to save snapshots to.
save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
Note:
Expand All @@ -54,6 +59,10 @@ class DistributedCheckpointSaver(BaseCheckpointer):
Note:
If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_ prior to starting training.
Note:
If best_checkpoint_config is enabled, the attribute must be on the unit upon checkpoint time, and must be castable to "float". This value must be maintained by the unit, and updated
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
"""

def __init__(
Expand All @@ -62,14 +71,18 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> None:
super().__init__(
dirpath=dirpath,
save_every_n_train_steps=save_every_n_train_steps,
save_every_n_epochs=save_every_n_epochs,
save_every_n_eval_epochs=save_every_n_eval_epochs,
keep_last_n_checkpoints=keep_last_n_checkpoints,
best_checkpoint_config=best_checkpoint_config,
process_group=process_group,
)

Expand Down
18 changes: 16 additions & 2 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
KnobOptions,
RestoreOptions,
)
from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
Expand Down Expand Up @@ -66,7 +70,9 @@ class TorchSnapshotSaver(BaseCheckpointer):
dirpath: Parent directory to save snapshots to.
save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
async_checkpoint: Whether to perform asynchronous snapshotting. Default: ``True``.
replicated: A glob-pattern of replicated key names that indicate which application state entries have the same state across all processes.
Expand All @@ -85,6 +91,10 @@ class TorchSnapshotSaver(BaseCheckpointer):
Note:
If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_ prior to starting training.
Note:
If best_checkpoint_config is enabled, the attribute must be on the unit upon checkpoint time, and must be castable to "float". This value must be maintained by the unit, and updated
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
"""

metadata_fname: Optional[str] = ".snapshot_metadata"
Expand All @@ -95,7 +105,9 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
async_checkpoint: bool = True,
replicated: Optional[List[str]] = None,
Expand All @@ -107,7 +119,9 @@ def __init__(
dirpath=dirpath,
save_every_n_train_steps=save_every_n_train_steps,
save_every_n_epochs=save_every_n_epochs,
save_every_n_eval_epochs=save_every_n_eval_epochs,
keep_last_n_checkpoints=keep_last_n_checkpoints,
best_checkpoint_config=best_checkpoint_config,
process_group=process_group,
)
self._async_checkpoint = async_checkpoint
Expand Down

0 comments on commit d236579

Please sign in to comment.