From 5d80dec98cfd5063c0097172183fd53b3b057d3f Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Fri, 26 Jan 2024 15:08:09 -0800 Subject: [PATCH] inspect user-provided non-finetuning schedule capable callback dependency configuration and use it to instantiate analogous FTS callback dependencies --- docs/source/index.rst | 14 ++++-- src/finetuning_scheduler/fts_supporters.py | 49 +++++++++++++++++++-- src/finetuning_scheduler/types.py | 4 ++ tests/test_finetuning_scheduler_callback.py | 49 ++++++++++++--------- 4 files changed, 89 insertions(+), 27 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 8a2124c..c6bff16 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,7 +20,7 @@ foundation model experimentation with flexible fine-tuning schedules. Training w If you're exploring using the :class:`~finetuning_scheduler.fts.FinetuningScheduler`, this is a great place to start! You may also find the `notebook-based tutorial `_ - useful and for those using the :doc:`LightningCLI`, there is a + useful and for those using the :external+pl:class:`~lightning.pytorch.cli.LightningCLI`, there is a :ref:`CLI-based` example at the bottom of this introduction. Setup @@ -69,6 +69,14 @@ and :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` callbacks with trainer = L.Trainer(callbacks=[FinetuningScheduler()]) +.. note:: + If not provided, FTS will instantiate its callback dependencies + (:class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` and + :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint`) with default configurations and ``monitor=val_loss``. + If the user provides base versions of these dependencies (e.g. + :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`, + :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint`) the provided configuration of + those callbacks will be used to instantiate their FTS analogs instead. .. _default schedule: @@ -353,8 +361,8 @@ A demonstration of the scheduled fine-tuning callback :class:`~finetuning_scheduler.fts.FinetuningScheduler` using the `RTE `_ and `BoolQ `_ tasks of the -`SuperGLUE `_ benchmark and the :doc:`LightningCLI` -is available under ``./fts_examples/stable``. +`SuperGLUE `_ benchmark and the +:external+pl:class:`~lightning.pytorch.cli.LightningCLI` is available under ``./fts_examples/stable``. Since this CLI-based example requires a few additional packages (e.g. ``transformers``, ``sentencepiece``), you should install them using the ``[examples]`` extra: diff --git a/src/finetuning_scheduler/fts_supporters.py b/src/finetuning_scheduler/fts_supporters.py index be250a6..76e94c2 100644 --- a/src/finetuning_scheduler/fts_supporters.py +++ b/src/finetuning_scheduler/fts_supporters.py @@ -20,6 +20,7 @@ import logging import os import pathlib +import inspect import re import warnings from abc import ABC, abstractmethod @@ -30,6 +31,7 @@ from functools import reduce from pprint import pformat from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing_extensions import TypeAlias import lightning.pytorch as pl import torch @@ -51,7 +53,8 @@ from torch.nn import Module from finetuning_scheduler.strategy_adapters.fsdp import FSDPStrategyAdapter, StrategyAdapter -from finetuning_scheduler.types import FTSLRSchedulerType, FTSLRSchedulerTypeTuple, ParamGroupAddable +from finetuning_scheduler.types import (FTSLRSchedulerType, FTSLRSchedulerTypeTuple, ParamGroupAddable, + BaseCallbackDepType) log = logging.getLogger(__name__) @@ -472,6 +475,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.best_model_path = state_dict["best_model_path"] +FTSCallbackDepType: TypeAlias = Union[Type[FTSEarlyStopping], Type[FTSCheckpoint]] + class UniqueKeyLoader(yaml.SafeLoader): """Alters SafeLoader to enable duplicate key detection by the SafeConstructor.""" @@ -1822,6 +1827,40 @@ def _reorder_callback_by_type(callbacks: List[Callback], target_callback: type) other_callbacks = [c for c in callbacks if not isinstance(c, target_callback)] return other_callbacks + target_callbacks + @staticmethod + def _extract_base_callback_cfg(trainer: "pl.Trainer", callback_type: BaseCallbackDepType) -> Dict: + """Extracts the configuration of a user-provided. + + :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` or + :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callback to enable the + subsequent instantiation of a fine-tuning schedule-capable FTS analog with a similar configuration. + + Args: + trainer (pl.Trainer): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object. + callback_type (BaseCallbackDepType): The type of base callback from which to extract the configuration. + + Returns: + Dict: The extracted user-provided callback configuration. + """ + base_callback = [c for c in trainer.callbacks if isinstance(c, callback_type)][0] + base_callback_params = dict(inspect.signature(base_callback.__init__).parameters) + return {k: v for k, v in base_callback.__dict__.items() if k in base_callback_params} + + @staticmethod + def _add_fts_callback(trainer: "pl.Trainer", fts_cls: FTSCallbackDepType, cfg: Dict) -> None: + """Adds a fine-tuning schedule-capable FTS callback dependency with a specified configuration. + + Args: + trainer (pl.Trainer): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object. + fts_cls (FTSCallbackDepType): The type of FTS callback dependency to instantiate. + cfg (Dict): The desired FTS callback configuration. + """ + if cfg.get("monitor", None) is None: + cfg["monitor"] = "val_loss" + rank_zero_warn(f"No monitor metric specified for {fts_cls.__class__.__name__}," + " using 'val_loss' as default.") + trainer.callbacks.append(fts_cls(**cfg)) + def _callback_dep_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: """Ensures all :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback dependencies are met, adding and configuring them if necessary. @@ -1858,7 +1897,7 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback Bool: Whether a :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` callback was added """ has_ckpt_fts, has_ckpt_base, has_es_fts, has_es_base, has_lr_monitor = self._inspect_callback_deps(trainer) - added_ckpt_fts, added_es_fts = False, False + added_ckpt_fts, added_es_fts, added_ckpt_fts_kwargs, added_es_fts_kwargs = False, False, {}, {} if not any([has_es_fts, self.epoch_transitions_only, self.gen_ft_sched_only]): # type: ignore[attr-defined] if has_es_base: rank_zero_warn( @@ -1866,13 +1905,14 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback "capable EarlyStopping callback such as FTSEarlyStopping. Substituting current " "EarlyStopping for FTSEarlyStopping" ) + added_es_fts_kwargs = CallbackDepMixin._extract_base_callback_cfg(trainer, EarlyStopping) trainer.callbacks = [c for c in trainer.callbacks if not isinstance(c, EarlyStopping)] else: rank_zero_warn( f"{self.__class__.__name__} currently depends upon an FTSEarlyStopping callback unless configured " "in epoch_transitions_only mode. Adding an FTSEarlyStopping callback with default configuration." ) - trainer.callbacks.append(FTSEarlyStopping(monitor="val_loss")) + CallbackDepMixin._add_fts_callback(trainer, FTSEarlyStopping, added_es_fts_kwargs) added_es_fts = True if (has_es_fts or has_es_base) and self.epoch_transitions_only: # type: ignore[attr-defined] rank_zero_warn( @@ -1887,8 +1927,9 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback "capable ModelCheckpoint callback such as FTSCheckpoint. Substituting current " "ModelCheckpoint for FTSCheckpoint" ) + added_ckpt_fts_kwargs = CallbackDepMixin._extract_base_callback_cfg(trainer, ModelCheckpoint) trainer.callbacks = [c for c in trainer.callbacks if not isinstance(c, ModelCheckpoint)] - trainer.callbacks.append(FTSCheckpoint(monitor="val_loss", verbose=True)) + CallbackDepMixin._add_fts_callback(trainer, FTSCheckpoint, added_ckpt_fts_kwargs) added_ckpt_fts = True for uc in [c for c in trainer.callbacks if any([isinstance(c, d) for d in CALLBACK_DEP_PARENTS.values()])]: uc.connect_callback(trainer) # type: ignore[attr-defined] diff --git a/src/finetuning_scheduler/types.py b/src/finetuning_scheduler/types.py index 70ab581..ac135f1 100644 --- a/src/finetuning_scheduler/types.py +++ b/src/finetuning_scheduler/types.py @@ -17,9 +17,11 @@ """ from typing import Any, Dict, Protocol, runtime_checkable, Type, Union +from typing_extensions import TypeAlias import torch from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER, Optimizable, ReduceLROnPlateau +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint @runtime_checkable @@ -45,3 +47,5 @@ def add_param_group(self, param_group: Dict[Any, Any]) -> None: ] FTSLRSchedulerTypeTuple = tuple(getattr(torch.optim.lr_scheduler, lr_class) for lr_class in supported_lrs) FTSLRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]] + +BaseCallbackDepType: TypeAlias = Union[Type[EarlyStopping], Type[ModelCheckpoint]] diff --git a/tests/test_finetuning_scheduler_callback.py b/tests/test_finetuning_scheduler_callback.py index 0fa496a..7c50911 100644 --- a/tests/test_finetuning_scheduler_callback.py +++ b/tests/test_finetuning_scheduler_callback.py @@ -26,7 +26,8 @@ from lightning.fabric.utilities import rank_zero_only from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.pytorch import LightningModule, seed_everything, Trainer -from lightning.pytorch.callbacks import Callback, EarlyStopping, LearningRateFinder, LearningRateMonitor +from lightning.pytorch.callbacks import (Callback, EarlyStopping, LearningRateFinder, LearningRateMonitor, + ModelCheckpoint) from lightning.pytorch.strategies import StrategyRegistry from lightning.pytorch.strategies.single_device import SingleDeviceStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -164,6 +165,7 @@ def __init__( weight_decay: float = 1.0e-06, init_lr_key: str = None, p0_params: Optional[List] = None, + monitor_metric: str = None, ): super().__init__() self.layer = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 2)) @@ -175,6 +177,7 @@ def __init__( self.weight_decay = weight_decay self.init_lr_key = init_lr_key self.p0_params = p0_params + self.monitor_metric = monitor_metric or "val_loss" def training_step(self, batch, batch_idx: int): loss = self.step(batch) @@ -193,7 +196,7 @@ def validation_step(self, batch, batch_idx): self.validation_step_outputs.append(loss) # we would normally use sync_dist for epoch-only logging in a distributed context but leaving it `False` here # to test FTS transition behavior when the test model is used in a distributed context - self.log("val_loss", loss, prog_bar=False) + self.log(self.monitor_metric, loss, prog_bar=False) return {"x": loss} def on_validation_epoch_end(self): @@ -1234,12 +1237,10 @@ def test_fts_decay(tmpdir, boring_ft_schedule, explicit_mode: bool, nodecay_mode (False, True, "kth", 1): (0, 0, 1), } EXPECTED_WARNS = [ - "does not have many workers", - "GPU available but", - "`max_epochs` was not", - "The dirpath has changed from", - #"reduce_op is deprecated", # warning caused upstream - #"`pydantic.config.Extra` is deprecated", + "does not have many workers", # required for all PyTorch/Lightning versions + "GPU available but", # required for all PyTorch/Lightning versions + "`max_epochs` was not", # required for all PyTorch/Lightning versions + "The dirpath has changed from", # required for all PyTorch/Lightning versions ] EXPECTED_TRAIN_CHK_WARNS = [] EXPECTED_DIRPATH = "" @@ -1294,10 +1295,6 @@ def test_fts_callback_resume( DYNAMO_EXPECTED_WARNS = [ "Final phase max_transition_epoch", - # using different callbacks for now to avoid creating another fixture with limited utility - # "Be aware that when using `ckpt_path`, callbacks used", - # "Your compiler for AOTAutograd is returning", # out of initial scope - #"tensor cores for float32 matrix multiplication available", # out of initial scope ] @@ -2036,17 +2033,29 @@ def on_train_epoch_start(self, trainer, pl_module) -> None: @pytest.mark.parametrize( - "callbacks, dist_mode, expected", + "callbacks, cust_monitor, dist_mode, expected", [ - ([FinetuningScheduler()], None, ("an FTSEarlyStopping", "as FTSCheck")), - ([FinetuningScheduler(), FTSEarlyStopping(monitor="val_loss", patience=1)], None, ("FTSCheckpoint. Subs")), + ([FinetuningScheduler()], None, None, ("an FTSEarlyStopping", "as FTSCheck")), + ([FinetuningScheduler(), FTSEarlyStopping(monitor="val_loss", patience=1)], None, None, + ("FTSCheckpoint. Subs")), ( [FinetuningScheduler(), EarlyStopping(monitor="val_loss", patience=1)], + None, None, + ("Stopping. Sub", "Checkpoint. Sub"), + ), + ( + [FinetuningScheduler(), EarlyStopping(monitor="abc_val_loss", patience=1), + ModelCheckpoint(monitor="abc_val_loss", verbose=True)], 'abc_val_loss', None, ("Stopping. Sub", "Checkpoint. Sub"), ), ( - [FinetuningScheduler(), FTSCheckpoint(monitor="val_loss", verbose=True)], + [FinetuningScheduler(), ModelCheckpoint(verbose=True)], None, + None, + ("Adding an FTSEarlyStopping", "Checkpoint. Sub", "No monitor metric specified"), + ), + ( + [FinetuningScheduler(), FTSCheckpoint(monitor="val_loss", verbose=True)], None, None, ("Adding an FTSEarlyStopping",), ), @@ -2055,19 +2064,19 @@ def on_train_epoch_start(self, trainer, pl_module) -> None: MockDistFTS(), FTSCheckpoint(monitor="val_loss", verbose=True), FTSEarlyStopping(monitor="val_loss", patience=1), - ], + ], None, "ddp", ("not being synchronized",), ), ], - ids=["default", "nondef_es", "def_es", "nondef_ftsckpt", "no_sync"], + ids=["default", "nondef_es", "def_es", "extract_base_callback_cfg", "missing_monitor"," nondef_ftsckpt", "no_sync"], ) def test_fts_callback_warns( - tmpdir, recwarn, callbacks: List[Callback], dist_mode: str, expected: Tuple[str] + tmpdir, recwarn, callbacks: List[Callback], cust_monitor: Optional[str], dist_mode: str, expected: Tuple[str] ): """Validate :class:`~finetuning_scheduler.FinetuningScheduler` warnings that require a :class:`~pytorch_lighting.trainer.Trainer` to be defined are properly issued""" - model = FinetuningSchedulerBoringModel() + model = FinetuningSchedulerBoringModel(monitor_metric=cust_monitor) dist_args = {"strategy": dist_mode, "accelerator": "cpu", "devices": "auto"} if dist_mode else {"devices": 1} trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, **dist_args) trainer.fit(model)