From 6b1da2de678623d669cf3dafbefbdb50d1b13972 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Fri, 22 Nov 2024 16:04:40 -0800 Subject: [PATCH] - bumped lightning dev sha - removed `_maybe_sync_loops` after https://github.com/Lightning-AI/pytorch-lightning/pull/20379 obviated the need for it --- requirements/base.txt | 2 +- requirements/standalone_base.txt | 2 +- setup.py | 2 +- src/finetuning_scheduler/fts.py | 38 +++++++---------- src/finetuning_scheduler/fts_supporters.py | 46 +++++++-------------- tests/test_finetuning_scheduler_callback.py | 11 +++-- 6 files changed, 40 insertions(+), 61 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index 4a61ad3..c5cb302 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,4 +1,4 @@ #lightning>=2.5.0,<2.5.1 # the below is uncommented when master is targeting a specific pl dev master commit -git+https://github.com/Lightning-AI/lightning.git@06a8d5bf33faf0a4f9a24207ae77b439354350af#egg=lightning +git+https://github.com/Lightning-AI/lightning.git@8ce52876ad6e5eb05e0965f72e034ffe46b327ba#egg=lightning torch>=2.2.0 diff --git a/requirements/standalone_base.txt b/requirements/standalone_base.txt index 7c08a3e..e5afeb8 100644 --- a/requirements/standalone_base.txt +++ b/requirements/standalone_base.txt @@ -1,4 +1,4 @@ #pytorch-lightning>=2.5.0,<2.5.1 # the below is uncommented when master is targeting a specific pl dev master commit -git+https://github.com/Lightning-AI/pytorch-lightning.git@06a8d5bf33faf0a4f9a24207ae77b439354350af#egg=pytorch-lightning +git+https://github.com/Lightning-AI/pytorch-lightning.git@8ce52876ad6e5eb05e0965f72e034ffe46b327ba#egg=pytorch-lightning torch>=2.2.0 diff --git a/setup.py b/setup.py index 3d8b70c..79f4f13 100755 --- a/setup.py +++ b/setup.py @@ -138,7 +138,7 @@ def _setup_args(standalone: bool = False) -> Dict[str, Any]: _INSTALL_PATHS["require"], file_name=base_reqs, standalone=standalone, - pl_commit="06a8d5bf33faf0a4f9a24207ae77b439354350af", + pl_commit="8ce52876ad6e5eb05e0965f72e034ffe46b327ba", ) base_setup["install_requires"] = install_requires return base_setup diff --git a/src/finetuning_scheduler/fts.py b/src/finetuning_scheduler/fts.py index b93ef1f..6289e75 100644 --- a/src/finetuning_scheduler/fts.py +++ b/src/finetuning_scheduler/fts.py @@ -48,10 +48,9 @@ class FinetuningScheduler(ScheduleImplMixin, ScheduleParsingMixin, CallbackDepMixin, BaseFinetuning): - r""" - This callback enables flexible, multi-phase, scheduled fine-tuning of foundation models. Gradual - unfreezing/thawing can help maximize foundation model knowledge retention while allowing (typically upper layers - of) the model to optimally adapt to new tasks during transfer learning. + r"""This callback enables flexible, multi-phase, scheduled fine-tuning of foundation models. Gradual + unfreezing/thawing can help maximize foundation model knowledge retention while allowing (typically upper + layers of) the model to optimally adapt to new tasks during transfer learning. :class:`~finetuning_scheduler.fts.FinetuningScheduler` orchestrates the gradual unfreezing of models via a fine-tuning schedule that is either implicitly generated (the default) or explicitly provided by the user (more computationally efficient). @@ -358,9 +357,6 @@ def step(self) -> None: assert self.pl_module.trainer.early_stopping_callback is not None self.pl_module.trainer.early_stopping_callback.final_phase = True # type: ignore[attr-defined] assert self._fts_state._ft_sync_objects is not None - if self._fts_state._resume_fit_from_ckpt: - # ensure multi-phase training session loops are synchronized for a fresh epoch restart - self._maybe_sync_loops() FinetuningScheduler.sync(self._fts_state._ft_sync_objects, self._fts_state._ft_sync_props) if self.pl_module._compiler_ctx and self.pl_module._compiler_ctx.get("compiler", None) == "dynamo": # reset currently required as `AOTAutograd`` is getting confused by `requires_grad` alteration @@ -383,8 +379,8 @@ def step_pg( ) -> None: """Configure optimizer parameter groups for the next scheduled fine-tuning level, adding parameter groups beyond the restored optimizer state up to - :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.current_depth` and reinitializing the optimizer and/or - learning rate scheduler as configured. + :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.current_depth` and reinitializing the optimizer + and/or learning rate scheduler as configured. Args: optimizer (:class:`~finetuning_scheduler.types.ParamGroupAddable`): The supported optimizer instance to @@ -731,8 +727,8 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - super().on_fit_start(trainer, pl_module) def state_dict(self) -> Dict[str, Any]: - """Before saving a checkpoint, add the - :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state to be saved. + """Before saving a checkpoint, add the :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback + state to be saved. Returns: Dict[str, Any]: The :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state dictionary @@ -756,9 +752,8 @@ def state_dict(self) -> Dict[str, Any]: } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """After loading a checkpoint, load the saved - :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state and update the - current callback state accordingly. + """After loading a checkpoint, load the saved :class:`~finetuning_scheduler.fts.FinetuningScheduler` + callback state and update the current callback state accordingly. Args: state_dict: The :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state dictionary that will @@ -775,10 +770,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def should_transition(self, trainer: "pl.Trainer") -> bool: """Phase transition logic is contingent on whether we are composing - :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria with - epoch-driven transition constraints or exclusively using epoch-driven transition scheduling. (i.e., - :attr:`~finetuning_scheduler.fts.FinetuningScheduler.epoch_transitions_only` is - ``True``) + :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria with epoch-driven transition + constraints or exclusively using epoch-driven transition scheduling. (i.e., + :attr:`~finetuning_scheduler.fts.FinetuningScheduler.epoch_transitions_only` is ``True``) Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The @@ -816,8 +810,8 @@ def should_transition(self, trainer: "pl.Trainer") -> bool: def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Before beginning a training epoch, configure the internal - :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, prepare the next - scheduled fine-tuning level and store the updated optimizer configuration before continuing training + :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, prepare the next scheduled fine-tuning + level and store the updated optimizer configuration before continuing training. Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The @@ -866,8 +860,8 @@ def on_before_zero_grad( optimizer: ParamGroupAddable, # type: ignore[override] ) -> None: """Afer the latest optimizer step, update the - :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, incrementing the - global fine-tuning steps taken + :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, incrementing the global fine-tuning steps + taken. Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The diff --git a/src/finetuning_scheduler/fts_supporters.py b/src/finetuning_scheduler/fts_supporters.py index f67c16f..44e2e9a 100644 --- a/src/finetuning_scheduler/fts_supporters.py +++ b/src/finetuning_scheduler/fts_supporters.py @@ -69,8 +69,7 @@ @dataclass class FTSState: - """Dataclass to encapsulate the - :class:`~finetuning_scheduler.fts.FinetuningScheduler` internal state.""" + """Dataclass to encapsulate the :class:`~finetuning_scheduler.fts.FinetuningScheduler` internal state.""" _resume_fit_from_ckpt: bool = False _ft_epoch: int = 0 @@ -164,9 +163,8 @@ def connect_callback(self, trainer: "pl.Trainer", reconnect: bool = False) -> No class FTSEarlyStopping(EarlyStopping, CallbackResolverMixin): - r""" - Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` to facilitate - multi-phase scheduled fine-tuning. + r"""Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` to + facilitate multi-phase scheduled fine-tuning. Adds :attr:`es_phase_complete`, :attr:`final_phase` and :attr:`finetuningscheduler_callback` attributes and modifies ``EarlyStopping._evaluate_stopping_criteria`` to enable multi-phase behavior. Usage of @@ -186,7 +184,6 @@ class FTSEarlyStopping(EarlyStopping, CallbackResolverMixin): Currently, :class:`~finetuning_scheduler.fts.FinetuningScheduler` supports the use of one :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback instance at a time. - """ _check_on_train_epoch_end: Optional[bool] best_score: Tensor @@ -323,13 +320,11 @@ def _improvement_message(self, current: Tensor) -> str: class FTSCheckpoint(ModelCheckpoint, CallbackResolverMixin): - r""" - Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` to facilitate - multi-phase scheduled fine-tuning. Overrides the - ``state_dict`` and ``load_state_dict`` hooks to maintain additional state (:attr:`current_ckpt_depth`, - :attr:`best_ckpt_depth`, :attr:`finetuningscheduler_callback`). Usage of - :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` is identical to - :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` and + r"""Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` to + facilitate multi-phase scheduled fine-tuning. Overrides the ``state_dict`` and ``load_state_dict`` hooks to + maintain additional state (:attr:`current_ckpt_depth`, :attr:`best_ckpt_depth`, + :attr:`finetuningscheduler_callback`). Usage of :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` is + identical to :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` and :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` will automatically be used if a :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback is detected. @@ -1365,8 +1360,10 @@ def gen_or_load_sched(self) -> None: def init_ft_sched(self) -> None: """Generate the default fine-tuning schedule and/or load it into - :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.ft_schedule`. Broadcast the - schedule to ensure it is available for use in a distributed context.""" + :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.ft_schedule`. + + Broadcast the schedule to ensure it is available for use in a distributed context. + """ self.gen_or_load_sched() assert isinstance(self.ft_schedule, Dict) if self.max_depth == -1: @@ -1495,9 +1492,8 @@ def load_yaml_schedule(schedule_yaml_file: os.PathLike) -> Dict: return schedule_dict def thaw_to_depth(self, depth: Optional[int] = None) -> None: - """Thaw/unfreeze the current - :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.pl_module` to the specified - fine-tuning depth (aka level) + """Thaw/unfreeze the current :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.pl_module` to the + specified fine-tuning depth (aka level) Args: depth: The depth/level to which the @@ -1672,20 +1668,6 @@ def sync(objs: Tuple, asets: Tuple, agg_func: Callable = max) -> None: for o, a in zip(objs, attrs): setattr(o, a, agg) - def _maybe_sync_loops(self) -> None: - """Synchronize total and current progress loops for the restart of a multi-phase training session.""" - assert self.pl_module._trainer is not None - fit_loop = self.pl_module._trainer.fit_loop - if fit_loop.epoch_loop.restarting: # if ``True``, we haven't completed resetting state - # since we're restoring from a checkpoint saved prior to processed and completed incrementing - fit_loop.epoch_progress.increment_processed() - fit_loop.epoch_progress.increment_completed() - # ensure current and total are synchronized for the continuation of our multi-phase fine-tuning session - fit_loop.epoch_progress.current = copy(fit_loop.epoch_progress.total) - # restarting outside of epoch end is not supported so the assumption here is to start with a fresh epoch - fit_loop.epoch_loop.restarting = False - fit_loop.epoch_loop.val_loop._restarting = False - def _inspect_fts_opt_state(self) -> Tuple: """Distills relevant initialized optimizer state for validation prior to fit start. diff --git a/tests/test_finetuning_scheduler_callback.py b/tests/test_finetuning_scheduler_callback.py index dfd124a..7b9d996 100644 --- a/tests/test_finetuning_scheduler_callback.py +++ b/tests/test_finetuning_scheduler_callback.py @@ -154,8 +154,8 @@ def __init__(self, param1, param2): class FinetuningSchedulerBoringModel(BoringModel): """Extend :class:`~tests.helpers.BoringModel` to facilitate testing of - :class:`~finetuning_scheduler.FinetuningScheduler` by ensuring deterministic divergence - and accommodating no_decay list configuration""" + :class:`~finetuning_scheduler.FinetuningScheduler` by ensuring deterministic divergence and accommodating + no_decay list configuration.""" def __init__( self, @@ -288,7 +288,7 @@ def train_dataloader(self): # return self.model(x) class FTSCustLRModel(FinetuningSchedulerBoringModel): - """overrides lr_scheduler_step to allow lr scheduler testing.""" + """Overrides lr_scheduler_step to allow lr scheduler testing.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1356,6 +1356,7 @@ def test_fts_decay(tmpdir, boring_ft_schedule, explicit_mode: bool, nodecay_mode "Conversion of an array with ndim > 0" # required for PyTorch 2.2 ] EXPECTED_DIRPATH = "is not empty." +EXPECTED_TRAINCHK = "could not find the monitored key in the returned" def ckpt_resume_launch(ckpt_set_fixture: object, diff_dirpath: bool, ckpt: str, max_depth: int, tmpdir: Path, save_on_train_epoch_end: Optional[bool] = None) -> None: @@ -1420,6 +1421,8 @@ def test_fts_callback_resume(tmpdir, ckpt_set, recwarn, diff_dirpath: bool, trai assert fts_callback.curr_depth == fts_callback.max_depth if not diff_dirpath: resume_warns.append(EXPECTED_DIRPATH) + if train_chk_mode: + resume_warns.append(EXPECTED_TRAINCHK) # ensure no unexpected warnings detected unexpected = unexpected_warns(rec_warns=recwarn.list, expected_warns=resume_warns) assert not unexpected, tuple(w.message.args[0] + ":" + w.filename + ":" + str(w.lineno) for w in unexpected) @@ -2213,7 +2216,7 @@ def test_fts_callback_warns( tmpdir, recwarn, callbacks: List[Callback], cust_monitor: Optional[str], dist_mode: str, expected: Tuple[str] ): """Validate :class:`~finetuning_scheduler.FinetuningScheduler` warnings that require a - :class:`~pytorch_lighting.trainer.Trainer` to be defined are properly issued""" + :class:`~pytorch_lighting.trainer.Trainer` to be defined are properly issued.""" model = FinetuningSchedulerBoringModel(monitor_metric=cust_monitor) dist_args = {"strategy": dist_mode, "accelerator": "cpu", "devices": "auto"} if dist_mode else {"devices": 1} trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, **dist_args)