Skip to content

Commit

Permalink
- bumped lightning dev sha
Browse files Browse the repository at this point in the history
- removed `_maybe_sync_loops` after Lightning-AI/pytorch-lightning#20379 obviated the need for it
  • Loading branch information
speediedan committed Nov 23, 2024
1 parent 2a7a4d1 commit 6b1da2d
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 61 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#lightning>=2.5.0,<2.5.1
# the below is uncommented when master is targeting a specific pl dev master commit
git+https://github.com/Lightning-AI/lightning.git@06a8d5bf33faf0a4f9a24207ae77b439354350af#egg=lightning
git+https://github.com/Lightning-AI/lightning.git@8ce52876ad6e5eb05e0965f72e034ffe46b327ba#egg=lightning
torch>=2.2.0
2 changes: 1 addition & 1 deletion requirements/standalone_base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pytorch-lightning>=2.5.0,<2.5.1
# the below is uncommented when master is targeting a specific pl dev master commit
git+https://github.com/Lightning-AI/pytorch-lightning.git@06a8d5bf33faf0a4f9a24207ae77b439354350af#egg=pytorch-lightning
git+https://github.com/Lightning-AI/pytorch-lightning.git@8ce52876ad6e5eb05e0965f72e034ffe46b327ba#egg=pytorch-lightning
torch>=2.2.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _setup_args(standalone: bool = False) -> Dict[str, Any]:
_INSTALL_PATHS["require"],
file_name=base_reqs,
standalone=standalone,
pl_commit="06a8d5bf33faf0a4f9a24207ae77b439354350af",
pl_commit="8ce52876ad6e5eb05e0965f72e034ffe46b327ba",
)
base_setup["install_requires"] = install_requires
return base_setup
Expand Down
38 changes: 16 additions & 22 deletions src/finetuning_scheduler/fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@


class FinetuningScheduler(ScheduleImplMixin, ScheduleParsingMixin, CallbackDepMixin, BaseFinetuning):
r"""
This callback enables flexible, multi-phase, scheduled fine-tuning of foundation models. Gradual
unfreezing/thawing can help maximize foundation model knowledge retention while allowing (typically upper layers
of) the model to optimally adapt to new tasks during transfer learning.
r"""This callback enables flexible, multi-phase, scheduled fine-tuning of foundation models. Gradual
unfreezing/thawing can help maximize foundation model knowledge retention while allowing (typically upper
layers of) the model to optimally adapt to new tasks during transfer learning.
:class:`~finetuning_scheduler.fts.FinetuningScheduler` orchestrates the gradual unfreezing of models via a
fine-tuning schedule that is either implicitly generated (the default) or explicitly provided by the user (more
computationally efficient).
Expand Down Expand Up @@ -358,9 +357,6 @@ def step(self) -> None:
assert self.pl_module.trainer.early_stopping_callback is not None
self.pl_module.trainer.early_stopping_callback.final_phase = True # type: ignore[attr-defined]
assert self._fts_state._ft_sync_objects is not None
if self._fts_state._resume_fit_from_ckpt:
# ensure multi-phase training session loops are synchronized for a fresh epoch restart
self._maybe_sync_loops()
FinetuningScheduler.sync(self._fts_state._ft_sync_objects, self._fts_state._ft_sync_props)
if self.pl_module._compiler_ctx and self.pl_module._compiler_ctx.get("compiler", None) == "dynamo":
# reset currently required as `AOTAutograd`` is getting confused by `requires_grad` alteration
Expand All @@ -383,8 +379,8 @@ def step_pg(
) -> None:
"""Configure optimizer parameter groups for the next scheduled fine-tuning level, adding parameter groups
beyond the restored optimizer state up to
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.current_depth` and reinitializing the optimizer and/or
learning rate scheduler as configured.
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.current_depth` and reinitializing the optimizer
and/or learning rate scheduler as configured.
Args:
optimizer (:class:`~finetuning_scheduler.types.ParamGroupAddable`): The supported optimizer instance to
Expand Down Expand Up @@ -731,8 +727,8 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
super().on_fit_start(trainer, pl_module)

def state_dict(self) -> Dict[str, Any]:
"""Before saving a checkpoint, add the
:class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state to be saved.
"""Before saving a checkpoint, add the :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback
state to be saved.
Returns:
Dict[str, Any]: The :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state dictionary
Expand All @@ -756,9 +752,8 @@ def state_dict(self) -> Dict[str, Any]:
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""After loading a checkpoint, load the saved
:class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state and update the
current callback state accordingly.
"""After loading a checkpoint, load the saved :class:`~finetuning_scheduler.fts.FinetuningScheduler`
callback state and update the current callback state accordingly.
Args:
state_dict: The :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state dictionary that will
Expand All @@ -775,10 +770,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

