Skip to content

Commit

Permalink
Added logic to more robustly condition depth-aligned checkpoint metad…
Browse files Browse the repository at this point in the history
…ata updates to address edge-cases where `current_score` precisely equaled the `best_model_score` at multiple different depths.
  • Loading branch information
speediedan committed Aug 28, 2024
1 parent 22f2a1a commit 0ef51de
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 33 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.4.1] - 2024-XX-XX

### Added

- Support for Lightning and PyTorch ``2.4.1``

### Fixed

- Added logic to more robustly condition depth-aligned checkpoint metadata updates to address edge-cases where `current_score` precisely equaled the `best_model_score` at multiple different depths. Resolved [#15](https://github.com/speediedan/finetuning-scheduler/issues/15).

## [2.4.0] - 2024-08-15

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/finetuning_scheduler/fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def state_dict(self) -> Dict[str, Any]:
assert self.pl_module is not None and self.pl_module.trainer is not None
trainer = self.pl_module.trainer
checkpoint_callback = trainer.checkpoint_callback
if checkpoint_callback.current_score == checkpoint_callback.best_model_score: # type: ignore[union-attr]
if checkpoint_callback._should_update_depth_meta: # type: ignore [union-attr]
self._fts_state._best_ckpt_depth = self._fts_state._curr_depth
for opt_idx, _ in enumerate(trainer.optimizers):
self._fts_state._fts_ckpt_metadata["best_ckpt_pgs"][opt_idx] = deepcopy(
Expand Down
38 changes: 33 additions & 5 deletions src/finetuning_scheduler/fts_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
import inspect
import re
import warnings
from contextlib import contextmanager
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections.abc import KeysView
from copy import copy, deepcopy
from dataclasses import dataclass, field, fields
from functools import reduce
from pprint import pformat
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Set
from typing_extensions import TypeAlias
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Set, Iterator
from typing_extensions import TypeAlias, override

import lightning.pytorch as pl
import torch
Expand Down Expand Up @@ -375,6 +376,8 @@ def __init__(self, *args: Any, **kwargs: Any):
self.current_ckpt_depth = 0
self.best_ckpt_depth = 0
self.finetuningscheduler_callback = None
self._prev_best_model_path = ''
self._has_depth_metadata_lock = False

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Verify a valid callback configuration is present before beginning training.
Expand Down Expand Up @@ -422,10 +425,8 @@ def state_dict(self) -> Dict[str, Any]:
Dict[str, Any]: the callback state dictionary that will be saved.
"""
self.current_ckpt_depth = self.finetuningscheduler_callback.curr_depth # type: ignore[attr-defined]
# note, if current score is precisely the best score but a previous depth had the same score the
# best ckpt depth will be set to the latest (deepest) depth with that score.
# a future enhancement of per-depth best score mapping could allow more fine-grained control of this behavior
if self.current_score == self.best_model_score:
if self._should_update_depth_meta:
self.best_ckpt_depth = self.current_ckpt_depth
return {
"monitor": self.monitor,
Expand Down Expand Up @@ -474,6 +475,33 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
self.best_model_path = state_dict["best_model_path"]

@property
def _should_update_depth_meta(self) -> bool:
# Depth-aligned checkpoint metadata is only updated if:
# 1. We are currently saving a top-k checkpoint
# 2. The `best_model_path` has changed
return self._has_depth_metadata_lock and self._prev_best_model_path != self.best_model_path

@contextmanager
def _depth_metadata_lock(self) -> Iterator[None]:
"""Context manager that conditions just-in-time mutability of depth-aligned checkpoint metadata."""
try:
self._has_depth_metadata_lock = True
yield
finally:
self._has_depth_metadata_lock = False

@override
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
"""Wrapper around :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint`'s
``_save_topk_checkpoint`` method.
To avoid altering the checkpoint sorting and saving logic of the superclass while conditionally enriching it
with depth-aligned checkpoint metadata and handling edge cases, we wrap this method with additional context.
"""
with self._depth_metadata_lock():
self._prev_best_model_path = self.best_model_path
super()._save_topk_checkpoint(trainer, monitor_candidates)

FTSCallbackDepType: TypeAlias = Union[Type[FTSEarlyStopping], Type[FTSCheckpoint]]

Expand Down
120 changes: 93 additions & 27 deletions tests/test_finetuning_scheduler_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ def configure_optimizers(self):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
return [optimizer], [lr_scheduler]

class ConstantLossBoringModel(FinetuningSchedulerBoringModel):
def validation_step(self, batch, batch_idx):
output = self(batch)
loss = self.val_loss(batch, output)
self.validation_step_outputs.append(loss)
self.log(self.monitor_metric, 1, prog_bar=False) # edge case with every checkpoint having the same loss
return {"x": loss}

class NonDynamicLossBoringModel(FinetuningSchedulerBoringModel):
def val_loss(self, batch, prediction):
Expand Down Expand Up @@ -543,28 +550,38 @@ def _monitor_candidates(self, trainer: Trainer) -> Dict[str, Tensor]:

class NonDynamicPhase0EnforceModel(NonDynamicLossBoringModel):
def configure_optimizers(self):
# if self.p0_params:
# for n, p in self.named_parameters():
# p.requires_grad = True if n in self.p0_params else False
# parameters = list(filter(lambda x: x.requires_grad, self.parameters()))
# optimizer = torch.optim.SGD(parameters, lr=1e-3, weight_decay=self.weight_decay)
# else:
optimizer = torch.optim.SGD(self.parameters(), lr=1e-3, weight_decay=self.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
return [optimizer], [lr_scheduler]


@pytest.fixture(scope="function")
def ckpt_set(tmpdir_factory) -> Dict:
"""A fixture that generates a 'best' and 'kth' checkpoint to be used in scheduled fine-tuning resumption
testing."""
def ckpt_set_setup(save_top_k: int, save_last: Optional[bool] = None) -> Dict:
seed_everything(42)
callbacks = [
FinetuningScheduler(max_depth=1),
FTSEarlyStopping(monitor="val_loss", patience=1, min_delta=0.001),
FTSCheckpoint(monitor="val_loss", verbose=True, save_top_k=2),
FTSCheckpoint(monitor="val_loss", verbose=True, save_last=save_last, save_top_k=save_top_k),
]
model = FinetuningSchedulerBoringModel()
return callbacks, model


@pytest.fixture(scope="function")
def ckpt_set_last(tmpdir_factory) -> Dict:
"""A fixture that generates a 'best', 'kth' and 'last' checkpoint to be used in scheduled fine-tuning
resumption testing."""
callbacks, model = ckpt_set_setup(save_top_k=3, save_last=True)
trainer = Trainer(default_root_dir=tmpdir_factory.getbasetemp(), callbacks=callbacks, devices=1)
trainer.fit(model)
return {"best": trainer.checkpoint_callback.best_model_path, "kth": trainer.checkpoint_callback.kth_best_model_path,
"last": trainer.checkpoint_callback.last_model_path}


@pytest.fixture(scope="function")
def ckpt_set(tmpdir_factory) -> Dict:
"""A fixture that generates a 'best' and 'kth' checkpoint to be used in scheduled fine-tuning resumption
testing."""
callbacks, model = ckpt_set_setup(save_top_k=2)
trainer = Trainer(default_root_dir=tmpdir_factory.getbasetemp(), callbacks=callbacks, devices=1)
trainer.fit(model)
return {"best": trainer.checkpoint_callback.best_model_path, "kth": trainer.checkpoint_callback.kth_best_model_path}
Expand Down Expand Up @@ -1311,21 +1328,14 @@ def test_fts_decay(tmpdir, boring_ft_schedule, explicit_mode: bool, nodecay_mode
]
EXPECTED_DIRPATH = "is not empty."

@pytest.mark.parametrize("diff_dirpath,", [True, False], ids=["diffdirpath", "samedirpath"])
@pytest.mark.parametrize("train_chk_mode,", [None, True], ids=["defaultchk", "trainchk"])
@pytest.mark.parametrize("ckpt,", ["best", "kth"], ids=["best", "kth"])
@pytest.mark.parametrize("max_depth", [-1, 1], ids=["nomaxdepth", "maxdepth1"])
def test_fts_callback_resume(
tmpdir, ckpt_set, recwarn, diff_dirpath: bool, train_chk_mode: Optional[bool], ckpt: str, max_depth: int
):
"""Validate scheduled fine-tuning resumption functions as expected from both 'best' and 'kth'(not-best)
checkpoints in both train/val stage check modes with and without max_depth specified."""
resume_warns = copy(EXPECTED_WARNS)
dirpath = None if diff_dirpath else Path(ckpt_set["best"]).parent
def ckpt_resume_launch(ckpt_set_fixture: object, diff_dirpath: bool, ckpt: str, max_depth: int, tmpdir: Path,
save_on_train_epoch_end: Optional[bool] = None) -> None:
dirpath = None if diff_dirpath else Path(ckpt_set_fixture["best"]).parent
resume_callbacks = [
FTSEarlyStopping(monitor="val_loss", patience=1, min_delta=0.001),
FTSCheckpoint(
monitor="val_loss", dirpath=dirpath, save_on_train_epoch_end=train_chk_mode, verbose=True, save_top_k=3
monitor="val_loss", dirpath=dirpath, verbose=True, save_top_k=3,
save_on_train_epoch_end=save_on_train_epoch_end
),
]
resume_callbacks.append(FinetuningScheduler(max_depth=max_depth, logging_level=DEBUG))
Expand All @@ -1334,8 +1344,38 @@ def test_fts_callback_resume(
model = FinetuningSchedulerBoringModel()
trainer = Trainer(default_root_dir=tmpdir, callbacks=resume_callbacks, devices=1)
finetuningscheduler_callback = get_fts(trainer)
trainer.ckpt_path = ckpt_set[ckpt]
trainer.ckpt_path = ckpt_set_fixture[ckpt]
trainer.fit(model)
return finetuningscheduler_callback, resume_callbacks, trainer

@pytest.mark.parametrize("diff_dirpath,", [True, False], ids=["diffdirpath", "samedirpath"])
@pytest.mark.parametrize("ckpt,", ["best", "kth", "last"], ids=["best", "kth", "last"])
def test_fts_callback_resume_last(tmpdir, ckpt_set_last, recwarn, diff_dirpath: bool, ckpt: str):
"""Validate scheduled fine-tuning resumption functions as expected from both 'best' and 'kth'(not-best)
checkpoints in both train/val stage check modes with and without max_depth specified."""
resume_warns = copy(EXPECTED_WARNS)
fts_callback, *_ = ckpt_resume_launch(ckpt_set_fixture=ckpt_set_last, diff_dirpath=diff_dirpath, ckpt=ckpt,
max_depth=1, tmpdir=tmpdir)
assert fts_callback.curr_depth == fts_callback.max_depth
if not diff_dirpath:
resume_warns.append(EXPECTED_DIRPATH)
# ensure no unexpected warnings detected
unexpected = unexpected_warns(rec_warns=recwarn.list, expected_warns=resume_warns)
assert not unexpected, tuple(w.message.args[0] + ":" + w.filename + ":" + str(w.lineno) for w in unexpected)


@pytest.mark.parametrize("diff_dirpath,", [True, False], ids=["diffdirpath", "samedirpath"])
@pytest.mark.parametrize("train_chk_mode,", [None, True], ids=["defaultchk", "trainchk"])
@pytest.mark.parametrize("ckpt,", ["best", "kth"], ids=["best", "kth"])
@pytest.mark.parametrize("max_depth", [-1, 1], ids=["nomaxdepth", "maxdepth1"])
def test_fts_callback_resume(tmpdir, ckpt_set, recwarn, diff_dirpath: bool, train_chk_mode: Optional[bool], ckpt: str,
max_depth: int):
"""Validate scheduled fine-tuning resumption functions as expected from both 'best' and 'kth'(not-best)
checkpoints in both train/val stage check modes with and without max_depth specified."""
resume_warns = copy(EXPECTED_WARNS)
fts_callback, resume_callbacks, trainer = ckpt_resume_launch(ckpt_set_fixture=ckpt_set, diff_dirpath=diff_dirpath,
ckpt=ckpt, save_on_train_epoch_end=train_chk_mode,
max_depth=max_depth, tmpdir=tmpdir)
# note if save_on_train_epoch_end is set to `None` then it will be False by default
expected_state = EXPECTED_RESUME_RESULTS[
(
Expand All @@ -1346,9 +1386,9 @@ def test_fts_callback_resume(
)
]
assert trainer.checkpoint_callback.best_ckpt_depth == expected_state[0]
assert finetuningscheduler_callback.depth_remaining == expected_state[1]
assert finetuningscheduler_callback.curr_depth == expected_state[2]
assert finetuningscheduler_callback.curr_depth == finetuningscheduler_callback.max_depth
assert fts_callback.depth_remaining == expected_state[1]
assert fts_callback.curr_depth == expected_state[2]
assert fts_callback.curr_depth == fts_callback.max_depth
if not diff_dirpath:
resume_warns.append(EXPECTED_DIRPATH)
# ensure no unexpected warnings detected
Expand Down Expand Up @@ -2563,6 +2603,32 @@ def test_fts_zero_opt_support(monkeypatch, tmpdir, strategy, enf_p0):
trainer.fit(model)


@pytest.mark.parametrize(
"constant_loss, expected_state",
[(True, (5, 0, 3, 1, 1, 0)), (False, (6, 2, 1, 1, 1, 0))],
ids=["constant_loss", "normal_loss"],
)
def test_fts_constant_loss(tmpdir, constant_loss: bool, expected_state: Tuple):
"""Validate scheduled fine-tuning works as expected in edge cases where the monitored loss value is constant
across multiple depths, exercising the logic necessary to disambiguate the current best checkpoint metadata."""
seed_everything(42)
model = ConstantLossBoringModel()if constant_loss else FinetuningSchedulerBoringModel()
callbacks = [
FTSCheckpoint(monitor="val_loss", verbose=True, save_top_k=3),
FinetuningScheduler(),
FTSEarlyStopping(monitor="val_loss", patience=1), # including an extraneous earlystopping callback to test warn
]
trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, devices=1, max_epochs=6)
finetuningscheduler_callback = get_fts(trainer)
trainer.fit(model)
assert finetuningscheduler_callback._fts_state._ft_epoch == expected_state[0]
assert finetuningscheduler_callback.depth_remaining == expected_state[1]
assert finetuningscheduler_callback.curr_depth == expected_state[2]
assert len(finetuningscheduler_callback._fts_state._fts_ckpt_metadata['best_ckpt_pgs']) == expected_state[3]
assert finetuningscheduler_callback._fts_state._fts_ckpt_metadata['current_ckpt_depth'] == expected_state[4]
assert finetuningscheduler_callback._fts_state._fts_ckpt_metadata['best_ckpt_depth'] == expected_state[5]


@pytest.mark.parametrize(
"epoch_only_cfg, expected_state",
[(True, ((0, 2, 6, 8, 3, 3), "extraneous EarlyS", "maximum phase-specified")), (False, (None, "missing a max_"))],
Expand Down

0 comments on commit 0ef51de

Please sign in to comment.