diff --git a/docs/source/index.rst b/docs/source/index.rst
index 8a2124c..c6bff16 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -20,7 +20,7 @@ foundation model experimentation with flexible fine-tuning schedules. Training w
    If you're exploring using the :class:`~finetuning_scheduler.fts.FinetuningScheduler`, this is a great place
    to start!
    You may also find the `notebook-based tutorial <https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/finetuning-scheduler.html>`_
-   useful and for those using the :doc:`LightningCLI<cli/lightning_cli>`, there is a
+   useful and for those using the :external+pl:class:`~lightning.pytorch.cli.LightningCLI`, there is a
    :ref:`CLI-based<scheduled-fine-tuning-superglue>` example at the bottom of this introduction.
 
 Setup
@@ -69,6 +69,14 @@ and :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` callbacks with
 
     trainer = L.Trainer(callbacks=[FinetuningScheduler()])
 
+.. note::
+    If not provided, FTS will instantiate its callback dependencies
+    (:class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` and
+    :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint`) with default configurations and ``monitor=val_loss``.
+    If the user provides base versions of these dependencies (e.g.
+    :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`,
+    :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint`) the provided configuration of
+    those callbacks will be used to instantiate their FTS analogs instead.
 
 .. _default schedule:
 
@@ -353,8 +361,8 @@ A demonstration of the scheduled fine-tuning callback
 :class:`~finetuning_scheduler.fts.FinetuningScheduler` using the
 `RTE <https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte>`_ and
 `BoolQ <https://github.com/google-research-datasets/boolean-questions>`_ tasks of the
-`SuperGLUE <https://paperswithcode.com/dataset/superglue>`_ benchmark and the :doc:`LightningCLI<cli/lightning_cli>`
-is available under ``./fts_examples/stable``.
+`SuperGLUE <https://paperswithcode.com/dataset/superglue>`_ benchmark and the
+:external+pl:class:`~lightning.pytorch.cli.LightningCLI` is available under ``./fts_examples/stable``.
 
 Since this CLI-based example requires a few additional packages (e.g. ``transformers``, ``sentencepiece``), you
 should install them using the ``[examples]`` extra:
diff --git a/src/finetuning_scheduler/fts_supporters.py b/src/finetuning_scheduler/fts_supporters.py
index be250a6..76e94c2 100644
--- a/src/finetuning_scheduler/fts_supporters.py
+++ b/src/finetuning_scheduler/fts_supporters.py
@@ -20,6 +20,7 @@
 import logging
 import os
 import pathlib
+import inspect
 import re
 import warnings
 from abc import ABC, abstractmethod
@@ -30,6 +31,7 @@
 from functools import reduce
 from pprint import pformat
 from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
+from typing_extensions import TypeAlias
 
 import lightning.pytorch as pl
 import torch
@@ -51,7 +53,8 @@
 from torch.nn import Module
 
 from finetuning_scheduler.strategy_adapters.fsdp import FSDPStrategyAdapter, StrategyAdapter
-from finetuning_scheduler.types import FTSLRSchedulerType, FTSLRSchedulerTypeTuple, ParamGroupAddable
+from finetuning_scheduler.types import (FTSLRSchedulerType, FTSLRSchedulerTypeTuple, ParamGroupAddable,
+                                        BaseCallbackDepType)
 
 log = logging.getLogger(__name__)
 
@@ -472,6 +475,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
             self.best_model_path = state_dict["best_model_path"]
 
 
+FTSCallbackDepType: TypeAlias = Union[Type[FTSEarlyStopping], Type[FTSCheckpoint]]
+
 class UniqueKeyLoader(yaml.SafeLoader):
     """Alters SafeLoader to enable duplicate key detection by the SafeConstructor."""
 
@@ -1822,6 +1827,40 @@ def _reorder_callback_by_type(callbacks: List[Callback], target_callback: type)
         other_callbacks = [c for c in callbacks if not isinstance(c, target_callback)]
         return other_callbacks + target_callbacks
 