def should_transition(self, trainer: "pl.Trainer") -> bool:
"""Phase transition logic is contingent on whether we are composing
:class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria with
epoch-driven transition constraints or exclusively using epoch-driven transition scheduling. (i.e.,
:attr:`~finetuning_scheduler.fts.FinetuningScheduler.epoch_transitions_only` is
``True``)
:class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria with epoch-driven transition
constraints or exclusively using epoch-driven transition scheduling. (i.e.,
:attr:`~finetuning_scheduler.fts.FinetuningScheduler.epoch_transitions_only` is ``True``)
Args:
trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The
Expand Down Expand Up @@ -816,8 +810,8 @@ def should_transition(self, trainer: "pl.Trainer") -> bool:

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Before beginning a training epoch, configure the internal
:attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, prepare the next
scheduled fine-tuning level and store the updated optimizer configuration before continuing training
:attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, prepare the next scheduled fine-tuning
level and store the updated optimizer configuration before continuing training.
Args:
trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The
Expand Down Expand Up @@ -866,8 +860,8 @@ def on_before_zero_grad(
optimizer: ParamGroupAddable, # type: ignore[override]
) -> None:
"""Afer the latest optimizer step, update the
:attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, incrementing the
global fine-tuning steps taken
:attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, incrementing the global fine-tuning steps
taken.
Args:
trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The
Expand Down
46 changes: 14 additions & 32 deletions src/finetuning_scheduler/fts_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@

@dataclass
class FTSState:
"""Dataclass to encapsulate the
:class:`~finetuning_scheduler.fts.FinetuningScheduler` internal state."""
"""Dataclass to encapsulate the :class:`~finetuning_scheduler.fts.FinetuningScheduler` internal state."""

_resume_fit_from_ckpt: bool = False
_ft_epoch: int = 0
Expand Down Expand Up @@ -164,9 +163,8 @@ def connect_callback(self, trainer: "pl.Trainer", reconnect: bool = False) -> No


class FTSEarlyStopping(EarlyStopping, CallbackResolverMixin):
r"""
Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` to facilitate
multi-phase scheduled fine-tuning.
r"""Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` to
facilitate multi-phase scheduled fine-tuning.
Adds :attr:`es_phase_complete`, :attr:`final_phase` and :attr:`finetuningscheduler_callback` attributes and modifies
``EarlyStopping._evaluate_stopping_criteria`` to enable multi-phase behavior. Usage of
Expand All @@ -186,7 +184,6 @@ class FTSEarlyStopping(EarlyStopping, CallbackResolverMixin):
Currently, :class:`~finetuning_scheduler.fts.FinetuningScheduler` supports the use of one
:class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback instance at a time.
"""
_check_on_train_epoch_end: Optional[bool]
best_score: Tensor
Expand Down Expand Up @@ -323,13 +320,11 @@ def _improvement_message(self, current: Tensor) -> str:


class FTSCheckpoint(ModelCheckpoint, CallbackResolverMixin):
r"""
Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` to facilitate
multi-phase scheduled fine-tuning. Overrides the
``state_dict`` and ``load_state_dict`` hooks to maintain additional state (:attr:`current_ckpt_depth`,
:attr:`best_ckpt_depth`, :attr:`finetuningscheduler_callback`). Usage of
:class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` is identical to
:external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` and
r"""Extends/specializes :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` to
facilitate multi-phase scheduled fine-tuning. Overrides the ``state_dict`` and ``load_state_dict`` hooks to
maintain additional state (:attr:`current_ckpt_depth`, :attr:`best_ckpt_depth`,
:attr:`finetuningscheduler_callback`). Usage of :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` is
identical to :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` and
:class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` will automatically be used if a
:class:`~finetuning_scheduler.fts.FinetuningScheduler` callback is detected.
Expand Down Expand Up @@ -1365,8 +1360,10 @@ def gen_or_load_sched(self) -> None:

def init_ft_sched(self) -> None:
"""Generate the default fine-tuning schedule and/or load it into
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.ft_schedule`. Broadcast the
schedule to ensure it is available for use in a distributed context."""
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.ft_schedule`.
Broadcast the schedule to ensure it is available for use in a distributed context.
"""
self.gen_or_load_sched()
assert isinstance(self.ft_schedule, Dict)
if self.max_depth == -1:
Expand Down Expand Up @@ -1495,9 +1492,8 @@ def load_yaml_schedule(schedule_yaml_file: os.PathLike) -> Dict:
return schedule_dict

def thaw_to_depth(self, depth: Optional[int] = None) -> None:
"""Thaw/unfreeze the current
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.pl_module` to the specified
fine-tuning depth (aka level)
"""Thaw/unfreeze the current :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.pl_module` to the
specified fine-tuning depth (aka level)
Args:
depth: The depth/level to which the
Expand Down Expand Up @@ -1672,20 +1668,6 @@ def sync(objs: Tuple, asets: Tuple, agg_func: Callable = max) -> None:
for o, a in zip(objs, attrs):
setattr(o, a, agg)

