Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward 'save_every_n_eval_epochs', 'best_checkpoint_config' to checkpointers #680

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
Expand Down
17 changes: 15 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
RestoreOptions,
)
from torchtnt.framework.state import State
from torchtnt.framework.unit import AppStateMixin, TTrainData
from torchtnt.framework.utils import get_timing_context
Expand All @@ -45,7 +48,9 @@ class DistributedCheckpointSaver(BaseCheckpointer):
dirpath: Parent directory to save snapshots to.
save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)

Note:
Expand All @@ -54,6 +59,10 @@ class DistributedCheckpointSaver(BaseCheckpointer):

Note:
If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_ prior to starting training.

Note:
If best_checkpoint_config is enabled, the attribute must be on the unit upon checkpoint time, and must be castable to "float". This value must be maintained by the unit, and updated
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
"""

def __init__(
Expand All @@ -62,14 +71,18 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> None:
super().__init__(
dirpath=dirpath,
save_every_n_train_steps=save_every_n_train_steps,
save_every_n_epochs=save_every_n_epochs,
save_every_n_eval_epochs=save_every_n_eval_epochs,
keep_last_n_checkpoints=keep_last_n_checkpoints,
best_checkpoint_config=best_checkpoint_config,
process_group=process_group,
)

Expand Down
18 changes: 16 additions & 2 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
KnobOptions,
RestoreOptions,
)
from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
Expand Down Expand Up @@ -66,7 +70,9 @@ class TorchSnapshotSaver(BaseCheckpointer):
dirpath: Parent directory to save snapshots to.
save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
async_checkpoint: Whether to perform asynchronous snapshotting. Default: ``True``.
replicated: A glob-pattern of replicated key names that indicate which application state entries have the same state across all processes.
Expand All @@ -85,6 +91,10 @@ class TorchSnapshotSaver(BaseCheckpointer):

Note:
If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_ prior to starting training.

Note:
If best_checkpoint_config is enabled, the attribute must be on the unit upon checkpoint time, and must be castable to "float". This value must be maintained by the unit, and updated
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
"""

metadata_fname: Optional[str] = ".snapshot_metadata"
Expand All @@ -95,7 +105,9 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
async_checkpoint: bool = True,
replicated: Optional[List[str]] = None,
Expand All @@ -107,7 +119,9 @@ def __init__(
dirpath=dirpath,
save_every_n_train_steps=save_every_n_train_steps,
save_every_n_epochs=save_every_n_epochs,
save_every_n_eval_epochs=save_every_n_eval_epochs,
keep_last_n_checkpoints=keep_last_n_checkpoints,
best_checkpoint_config=best_checkpoint_config,
process_group=process_group,
)
self._async_checkpoint = async_checkpoint
Expand Down
Loading