Skip to content

Commit

Permalink
inspect user-provided non-finetuning schedule capable callback depend…
Browse files Browse the repository at this point in the history
…ency configuration and use it to instantiate analogous FTS callback dependencies
  • Loading branch information
speediedan committed Jan 26, 2024
1 parent d3003d0 commit 5d80dec
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 27 deletions.
14 changes: 11 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ foundation model experimentation with flexible fine-tuning schedules. Training w
If you're exploring using the :class:`~finetuning_scheduler.fts.FinetuningScheduler`, this is a great place
to start!
You may also find the `notebook-based tutorial <https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/finetuning-scheduler.html>`_
useful and for those using the :doc:`LightningCLI<cli/lightning_cli>`, there is a
useful and for those using the :external+pl:class:`~lightning.pytorch.cli.LightningCLI`, there is a
:ref:`CLI-based<scheduled-fine-tuning-superglue>` example at the bottom of this introduction.

Setup
Expand Down Expand Up @@ -69,6 +69,14 @@ and :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` callbacks with
trainer = L.Trainer(callbacks=[FinetuningScheduler()])
.. note::
If not provided, FTS will instantiate its callback dependencies
(:class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` and
:class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint`) with default configurations and ``monitor=val_loss``.
If the user provides base versions of these dependencies (e.g.
:external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`,
:external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint`) the provided configuration of
those callbacks will be used to instantiate their FTS analogs instead.

.. _default schedule:

Expand Down Expand Up @@ -353,8 +361,8 @@ A demonstration of the scheduled fine-tuning callback
:class:`~finetuning_scheduler.fts.FinetuningScheduler` using the
`RTE <https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte>`_ and
`BoolQ <https://github.com/google-research-datasets/boolean-questions>`_ tasks of the
`SuperGLUE <https://paperswithcode.com/dataset/superglue>`_ benchmark and the :doc:`LightningCLI<cli/lightning_cli>`
is available under ``./fts_examples/stable``.
`SuperGLUE <https://paperswithcode.com/dataset/superglue>`_ benchmark and the
:external+pl:class:`~lightning.pytorch.cli.LightningCLI` is available under ``./fts_examples/stable``.

Since this CLI-based example requires a few additional packages (e.g. ``transformers``, ``sentencepiece``), you
should install them using the ``[examples]`` extra:
Expand Down
49 changes: 45 additions & 4 deletions src/finetuning_scheduler/fts_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import os
import pathlib
import inspect
import re
import warnings
from abc import ABC, abstractmethod
Expand All @@ -30,6 +31,7 @@
from functools import reduce
from pprint import pformat
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing_extensions import TypeAlias

import lightning.pytorch as pl
import torch
Expand All @@ -51,7 +53,8 @@
from torch.nn import Module

from finetuning_scheduler.strategy_adapters.fsdp import FSDPStrategyAdapter, StrategyAdapter
from finetuning_scheduler.types import FTSLRSchedulerType, FTSLRSchedulerTypeTuple, ParamGroupAddable
from finetuning_scheduler.types import (FTSLRSchedulerType, FTSLRSchedulerTypeTuple, ParamGroupAddable,
BaseCallbackDepType)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -472,6 +475,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.best_model_path = state_dict["best_model_path"]


FTSCallbackDepType: TypeAlias = Union[Type[FTSEarlyStopping], Type[FTSCheckpoint]]

class UniqueKeyLoader(yaml.SafeLoader):
"""Alters SafeLoader to enable duplicate key detection by the SafeConstructor."""

Expand Down Expand Up @@ -1822,6 +1827,40 @@ def _reorder_callback_by_type(callbacks: List[Callback], target_callback: type)
other_callbacks = [c for c in callbacks if not isinstance(c, target_callback)]
return other_callbacks + target_callbacks

@staticmethod
def _extract_base_callback_cfg(trainer: "pl.Trainer", callback_type: BaseCallbackDepType) -> Dict:
"""Extracts the configuration of a user-provided.
:external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` or
:external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callback to enable the
subsequent instantiation of a fine-tuning schedule-capable FTS analog with a similar configuration.
Args:
trainer (pl.Trainer): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object.
callback_type (BaseCallbackDepType): The type of base callback from which to extract the configuration.
Returns:
Dict: The extracted user-provided callback configuration.
"""
base_callback = [c for c in trainer.callbacks if isinstance(c, callback_type)][0]
base_callback_params = dict(inspect.signature(base_callback.__init__).parameters)
return {k: v for k, v in base_callback.__dict__.items() if k in base_callback_params}