def _maybe_sync_loops(self) -> None:
"""Synchronize total and current progress loops for the restart of a multi-phase training session."""
assert self.pl_module._trainer is not None
fit_loop = self.pl_module._trainer.fit_loop
if fit_loop.epoch_loop.restarting: # if ``True``, we haven't completed resetting state
# since we're restoring from a checkpoint saved prior to processed and completed incrementing
fit_loop.epoch_progress.increment_processed()
fit_loop.epoch_progress.increment_completed()
# ensure current and total are synchronized for the continuation of our multi-phase fine-tuning session
fit_loop.epoch_progress.current = copy(fit_loop.epoch_progress.total)
# restarting outside of epoch end is not supported so the assumption here is to start with a fresh epoch
fit_loop.epoch_loop.restarting = False
fit_loop.epoch_loop.val_loop._restarting = False

def _inspect_fts_opt_state(self) -> Tuple:
"""Distills relevant initialized optimizer state for validation prior to fit start.
Expand Down
11 changes: 7 additions & 4 deletions tests/test_finetuning_scheduler_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def __init__(self, param1, param2):

class FinetuningSchedulerBoringModel(BoringModel):
"""Extend :class:`~tests.helpers.BoringModel` to facilitate testing of
:class:`~finetuning_scheduler.FinetuningScheduler` by ensuring deterministic divergence
and accommodating no_decay list configuration"""
:class:`~finetuning_scheduler.FinetuningScheduler` by ensuring deterministic divergence and accommodating
no_decay list configuration."""

def __init__(
self,
Expand Down Expand Up @@ -288,7 +288,7 @@ def train_dataloader(self):
# return self.model(x)

class FTSCustLRModel(FinetuningSchedulerBoringModel):
"""overrides lr_scheduler_step to allow lr scheduler testing."""
"""Overrides lr_scheduler_step to allow lr scheduler testing."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -1356,6 +1356,7 @@ def test_fts_decay(tmpdir, boring_ft_schedule, explicit_mode: bool, nodecay_mode
"Conversion of an array with ndim > 0" # required for PyTorch 2.2
]
EXPECTED_DIRPATH = "is not empty."
EXPECTED_TRAINCHK = "could not find the monitored key in the returned"

def ckpt_resume_launch(ckpt_set_fixture: object, diff_dirpath: bool, ckpt: str, max_depth: int, tmpdir: Path,
save_on_train_epoch_end: Optional[bool] = None) -> None:
Expand Down Expand Up @@ -1420,6 +1421,8 @@ def test_fts_callback_resume(tmpdir, ckpt_set, recwarn, diff_dirpath: bool, trai
assert fts_callback.curr_depth == fts_callback.max_depth
if not diff_dirpath:
resume_warns.append(EXPECTED_DIRPATH)
if train_chk_mode:
resume_warns.append(EXPECTED_TRAINCHK)
# ensure no unexpected warnings detected
unexpected = unexpected_warns(rec_warns=recwarn.list, expected_warns=resume_warns)
assert not unexpected, tuple(w.message.args[0] + ":" + w.filename + ":" + str(w.lineno) for w in unexpected)
Expand Down Expand Up @@ -2213,7 +2216,7 @@ def test_fts_callback_warns(
tmpdir, recwarn, callbacks: List[Callback], cust_monitor: Optional[str], dist_mode: str, expected: Tuple[str]
):
"""Validate :class:`~finetuning_scheduler.FinetuningScheduler` warnings that require a
:class:`~pytorch_lighting.trainer.Trainer` to be defined are properly issued"""
:class:`~pytorch_lighting.trainer.Trainer` to be defined are properly issued."""
model = FinetuningSchedulerBoringModel(monitor_metric=cust_monitor)
dist_args = {"strategy": dist_mode, "accelerator": "cpu", "devices": "auto"} if dist_mode else {"devices": 1}
trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, **dist_args)
Expand Down

0 comments on commit 6b1da2d

Please sign in to comment.