diff --git a/docs/source/index.rst b/docs/source/index.rst index 8ef3e2f..c16a08a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -185,6 +185,16 @@ either integers or convertible to integers via ``int()``. ``0`` of the current fine-tuning schedule. This auto-configuration can be disabled if desired by setting :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.enforce_phase0_params` to ``False``. +.. note:: + + When freezing ``torch.nn.modules.batchnorm._BatchNorm`` modules, Lightning by default disables + ``BatchNorm.track_running_stats``. To override this behavior so that even frozen ``BatchNorm`` layers continue to + have ``track_running_stats`` set to ``True``, set the FTS parameter + :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.frozen_bn_track_running_stats` to ``True``. + Beginning with FTS ``2.4.0``, + :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.frozen_bn_track_running_stats` will default to ``True`` + + EarlyStopping and Epoch-Driven Phase Transition Criteria ******************************************************** diff --git a/src/finetuning_scheduler/fts.py b/src/finetuning_scheduler/fts.py index c47f2ea..b0a26b3 100644 --- a/src/finetuning_scheduler/fts.py +++ b/src/finetuning_scheduler/fts.py @@ -20,6 +20,7 @@ from copy import deepcopy from pprint import pformat from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing_extensions import override import lightning.pytorch as pl import torch @@ -82,6 +83,12 @@ class FinetuningScheduler(ScheduleImplMixin, ScheduleParsingMixin, CallbackDepMi :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` or :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback instances. + .. note:: + + While :class:`~finetuning_scheduler.fts.FinetuningScheduler` supports the use of + :external+torch:class:`~torch.distributed.optim.ZeroRedundancyOptimizer`, setting ``overlap_with_ddp`` to + ``True`` is not supported because that optimizer mode only supports a single parameter group. + .. note:: While :class:`~finetuning_scheduler.fts.FinetuningScheduler` supports the use of @@ -107,6 +114,7 @@ def __init__( apply_lambdas_new_pgs: bool = False, logging_level: int = logging.INFO, enforce_phase0_params: bool = True, + frozen_bn_track_running_stats: bool = False, ): r""" Arguments used to define and configure a scheduled fine-tuning training session: @@ -229,6 +237,11 @@ def __init__( and present in the optimizer differs from the parameters specified in phase 0. Only the parameters included in the optimizer are affected; the choice of optimizer, lr_scheduler etc. remains unaltered. Defaults to ``True``. + frozen_bn_track_running_stats: When freezing ``torch.nn.modules.batchnorm._BatchNorm`` layers, whether + :class:`~finetuning_scheduler.fts.FinetuningScheduler` should set ``BatchNorm`` ``track_running_stats`` + to ``True``. Setting this to ``True`` overrides the the default Lightning behavior that sets + ``BatchNorm`` ``track_running_stats`` to ``False`` when freezing ``BatchNorm`` layers. Defaults to + ``False`` for backwards compatibility. Default will be ``True`` with FTS >= 2.4.0. Attributes: _fts_state: The internal :class:`~finetuning_scheduler.fts.FinetuningScheduler` state. @@ -255,7 +268,9 @@ def __init__( self.allow_untested = allow_untested self.apply_lambdas_new_pgs = apply_lambdas_new_pgs self.enforce_phase0_params = enforce_phase0_params + self.frozen_bn_track_running_stats = frozen_bn_track_running_stats self._has_reinit_schedule = False + self._msg_cache = set() rz_logger = logging.getLogger("lightning.pytorch.utilities.rank_zero") rz_logger.setLevel(logging_level) @@ -292,6 +307,7 @@ def _supported_strategy_flags() -> Sequence[str]: # "deepspeed", # relevant FTS strategy adapter not yet available, PRs welcome! ) + @override def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: """Freezes all model parameters so that parameter subsets can be subsequently thawed according to the fine- tuning schedule. @@ -300,7 +316,11 @@ def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The target :external+pl:class:`~lightning.pytorch.core.module.LightningModule` to freeze parameters of """ - self.freeze(modules=pl_module, train_bn=False) + # We avoid overriding `BaseFinetuning`'s `freeze` and `freeze_module` methods at the small marginal cost + # of conditionally revisiting `BatchNorm` layers to set `track_running_stats` to `True` when we are in + # `frozen_bn_track_running_stats` mode. + BaseFinetuning.freeze(modules=pl_module, train_bn=False) + self.strategy_adapter._module_specific_freezing(modules=pl_module) def step(self) -> None: """Prepare and execute the next scheduled fine-tuning level @@ -387,6 +407,7 @@ def step_pg( else: thaw_layers = {depth: self.ft_schedule[depth]}.items() for i, orig_next_tl in thaw_layers: + self.strategy_adapter._maybe_set_bn_track_running_stats(i) next_tl = deepcopy(orig_next_tl) if i <= depth: next_tl["params"] = self.strategy_adapter.fts_optim_transform(next_tl["params"]) diff --git a/src/finetuning_scheduler/fts_supporters.py b/src/finetuning_scheduler/fts_supporters.py index 93129b2..75c1a68 100644 --- a/src/finetuning_scheduler/fts_supporters.py +++ b/src/finetuning_scheduler/fts_supporters.py @@ -30,7 +30,7 @@ from dataclasses import dataclass, field, fields from functools import reduce from pprint import pformat -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Set from typing_extensions import TypeAlias import lightning.pytorch as pl @@ -1271,6 +1271,7 @@ class ScheduleImplMixin(ABC): reinit_optim_cfg: Optional[Dict] reinit_lr_cfg: Optional[Dict] max_depth: int + _msg_cache: Set _fts_state: FTSState PHASE_0_DIVERGENCE_MSG = ( "After executing the provided `configure_optimizers` method, the optimizer state differs from the configuration" @@ -1479,6 +1480,7 @@ def thaw_to_depth(self, depth: Optional[int] = None) -> None: depth = depth or self.curr_depth for i, next_tl in self.ft_schedule.items(): # type: ignore[union-attr] if i <= depth: + self.strategy_adapter._maybe_set_bn_track_running_stats(i) _, self._fts_state._curr_thawed_params = self.strategy_adapter.exec_ft_phase( self.pl_module, thaw_pl=self.strategy_adapter.fts_optim_transform(next_tl["params"]) ) @@ -1749,6 +1751,23 @@ def _validate_opt_init(self) -> None: ) rank_zero_warn(w_msg) + def _conditional_warn_once(self, condition: Any, msg: str) -> None: + """A helper function that conditionally issues a warning message only once based on the provided condition + variable. Robust to context managers that may prevent warnings.filterwarnings("once") from behaving as + expected. + + Args: + condition (Any): The condition to evaluate for issuing the warning. + msg (str): The warning message to display. + + Returns: + None + """ + if not bool(condition) or msg in self._msg_cache: + return + self._msg_cache.add(msg) + rank_zero_warn(msg) + class CallbackDepMixin(ABC): """Functionality for validating/managing callback dependencies.""" diff --git a/src/finetuning_scheduler/strategy_adapters/base.py b/src/finetuning_scheduler/strategy_adapters/base.py index 1ce3dcf..4bfdea3 100644 --- a/src/finetuning_scheduler/strategy_adapters/base.py +++ b/src/finetuning_scheduler/strategy_adapters/base.py @@ -18,15 +18,17 @@ """ from functools import partialmethod from pprint import pformat as pfmt -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Dict + +import torch from lightning.fabric.utilities import rank_zero_info from lightning.fabric.utilities.types import ReduceLROnPlateau from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback +from lightning.pytorch.callbacks import BaseFinetuning from lightning.pytorch.strategies.strategy import Strategy from lightning.pytorch.utilities.rank_zero import rank_zero_debug -from torch.nn import Module class StrategyAdapter: @@ -51,6 +53,14 @@ class StrategyAdapter: """ fts_handle: Callback + _ft_schedule_module_map: Dict + _unscheduled_params: List + + FROZEN_BN_DEFAULT_WARN = ( # TODO: remove warning with release of FTS 2.4.0 + "Starting with the next minor release of FTS (2.4.0), the default value for `frozen_bn_track_running_stats`" + " will change to `True`. To retain the current `track_running_stats` `False` behavior with FTS >= 2.4.0, frozen" + " `BatchNorm` layers like those in this model will require setting `frozen_bn_track_running_stats` to `False`." + ) def __init__(self) -> None: """The default fine-tuning phase execution function is set on @@ -106,6 +116,9 @@ def on_after_init_fts(self) -> None: """Hook executed in :class:`~finetuning_scheduler.fts.FinetuningScheduler` setup immediately after :meth:`~finetuning_scheduler.fts_supporters.ScheduleImplMixin.init_fts`. """ + self._gen_ft_sched_module_map() + self.scheduled_mod_lists = [list(self._ft_schedule_module_map[d]) for d in self._ft_schedule_module_map.keys()] + self._maybe_set_bn_track_running_stats(0) _, self.fts_handle._fts_state._curr_thawed_params = self.exec_ft_phase( self.pl_module, thaw_pl=self.fts_optim_transform(self.fts_handle.ft_schedule[0]["params"]), @@ -160,6 +173,26 @@ def logical_param_translation(self, param_names: List) -> List: """ return param_names + def _gen_ft_sched_module_map(self) -> None: + """Generate a module-level mapping of the modules associated with each fine-tuning phase, including modules + not present in the fine-tuning schedule grouped together into a single unscheduled phase to facilitate the + relevant disjointness check.""" + assert isinstance(self.fts_handle.ft_schedule, Dict) + module_map: Dict = {} + for depth in self.fts_handle.ft_schedule.keys(): # type: ignore[union-attr] + phase_params = self.fts_handle.ft_schedule[depth].get("params", []) # type: ignore[union-attr] + module_map[depth] = set() + for p in phase_params: + module_map[depth].add(p.rpartition(".")[0]) + self._ft_schedule_module_map = module_map + scheduled_mods = list(set().union(*module_map.values())) + unscheduled_mods = tuple( + n for n, m in self.pl_module.named_modules() if n not in scheduled_mods and m._parameters + ) + self._unscheduled_params = [ + f"{m}.{n}" for m in unscheduled_mods for n, _ in self.pl_module.get_submodule(m).named_parameters() + ] + @staticmethod def _clean_optim_lr_pgs(trainer: Trainer) -> List: """Delete existing param groups from an optimizer that was found to be misaligned with respect to phase 0 @@ -246,8 +279,8 @@ def phase0_optimizer_override(self) -> None: @staticmethod def base_ft_phase( - module: Module, thaw_pl: List, translation_func: Optional[Callable] = None, init_thaw: bool = False - ) -> Tuple[List, List]: + module: torch.nn.Module, thaw_pl: List, translation_func: Optional[Callable] = None, init_thaw: bool = False) \ + -> Tuple[List, List]: """Thaw/unfreeze the provided list of parameters in the provided :class:`~torch.nn.Module` Args: @@ -281,4 +314,61 @@ def base_ft_phase( ) return thawed_p_names, curr_thawed + #################################################################################################################### + # BatchNorm module-specific handling + # (if additional modules require special handling, these will be refactored to accommodate a more generic + # dispatching pattern for module-specific handling) + #################################################################################################################### + + def _module_specific_freezing(self, modules: torch.nn.Module) -> None: + """Orchestrates module-specific freezing behavior. Currently only. + + :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` layers require special handling. Running + statistics tracking for frozen `BatchNorm` layers is conditionally re-enabled here based on the + `frozen_bn_track_running_stats` flag. + + Args: + modules (torch.nn.Module): The modules for which the `BatchNorm` layer running statistics should be enabled. + Returns: + None + """ + if self.fts_handle.frozen_bn_track_running_stats: + rank_zero_info("Since `frozen_bn_track_running_stats` is currently set to `True`, FinetuningScheduler" + " will set `track_running_stats` to `True` for all `BatchNorm` layers.") + modules = BaseFinetuning.flatten_modules(modules) # type: ignore[assignment] + for mod in modules: + if isinstance(mod, torch.nn.modules.batchnorm._BatchNorm): + mod.track_running_stats = True + + def _maybe_set_bn_track_running_stats(self, schedule_phase: int) -> None: + """Enable `track_running_stats` for :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` modules + that may require it based on `frozen_bn_track_running_stats` and a given schedule phase. + + Args: + schedule_phase (int): The phase of the schedule to evaluate. + + Returns: + None + """ + if not self.fts_handle.frozen_bn_track_running_stats: + target_bn_modules = self._get_target_bn_modules(schedule_phase) + self.fts_handle._conditional_warn_once(target_bn_modules, self.FROZEN_BN_DEFAULT_WARN) + for _, m in target_bn_modules: + m.track_running_stats = True + + def _get_target_bn_modules(self, schedule_phase: int) -> List: + """Enumerate the :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` modules for a given + schedule phase. + + Args: + schedule_phase (int): The phase of the schedule to evaluate. + + Returns: + List[Tuple[str, torch.nn.modules.batchnorm._BatchNorm]]: A list of tuples containing the names and instances + of `BatchNorm` modules associated with a given schedule phase. + """ + return [(n, m) for n, m in self.pl_module.named_modules() if + n in self.scheduled_mod_lists[schedule_phase] and + isinstance(m, torch.nn.modules.batchnorm._BatchNorm)] + fts_optim_inspect = partialmethod(fts_optim_transform, inspect_only=True) diff --git a/src/finetuning_scheduler/strategy_adapters/fsdp.py b/src/finetuning_scheduler/strategy_adapters/fsdp.py index 25f4b1d..51f2a65 100644 --- a/src/finetuning_scheduler/strategy_adapters/fsdp.py +++ b/src/finetuning_scheduler/strategy_adapters/fsdp.py @@ -28,6 +28,7 @@ from functools import partial, partialmethod, wraps from pprint import pformat from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union +from typing_extensions import override import torch from lightning.fabric.strategies.fsdp import _get_full_state_dict_context, _setup_activation_checkpointing @@ -211,6 +212,8 @@ def on_after_init_fts(self) -> None: """To accommodate FSDP, we defer executing the first fine-tuning phase that would otherwise be executed in this hook, which fires in :class:`~finetuning_scheduler.fts.FinetuningScheduler` setup immediately after :meth:`~finetuning_scheduler.fts_supporters.ScheduleImplMixin.init_fts`""" + self._gen_ft_sched_module_map() + self.scheduled_mod_lists = [list(self._ft_schedule_module_map[d]) for d in self._ft_schedule_module_map.keys()] def on_before_fts_fit_start(self) -> None: """In this hook executed immediately before the :class:`~finetuning_scheduler.fts.FinetuningScheduler` @@ -539,8 +542,7 @@ def _validate_fsdp_phases_disjoint(self) -> Tuple: feedback_nonerrors.append(has_no_local_shards) fsdp_dup_params = set() unsched_dup_params = set() - scheduled_mod_lists = [list(self._ft_schedule_module_map[d]) for d in self._ft_schedule_module_map.keys()] - ft_sched_dup_mods = FSDPStrategyAdapter._phasewise_intersection(scheduled_mod_lists) + ft_sched_dup_mods = FSDPStrategyAdapter._phasewise_intersection(self.scheduled_mod_lists) fsdp_dup_params = self._phase_unaligned_fsdp_params() if not fsdp_dup_params: # unsched_dup_params will be a superset of fsdp_dup_params unsched_dup_params = self._phase_unaligned_fsdp_params(check_unsched=True) @@ -715,29 +717,9 @@ def _fts_auto_configure_model(self) -> None: "The provided model is already wrapped by FSDP. Cannot apply an FSDP auto-wrapping policy along" " fine-tuning schedule phase boundaries if the model is already wrapped." ) - self._gen_ft_sched_module_map() self._fts_auto_wrap() self._after_configure_model() - def _gen_ft_sched_module_map(self) -> None: - """Generate a module-level mapping of the modules associated with each fine-tuning phase, including modules - not present in the fine-tuning schedule grouped together into a single unscheduled phase to facilitate the - relevant disjointness check.""" - assert isinstance(self.fts_handle.ft_schedule, Dict) - module_map: Dict = {} - for depth in self.fts_handle.ft_schedule.keys(): # type: ignore[union-attr] - phase_params = self.fts_handle.ft_schedule[depth].get("params", []) # type: ignore[union-attr] - module_map[depth] = set() - for p in phase_params: - module_map[depth].add(p.rpartition(".")[0]) - self._ft_schedule_module_map = module_map - scheduled_mods = list(set().union(*module_map.values())) - unscheduled_mods = tuple( - n for n, m in self.pl_module.named_modules() if n not in scheduled_mods and m._parameters - ) - self._unscheduled_params = [ - f"{m}.{n}" for m in unscheduled_mods for n, _ in self.pl_module.get_submodule(m).named_parameters() - ] def _fts_auto_wrap(self) -> None: """Apply the provided ``auto_wrap_policy`` within a context-manager that composes any ``awp_overrides`` @@ -769,6 +751,7 @@ def _after_configure_model(self) -> None: the previously deferred first fine-tuning phase.""" assert isinstance(self.fts_handle.ft_schedule, Dict) # TODO: move/consolidate ft_schedule assertions self._init_fsdp_param_map() + self._maybe_set_bn_track_running_stats(0) _, self.fts_handle._fts_state._curr_thawed_params = self.exec_ft_phase( self.pl_module, thaw_pl=self.fts_optim_transform(self.fts_handle.ft_schedule[0]["params"]), @@ -798,7 +781,6 @@ def _wrapped_configure_model(self, csm_func: Callable) -> Callable: @wraps(csm_func) def wrapped_func() -> None: - self._gen_ft_sched_module_map() csm_func() self._after_configure_model() @@ -833,4 +815,27 @@ def _enable_name_based_overrides(self) -> Generator: finally: _ConfigAutoWrap.kwargs["auto_wrap_policy"] = auto_wrap_policy_handle + @override + def _get_target_bn_modules(self, schedule_phase: int) -> List: + """Enumerate the :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` modules for a given + schedule phase. + + Args: + schedule_phase (int): The phase of the schedule to evaluate. + + Returns: + List[Tuple[str, torch.nn.modules.batchnorm._BatchNorm]]: A list of tuples containing the names and + (possibly FSDP wrapped) instances of `BatchNorm` modules associated with a given schedule phase. + """ + target_bn_modules = [] + for m_name in self.scheduled_mod_lists[schedule_phase]: + mod = self.pl_module.get_submodule(m_name) + if isinstance(mod, torch.nn.modules.batchnorm._BatchNorm): + target_bn_modules.append((m_name, mod)) + # TODO: once 2.0 is no longer supported, switch to using FSDP_WRAPPED_MODULE constant here + elif orig_mod := getattr(mod, '_fsdp_wrapped_module', None): + if isinstance(orig_mod, torch.nn.modules.batchnorm._BatchNorm): + target_bn_modules.append((m_name, orig_mod)) + return target_bn_modules + fts_optim_inspect = partialmethod(fts_optim_transform, inspect_only=True) diff --git a/tests/fsdp_expected_paths.py b/tests/fsdp_expected_paths.py new file mode 100644 index 0000000..cbe5c79 --- /dev/null +++ b/tests/fsdp_expected_paths.py @@ -0,0 +1,224 @@ +from enum import Enum, auto + + +class AutoStrEnum(Enum): + def _generate_next_value_(name, start, count, last_values) -> str: # type: ignore + return name + +class ResultEnum(AutoStrEnum): + """Characterization of an expected result value based on a test sample transformation or approximation.""" + default = auto() + nondefault = auto() + +# expected training path aliases +path_default = {0: (2, 4), 1: (6, 12), 2: (7, 14)} +path_default_orig = {0: (4, 4), 1: (12, 12), 2: (14, 14)} +path_default_orig_eo_dyn = {0: (4, 4), 1: (12, 12), 2: (14, 14), 3: (14, 14)} +path_ignore_p_uo = {0: (4, 4), 1: (12, 12), 2: (14, 14)} +path_8_14 = {0: (2, 4), 1: (7, 12), 2: (8, 14)} +path_8_16 = {0: (4, 8), 1: (7, 14), 2: (8, 16)} +path_5_10 = {0: (2, 4), 1: (3, 6), 2: (5, 10)} +path_ext_7_14 = {0: (2, 4), 1: (2, 4), 2: (6, 12), 3: (6, 12), 4: (7, 14)} +path_ext_8_16 = {0: (3, 6), 1: (7, 14), 2: (8, 16)} +path_optimlr_reinit = {0: (2, 4, "SGD", 0, 0.1), 1: (6, 12, "Adam", 32, 0.00021), 2: (7, 14, "SGD", 64, 0.002)} +lrs_path_default = {0: (0.1,), 1: (0.07, 1e-06), 2: (0.049, 7e-07, 1e-05)} +lrs_path_optimlr_reinit = {0: (0.1,), 1: (0.00021, 1e-06), 2: (0.002, 1e-06, 3e-06)} + + +path_bn_track_false = { + 0: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": False, + "training": True, + "running_mean": ResultEnum.default, + "running_var": ResultEnum.default, + "num_batches_tracked": 0, + "requires_grad": False, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.default, + "running_var": ResultEnum.default, + "num_batches_tracked": 0, + "requires_grad": True, + }, + }, + 4, + 8, + ), + 1: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.default, + "running_var": ResultEnum.default, + "num_batches_tracked": 0, + "requires_grad": True, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 16, + "requires_grad": True, + }, + }, + 8, + 16, + ), + 2: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.default, + "running_var": ResultEnum.default, + "num_batches_tracked": 0, + "requires_grad": True, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 16, + "requires_grad": True, + }, + }, + 9, + 18, + ), + 3: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 16, + "requires_grad": True, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 32, + "requires_grad": True, + }, + }, + 9, + 18, + ), +} + +path_bn_track_true = { + 0: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.default, + "running_var": ResultEnum.default, + "num_batches_tracked": 0, + "requires_grad": False, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.default, + "running_var": ResultEnum.default, + "num_batches_tracked": 0, + "requires_grad": True, + }, + }, + 4, + 8, + ), + 1: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 16, + "requires_grad": True, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 16, + "requires_grad": True, + }, + }, + 8, + 16, + ), + 2: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 16, + "requires_grad": True, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 16, + "requires_grad": True, + }, + }, + 9, + 18, + ), + 3: ( + { + 8: { + "layer_fqn": "layer._fsdp_wrapped_module.2._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 32, + "requires_grad": True, + }, + 16: { + "layer_fqn": "layer._fsdp_wrapped_module.6._fsdp_wrapped_module", + "track_running_stats": True, + "training": True, + "running_mean": ResultEnum.nondefault, + "running_var": ResultEnum.nondefault, + "num_batches_tracked": 32, + "requires_grad": True, + }, + }, + 9, + 18, + ), +} diff --git a/tests/test_finetuning_scheduler_callback.py b/tests/test_finetuning_scheduler_callback.py index c53f3ed..3081683 100644 --- a/tests/test_finetuning_scheduler_callback.py +++ b/tests/test_finetuning_scheduler_callback.py @@ -38,7 +38,8 @@ from finetuning_scheduler import CallbackResolverMixin, FinetuningScheduler, FTSCheckpoint, FTSEarlyStopping from tests.helpers import BoringModel -from tests.helpers.boring_model import CustomLRScheduler, LinearWarmupLR, unexpected_warns, unmatched_warns +from tests.helpers.boring_model import (CustomLRScheduler, LinearWarmupLR, unexpected_warns, unmatched_warns, + RandomDataset) from tests.helpers.runif import RunIf fts_resolver = CallbackResolverMixin() @@ -238,6 +239,19 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] +class BNBoringModel(FinetuningSchedulerBoringModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layer = nn.Sequential( + OrderedDict( + [("lin_base", nn.Linear(32, 32)), ("bn", nn.BatchNorm1d(32)), ("lin_classif", nn.Linear(32, 2))] + ) + ) + + def train_dataloader(self): + # when testing BatchNorm layers, we need to ensure there are more than 1 samples per batch + return DataLoader(RandomDataset(32, 64), batch_size=2) + class FTSCustLRModel(FinetuningSchedulerBoringModel): """overrides lr_scheduler_step to allow lr scheduler testing.""" @@ -340,7 +354,10 @@ def log_dev_state(self) -> None: for dev_d in [self.dev_expected_states, self.dev_lrs_states]: fp.write(os.linesep) for k, v in dev_d.items(): # control formatting precisely to allow copy/paste expected output - fp.write(f"{' ' * 8}{k}: {v},{os.linesep}") + if isinstance(k, int): + fp.write(f"{' ' * 8}{k}: {v},{os.linesep}") + else: + fp.write(f"""{' ' * 8}'{k}': {v},{os.linesep}""") class FitStartOnlyFTS(TestFinetuningScheduler): @@ -348,6 +365,42 @@ def on_fit_start(self, trainer, pl_module) -> None: super().on_fit_start(trainer, pl_module) raise SystemExit(0) +class BNInspectFTS(TestFinetuningScheduler): + + def sample_bn_state(self, trainer, pl_module) -> None: + phase_subkey = "train" if pl_module.training else "val" + state_key = f"{trainer.current_epoch}_{phase_subkey}" + sample_running_mean = self.pl_module.layer.bn._buffers.get('running_mean', None) + sample_running_var = self.pl_module.layer.bn._buffers.get('running_var', None) + sample_num_tracked = self.pl_module.layer.bn.num_batches_tracked + current_state = ( + self.pl_module.layer.bn.track_running_stats, + self.pl_module.layer.bn.weight.requires_grad, + self.pl_module.layer.bn.training, + round(sample_num_tracked.item(), 6) if sample_num_tracked is not None else None, + round(sample_running_mean.max().item(), 6) if sample_running_mean is not None else None, + round(sample_running_var.max().item(), 6) if sample_running_var is not None else None, + ) + lrs_state = None + return current_state, lrs_state, state_key + + def on_train_epoch_start(self, trainer, pl_module): + super(TestFinetuningScheduler, self).on_train_epoch_start(trainer, pl_module) + current_state, lrs_state, state_key = self.sample_bn_state(trainer, pl_module) + self.inspect_or_assert(current_state, lrs_state, state_key) + + def on_validation_epoch_start(self, trainer, pl_module): + super(TestFinetuningScheduler, self).on_train_epoch_start(trainer, pl_module) + current_state, lrs_state, state_key = self.sample_bn_state(trainer, pl_module) + self.inspect_or_assert(current_state, lrs_state, state_key) + + def state_dict(self) -> Dict[str, Any]: + return super(TestFinetuningScheduler, self).state_dict() + + def restore_best_ckpt(self) -> None: + super(TestFinetuningScheduler, self).restore_best_ckpt() + self.restored_best_cnt += 1 + class OptInspectFTS(TestFinetuningScheduler): def on_train_epoch_start(self, trainer, pl_module): @@ -647,6 +700,11 @@ def boring_ft_schedule(tmpdir_factory) -> Tuple[Path, Dict]: "pl_lrs_cfg": {"interval": "epoch", "frequency": 1, "name": "Custom_Reinit_LR"}, "init_pg_lrs": [2.0e-06, 3.0e-06], } + bn_sched_dict = {0: {'params': ['layer.lin_classif.bias', 'layer.lin_classif.weight']}, + 1: {'params': ['layer.bn.bias', 'layer.bn.weight']}, + 2: {'params': ['layer.lin_base.bias', 'layer.lin_base.weight']}} + bn_sched_dict[0]["max_transition_epoch"] = 1 + bn_sched_dict[1]["max_transition_epoch"] = 2 return ( unmod_schedule_file, mod_sched_dict, @@ -659,6 +717,7 @@ def boring_ft_schedule(tmpdir_factory) -> Tuple[Path, Dict]: reinitlr_optim_lambdalr_sched, reinitlr_optim_rlrop_sched, reinitlr_optim_use_curr_sched_dict, + bn_sched_dict, ) @@ -2536,6 +2595,53 @@ def test_fts_epoch_trans_only(tmpdir, boring_ft_schedule, epoch_only_cfg: bool, trainer.fit(model) +EXPECTED_BN_INTRAFIT_STATE = { + (False,): { + '0_train': (False, False, True, 0, 0.0, 1.0), + '0_val': (False, False, False, 0, 0.0, 1.0), + '1_train': (True, True, True, 0, 0.0, 1.0), + '1_val': (True, True, False, 32, 0.30678, 0.688297), + '2_train': (True, True, True, 0, 0.0, 1.0), + '2_val': (True, True, False, 32, 0.306801, 0.688293), + +}, + (True,): { + '0_train': (True, False, True, 0, 0.0, 1.0), + '0_val': (True, False, False, 32, 0.30678, 0.688297), + '1_train': (True, True, True, 32, 0.30678, 0.688297), + '1_val': (True, True, False, 64, 0.317314, 0.677594), + '2_train': (True, True, True, 64, 0.317314, 0.677594), + '2_val': (True, True, False, 96, 0.317694, 0.677223), +} +} + + +@pytest.mark.parametrize("frozen_bn_track_running_stats", [True, False], ids=["frozen_bn_track", "no_frozen_bn_track"]) +def test_fts_frozen_bn_track_running_stats(tmpdir, boring_ft_schedule, frozen_bn_track_running_stats: bool): + """Inspect scheduled fine-tuning state within the training process to ensure it is taking the expected path in + both restore_best modes.""" + seed_everything(42) + ft_schedule = boring_ft_schedule[11] + model = BNBoringModel() + callbacks = [ + BNInspectFTS(expected_state=EXPECTED_BN_INTRAFIT_STATE[(frozen_bn_track_running_stats,)], + ft_schedule=ft_schedule, frozen_bn_track_running_stats=frozen_bn_track_running_stats, + #state_log_dir=tmpdir + ), + FTSEarlyStopping(monitor="val_loss", patience=1), + ] + trainer = Trainer(default_root_dir=tmpdir, callbacks=callbacks, devices=1, max_epochs=3, num_sanity_val_steps=0) + if not frozen_bn_track_running_stats: + with pytest.warns(UserWarning, match="with the next minor release of FTS"): + trainer.fit(model) + else: + trainer.fit(model) + finetuningscheduler_callback = get_fts(trainer) + assert finetuningscheduler_callback.depth_remaining == 0 + assert finetuningscheduler_callback.curr_depth == 2 + assert finetuningscheduler_callback.curr_depth == finetuningscheduler_callback.max_depth + + @pytest.mark.parametrize("stop_value", [torch.tensor(np.inf), torch.tensor(np.nan)]) def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value): callbacks = [ diff --git a/tests/test_fsdp.py b/tests/test_fsdp.py index a9cb862..022e5d0 100644 --- a/tests/test_fsdp.py +++ b/tests/test_fsdp.py @@ -32,6 +32,10 @@ from finetuning_scheduler.strategy_adapters import FSDPStrategyAdapter from tests.helpers.boring_model import RandomDataset, unexpected_warns, unmatched_warns from tests.helpers.runif import RunIf +from tests.fsdp_expected_paths import (path_default, path_default_orig, path_default_orig_eo_dyn, path_ignore_p_uo, + path_8_14, path_5_10, path_ext_7_14, path_ext_8_16, path_optimlr_reinit, + lrs_path_optimlr_reinit, path_bn_track_false, path_bn_track_true, ResultEnum) + from tests.test_finetuning_scheduler_callback import ( EXPECTED_WARNS, ExplicitLossFTSCheckpoint, @@ -87,7 +91,6 @@ # FTS FSDP Test Fixtures # ########################## - @pytest.fixture(scope="module") def fsdp_ft_schedules(tmpdir_factory) -> Tuple[Path, Dict]: """Generates a default fine-tuning schedule for 'implicit' testing, a modified one for 'explicit' mode and an @@ -156,8 +159,8 @@ def fsdp_ft_schedules(tmpdir_factory) -> Tuple[Path, Dict]: "pl_lrs_cfg": {"interval": "epoch", "frequency": 1, "name": "Custom_Reinit_LR"}, } fsdp_bn_gen_sched_dict = deepcopy(fsdp_gen_sched_dict) - fsdp_bn_gen_sched_dict[0]["params"] = ["layer.(8|[4-6]).*"] - fsdp_bn_gen_sched_dict[1]["params"] = ["layer.[1-3].*"] + fsdp_bn_gen_sched_dict[0]["params"] = ["layer.(9|[5-7]).*"] + fsdp_bn_gen_sched_dict[1]["params"] = ["layer.[1-4].*"] fsdp_shared_param_sched_dict = deepcopy(fsdp_gen_sched_dict) fsdp_shared_param_sched_dict[0]["params"] = ["layer.(7|4).*", "layer.5.weight", "layer.5.bias"] fsdp_shared_param_sched_dict[1]["params"] = ["layer.2.*", "layer.3.weight", "layer.3.bias"] @@ -441,6 +444,7 @@ def __init__(self, *args, **kwargs): self.layer = torch.nn.Sequential( torch.nn.Linear(32, 32), torch.nn.Linear(32, 32), + torch.nn.BatchNorm1d(32), torch.nn.Linear(32, 32), torch.nn.Linear(32, 32), torch.nn.Linear(32, 32), @@ -508,6 +512,41 @@ def on_train_epoch_start(self, trainer, pl_module): self.inspect_or_assert(current_state, lrs_state, state_key) +class BNInspectFTS(FSDPTestFinetuningScheduler): + def on_train_epoch_start(self, trainer, pl_module): + super(TestFinetuningScheduler, self).on_train_epoch_start(trainer, pl_module) + state_key = trainer.current_epoch + bn_layer_state = self._collect_bnl_state() + current_state = ( + bn_layer_state, + len(self._fts_state._curr_thawed_params), + len(self.strategy_adapter.logical_param_translation(self._fts_state._curr_thawed_params)), + ) + lrs_state = None + self.inspect_or_assert(current_state, lrs_state, state_key) + + def _collect_bnl_state(self): + bnl_sample = {} + for i, (n, bn_layer) in enumerate(self.pl_module.named_modules()): + if isinstance(bn_layer, torch.nn.modules.batchnorm._BatchNorm): + bnl_sample.setdefault(i, {}) + bnl_sample[i]['layer_fqn'] = n + for attr in ['track_running_stats', 'training']: + attr_v = getattr(bn_layer, attr, None) + bnl_sample[i][attr] = attr_v + for attr in ['running_mean', 'running_var']: + attr_v = bn_layer._buffers.get(attr, None) + if attr_v is not None: + attr_v = round(attr_v.max().item(), 9) + # inspect whether default or non-default bn tracking values are present + bnl_sample[i][attr] = ResultEnum.nondefault if attr_v not in [0.0, 1.0] else ResultEnum.default + else: + bnl_sample[i][attr] = attr_v # None + bnl_sample[i]['num_batches_tracked'] = round(bn_layer.num_batches_tracked.item(), 2) if \ + bn_layer.num_batches_tracked is not None else None + bnl_sample[i]['requires_grad'] = bn_layer.weight.requires_grad + return bnl_sample + # model aliases base_model = FTSBaseFSDPModel nond_loss_adam_model = NonDynamicLossAdamFSDPModel @@ -625,22 +664,11 @@ def policy(self): "test_es": "disable", "test_ckpt": ExplicitLossFTSCheckpoint(monitor="val_loss", verbose=True), } +bn_inspect = {"test_fts": BNInspectFTS} +bn_track_false = {**bn_inspect, "frozen_bn_track_running_stats": False} +bn_track_true = {**bn_inspect, "frozen_bn_track_running_stats": True} opt_inspect = {"test_fts": FSDPOptInspectFTS} -# expected training path aliases -path_default = {0: (2, 4), 1: (6, 12), 2: (7, 14)} -path_default_orig = {0: (4, 4), 1: (12, 12), 2: (14, 14)} -path_default_orig_eo_dyn = {0: (4, 4), 1: (12, 12), 2: (14, 14), 3: (14, 14)} -path_ignore_p_uo = {0: (4, 4), 1: (12, 12), 2: (14, 14)} -path_8_14 = {0: (2, 4), 1: (7, 12), 2: (8, 14)} -path_8_16 = {0: (4, 8), 1: (7, 14), 2: (8, 16)} -path_5_10 = {0: (2, 4), 1: (3, 6), 2: (5, 10)} -path_ext_7_14 = {0: (2, 4), 1: (2, 4), 2: (6, 12), 3: (6, 12), 4: (7, 14)} -path_ext_8_16 = {0: (3, 6), 1: (7, 14), 2: (8, 16)} -path_optimlr_reinit = {0: (2, 4, "SGD", 0, 0.1), 1: (6, 12, "Adam", 32, 0.00021), 2: (7, 14, "SGD", 64, 0.002)} -lrs_path_default = {0: (0.1,), 1: (0.07, 1e-06), 2: (0.049, 7e-07, 1e-05)} -lrs_path_optimlr_reinit = {0: (0.1,), 1: (0.00021, 1e-06), 2: (0.002, 1e-06, 3e-06)} - # consolidate all core FTS FSDP test configuration into this dictionary to dedup config FTS_FSDP_TESTS = { "cust_awp_noprec_no_use_orig": ( @@ -763,10 +791,15 @@ def policy(self): "min2_1", (path_default, *nones(3)), ), - "batch_norm_auto_prec_no_use_orig": ( - (BN_model, cust_awp, True, 2, unwrap_8_mp, *nones(3), DISABLE_USE_ORIG), + "batch_norm_auto_prec_no_use_orig_track_false": ( + (BN_model, cust_awp, True, 2, unwrap_8_mp, None, bn_track_false, max_epoch_5, DISABLE_USE_ORIG), "min2_1", - (path_8_16, ("Both mixed precision",), *nones(2)), + (path_bn_track_false, ("Both mixed precision", "retain the current `track_running_stats`"), *nones(2)), + ), + "batch_norm_auto_prec_no_use_orig_track_true": ( + (BN_model, cust_awp, True, 2, unwrap_8_mp, None, bn_track_true, max_epoch_5, DISABLE_USE_ORIG), + "min2_1", + (path_bn_track_true, ("Both mixed precision",), *nones(2)), ), "shared_params_auto_prec_no_use_orig": ( (shared_model, cust_awp, True, 3, unwrap_7_mp, awp_1, *nones(2), DISABLE_USE_ORIG), @@ -887,6 +920,17 @@ def test_fsdp_multi_gpus_resume(tmpdir, recwarn, fsdp_ft_schedules, fsdp_ckpt, m check_fts_fsdp_warns(warns_expected, recwarn) +def test_fsdp_get_bn_unwrapped(): + """Conservative (end-to-end) test for FTS training resumption with FSDP.""" + test_adapter = FSDPStrategyAdapter() + test_adapter.scheduled_mod_lists = {0: ['layer.0']} + test_module = torch.nn.Module() + test_module.layer = torch.nn.Sequential(torch.nn.BatchNorm1d(32)) + setattr(test_adapter, 'fts_handle', FinetuningScheduler()) + setattr(test_adapter.fts_handle, 'pl_module', test_module) + bn_modules = test_adapter._get_target_bn_modules(0) + assert all(isinstance(m, torch.nn.modules.batchnorm._BatchNorm) for _, m in bn_modules) + def gen_exceptions(trainer, model, model_cfg_key, exception_expected): if model_cfg_key == "no_fsdp_params_p0": with mock.patch.object(FSDPStrategyAdapter, "_rank_zero_logger", 42): @@ -903,7 +947,7 @@ def init_fts_cfg(fts_state, lrs_state, strategy_adapter_cfg, fts_cfg, tmpdir): "expected_state": fts_state, "lrs_state": lrs_state, "strategy_adapter_cfg": strategy_adapter_cfg, - # "state_log_dir": tmpdir + #"state_log_dir": tmpdir } fts_cls = fts_cfg.pop("test_fts") if fts_cfg and fts_cfg.get("test_fts") else FSDPTestFinetuningScheduler test_cfg = {**fts_cfg, **def_fts_cfg}