@staticmethod
def _add_fts_callback(trainer: "pl.Trainer", fts_cls: FTSCallbackDepType, cfg: Dict) -> None:
"""Adds a fine-tuning schedule-capable FTS callback dependency with a specified configuration.
Args:
trainer (pl.Trainer): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object.
fts_cls (FTSCallbackDepType): The type of FTS callback dependency to instantiate.
cfg (Dict): The desired FTS callback configuration.
"""
if cfg.get("monitor", None) is None:
cfg["monitor"] = "val_loss"
rank_zero_warn(f"No monitor metric specified for {fts_cls.__class__.__name__},"
" using 'val_loss' as default.")
trainer.callbacks.append(fts_cls(**cfg))

def _callback_dep_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Ensures all :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback dependencies are met, adding
and configuring them if necessary.
Expand Down Expand Up @@ -1858,21 +1897,22 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback
Bool: Whether a :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` callback was added
"""
has_ckpt_fts, has_ckpt_base, has_es_fts, has_es_base, has_lr_monitor = self._inspect_callback_deps(trainer)
added_ckpt_fts, added_es_fts = False, False
added_ckpt_fts, added_es_fts, added_ckpt_fts_kwargs, added_es_fts_kwargs = False, False, {}, {}
if not any([has_es_fts, self.epoch_transitions_only, self.gen_ft_sched_only]): # type: ignore[attr-defined]
if has_es_base:
rank_zero_warn(
f"{self.__class__.__name__} currently depends upon a fine-tuning schedule "
"capable EarlyStopping callback such as FTSEarlyStopping. Substituting current "
"EarlyStopping for FTSEarlyStopping"
)
added_es_fts_kwargs = CallbackDepMixin._extract_base_callback_cfg(trainer, EarlyStopping)
trainer.callbacks = [c for c in trainer.callbacks if not isinstance(c, EarlyStopping)]
else:
rank_zero_warn(
f"{self.__class__.__name__} currently depends upon an FTSEarlyStopping callback unless configured "
"in epoch_transitions_only mode. Adding an FTSEarlyStopping callback with default configuration."
)
trainer.callbacks.append(FTSEarlyStopping(monitor="val_loss"))
CallbackDepMixin._add_fts_callback(trainer, FTSEarlyStopping, added_es_fts_kwargs)
added_es_fts = True
if (has_es_fts or has_es_base) and self.epoch_transitions_only: # type: ignore[attr-defined]
rank_zero_warn(
Expand All @@ -1887,8 +1927,9 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback
"capable ModelCheckpoint callback such as FTSCheckpoint. Substituting current "
"ModelCheckpoint for FTSCheckpoint"
)
added_ckpt_fts_kwargs = CallbackDepMixin._extract_base_callback_cfg(trainer, ModelCheckpoint)
trainer.callbacks = [c for c in trainer.callbacks if not isinstance(c, ModelCheckpoint)]
trainer.callbacks.append(FTSCheckpoint(monitor="val_loss", verbose=True))
CallbackDepMixin._add_fts_callback(trainer, FTSCheckpoint, added_ckpt_fts_kwargs)
added_ckpt_fts = True
for uc in [c for c in trainer.callbacks if any([isinstance(c, d) for d in CALLBACK_DEP_PARENTS.values()])]:
uc.connect_callback(trainer) # type: ignore[attr-defined]
Expand Down
4 changes: 4 additions & 0 deletions src/finetuning_scheduler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
"""
from typing import Any, Dict, Protocol, runtime_checkable, Type, Union
from typing_extensions import TypeAlias

import torch
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER, Optimizable, ReduceLROnPlateau
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint


@runtime_checkable
Expand All @@ -45,3 +47,5 @@ def add_param_group(self, param_group: Dict[Any, Any]) -> None:
]
FTSLRSchedulerTypeTuple = tuple(getattr(torch.optim.lr_scheduler, lr_class) for lr_class in supported_lrs)
FTSLRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]]

BaseCallbackDepType: TypeAlias = Union[Type[EarlyStopping], Type[ModelCheckpoint]]
49 changes: 29 additions & 20 deletions tests/test_finetuning_scheduler_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from lightning.fabric.utilities import rank_zero_only
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch import LightningModule, seed_everything, Trainer
from lightning.pytorch.callbacks import Callback, EarlyStopping, LearningRateFinder, LearningRateMonitor
from lightning.pytorch.callbacks import (Callback, EarlyStopping, LearningRateFinder, LearningRateMonitor,
ModelCheckpoint)
from lightning.pytorch.strategies import StrategyRegistry
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(
weight_decay: float = 1.0e-06,
init_lr_key: str = None,
p0_params: Optional[List] = None,
monitor_metric: str = None,
):
super().__init__()
self.layer = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 2))
Expand All @@ -175,6 +177,7 @@ def __init__(
self.weight_decay = weight_decay
self.init_lr_key = init_lr_key
self.p0_params = p0_params
self.monitor_metric = monitor_metric or "val_loss"

def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
Expand All @@ -193,7 +196,7 @@ def validation_step(self, batch, batch_idx):
self.validation_step_outputs.append(loss)
# we would normally use sync_dist for epoch-only logging in a distributed context but leaving it `False` here
# to test FTS transition behavior when the test model is used in a distributed context
self.log("val_loss", loss, prog_bar=False)
self.log(self.monitor_metric, loss, prog_bar=False)
return {"x": loss}

def on_validation_epoch_end(self):
Expand Down Expand Up @@ -1234,12 +1237,10 @@ def test_fts_decay(tmpdir, boring_ft_schedule, explicit_mode: bool, nodecay_mode
(False, True, "kth", 1): (0, 0, 1),
}
EXPECTED_WARNS = [
"does not have many workers",
"GPU available but",
"`max_epochs` was not",
"The dirpath has changed from",
#"reduce_op is deprecated", # warning caused upstream
#"`pydantic.config.Extra` is deprecated",
"does not have many workers", # required for all PyTorch/Lightning versions
"GPU available but", # required for all PyTorch/Lightning versions
"`max_epochs` was not", # required for all PyTorch/Lightning versions
"The dirpath has changed from", # required for all PyTorch/Lightning versions
]
EXPECTED_TRAIN_CHK_WARNS = []
EXPECTED_DIRPATH = ""
Expand Down Expand Up @@ -1294,10 +1295,6 @@ def test_fts_callback_resume(

DYNAMO_EXPECTED_WARNS = [
"Final phase max_transition_epoch",
# using different callbacks for now to avoid creating another fixture with limited utility
# "Be aware that when using `ckpt_path`, callbacks used",
# "Your compiler for AOTAutograd is returning", # out of initial scope
#"tensor cores for float32 matrix multiplication available", # out of initial scope
]


Expand Down Expand Up @@ -2036,17 +2033,29 @@ def on_train_epoch_start(self, trainer, pl_module) -> None:


@pytest.mark.parametrize(
"callbacks, dist_mode, expected",
"callbacks, cust_monitor, dist_mode, expected",
[
([FinetuningScheduler()], None, ("an FTSEarlyStopping", "as FTSCheck")),
([FinetuningScheduler(), FTSEarlyStopping(monitor="val_loss", patience=1)], None, ("FTSCheckpoint. Subs")),
([FinetuningScheduler()], None, None, ("an FTSEarlyStopping", "as FTSCheck")),
([FinetuningScheduler(), FTSEarlyStopping(monitor="val_loss", patience=1)], None, None,
("FTSCheckpoint. Subs")),
(
[FinetuningScheduler(), EarlyStopping(monitor="val_loss", patience=1)],
None, None,
("Stopping. Sub", "Checkpoint. Sub"),
),
(
[FinetuningScheduler(), EarlyStopping(monitor="abc_val_loss", patience=1),
ModelCheckpoint(monitor="abc_val_loss", verbose=True)], 'abc_val_loss',
None,
("Stopping. Sub", "Checkpoint. Sub"),
),
(
[FinetuningScheduler(), FTSCheckpoint(monitor="val_loss", verbose=True)],
[FinetuningScheduler(), ModelCheckpoint(verbose=True)], None,
None,
("Adding an FTSEarlyStopping", "Checkpoint. Sub", "No monitor metric specified"),
),
(
[FinetuningScheduler(), FTSCheckpoint(monitor="val_loss", verbose=True)], None,
None,
("Adding an FTSEarlyStopping",),
),
Expand All @@ -2055,19 +2064,19 @@ def on_train_epoch_start(self, trainer, pl_module) -> None:
MockDistFTS(),
FTSCheckpoint(monitor="val_loss", verbose=True),
FTSEarlyStopping(monitor="val_loss", patience=1),
],
], None,
"ddp",
("not being synchronized",),
),
],
ids=["default", "nondef_es", "def_es", "nondef_ftsckpt", "no_sync"],
ids=["default", "nondef_es", "def_es", "extract_base_callback_cfg", "missing_monitor"," nondef_ftsckpt", "no_sync"],
)
def test_fts_callback_warns(
tmpdir, recwarn, callbacks: List[Callback], dist_mode: str, expected: Tuple[str]
tmpdir, recwarn, callbacks: List[Callback], cust_monitor: Optional[str], dist_mode: str, expected: Tuple[str]
):
"""Validate :class:`~finetuning_scheduler.FinetuningScheduler` warnings that require a
:class:`~pytorch_lighting.trainer.Trainer` to be defined are properly issued"""
model = FinetuningSchedulerBoringModel()
model = FinetuningSchedulerBoringModel(monitor_metric=cust_monitor)
dist_args = {"strategy": dist_mode, "accelerator": "cpu", "devices": "auto"} if dist_mode else {"devices": 1}
trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, **dist_args)
trainer.fit(model)
Expand Down

0 comments on commit 5d80dec

Please sign in to comment.