+    @staticmethod
+    def _extract_base_callback_cfg(trainer: "pl.Trainer", callback_type: BaseCallbackDepType) -> Dict:
+        """Extracts the configuration of a user-provided.
+
+        :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` or
+        :external+pl:class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callback to enable the
+        subsequent instantiation of a fine-tuning schedule-capable FTS analog with a similar configuration.
+
+        Args:
+            trainer (pl.Trainer): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object.
+            callback_type (BaseCallbackDepType): The type of base callback from which to extract the configuration.
+
+        Returns:
+            Dict: The extracted user-provided callback configuration.
+        """
+        base_callback  = [c for c in trainer.callbacks if isinstance(c, callback_type)][0]
+        base_callback_params = dict(inspect.signature(base_callback.__init__).parameters)
+        return {k: v for k, v in base_callback.__dict__.items() if k in base_callback_params}
+
+    @staticmethod
+    def _add_fts_callback(trainer: "pl.Trainer", fts_cls: FTSCallbackDepType, cfg: Dict) -> None:
+        """Adds a fine-tuning schedule-capable FTS callback dependency with a specified configuration.
+
+        Args:
+            trainer (pl.Trainer): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object.
+            fts_cls (FTSCallbackDepType): The type of FTS callback dependency to instantiate.
+            cfg (Dict): The desired FTS callback configuration.
+        """
+        if cfg.get("monitor", None) is None:
+            cfg["monitor"] = "val_loss"
+            rank_zero_warn(f"No monitor metric specified for {fts_cls.__class__.__name__},"
+                           " using 'val_loss' as default.")
+        trainer.callbacks.append(fts_cls(**cfg))
+
     def _callback_dep_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
         """Ensures all :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback dependencies are met, adding
         and configuring them if necessary.
@@ -1858,7 +1897,7 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback
             Bool: Whether a :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` callback was added
         """
         has_ckpt_fts, has_ckpt_base, has_es_fts, has_es_base, has_lr_monitor = self._inspect_callback_deps(trainer)
-        added_ckpt_fts, added_es_fts = False, False
+        added_ckpt_fts, added_es_fts, added_ckpt_fts_kwargs, added_es_fts_kwargs = False, False, {}, {}
         if not any([has_es_fts, self.epoch_transitions_only, self.gen_ft_sched_only]):  # type: ignore[attr-defined]
             if has_es_base:
                 rank_zero_warn(
@@ -1866,13 +1905,14 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback
                     "capable EarlyStopping callback such as FTSEarlyStopping. Substituting current "
                     "EarlyStopping for FTSEarlyStopping"
                 )
+                added_es_fts_kwargs = CallbackDepMixin._extract_base_callback_cfg(trainer, EarlyStopping)
                 trainer.callbacks = [c for c in trainer.callbacks if not isinstance(c, EarlyStopping)]
             else:
                 rank_zero_warn(
                     f"{self.__class__.__name__} currently depends upon an FTSEarlyStopping callback unless configured "
                     "in epoch_transitions_only mode. Adding an FTSEarlyStopping callback with default configuration."
                 )
