Skip to content

Commit

Permalink
update checkpointing docs to include best checkpoint details (#683)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #683

Makes few touch ups to checkpointing docs, and updates to include details on the best checkpoint feature

Reviewed By: galrotem, williamhufb

Differential Revision: D52891863

fbshipit-source-id: 0e0e0ed112451ccb49c7fd27e56ff56c019b5f61
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 19, 2024
1 parent ee06641 commit 994663d
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions docs/source/checkpointing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.Torc
There is built-in support for saving and loading distributed models (DDP, FSDP).

Fully Sharded Data Parallel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The state dict type to be used for checkpointing FSDP modules can be specified in the :class:`~torchtnt.utils.prepare_module.FSDPStrategy`'s state_dict_type argument like so:

.. code-block:: python
Expand Down Expand Up @@ -63,6 +66,10 @@ Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.or
callbacks=[tss]
)
Finetuning
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When finetuning your models, you can pass RestoreOptions to avoid loading optimizers and learning rate schedulers like so:

.. code-block:: python
Expand All @@ -81,3 +88,62 @@ When finetuning your models, you can pass RestoreOptions to avoid loading optimi
train_dataloader=dataloader,
restore_options=RestoreOptions(restore_optimizers=False, restore_lr_schedulers=False)
)
Best Model by Metric
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Sometimes it may be helpful to keep track of how models perform. This can be done via the BestCheckpointConfig param:

.. code-block:: python
module = nn.Linear(input_dim, 1)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="train_loss",
mode="min"
)
)
train(
unit,
dataloader,
callbacks=[tss]
)
By specifying the monitored metric to be "train_loss", the checkpointer will expect the :class:`~torchtnt.framework.unit.TrainUnit` to have a "train_loss" attribute at the time of checkpointing, and it will cast this value to a float and append the value to the checkpoint path name. This attribute is expected to be computed and kept up to date appropriately in the unit by the user.

Later on, the best checkpoint can be loaded via

.. code-block:: python
TorchSnapshotSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
If you'd like to monitor a validation metric (say validation loss after each eval epoch during :py:func:`~torchtnt.framework.fit.fit`), you can use the `save_every_n_eval_epochs` flag instead, like so

.. code-block:: python
tss = TorchSnapshotSaver(
dirpath=your_dirpath_here,
save_every_n_eval_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="eval_loss",
mode="min"
)
)
And to save only the top three performing models, you can use the existing `keep_last_n_checkpoints` flag like so

.. code-block:: python
tss = TorchSnapshotSaver(
dirpath=your_dirpath_here,
save_every_n_eval_epochs=1,
keep_last_n_checkpoints=3,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="eval_loss",
mode="min"
)
)

0 comments on commit 994663d

Please sign in to comment.