-            trainer.callbacks.append(FTSEarlyStopping(monitor="val_loss"))
+            CallbackDepMixin._add_fts_callback(trainer, FTSEarlyStopping, added_es_fts_kwargs)
             added_es_fts = True
         if (has_es_fts or has_es_base) and self.epoch_transitions_only:  # type: ignore[attr-defined]
             rank_zero_warn(
@@ -1887,8 +1927,9 @@ def _configure_callback_deps(self, trainer: "pl.Trainer") -> Tuple[List[Callback
                     "capable ModelCheckpoint callback such as FTSCheckpoint. Substituting current "
                     "ModelCheckpoint for FTSCheckpoint"
                 )
+                added_ckpt_fts_kwargs = CallbackDepMixin._extract_base_callback_cfg(trainer, ModelCheckpoint)
                 trainer.callbacks = [c for c in trainer.callbacks if not isinstance(c, ModelCheckpoint)]
-            trainer.callbacks.append(FTSCheckpoint(monitor="val_loss", verbose=True))
+            CallbackDepMixin._add_fts_callback(trainer, FTSCheckpoint, added_ckpt_fts_kwargs)
             added_ckpt_fts = True
         for uc in [c for c in trainer.callbacks if any([isinstance(c, d) for d in CALLBACK_DEP_PARENTS.values()])]:
             uc.connect_callback(trainer)  # type: ignore[attr-defined]
diff --git a/src/finetuning_scheduler/types.py b/src/finetuning_scheduler/types.py
index 70ab581..ac135f1 100644
--- a/src/finetuning_scheduler/types.py
+++ b/src/finetuning_scheduler/types.py
@@ -17,9 +17,11 @@
 
 """
 from typing import Any, Dict, Protocol, runtime_checkable, Type, Union
+from typing_extensions import TypeAlias
 
 import torch
 from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER, Optimizable, ReduceLROnPlateau
+from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
 
 
 @runtime_checkable
@@ -45,3 +47,5 @@ def add_param_group(self, param_group: Dict[Any, Any]) -> None:
 ]
 FTSLRSchedulerTypeTuple = tuple(getattr(torch.optim.lr_scheduler, lr_class) for lr_class in supported_lrs)
 FTSLRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]]
+
+BaseCallbackDepType: TypeAlias = Union[Type[EarlyStopping], Type[ModelCheckpoint]]
diff --git a/tests/test_finetuning_scheduler_callback.py b/tests/test_finetuning_scheduler_callback.py
index 0fa496a..7c50911 100644
--- a/tests/test_finetuning_scheduler_callback.py
+++ b/tests/test_finetuning_scheduler_callback.py
@@ -26,7 +26,8 @@
 from lightning.fabric.utilities import rank_zero_only
 from lightning.fabric.utilities.cloud_io import get_filesystem
 from lightning.pytorch import LightningModule, seed_everything, Trainer
-from lightning.pytorch.callbacks import Callback, EarlyStopping, LearningRateFinder, LearningRateMonitor
+from lightning.pytorch.callbacks import (Callback, EarlyStopping, LearningRateFinder, LearningRateMonitor,
+                                         ModelCheckpoint)
 from lightning.pytorch.strategies import StrategyRegistry
 from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
 from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -164,6 +165,7 @@ def __init__(
         weight_decay: float = 1.0e-06,
         init_lr_key: str = None,
         p0_params: Optional[List] = None,
+        monitor_metric: str = None,
     ):
         super().__init__()
         self.layer = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 2))
@@ -175,6 +177,7 @@ def __init__(
         self.weight_decay = weight_decay
         self.init_lr_key = init_lr_key
         self.p0_params = p0_params
+        self.monitor_metric = monitor_metric or "val_loss"
 
     def training_step(self, batch, batch_idx: int):
         loss = self.step(batch)
@@ -193,7 +196,7 @@ def validation_step(self, batch, batch_idx):
         self.validation_step_outputs.append(loss)
         # we would normally use sync_dist for epoch-only logging in a distributed context but leaving it `False` here
         # to test FTS transition behavior when the test model is used in a distributed context
-        self.log("val_loss", loss, prog_bar=False)
+        self.log(self.monitor_metric, loss, prog_bar=False)
         return {"x": loss}
 
     def on_validation_epoch_end(self):
@@ -1234,12 +1237,10 @@ def test_fts_decay(tmpdir, boring_ft_schedule, explicit_mode: bool, nodecay_mode
     (False, True, "kth", 1): (0, 0, 1),
 }
 EXPECTED_WARNS = [
-    "does not have many workers",
-    "GPU available but",
-    "`max_epochs` was not",
-    "The dirpath has changed from",
-    #"reduce_op is deprecated",  # warning caused upstream
-    #"`pydantic.config.Extra` is deprecated",
+    "does not have many workers",  # required for all PyTorch/Lightning versions
+    "GPU available but",  # required for all PyTorch/Lightning versions
+    "`max_epochs` was not",  # required for all PyTorch/Lightning versions
+    "The dirpath has changed from",  # required for all PyTorch/Lightning versions
 ]
 EXPECTED_TRAIN_CHK_WARNS = []
 EXPECTED_DIRPATH = ""
@@ -1294,10 +1295,6 @@ def test_fts_callback_resume(
 
 DYNAMO_EXPECTED_WARNS = [
     "Final phase max_transition_epoch",
-    # using different callbacks for now to avoid creating another fixture with limited utility
-    # "Be aware that when using `ckpt_path`, callbacks used",
-    # "Your compiler for AOTAutograd is returning",  # out of initial scope
-    #"tensor cores for float32 matrix multiplication available",  # out of initial scope
 ]
 
 
@@ -2036,17 +2033,29 @@ def on_train_epoch_start(self, trainer, pl_module) -> None:
 
 
 @pytest.mark.parametrize(
-    "callbacks, dist_mode, expected",
+    "callbacks, cust_monitor, dist_mode, expected",
     [
-        ([FinetuningScheduler()], None, ("an FTSEarlyStopping", "as FTSCheck")),
-        ([FinetuningScheduler(), FTSEarlyStopping(monitor="val_loss", patience=1)], None, ("FTSCheckpoint. Subs")),
+        ([FinetuningScheduler()], None, None, ("an FTSEarlyStopping", "as FTSCheck")),
+        ([FinetuningScheduler(), FTSEarlyStopping(monitor="val_loss", patience=1)], None, None,
+         ("FTSCheckpoint. Subs")),
         (
             [FinetuningScheduler(), EarlyStopping(monitor="val_loss", patience=1)],
+            None, None,
+            ("Stopping. Sub", "Checkpoint. Sub"),
+        ),
+        (
+            [FinetuningScheduler(), EarlyStopping(monitor="abc_val_loss", patience=1),
+             ModelCheckpoint(monitor="abc_val_loss", verbose=True)], 'abc_val_loss',
             None,
             ("Stopping. Sub", "Checkpoint. Sub"),
         ),
         (
-            [FinetuningScheduler(), FTSCheckpoint(monitor="val_loss", verbose=True)],
+            [FinetuningScheduler(), ModelCheckpoint(verbose=True)], None,
+            None,
+            ("Adding an FTSEarlyStopping", "Checkpoint. Sub", "No monitor metric specified"),
+        ),
+        (
+            [FinetuningScheduler(), FTSCheckpoint(monitor="val_loss", verbose=True)], None,
             None,
             ("Adding an FTSEarlyStopping",),
         ),
@@ -2055,19 +2064,19 @@ def on_train_epoch_start(self, trainer, pl_module) -> None:
                 MockDistFTS(),
                 FTSCheckpoint(monitor="val_loss", verbose=True),
                 FTSEarlyStopping(monitor="val_loss", patience=1),
-            ],
+            ], None,
             "ddp",
             ("not being synchronized",),
         ),
     ],
-    ids=["default", "nondef_es", "def_es", "nondef_ftsckpt", "no_sync"],
+    ids=["default", "nondef_es", "def_es", "extract_base_callback_cfg", "missing_monitor"," nondef_ftsckpt", "no_sync"],
 )
 def test_fts_callback_warns(
-    tmpdir, recwarn, callbacks: List[Callback], dist_mode: str, expected: Tuple[str]
+    tmpdir, recwarn, callbacks: List[Callback], cust_monitor: Optional[str], dist_mode: str, expected: Tuple[str]
 ):
     """Validate :class:`~finetuning_scheduler.FinetuningScheduler` warnings that require a
     :class:`~pytorch_lighting.trainer.Trainer` to be defined are properly issued"""
-    model = FinetuningSchedulerBoringModel()
+    model = FinetuningSchedulerBoringModel(monitor_metric=cust_monitor)
     dist_args = {"strategy": dist_mode, "accelerator": "cpu", "devices": "auto"} if dist_mode else {"devices": 1}
     trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, **dist_args)
     trainer.fit(model)