diff --git a/src/finetuning_scheduler/fts_supporters.py b/src/finetuning_scheduler/fts_supporters.py index 50d5729..9197172 100644 --- a/src/finetuning_scheduler/fts_supporters.py +++ b/src/finetuning_scheduler/fts_supporters.py @@ -324,7 +324,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[s should_stop = True reason = ( f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records." - f" Best score: {self.best_score:.3f}. Signaling Trainer to stop." + f" Best score: {self.best_score.item():.3f}. Signaling Trainer to stop." ) else: self._transition_es_phase() diff --git a/src/finetuning_scheduler/strategy_adapters/__init__.py b/src/finetuning_scheduler/strategy_adapters/__init__.py index aa03567..fba81f2 100644 --- a/src/finetuning_scheduler/strategy_adapters/__init__.py +++ b/src/finetuning_scheduler/strategy_adapters/__init__.py @@ -15,5 +15,6 @@ """ from finetuning_scheduler.strategy_adapters.base import StrategyAdapter from finetuning_scheduler.strategy_adapters.fsdp import FSDPStrategyAdapter +from finetuning_scheduler.strategy_adapters.model_parallel import ModelParallelStrategyAdapter -__all__ = ["StrategyAdapter", "FSDPStrategyAdapter"] +__all__ = ["StrategyAdapter", "FSDPStrategyAdapter", "ModelParallelStrategyAdapter"] diff --git a/src/finetuning_scheduler/strategy_adapters/model_parallel.py b/src/finetuning_scheduler/strategy_adapters/model_parallel.py index 1cb98c9..dd7db5a 100644 --- a/src/finetuning_scheduler/strategy_adapters/model_parallel.py +++ b/src/finetuning_scheduler/strategy_adapters/model_parallel.py @@ -17,734 +17,97 @@ for PyTorch's SPMD style APIs (e.g. DeviceMesh, FSDP2). """ -from contextlib import contextmanager -from typing import Any, Generator, List, Optional -from typing_extensions import override +from typing import Any, TYPE_CHECKING -import torch +from lightning.pytorch.utilities.exceptions import MisconfigurationException from finetuning_scheduler.strategy_adapters.base import StrategyAdapter -from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generator, List, Optional - -import torch -from typing_extensions import override +# TODO: replace local version once Lightning version available +# from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5 +import operator +from lightning_utilities.core.imports import compare_version +_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0", use_base_version=True) if TYPE_CHECKING: pass -if torch.distributed.is_available(): - - pass - - - - class ModelParallelStrategyAdapter(StrategyAdapter): """""" - # _fsdp_flat_to_unflat_mapping: Dict - # _fsdp_unflat_to_flat_mapping: Dict - # _ft_schedule_module_map: Dict - # _unscheduled_params: List - # _use_orig_params: bool - # _allow_mixed_req_grad: bool - # _rank_zero_logger: logging.Logger = logging.getLogger("lightning.pytorch.utilities.rank_zero") - - def __init__(self, awp_overrides: Optional[List] = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: """The only user-facing configuration for.""" super().__init__(*args, **kwargs) - self.awp_overrides = awp_overrides or [] - # self._min_wrap_validated: bool = False - # self._suppress_cm_warns() - # self.exec_ft_phase = partial(StrategyAdapter.base_ft_phase, translation_func=self.logical_param_translation) + if not _TORCH_GREATER_EQUAL_2_5: + # specifically, depends upon https://github.com/pytorch/pytorch/pull/133502 among other changes + raise MisconfigurationException(f"{type(self).__name__} requires PyTorch 2.5 or higher.") - def on_before_init_fts(self) -> None: - """In this hook executed immediately before - :meth:`~finetuning_scheduler.fts_supporters.ScheduleImplMixin.init_fts`, to accommodate FSDP we: - - 1. Disable Lightning's restoration of the optimizer to allow us to implement special handling - 2. Prune ``no_decay`` specification since it is not currently supported in the context of FSDP fine-tuning - 3. Validate the :attr:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter.awp_overrides` configuration - 4. Configure FTS wrapping of the provided :external+pl:class:`~lightning.pytorch.core.module.LightningModule` - to either use the provided ``LightningModule.configure_model`` method (if present) or a provided - ``auto_wrap_policy``. - """ - # # hack to avoid subclassing ModelParallelStrategy strategy for this adapter - # setattr(Strategy, "lightning_restore_optimizer", self.lightning_restore_optimizer) - # setattr(self.pls_handle, "optimizer_state", self.optimizer_state) - # self._maybe_squeeze_device_mesh() # TODO: not currently useful, probably remove before releasing - pass - #setattr(self.pls_handle, "_setup_device_mesh", ModelParallelStrategyAdapter._setup_device_mesh) - # self._use_orig_params = self.pls_handle.kwargs.get("use_orig_params", False) - # # w/ `use_orig_params`, schedule/wrapping alignment constraints can be relaxed - # self._allow_mixed_req_grad = self._use_orig_params - # self._prune_nodecay() - # self._validate_awp_overrides() - # if is_overridden("configure_model", self.pl_module): - # rank_zero_info( - # "You have overridden the `LightningModule.configure_model` hook. Fine-Tuning Scheduler" - # " will attempt to validate that you have wrapped the provided model in a manner that aligns with the" - # " defined fine-tuning schedule phases. If you would like to have Fine-Tuning Scheduler" - # " automatically wrap your model according to a given auto wrap policy, avoid overriding" - # " `configure_model` in your module and provide the desired auto wrap policy." - # ) - # csm_func = self._wrapped_configure_model(self.pl_module.configure_model) - # setattr(self.pl_module, "configure_model", csm_func) - # else: - # setattr(self.pl_module, "configure_model", self._fts_auto_configure_model) + # def on_before_init_fts(self) -> None: + # # TODO: if offering auto-wrap functionality hook `configure_model` here + # pass # def on_after_init_fts(self) -> None: - # """To accommodate FSDP, we defer executing the first fine-tuning phase that would otherwise be executed in - # this hook, which fires in :class:`~finetuning_scheduler.fts.FinetuningScheduler` setup immediately after - # :meth:`~finetuning_scheduler.fts_supporters.ScheduleImplMixin.init_fts`""" - # self._gen_ft_sched_module_map() - # self.scheduled_mod_lists = [list(self._ft_schedule_module_map[d]) for d in self._ft_schedule_module_map.keys()] - - def on_before_fts_fit_start(self) -> None: - """In this hook executed immediately before the :class:`~finetuning_scheduler.fts.FinetuningScheduler` - :meth:`~finetuning_scheduler.fts.FinetuningScheduler.on_fit_start` hook begins, we ensure the provided - fine-tuning schedule and FSDP wrapped :external+pl:class:`~lightning.pytorch.core.module.LightningModule` are - appropriately aligned and valid. If the fine-tuning schedule and wrapped module are detected to be incompatible, - detailed feedback is provided to the user (which is why multiple checks are aggregated before returning any - alignment exceptions). - - Raises: - MisconfigurationException: If any FTS FSDP fine-tuning schedule/module wrapping alignment exceptions are - thrown. The provided exceptions provide detailed feedback for the user to address the misalignment. - """ - pass - # world_feedback_set: Set = set() - # world_feedback = [[None] for _ in range(self.pls_handle.world_size)] - # all_gather_object(world_feedback, self._validate_fsdp_fts_config()) - # for feedback in world_feedback: - # world_feedback_set.update(feedback) # feedback could be rank-specific - # if world_feedback_set: - # exceptions, debug_msgs = [], [] - # for msg in world_feedback_set: - # if msg[0] == "ERROR": - # exceptions.append(MisconfigurationException(msg[1])) - # else: - # debug_msgs.append(msg[1]) # currently, diagnostics are for DEBUG level only - # if debug_msgs: - # for debug_msg in debug_msgs: - # rank_zero_debug(debug_msg) - # if exceptions: - # raise MisconfigurationException(*exceptions) - - # TODO: not currently useful, probably remove before releasing - # def _maybe_squeeze_device_mesh(self) -> None: - # from torch.distributed.device_mesh import init_device_mesh - # assert self.pls_handle.device_mesh - # if self.pls_handle._tensor_parallel_size == 1 and self.pls_handle._data_parallel_size > 1: - # self.pls_handle._device_mesh = init_device_mesh( - # device_type=self.pl_module.device.type, mesh_shape=(self.pls_handle._data_parallel_size,), - # mesh_dim_names=("data_parallel",)) - # elif self.pls_handle._tensor_parallel_size > 1 and self.pls_handle._data_parallel_size == 1: - # self.pls_handle._device_mesh = init_device_mesh( - # device_type=self.pl_module.device.type, mesh_shape=(self.pls_handle._tensor_parallel_size,), - # mesh_dim_names=("tensor_parallel",)) - # elif self.pls_handle._tensor_parallel_size == 1 and self.pls_handle._data_parallel_size == 1: - # raise MisconfigurationException( - # "When using model parallel, either `tensor_parallel_size` or `data_parallel_size` should be > 1.") - # # TODO: determine if we really need to block until all ranks have re-initialized the device mesh, probably safe - # self.pls_handle.barrier() - - # def on_before_restore_optimizers_and_lrs(self) -> None: - # """Allow the :class:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter` to override the default - # ``load_optimizer_state_dict`` method. - - # This is necessary so we can allow FSDP to manage the movement of restored optimizer states to the relevant - # devices. - # """ - # checkpoint_connector = self.pl_module.trainer._checkpoint_connector - - # # Restore the optimizer states from the pre-loaded checkpoint. - # self.load_optimizer_state_dict(checkpoint_connector) - - # def load_optimizer_state_dict(self, checkpoint_connector: _CheckpointConnector) -> None: - # """Override the default ``load_optimizer_state_dict`` method so that we can allow FSDP to manage the - # movement of restored optimizer states to the relevant devices. - - # Args: - # checkpoint_connector (_CheckpointConnector): The ``_CheckpointConnector`` associated with the current - # training session. - # """ - # optimizer_states = checkpoint_connector._loaded_checkpoint["optimizer_states"] - - # assert self.pls_handle.model is not None - - # # rank0_only should be false to enable loading of the optimizer state on all ranks - # # irrespective of `use_orig_params` mode, we start with a full, unflattened, unsharded, consolidated osd - # # we then ensure the local osd is properly keyed and transformed for loading into each rank's local optimizer - # with _get_full_state_dict_context( - # self.pls_handle.model, world_size=self.pls_handle.world_size, rank0_only=False - # ): - # for optimizer, opt_state in zip(self.pls_handle.optimizers, optimizer_states): - - # # usually, this will basically be a noop since FTS should be restoring osd saved with param fqn keys - # opt_state = FullyShardedDataParallel.rekey_optim_state_dict( - # opt_state, OptimStateKeyType.PARAM_NAME, self.pls_handle.model - # ) - - # opt_state = FullyShardedDataParallel.optim_state_dict_to_load( - # optim_state_dict=opt_state, - # model=self.pls_handle.model, - # optim=optimizer, - # ) - - # optimizer.load_state_dict(opt_state) - - # def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: - # """Override the default ``optimizer_state`` method so that we can unify `use_orig_params` code-paths and - # save a full, consolidated optimizer state dict to be restored via ``load_optimizer_state_dict``. - - # Args: - # optimizer (Optimizer): The optimizer instance for which a full optimizer state dict will be captured. - - # Returns: - # Dict[str, Tensor]: The consolidated full optimizer state dict (if on rank 0, otherwise an empty dict). - # """ - # assert self.pls_handle.model is not None + # # TODO: if offering auto-wrap functionality gen module map here + # pass + # # self._gen_ft_sched_module_map() + # # self.scheduled_mod_lists = [list(self._ft_schedule_module_map[d]) for d in + # # self._ft_schedule_module_map.keys()] - # # irrespective of `use_orig_params` mode, we need the full, unflattened, unsharded, consolidated osd - # with _get_full_state_dict_context(self.pl_module, world_size=self.pls_handle.world_size, rank0_only=True): - # state_dict = FullyShardedDataParallel.optim_state_dict(self.pl_module, optimizer) - - # return state_dict - - # def fts_optim_transform(self, orig_pl: List, inspect_only: bool = False) -> List: - # """Because FSDP performs parameter transformations that cause the current optimizer's view of parameter - # names to diverge from the original parameter names, this parameter transformation is required for optimizer - # operations. - - # Args: - # orig_pl (List): The original parameter name list before FSDP's transformation of them. - # inspect_only (bool): Whether to use the specified transform in read-only (i.e. ``inspect_only``) mode, - # avoiding any persistent state transformation that may accompany normal usage. Typically useful for state - # inspection and validation contexts. - - # Returns: - # List: A transformed parameter name list that matches the current optimizer's view of them after FSDP's - # transformation of the original parameter names. - # """ - # return self.fsdp_param_transform(orig_pl, inspect_only) - - # def fsdp_param_transform(self, orig_thaw_pl: List, inspect_only: bool) -> List: - # """The parameter transformation function currently used by - # :meth:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter.fts_optim_transform` to transform original - # parameter lists for optimizer operations. - - # Args: - # orig_thaw_pl (List): The original parameter name list before FSDP's transformation of them. - # inspect_only (bool): Whether to use the specified transform in read-only (i.e. ``inspect_only``) mode, - # avoiding any persistent state transformation that may accompany normal usage. Typically useful for state - # inspection and validation contexts. - - # Returns: - # List: A transformed parameter name list that matches the current optimizer's view of them after FSDP's - # transformation of the original parameter names. - # """ - # flat_next_tl = {self._fsdp_unflat_to_flat_mapping[p] for p in orig_thaw_pl} - # if self._use_orig_params and not inspect_only: - # self._flat_param_thaw(flat_next_tl) - # return [n for n, p in self.pl_module.named_parameters() if p in flat_next_tl] - - # def _flat_param_thaw(self, flat_next_tl: Set) -> None: - # """For FSDP modules that have been configured with ``_use_orig_params`` set to ``True``, this method - # ensures that the ``FlatParameter`` objects containing the logically original ``Parameter`` objects require - # grad when one or more of those contained original parameters are transformed for optimizer operations. - - # Args: - # flat_next_tl (Set): The set of original ``Parameter`` s to transform for optimizer operations. These should - # be ``Parameter`` objects rather than ``FlatParameter`` objects because ``_use_orig_params`` is ``True`` in - # this context. - # """ - # use_orig_flat_params_mods = set() - # for m in self.pl_module.modules(): - # is_fsdp_managed = getattr(m, "_is_fsdp_managed_module", False) - # if is_fsdp_managed and m._fsdp_use_orig_params and getattr(m, FLAT_PARAM, None) is not None: - # use_orig_flat_params_mods.add(m) - # flat_params_to_thaw = set() - # for m in use_orig_flat_params_mods: - # for p in flat_next_tl: - # if any([p is ofp for ofp in m._flat_param._params]): # type: ignore[union-attr] - # flat_params_to_thaw.add((m, getattr(m, FLAT_PARAM))) - # thawed_fp_mods = set() - # for fpm, fp in flat_params_to_thaw: - # fp.requires_grad = True - # thawed_fp_mods.add(fpm) - # thawed_fp_fqns = [n + "." + FLAT_PARAM for n, m in self.pl_module.named_modules() if m in thawed_fp_mods] - # rank_zero_debug( - # "Since FSDP has been configured with `use_orig_params` set to `True`, the following `FlatParameter`s" - # " have been thawed because they contain the original parameters you specified be thawed." - # f" `FlatParameters` thawed: {os.linesep}{pformat(thawed_fp_fqns)}" - # ) - - # def logical_param_translation(self, param_names: List) -> List: - # """Effectively the reverse transformation of - # :meth:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter.fts_optim_transform`. - - # Args: - # param_names (List): A parameter name list from the current optimizer's view of them after FSDP's - # transformation of the original parameter names. - - # Returns: - # List: The original parameter name list before a given FSDP's transformation. - # """ - # logical_param_names = [] - # for n, p in self.pl_module.named_parameters(): - # if n in param_names: - # logical_param_names.extend(self._fsdp_flat_to_unflat_mapping[p]) - # return logical_param_names - - # def _prune_nodecay(self) -> None: - # """If the ``no_decay`` attribute is present on the provided. - - # :external+pl:class:`~lightning.pytorch.core.module.LightningModule` s remove it (with a warning) because it is - # not currently supported in the context of FSDP fine-tuning. - # """ - # if hasattr(self.pl_module, "no_decay") and self.pl_module.no_decay is not None: - # rank_zero_warn( - # "Specifying a `no_decay` lightning module attribute is not currently supported by the Fine-Tuning" - # f" Scheduler FSDP strategy adapter. The `no_decay` attribute currently set ({self.pl_module.no_decay})" - # " will now be unset by the adapter to allow training to proceed." - # ) - # setattr(self.pl_module, "no_decay", None) + # def on_before_fts_fit_start(self) -> None: + # # TODO: if offering auto-wrap functionality, validate config globally here + # pass # def _suppress_cm_warns(self) -> None: - # """Because Fine-Tuning Scheduler internally leverages the ``configure_model`` method to implement FSDP - # auto-wrapping enhancements, we suppress superfluous warnings about ``configure_model`` overrides.""" - # try: - # warns_to_suppress = (".*model is already wrapped.*", ".*model already contains checkpointed layers.*") - # for w in warns_to_suppress: - # warnings.filterwarnings("ignore", w, category=UserWarning) - # except Exception: - # # suppressing this message is largely cosmetic so if we cannot suppress this message for any reason at all - # # (e.g. logger rename) continue anyway - # pass - - # def _validate_fsdp_fts_config(self) -> List: - # """Execute fine-tuning schedule/module wrapping misalignment checks, generating and aggregating detailed - # feedback to facilitate the user's remediation of the issue. - - # Returns: - # List: Any FTS FSDP fine-tuning schedule/module wrapping misalignment feedback messages generated by the - # validation functions. - # """ - # # collect all configuration feedback before returning it to the user to facilitate faster remediation - # return [cf for cf in [self._validate_min_wrap_condition(), *self._validate_fsdp_phases_disjoint()] if cf] + # # TODO: only required if offering auto-wrap functionality + # pass # def _validate_awp_overrides(self) -> None: - # """Expand any regex expressions specified in - # :attr:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter.awp_overrides`. - - # Raises: - # MisconfigurationException: If a specified module name or regex does not resolve to at least one named - # module. - # """ - # if not self.awp_overrides: - # return - # if is_overridden("configure_model", self.pl_module): - # rank_zero_warn( - # "You have overridden the `LightningModule.configure_model` hook but also provided" - # " an `awp_overrides` configuration. Since `awp_overrides` only applies to configurations that use" - # f" policy-based FSDP wrapping, this configuration ({self.awp_overrides}) will be unset and not applied." - # ) - # self.awp_overrides = [] - # return - # named_modules = dict(self.pl_module.named_modules()).keys() - # resolved_modules = [] - # for m in self.awp_overrides: - # regex_modules = [] - # explicit_mods = False - # if m in named_modules: - # explicit_mods = True - # resolved_modules.append(m) - # else: - # mpat = re.compile(m) - # regex_modules = [m for m in named_modules if mpat.match(m)] - # resolved_modules.extend(regex_modules) - # if not (regex_modules or explicit_mods): - # raise MisconfigurationException( - # f"The module or regex '{m}' specified in `awp_overrides` did not match any named modules in the" - # " provided model." - # ) - # self.awp_overrides = resolved_modules - - # @staticmethod - # def _phasewise_intersection(phase_lists: List[List]) -> Set: - # """Calculates a phase-wise intersection of elements (whether modules or parameters) - - # Args: - # phase_lists (List[List]): Element lists (modules or parameters) for each fine-tuning schedule phase. - - # Returns: - # Set: The set of elements present in more than one phase. - # """ - # elems = Counter(list(itertools.chain(*phase_lists))) - # unique_elems = Counter(list(set().union(*phase_lists))) - # elems.subtract(unique_elems) - # dup_elems = set(elems.elements()) - # return dup_elems - - # def _log_nonzero_local_shards(self) -> Optional[str]: - # """If debugging diagnostics are requested, evaluate whether there are any ranks with no (non-zero sized) - # parameter shards and if so, provide parameter shard allocation debugging info for the user. - - # Returns: - # Optional[str]: Per-rank debugging info distilling relevant parameter shard allocation. - # """ - # curr_optimizer_params = [p for pg in self.pls_handle._optimizers[0].param_groups for p in pg["params"]] - # if not any(p.shape[0] for p in curr_optimizer_params if p.requires_grad): - # params_w_ls = set() - # for curr_optim_p in curr_optimizer_params: - # for fsdp_mod in FullyShardedDataParallel.fsdp_modules(self.pl_module): - # fp = fsdp_mod._flat_param - # assert fp is not None - # assert isinstance(fp._params, Iterable) - # assert isinstance(fp._shard_param_infos, Iterable) - # for fp_p, fp_shard_info in zip(fp._params, fp._shard_param_infos): - # if fp_p is curr_optim_p and not fp_shard_info[0]: - # w_local = [p for p, spi in zip(fp._params, fp._shard_param_infos) if spi[0]] - # params_w_ls.update(w_local) - # params_w_ls_names = [self._fsdp_flat_to_unflat_mapping[lsp][0] for lsp in params_w_ls] - # params_w_ls_names.sort() - # rank_specific_advice = ( - # "Below are parameters in the same FSDP module as those currently specified in phase 0 but that DO have " - # f"local shards for rank {self.pl_module.global_rank}: " - # f"{os.linesep}{pformat(params_w_ls_names)}{os.linesep}" - # ) - # local_shard_advice = ( - # f"The global rank {self.pl_module.global_rank} optimizer has no (non-zero sized) local shards of the " - # "phase 0 fine-tuning schedule parameters. \n" - # f"Additional rank-specific details for **RANK {self.pl_module.global_rank}**: {os.linesep}" - # f"{rank_specific_advice}" - # ) - # return local_shard_advice - - # def _validate_fsdp_phases_disjoint(self) -> Tuple: - # """Validate that the defined schedule does not specify any wrapped module or parameter in multiple phases. - - # Returns: - # Tuple: Any fine-tuning schedule/wrapped module misalignment feedback messages to be provided to the user. - # """ - # feedback_errors: List[str] = [] - # feedback_nonerrors: List[str] = [] - # if self._allow_mixed_req_grad: - # rank_zero_debug( - # "Bypassing FSDP-specific phase disjointness validation because `use_orig_params` is " - # "``True`` and PyTorch is >= `2.1.0`" - # ) - # assert self.pl_module._trainer is not None - # # check only required for mixed-precision training with DEBUG level logging requested - # if ( - # self.pl_module._trainer.precision in ("16-mixed", "bf16-mixed", "16-true", "bf16-true") - # and self._rank_zero_logger.level <= 10 - # ): - # has_no_local_shards = self._log_nonzero_local_shards() - # if has_no_local_shards: - # feedback_nonerrors.append(has_no_local_shards) - # fsdp_dup_params = set() - # unsched_dup_params = set() - # ft_sched_dup_mods = FSDPStrategyAdapter._phasewise_intersection(self.scheduled_mod_lists) - # fsdp_dup_params = self._phase_unaligned_fsdp_params() - # if not fsdp_dup_params: # unsched_dup_params will be a superset of fsdp_dup_params - # unsched_dup_params = self._phase_unaligned_fsdp_params(check_unsched=True) - # if ft_sched_dup_mods: - # feedback_errors.append(self._module_overlap_feedback(ft_sched_dup_mods)) - # if unsched_dup_params: # conditionally emphasize parameters not included in the fine-tuning schedule - # feedback_errors.append(self._fsdp_param_phase_overlap_feedback(unsched_dup_params, unsched_msg=True)) - # elif fsdp_dup_params: - # feedback_errors.append(self._fsdp_param_phase_overlap_feedback(fsdp_dup_params)) - # feedback_msgs = [("ERROR", fe) for fe in feedback_errors] - # for fw in feedback_nonerrors: - # feedback_msgs.append(("DEBUG", fw)) - # return tuple(feedback_msgs) - - # @staticmethod - # def _module_has_fp(submodule: torch.nn.Module) -> bool: - # """Evaluate whether a given module has any FSDP-flattened params. - - # Args: - # submodule (torch.nn.Module): The module to inspect for FSDP-flattened params. - - # Returns: - # bool: ``True`` if the specified module contains any FSDP-flattened params, otherwise ``False``. - # """ - # return any(_is_fsdp_flattened(param) for param in submodule.parameters()) - - # def _validate_min_wrap_condition(self) -> Optional[Tuple]: - # """Validate (prior to optimizer validation via Lightning that occurs after a potential FTS phase 0 - # override) that at least scheduled phase 0 contains FSDP flattened parameters with ``requires_grad`` set to - # ``True``. - - # Returns: - # Optional[str]: Error message for the user if the first fine-tuning phase does not include one or more FSDP - # flattened parameters. - # """ - # has_flattened = False - # # function configuration to inspect at a module level: - # mod_cfg = (self._ft_schedule_module_map[0], FSDPStrategyAdapter._module_has_fp, self.pl_module.get_submodule) - # # function configuration to inspect at a parameter level: - # param_cfg = (self.fts_handle.ft_schedule[0]["params"], _is_fsdp_flattened, self.pl_module.get_parameter) - - # def inspect_flattened(iter_inspect: Iterable, inspect_func: Callable, inspect_prepare: Callable) -> bool: - # return any(inspect_func(inspect_prepare(i)) for i in iter_inspect) - - # has_flattened = inspect_flattened(*mod_cfg) if not self._allow_mixed_req_grad else inspect_flattened(*param_cfg) - # if not has_flattened: - # fts_p0_err = ( - # "Training an FSDP wrapped model requires one or more FSDP parameters to be included in the optimizer." - # " The `configure_model method or auto_wrap_policy` you have specified did not wrap any of the" - # " layers specified in fine-tuning phase 0. Ensure your overridden `configure_model` method or" - # " auto_wrap_policy wraps at least one module included in phase `0`." - # ) - # return ("ERROR", fts_p0_err) - - # def _phase_unaligned_fsdp_params(self, check_unsched: bool = False) -> Set: - # """Inspect the fine-tuning schedule and FSDP-wrapped module for parameters that are unaligned with the FSDP - # wrapped module. - - # Args: - # check_unsched (bool, optional): Whether to include parameters not in the fine-tuning schedule in the - # disjointness check. The unscheduled parameter disjointness check will only be executed if the - # scheduled parameter phase disjointness check passes (since the unscheduled check is a superset of the - # scheduled one). Defaults to False. - - # Returns: - # Set: The set of module parameters in more than one fine-tuning phase. - # """ - # fsdp_param_sets: dict = {} - # inspection_map = deepcopy(self.fts_handle.ft_schedule) - # if check_unsched: - # inspection_map[-1] = {"params": self._unscheduled_params} - # for d, pl in inspection_map.items(): - # fsdp_param_sets[d] = set() - # for lp in pl["params"]: - # fsdp_param_sets[d].update(self._fsdp_flat_to_unflat_mapping[self._fsdp_unflat_to_flat_mapping[lp]]) - # fsdp_phase_lists = [list(fsdp_param_sets[d]) for d in fsdp_param_sets.keys()] - # return FSDPStrategyAdapter._phasewise_intersection(fsdp_phase_lists) - - # def _fsdp_param_phase_overlap_feedback(self, dup_params: Set, unsched_msg: bool = False) -> str: - # """Generate parameter-level phase overlap feedback for the user, identifying owning FSDP instances - # associated with parameters that span more than one fine-tuning phase. - - # Args: - # dup_params (Set): The set of module parameters in more than one fine-tuning phase. - # unsched_msg (bool, optional): Whether to indicate the misaligned parameters were not included in the - # fine-tuning schedule. Defaults to False. - - # Returns: - # str: User feedback detailing the FSDP instances associated with any parameters spanning more than one - # fine-tuning phase. - # """ - - # def get_fsdp_owner(lp: str) -> str: - # owner = "no owner found" - # for fsdp_mod in FullyShardedDataParallel.fsdp_modules(self.pl_module): - # for p in fsdp_mod.params: - # if self._fsdp_unflat_to_flat_mapping[lp] is p: - # owner = fsdp_mod.module._get_name() - # return owner - - # dup_params_fsdp_mapping = {lp: get_fsdp_owner(lp) for lp in dup_params} - # unsched_param_msg = ( - # "In this particular case, there are parameters not included in your fine-tuning schedule that span more" - # " than one fine-tuning phase.\nHINT: parameters associated with unwrapped modules will be included in the" - # " top-level (aka 'root') FSDP instance so ensuring all modules associated with fine-tuning scheduled" - # " parameters are wrapped separately from the top-level FSDP instance may avoid triggering this exception." - # ) - # warn_msg = ( - # "\n\nFine-tuning schedule phases do not have disjoint FSDP-flattened parameter sets. Because the" - # " `requires_grad` attribute of FSDP-flattened parameters currently must be the same for all flattened" - # " parameters (for PyTorch < ``2.1.0`` or if in ``use_orig_params=False`` mode), fine-tuning schedules must" - # " avoid thawing parameters in the same FSDP-flattened parameter in different phases. Please ensure" - # " parameters associated with each phase are wrapped in separate phase-aligned FSDP instances.\n\n" - # f"""{unsched_param_msg if unsched_msg else ''}\n\n""" - # "The following logical parameters are associated with an FSDP-flattened parameter that spans more than one" - # " fine-tuning phase. The mapping of each logical parameter with the module name wrapped by its associated" - # " FSDP instance is provided below:\n" - # f"{pformat(dup_params_fsdp_mapping)}{os.linesep}" - # ) - # return warn_msg - - # def _module_overlap_feedback(self, dup_mods: Set) -> str: - # """Generate module-level phase overlap feedback for the user, identifying owning FSDP instances associated - # with modules that span more than one fine-tuning phase. - - # Args: - # dup_mods (Set): The set of module parameters in more than one fine-tuning phase. - - # Returns: - # str: User feedback detailing the FSDP instances associated with any modules spanning more than one - # fine-tuning phase. - # """ - # ft_sched = self.fts_handle.ft_schedule - # dup_mod_dict = { - # m: list( - # itertools.chain( - # *[self._fsdp_flat_to_unflat_mapping[p] for p in self.pl_module.get_submodule(m).parameters()] - # ) - # ) - # for m in dup_mods - # } - # phase_mod_intersect: Dict = {} - # for m, plist in dup_mod_dict.items(): - # phase_mod_intersect[m] = {} - # for phase in ft_sched.keys(): - # if set(plist).intersection(set(ft_sched[phase]["params"])): - # phase_mod_intersect[m][phase] = set(plist).intersection(set(ft_sched[phase]["params"])) - # warn_msg = ( - # "Fine-tuning schedule phases do not have disjoint module sets. FSDP currently wraps at a module level" - # " which requires fine-tuning schedules avoid thawing parameters of the same module in different phases." - # " The following modules span fine-tuning phases (with associated parameters by phase):" - # f" {os.linesep}{phase_mod_intersect}" - # ) - # return warn_msg + # # TODO: only required if offering auto-wrap functionality + # pass # def _fts_auto_configure_model(self) -> None: - # """Apply the ``auto_wrap_policy`` provided by the user and generate the relevant module and parameter-level - # internal mappings that allow the FTS FSDP adapter to translate and orchestrate a fine-tuning schedule. - - # 1. Generate a mapping of fine-tuning schedule phases to associated modules - # 2. Apply the provided ``auto_wrap_policy`` (composed w/ any ``awp_overrides``) to the user's ``LightningModule`` - # 3. After module wrapping, generate parameter-level bi-directional translations between unflat (original) and - # flat (FSDP-flattened) parameters. - - # Raises: - # MisconfigurationException: If the module is already FSDP-wrapped before applying the ``auto_wrap_policy``. - # """ - # for m in self.pl_module.modules(): - # # if the model is already wrapped with FSDP - # if isinstance(m, FullyShardedDataParallel): - # raise MisconfigurationException( - # "The provided model is already wrapped by FSDP. Cannot apply an FSDP auto-wrapping policy along" - # " fine-tuning schedule phase boundaries if the model is already wrapped." - # ) - # self._fts_auto_wrap() - # self._after_configure_model() - + # # TODO: only required if offering auto-wrap functionality + # pass # def _fts_auto_wrap(self) -> None: - # """Apply the provided ``auto_wrap_policy`` within a context-manager that composes any ``awp_overrides`` - # directives with the policy. - - # Subsequently, apply activation checkpointing wrappers if requested - # """ - # if self.pls_handle.kwargs.get("auto_wrap_policy", None): - # with self._enable_name_based_overrides(): - # for n, m in self.pl_module.named_children(): - # setattr(self.pl_module, n, wrap(m)) - # else: - # rank_zero_warn( - # "Wrapping the provided model in an outer FSDP module since neither an ``auto_wrap_policy`` " - # "nor a manual ``configure_model`` method for wrapping have been provided. This " - # "configuration is often (but not always) degenerate and unintended by the user." - # ) - # for n, m in self.pl_module.named_children(): - # setattr(self.pl_module, n, wrap(m)) - - # # apply wrappers to enable activation checkpointing if requested - # if self.pls_handle._activation_checkpointing_kwargs: - # _setup_activation_checkpointing( - # module=self.pl_module, activation_checkpointing_kwargs=self.pls_handle._activation_checkpointing_kwargs - # ) + # # TODO: only required if offering auto-wrap functionality + # pass # def _after_configure_model(self) -> None: - # """Generate the parameter-level bi-directional translations the FTS FSDP adapter requires and then execute - # the previously deferred first fine-tuning phase.""" - # assert isinstance(self.fts_handle.ft_schedule, Dict) # TODO: move/consolidate ft_schedule assertions - # self._init_fsdp_param_map() - # self._maybe_set_bn_track_running_stats(0) - # _, self.fts_handle._fts_state._curr_thawed_params = self.exec_ft_phase( - # self.pl_module, - # thaw_pl=self.fts_optim_transform(self.fts_handle.ft_schedule[0]["params"]), - # init_thaw=True, - # ) - - # def _init_fsdp_param_map(self) -> None: - # """Generate parameter-level bi-directional translations between unflat (original) and flat (FSDP-flattened) - # parameters.""" - # self._fsdp_flat_to_unflat_mapping = _get_param_to_fqns(self.pl_module) - # self._fsdp_unflat_to_flat_mapping = { - # up: fpn for fpn, upl in self._fsdp_flat_to_unflat_mapping.items() for up in upl - # } - - # def _wrapped_configure_model(self, csm_func: Callable) -> Callable: - # """If the user has overridden ``configure_model`` in their ``LightningModule``, wrap the user's - # explicit wrapping method with the required - # :class:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter` methods. - - # Args: - # csm_func (Callable): The user's overridden ``LightningModule.configure_model`` method - - # Returns: - # Callable: The user's overridden ``LightningModule.configure_model`` method wrapped with this - # adapter's internal implementation methods. - # """ - - # @wraps(csm_func) - # def wrapped_func() -> None: - # csm_func() - # self._after_configure_model() - - # return wrapped_func - - @contextmanager - def _enable_name_based_overrides(self) -> Generator: - """A context manager that enables name-driven overriding of a given ``auto_wrap_policy`` with a list of - module names. + # # TODO: only required if offering auto-wrap functionality + # pass - The composition of module name-based wrapping directives with a given ``auto_wrap_policy`` is achieved here by: - 1. Generating an object id-based module name mapping lambda and passing it to the standard - ``lambda_auto_wrap_policy``. - 2. Composing the user's provided ``auto_wrap_policy`` with the above name-based policy using the standard - ``_or_policy``. + # def _wrapped_configure_model(self) -> None: #csm_func: Callable) -> Callable: + # # TODO: only required if offering auto-wrap functionality + # pass - Yields: - Generator: A wrapping context that applies the provided ``auto_wrap_policy`` along with any user specified - name-based complements to that policy. - """ - # auto_wrap_policy_handle = _ConfigAutoWrap.kwargs.pop("auto_wrap_policy", None) - # override_ids = [id(m) for n, m in self.pl_module.named_modules() if n in self.awp_overrides] - # name_based_override_policy: Union[NameDrivenPolicy, Callable] - # if isinstance(auto_wrap_policy_handle, _Policy): - # name_based_override_policy = NameDrivenPolicy(auto_wrap_policy_handle, override_ids=override_ids) - # else: # Callable policy implementation path - # name_driven_policy = partial(lambda_auto_wrap_policy, lambda_fn=lambda m: id(m) in override_ids) - # name_based_override_policy = partial(_or_policy, policies=[auto_wrap_policy_handle, name_driven_policy]) - # _ConfigAutoWrap.kwargs["auto_wrap_policy"] = name_based_override_policy - # try: - # yield - # finally: - # _ConfigAutoWrap.kwargs["auto_wrap_policy"] = auto_wrap_policy_handle + # @contextmanager + # def _enable_name_based_overrides(self) -> Generator: + # TODO: only required if offering auto-wrap functionality + pass # TODO: just a stub for testing rn - @override - def _get_target_bn_modules(self, schedule_phase: int) -> List: - """Enumerate the :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` modules for a given - schedule phase. - - Args: - schedule_phase (int): The phase of the schedule to evaluate. + # @override + # def _get_target_bn_modules(self, schedule_phase: int) -> List: + # """Enumerate the :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` modules for a given + # schedule phase. - Returns: - List[Tuple[str, torch.nn.modules.batchnorm._BatchNorm]]: A list of tuples containing the names and - (possibly FSDP wrapped) instances of `BatchNorm` modules associated with a given schedule phase. - """ - target_bn_modules = [] - # for m_name in self.scheduled_mod_lists[schedule_phase]: - # mod = self.pl_module.get_submodule(m_name) - # if isinstance(mod, torch.nn.modules.batchnorm._BatchNorm): - # target_bn_modules.append((m_name, mod)) - # # TODO: once 2.0 is no longer supported, switch to using FSDP_WRAPPED_MODULE constant here - # elif orig_mod := getattr(mod, '_fsdp_wrapped_module', None): - # if isinstance(orig_mod, torch.nn.modules.batchnorm._BatchNorm): - # target_bn_modules.append((m_name, orig_mod)) - return target_bn_modules + # Args: + # schedule_phase (int): The phase of the schedule to evaluate. - # fts_optim_inspect = partialmethod(fts_optim_transform, inspect_only=True) + # Returns: + # List[Tuple[str, torch.nn.modules.batchnorm._BatchNorm]]: A list of tuples containing the names and + # (possibly FSDP wrapped) instances of `BatchNorm` modules associated with a given schedule phase. + # """ + # target_bn_modules = [] + # # for m_name in self.scheduled_mod_lists[schedule_phase]: + # # mod = self.pl_module.get_submodule(m_name) + # # if isinstance(mod, torch.nn.modules.batchnorm._BatchNorm): + # # target_bn_modules.append((m_name, mod)) + # # # TODO: once 2.0 is no longer supported, switch to using FSDP_WRAPPED_MODULE constant here + # # elif orig_mod := getattr(mod, '_fsdp_wrapped_module', None): + # # if isinstance(orig_mod, torch.nn.modules.batchnorm._BatchNorm): + # # target_bn_modules.append((m_name, orig_mod)) + # return target_bn_modules diff --git a/src/fts_examples/stable/patching/dep_patch_shim.py b/src/fts_examples/stable/patching/dep_patch_shim.py index 7945e6c..8622b79 100644 --- a/src/fts_examples/stable/patching/dep_patch_shim.py +++ b/src/fts_examples/stable/patching/dep_patch_shim.py @@ -1,6 +1,7 @@ import operator import sys import os +from enum import Enum from typing import NamedTuple, Tuple, Callable from fts_examples.stable.patching._patch_utils import lwt_compare_version @@ -35,7 +36,7 @@ def _patch_einsum_strategies(): # In this case fortunately, we only import/call `gen_einsum_strategies` from # `torch.distributed._tensor.ops.matrix_ops`, so only need to patch there. - target_mod = 'torch.distributed._tensor.ops.matrix_ops' + target_mod = 'torch.distributed._tensor.ops._matrix_ops' sys.modules.get(target_mod).__dict__['gen_einsum_strategies'] = gen_einsum_strategies def _patch_unsupported_numpy_arrow_extractor(): @@ -54,7 +55,7 @@ def _patch_triton(): einsum_strategies_patch = DependencyPatch( - condition=(lwt_compare_version("torch", operator.le, "2.4.1"),), + condition=(lwt_compare_version("torch", operator.le, "2.5.1"),), env_flag=OSEnvToggle("ENABLE_FTS_EINSUM_STRATEGY_PATCH", default="0"), function=_patch_einsum_strategies, patched_package='torch', description='Address trivial tp submesh limitation until PyTorch provides upstream fix') @@ -73,11 +74,17 @@ def _patch_triton(): function=_patch_triton, patched_package='pytorch-triton', description='Address `triton` #3564 until PyTorch pins the upstream fix') -_DEFINED_PATCHES = {einsum_strategies_patch, datasets_numpy_extractor_patch, triton_codgen_patch} +class ExpPatch(Enum): + EINSUM_STRATEGIES = einsum_strategies_patch + NUMPY_EXTRACTOR = datasets_numpy_extractor_patch + TRITON_CODEGEN = triton_codgen_patch + +#_DEFINED_PATCHES = {einsum_strategies_patch, datasets_numpy_extractor_patch, triton_codgen_patch} +_DEFINED_PATCHES = set(ExpPatch) _ACTIVE_PATCHES = set() for defined_patch in _DEFINED_PATCHES: - if all(defined_patch.condition) and os.environ.get(defined_patch.env_flag.env_var_name, - defined_patch.env_flag.default) == "1": - defined_patch.function() + if all(defined_patch.value.condition) and os.environ.get(defined_patch.value.env_flag.env_var_name, + defined_patch.value.env_flag.default) == "1": + defined_patch.value.function() _ACTIVE_PATCHES.add(defined_patch) diff --git a/src/fts_examples/stable/patching/patched_einsum_strategies.py b/src/fts_examples/stable/patching/patched_einsum_strategies.py index dec2620..17bbe93 100644 --- a/src/fts_examples/stable/patching/patched_einsum_strategies.py +++ b/src/fts_examples/stable/patching/patched_einsum_strategies.py @@ -1,6 +1,6 @@ from fts_examples.stable.patching._patch_utils import _prepare_module_ctx -globals().update(_prepare_module_ctx('torch.distributed._tensor.ops.basic_strategy', globals())) +globals().update(_prepare_module_ctx('torch.distributed._tensor.ops._einsum_strategy', globals())) # we ignore these for the entire file since we're using our global namespace trickeration to patch # ruff: noqa: F821 diff --git a/tests/helpers/boring_models.py b/tests/helpers/boring_models.py index b2b55f1..a4dab0c 100644 --- a/tests/helpers/boring_models.py +++ b/tests/helpers/boring_models.py @@ -137,10 +137,6 @@ def step(self, batch: Tensor) -> Tensor: def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: return {"loss": self.step(batch)} - # def training_step(self, batch, batch_idx): - # output = self(batch) - # loss = self.loss(batch, output) - # return {"loss": loss} def training_step_end(self, training_step_output: STEP_OUTPUT) -> STEP_OUTPUT: return training_step_output @@ -148,25 +144,10 @@ def training_step_end(self, training_step_output: STEP_OUTPUT) -> STEP_OUTPUT: def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: return {"x": self.step(batch)} - # def validation_step(self, batch, batch_idx): - # output = self(batch) - # loss = self.loss(batch, output) - # return {"x": loss} - - # def on_validation_epoch_end(self, outputs) -> None: - # torch.stack([x["x"] for x in outputs]).mean() - - # def test_step(self, batch, batch_idx): - # output = self(batch) - # loss = self.loss(batch, output) - # return {"y": loss} def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: return {"y": self.step(batch)} - # def test_epoch_end(self, outputs) -> None: - # torch.stack([x["y"] for x in outputs]).mean() - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]: optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) @@ -224,15 +205,6 @@ def __init__(self): super().__init__() self.automatic_optimization = False - # def training_step(self, batch, batch_idx): - # opt = self.optimizers() - # output = self(batch) - # loss = self.loss(batch, output) - # opt.zero_grad() - # self.manual_backward(loss) - # opt.step() - # return loss - def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: opt = self.optimizers() assert isinstance(opt, (Optimizer, LightningOptimizer)) @@ -249,43 +221,6 @@ class FTSWikiText2(WikiText2): def __init__(self, data_dir: Path = Path(_PATH_DATASETS), block_size: int = 32, *args, **kwargs) -> None: super().__init__(data_dir=data_dir, block_size=block_size, *args, **kwargs) - # @property - # def vocab_size(self) -> int: - # return len(self.dictionary) - - # def __len__(self) -> int: - # return len(self.data) // self.block_size - 1 - - # def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: - # start = index * self.block_size - # end = start + self.block_size - # inputs = self.data[start:end] - # target = self.data[(start + 1) : (end + 1)] - # return inputs, target - - # @staticmethod - # def download(destination: Path) -> None: - # if not _REQUESTS_AVAILABLE: - # raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE)) - - # import requests - - # os.makedirs(destination.parent, exist_ok=True) - # url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt" - # if os.path.exists(destination): - # return - # with open(destination, "w") as f: - # f.write(requests.get(url).text) - - -# class SampledOutput(NamedTuple): -# """Sampled Output Named Tuple. - -# Named tuple object for if we want to output both logits and tokens. -# """ - -# tokens: Union[torch.Tensor, str] -# logits: torch.Tensor ################################################################################ # Toy Configurable Transformer (non-TransformerLens) @@ -295,9 +230,6 @@ def __init__(self, data_dir: Path = Path(_PATH_DATASETS), block_size: int = 32, ################################################################################ - - - @dataclass class TestModelArgs: n_layers: int = 2 # 2 @@ -305,23 +237,10 @@ class TestModelArgs: max_seq_len: int = 200 # 10 dim: int = 200 # 10 n_heads: int = 2 - dropout_p: float = 0.2 # 0.1 + dropout_p: float = 0.0 # 0.2 # 0.1 use_attn_mask: bool = True weight_tying: bool = False # True checkpoint_activations: bool = False - #tokenizer: Optional[Callable] = None - #device: Optional[torch.device] = None - #dtype: Optional[torch.dtype] = None - # handle below can be used at runtime to allow this model's `generate` to adapt to various configuration contexts - #ctx_handle: Optional[torch.nn.Module] = None - - # def __post_init__(self): - # if self.ctx_handle: - # # snag potentially useful context references and then delete the handle - # #self.tokenizer = self.tokenizer or self.ctx_handle.it_cfg.tokenizer - # self.device = self.device or self.ctx_handle.device - # self.dtype = self.dtype or self.ctx_handle.torch_dtype - # del self.ctx_handle class Attention(torch.nn.Module): diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index e98cc9b..d4b7121 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -13,7 +13,7 @@ import os import re import sys -from typing import Optional +from typing import Optional, Set import pytest import torch @@ -21,15 +21,17 @@ from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE from packaging.version import Version from pkg_resources import get_distribution +from fts_examples.stable.patching.dep_patch_shim import ExpPatch, _ACTIVE_PATCHES EXTENDED_VER_PAT = re.compile(r"([0-9]+\.){2}[0-9]+") # RunIf aliases RUNIF_MAP = { - "min2_4": {"min_torch": "2.4.0"}, + "min2_5": {"min_torch": "2.5.0"}, "min2_2": {"min_torch": "2.2.0"}, "max3_12_min2_3": {"max_python": "3.12", "min_torch": "2.3.0"}, "max3_12_min2_2": {"max_python": "3.12", "min_torch": "2.2.0"}, + "einsum_exp": {"exp_patch": {ExpPatch.EINSUM_STRATEGIES}, "min_torch": "2.5.0"}, } @@ -58,6 +60,7 @@ def __new__( standalone: bool = False, deepspeed: bool = False, slow: bool = False, + exp_patch: Optional[ExpPatch|Set[ExpPatch]] = None, **kwargs, ): """ @@ -74,6 +77,7 @@ def __new__( standalone: Mark the test as standalone, our CI will run it in a separate process. This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set. deepspeed: Require that microsoft/DeepSpeed is installed. + exp_patch: Require that a given experimental patch is installed. slow: Mark the test as slow, our CI will run it in a separate job. This requires that the ``PL_RUN_SLOW_TESTS=1`` environment variable is set. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. @@ -143,6 +147,12 @@ def __new__( conditions.append(not _DEEPSPEED_AVAILABLE) reasons.append("Deepspeed") + if exp_patch: + if not isinstance(exp_patch, Set): + exp_patch = {exp_patch} + conditions.append(not exp_patch.issubset(_ACTIVE_PATCHES)) + reasons.append(f"Required experimental patch configuration {exp_patch} is not active.") + if slow: env_flag = os.getenv("PL_RUN_SLOW_TESTS", "0") conditions.append(env_flag != "1") diff --git a/tests/model_parallel_expected_paths.py b/tests/model_parallel_expected_paths.py index a92ddbe..7231f7c 100644 --- a/tests/model_parallel_expected_paths.py +++ b/tests/model_parallel_expected_paths.py @@ -1,685 +1,198 @@ import torch -from tests.helpers.common import DeviceMeshSummary - ## Expected Test Result Configuration Aliases ## example template result, providing TP weight and FSDP module states you want a test to validate # state_key: ({p_states, fsdp_mod_states}, len(self._fts_state._curr_thawed_params)) -# basic_template_result = { -# 0: ( -# {"p_states": { -# "model.layers.0.feed_forward.w2.weight": {}, -# "model.layers.0.feed_forward.w2.bias": {}, -# "model.layers.1.feed_forward.w2.weight": {}, -# "model.layers.1.feed_forward.w2.bias": {}, -# "model.norm.weight": {}, -# "model.norm.bias": {}, -# "model.output.weight": {}, -# }}, -# 3, -# ), -# 1: ( -# {"p_states": { -# "model.layers.0.feed_forward.w2.weight": {}, -# "model.layers.0.feed_forward.w2.bias": {}, -# "model.layers.1.feed_forward.w2.weight": {}, -# "model.layers.1.feed_forward.w2.bias": {}, -# "model.norm.weight": {}, -# "model.norm.bias": {}, -# "model.output.weight": {}, -# }}, -# 27, -# ), -# 2: ( -# {"p_states": { -# "model.layers.0.feed_forward.w2.weight": {}, -# "model.layers.0.feed_forward.w2.bias": {}, -# "model.layers.1.feed_forward.w2.weight": {}, -# "model.layers.1.feed_forward.w2.bias": {}, -# "model.norm.weight": {}, -# "model.norm.bias": {}, -# "model.output.weight": {}, -# }}, -# 29, -# ), -# } - -# extended_fsdp_template_result = { -# 0: ( -# {"p_states": { -# "model.layers.0.feed_forward.w2.weight": {}, -# "model.layers.0.feed_forward.w2.bias": {}, -# "model.layers.1.feed_forward.w2.weight": {}, -# "model.layers.1.feed_forward.w2.bias": {}, -# "model.norm.weight": {}, -# "model.norm.bias": {}, -# "model.output.weight": {}, -# }, -# "fsdp_mod_states": { -# "model.layers.0": {}, -# "model.layers.1": {}, -# "model.norm": {}, -# "model.output": {}, -# }, -# }, -# 3, -# ), -# 1: ( -# {"p_states": { -# "model.layers.0.feed_forward.w2.weight": {}, -# "model.layers.0.feed_forward.w2.bias": {}, -# "model.layers.1.feed_forward.w2.weight": {}, -# "model.layers.1.feed_forward.w2.bias": {}, -# "model.norm.weight": {}, -# "model.norm.bias": {}, -# "model.output.weight": {}, -# }, -# "fsdp_mod_states": { -# "model.layers.0": {}, -# "model.layers.1": {}, -# "model.norm": {}, -# "model.output": {}, -# }, -# }, -# 27, -# ), -# 2: ( -# {"p_states": { -# "model.layers.0.feed_forward.w2.weight": {}, -# "model.layers.0.feed_forward.w2.bias": {}, -# "model.layers.1.feed_forward.w2.weight": {}, -# "model.layers.1.feed_forward.w2.bias": {}, -# "model.norm.weight": {}, -# "model.norm.bias": {}, -# "model.output.weight": {}, -# }, -# "fsdp_mod_states": { -# "model.layers.0": {}, -# "model.layers.1": {}, -# "model.norm": {}, -# "model.output": {}, -# }, -# }, -# 29, -# ), -# } path_tt_tp_no_fsdp = { - 0: ( - {"p_states": { - "model.layers.0.feed_forward.w2.weight": {}, - "model.layers.0.feed_forward.w2.bias": {}, - "model.layers.1.feed_forward.w2.weight": {}, - "model.layers.1.feed_forward.w2.bias": {}, - "model.norm.weight": {}, - "model.norm.bias": {}, - "model.output.weight": {}, - }}, - 3, - ), - 1: ( - {"p_states": { - "model.layers.0.feed_forward.w2.weight": {}, - "model.layers.0.feed_forward.w2.bias": {}, - "model.layers.1.feed_forward.w2.weight": {}, - "model.layers.1.feed_forward.w2.bias": {}, - "model.norm.weight": {}, - "model.norm.bias": {}, - "model.output.weight": {}, - }}, - 27, - ), - 2: ( - {"p_states": { - "model.layers.0.feed_forward.w2.weight": {}, - "model.layers.0.feed_forward.w2.bias": {}, - "model.layers.1.feed_forward.w2.weight": {}, - "model.layers.1.feed_forward.w2.bias": {}, - "model.norm.weight": {}, - "model.norm.bias": {}, - "model.output.weight": {}, - }}, - 29, - ), + 0: ({'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True}}}, 3), + 1: ({'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True}}}, 15), + 2: ({'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True}}}, 29), } path_tt_fsdp_no_tp = { 0: ( - {"p_states": { - "model.layers.0.feed_forward.w2.weight": {"requires_grad": False, "is_DTensor": True}, - "model.layers.0.feed_forward.w2.bias": {"requires_grad": False, "is_DTensor": True}, - "model.layers.1.feed_forward.w2.weight": {"requires_grad": False, "is_DTensor": True}, - "model.layers.1.feed_forward.w2.bias": {"requires_grad": False, "is_DTensor": True}, - "model.norm.weight": {"requires_grad": True, "is_DTensor": True}, - "model.norm.bias": {"requires_grad": True, "is_DTensor": True}, - "model.output.weight": {"requires_grad": True, "is_DTensor": True}, + {'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True} }}, - 3, - ), + 3), 1: ( - {"p_states": { - "model.layers.0.feed_forward.w2.weight": {"requires_grad": True, "is_DTensor": True}, - "model.layers.0.feed_forward.w2.bias": {"requires_grad": True, "is_DTensor": True}, - "model.layers.1.feed_forward.w2.weight": {"requires_grad": True, "is_DTensor": True}, - "model.layers.1.feed_forward.w2.bias": {"requires_grad": True, "is_DTensor": True}, - "model.norm.weight": {"requires_grad": True, "is_DTensor": True}, - "model.norm.bias": {"requires_grad": True, "is_DTensor": True}, - "model.output.weight": {"requires_grad": True, "is_DTensor": True}, + {'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True} }}, - 27, - ), + 15), 2: ( - {"p_states": { - "model.layers.0.feed_forward.w2.weight": {"requires_grad": True, "is_DTensor": True}, - "model.layers.0.feed_forward.w2.bias": {"requires_grad": True, "is_DTensor": True}, - "model.layers.1.feed_forward.w2.weight": {"requires_grad": True, "is_DTensor": True}, - "model.layers.1.feed_forward.w2.bias": {"requires_grad": True, "is_DTensor": True}, - "model.norm.weight": {"requires_grad": True, "is_DTensor": True}, - "model.norm.bias": {"requires_grad": True, "is_DTensor": True}, - "model.output.weight": {"requires_grad": True, "is_DTensor": True}, + {'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True} }}, - 29, - ), + 29), } -path_ff_fsdp_tp = { - 0: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": False, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([1, 64]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "shard(dim=1)"], - ), - }, - "model.w3.bias": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2]), - "local_shape": torch.Size([1]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=1, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "replica"], - ), - }, - }, - "fsdp_mod_states": { - "model.w2": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w2.weight", torch.Size([64, 32]), torch.Size([32, 32])), - ("w2.bias", torch.Size([64]), torch.Size([32])), - ], - }, - "model.w3": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w3.weight", torch.Size([2, 64]), torch.Size([1, 64])), - ("w3.bias", torch.Size([2]), torch.Size([1])), - ], - }, - }, - }, - 2, - ), - 1: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([1, 64]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "shard(dim=1)"], - ), - }, - "model.w3.bias": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2]), - "local_shape": torch.Size([1]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=1, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "replica"], - ), - }, - }, - "fsdp_mod_states": { - "model.w2": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w2.weight", torch.Size([64, 32]), torch.Size([32, 32])), - ("w2.bias", torch.Size([64]), torch.Size([32])), - ], - }, - "model.w3": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w3.weight", torch.Size([2, 64]), torch.Size([1, 64])), - ("w3.bias", torch.Size([2]), torch.Size([1])), - ], - }, - }, - }, - 4, - ), +path_tt_fsdp_tp = { + 0: ({'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True}}, + 'fsdp_mod_states': { + 'model.layers.0': {'is_fsdp_managed': True, 'is_fsdp_composed': False}, + 'model.layers.1': { + 'is_fsdp_managed': True, 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + (None, torch.Size([200]), torch.Size([100])), + (None, torch.Size([200]), torch.Size([100])), + (None, torch.Size([200, 200]), torch.Size([100, 200])), + (None, torch.Size([200, 200]), torch.Size([100, 200])), + (None, torch.Size([200, 200]), torch.Size([100, 200])), + (None, torch.Size([200, 200]), torch.Size([100, 200])), + (None, torch.Size([200]), torch.Size([100])), + (None, torch.Size([200]), torch.Size([100])), + (None, torch.Size([800, 200]), torch.Size([400, 200])), + (None, torch.Size([800]), torch.Size([400])), + (None, torch.Size([200, 800]), torch.Size([100, 800])), + (None, torch.Size([200]), torch.Size([100]))]}, + 'model.norm': {'is_fsdp_managed': True, 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + (None, torch.Size([200]), torch.Size([100])), + (None, torch.Size([200]), torch.Size([100]))]}, + 'model.output': {'is_fsdp_managed': True, 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + (None, torch.Size([33278, 200]), torch.Size([16639, 200]))]}}}, + 3), + 1: ({'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': False, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True}}, + 'fsdp_mod_states': { + 'model.layers.0': {'is_fsdp_managed': True, 'is_fsdp_composed': False}, + 'model.layers.1': { + 'is_fsdp_managed': True, 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + ('layers.1.attention_norm.weight', torch.Size([200]), torch.Size([100])), + ('layers.1.attention_norm.bias', torch.Size([200]), torch.Size([100])), + ('layers.1.attention.wq.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.attention.wk.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.attention.wv.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.attention.wo.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.ffn_norm.weight', torch.Size([200]), torch.Size([100])), + ('layers.1.ffn_norm.bias', torch.Size([200]), torch.Size([100])), + ('layers.1.feed_forward.w1.weight', torch.Size([800, 200]), torch.Size([400, 200])), + ('layers.1.feed_forward.w1.bias', torch.Size([800]), torch.Size([400])), + ('layers.1.feed_forward.w2.weight', torch.Size([200, 800]), torch.Size([100, 800])), + ('layers.1.feed_forward.w2.bias', torch.Size([200]), torch.Size([100]))]}, + 'model.norm': {'is_fsdp_managed': True, 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + ('norm.weight', torch.Size([200]), torch.Size([100])), + ('norm.bias', torch.Size([200]), torch.Size([100]))]}, + 'model.output': {'is_fsdp_managed': True, 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + ('output.weight', torch.Size([33278, 200]), torch.Size([16639, 200]))]}}}, + 15), 2: ( { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([1, 64]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "shard(dim=1)"], - ), - }, - "model.w3.bias": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2]), - "local_shape": torch.Size([1]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=1, mesh_ndim=2, mesh_shape=(2, 1), - mesh_dim_names=("data_parallel", "tensor_parallel"), - placement_summary=["shard(dim=0)", "replica"], - ), - }, + 'p_states': { + 'model.layers.0.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.0.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.layers.1.feed_forward.w2.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.weight': {'requires_grad': True, 'is_DTensor': True}, + 'model.norm.bias': {'requires_grad': True, 'is_DTensor': True}, + 'model.output.weight': {'requires_grad': True, 'is_DTensor': True}, }, - "fsdp_mod_states": { - "model.w2": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w2.weight", torch.Size([64, 32]), torch.Size([32, 32])), - ("w2.bias", torch.Size([64]), torch.Size([32])), + 'fsdp_mod_states': { + 'model.layers.0': { + 'is_fsdp_managed': True, + 'is_fsdp_composed': False + }, + 'model.layers.1': { + 'is_fsdp_managed': True, + 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + ('layers.1.attention_norm.weight', torch.Size([200]), torch.Size([100])), + ('layers.1.attention_norm.bias', torch.Size([200]), torch.Size([100])), + ('layers.1.attention.wq.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.attention.wk.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.attention.wv.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.attention.wo.weight', torch.Size([200, 200]), torch.Size([100, 200])), + ('layers.1.ffn_norm.weight', torch.Size([200]), torch.Size([100])), + ('layers.1.ffn_norm.bias', torch.Size([200]), torch.Size([100])), + ('layers.1.feed_forward.w1.weight', torch.Size([800, 200]), torch.Size([400, 200])), + ('layers.1.feed_forward.w1.bias', torch.Size([800]), torch.Size([400])), + ('layers.1.feed_forward.w2.weight', torch.Size([200, 800]), torch.Size([100, 800])), + ('layers.1.feed_forward.w2.bias', torch.Size([200]), torch.Size([100])), ], }, - "model.w3": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w3.weight", torch.Size([2, 64]), torch.Size([1, 64])), - ("w3.bias", torch.Size([2]), torch.Size([1])), + 'model.norm': { + 'is_fsdp_managed': True, + 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [ + ('norm.weight', torch.Size([200]), torch.Size([100])), + ('norm.bias', torch.Size([200]), torch.Size([100])), ], }, - }, - }, - 6, - ), -} - -path_ff_fsdp_no_tp = { - 0: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": False, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("data_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([1, 64]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("data_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - }, - "fsdp_mod_states": { - "model.w2": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w2.weight", torch.Size([64, 32]), torch.Size([32, 32])), - ("w2.bias", torch.Size([64]), torch.Size([32])), - ], - }, - "model.w3": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w3.weight", torch.Size([2, 64]), torch.Size([1, 64])), - ("w3.bias", torch.Size([2]), torch.Size([1])), - ], - }, - }, - }, - 2, - ), - 1: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("data_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([1, 64]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("data_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - }, - "fsdp_mod_states": { - "model.w2": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w2.weight", torch.Size([64, 32]), torch.Size([32, 32])), - ("w2.bias", torch.Size([64]), torch.Size([32])), - ], - }, - "model.w3": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w3.weight", torch.Size([2, 64]), torch.Size([1, 64])), - ("w3.bias", torch.Size([2]), torch.Size([1])), - ], - }, - }, - }, - 4, - ), - 2: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("data_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([1, 64]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("data_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - }, - "fsdp_mod_states": { - "model.w2": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w2.weight", torch.Size([64, 32]), torch.Size([32, 32])), - ("w2.bias", torch.Size([64]), torch.Size([32])), - ], - }, - "model.w3": { - "is_fsdp_managed": True, - "prec_policy_summ": (None, None, None, True), - "param_group_summ": [ - ("w3.weight", torch.Size([2, 64]), torch.Size([1, 64])), - ("w3.bias", torch.Size([2]), torch.Size([1])), - ], + 'model.output': { + 'is_fsdp_managed': True, + 'is_fsdp_composed': True, + 'prec_policy_summ': (None, None, None, True), + 'param_group_summ': [('output.weight', torch.Size([33278, 200]), torch.Size([16639, 200]))], }, }, }, - 6, - ), -} - -path_ff_tp_no_fsdp = { - 0: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": False, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([2, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["shard(dim=1)"], - ), - }, - "model.w3.bias": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2]), - "local_shape": torch.Size([2]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=1, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["replica"], - ), - }, - }, - "fsdp_mod_states": {}, - }, - 2, - ), - 1: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([2, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["shard(dim=1)"], - ), - }, - "model.w3.bias": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2]), - "local_shape": torch.Size([2]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=1, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["replica"], - ), - }, - }, - "fsdp_mod_states": {}, - }, - 4, - ), - 2: ( - { - "p_states": { - "model.w2.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([64, 32]), - "local_shape": torch.Size([32, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["shard(dim=0)"], - ), - }, - "model.w3.weight": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2, 64]), - "local_shape": torch.Size([2, 32]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=2, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["shard(dim=1)"], - ), - }, - "model.w3.bias": { - "is_DTensor": True, - "requires_grad": True, - "dtype": torch.float32, - "orig_shape": torch.Size([2]), - "local_shape": torch.Size([2]), - "device_mesh": DeviceMeshSummary( - tensor_ndim=1, - mesh_ndim=1, - mesh_shape=(2,), - mesh_dim_names=("tensor_parallel",), - placement_summary=["replica"], - ), - }, - }, - "fsdp_mod_states": {}, - }, - 6, + 29, ), } diff --git a/tests/test_finetuning_scheduler_callback.py b/tests/test_finetuning_scheduler_callback.py index 3ebadbb..b274e14 100644 --- a/tests/test_finetuning_scheduler_callback.py +++ b/tests/test_finetuning_scheduler_callback.py @@ -337,12 +337,12 @@ def on_fit_start(self, trainer, pl_module) -> None: raise SystemExit(0) def state_dict(self) -> Dict[str, Any]: - self.best_ckpt_test_weight = self.pl_module._modules["layer"]._modules["3"].bias.data.detach().clone() + self.best_ckpt_test_weight = self.pl_module._modules["model"]._modules["3"].bias.data.detach().clone() return super().state_dict() def restore_best_ckpt(self) -> None: super().restore_best_ckpt() - assert torch.equal(self.pl_module._modules["layer"]._modules["3"].bias.data, self.best_ckpt_test_weight) + assert torch.equal(self.pl_module._modules["model"]._modules["3"].bias.data, self.best_ckpt_test_weight) self.restored_best_cnt += 1 def on_train_epoch_start(self, trainer, pl_module): @@ -742,11 +742,6 @@ def boring_ft_schedule(tmpdir_factory) -> Tuple[Path, Dict]: 2: {'params': ['model.lin_base.bias', 'model.lin_base.weight']}} bn_sched_dict[0]["max_transition_epoch"] = 1 bn_sched_dict[1]["max_transition_epoch"] = 2 - # mp_tp_sched_dict = {0: {'params': ['model.w3.bias', 'model.w3.weight']}, - # 1: {'params': ['model.w2.bias', 'model.w2.weight']}, - # 2: {'params': ['model.w1.bias', 'model.w1.weight']}} - # mp_tp_sched_dict[0]["max_transition_epoch"] = 1 - # mp_tp_sched_dict[1]["max_transition_epoch"] = 2 return ( unmod_schedule_file, mod_sched_dict, diff --git a/tests/test_fsdp.py b/tests/test_fsdp.py index 4637e75..2e87b7d 100644 --- a/tests/test_fsdp.py +++ b/tests/test_fsdp.py @@ -37,8 +37,6 @@ EXPECTED_WARNS, ExplicitLossFTSCheckpoint, FinetuningSchedulerBoringModel, - get_fts, - nones, TestFinetuningScheduler, get_sched_fixture_tmpdir, ) @@ -624,9 +622,9 @@ def policy(self): # awp_overrides configuration aliases awp_5_9 = {"awp_overrides": ["model.9", "model.5"]} -awp_1 = {"awp_overrides": ["l.*yer.1"]} +awp_1 = {"awp_overrides": ["m.*del.1"]} awp_7 = {"awp_overrides": ["model.7"]} -awp_7_8 = {"awp_overrides": ["l.*yer.8", "model.7"]} +awp_7_8 = {"awp_overrides": ["m.*del.8", "model.7"]} # FSDP strategy configuration aliases act_ckpt_cfg = {"activation_checkpointing_policy": {torch.nn.Linear}, **DISABLE_USE_ORIG} @@ -828,7 +826,7 @@ def policy(self): ] -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=False) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize( "model_cfg_key, model_cls, auto_wrap_policy, use_precision, ft_sched_idx, model_cfg, strategy_adapter_cfg, fts_cfg,\ trainer_cfg, strategy_cfg", @@ -898,7 +896,7 @@ def test_fsdp_multi_gpus_resume(tmpdir, recwarn, fsdp_ft_schedules, fsdp_ckpt, m def test_fsdp_get_bn_unwrapped(): """Conservative (end-to-end) test for FTS training resumption with FSDP.""" test_adapter = FSDPStrategyAdapter() - test_adapter.scheduled_mod_lists = {0: ['model.0']} + test_adapter.scheduled_mod_lists = {0: ['layer.0']} test_module = torch.nn.Module() test_module.layer = torch.nn.Sequential(torch.nn.BatchNorm1d(32)) setattr(test_adapter, 'fts_handle', FinetuningScheduler()) diff --git a/tests/test_model_parallel.py b/tests/test_model_parallel.py index abec4e5..d0a2429 100644 --- a/tests/test_model_parallel.py +++ b/tests/test_model_parallel.py @@ -9,7 +9,6 @@ import pytest import torch -from torch.distributed._tensor.debug import CommDebugMode from lightning.pytorch import seed_everything, Trainer from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision from lightning.pytorch.strategies import ModelParallelStrategy @@ -23,7 +22,6 @@ from torch.distributed.tensor.parallel import loss_parallel from torch.distributed._composable.fsdp import FSDPModule from torch.distributed._tensor import Replicate, Shard -from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( ColwiseParallel, PrepareModuleInput, @@ -33,12 +31,12 @@ ) from finetuning_scheduler import FinetuningScheduler, FTSCheckpoint, FTSEarlyStopping -from finetuning_scheduler.strategy_adapters import FSDPStrategyAdapter +from finetuning_scheduler.strategy_adapters import ModelParallelStrategyAdapter from tests.helpers.boring_models import FTSToyTransformer, TestModelArgs, FTSWikiText2 from tests.helpers.common import (ExpectedResults, fts_check_warns, pytest_param_factory, get_fts, default_fts_sanity_chk, DeviceMeshSummary) -from tests.model_parallel_expected_paths import (path_ff_tp_no_fsdp, path_ff_fsdp_no_tp, path_ff_fsdp_tp, +from tests.model_parallel_expected_paths import (path_tt_fsdp_tp, path_tt_fsdp_no_tp, path_tt_tp_no_fsdp) from tests.helpers.runif import RunIf @@ -51,21 +49,6 @@ ) FTS_GLOBAL_STATE_LOG_MODE = os.environ.get("FTS_GLOBAL_STATE_LOG_MODE", "0") == "1" - - -if torch.distributed.is_available(): - from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision - from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, wrap -else: - FullyShardedDataParallel = None # type: ignore[misc,assignment] - MixedPrecision = None # type: ignore[misc,assignment] - BackwardPrefetch = None # type: ignore[misc,assignment] - CPUOffload = None # type: ignore[misc,assignment] - size_based_auto_wrap_policy = object - wrap = object - - MODEL_PARALLEL_BASE_WARNS = copy(EXPECTED_WARNS) additional_model_parallel_warns = [ "model contains an instance of `UninitializedParameter`", @@ -76,85 +59,17 @@ "torch.cpu.amp.autocast", # required as of PT 2.4 "FSDP.state_dict_type", # temporarily required until Lightning uses new FSDP state dict API with PT 2.4 "Final phase max_transition_epoch", # required for some experimental dtensor tests with PT 2.4 + "interactive_bk attribute", # TODO: remove, only for temporary debugging with torch from source ] MODEL_PARALLEL_BASE_WARNS.extend(additional_model_parallel_warns) MODEL_PARALLEL_DYNAMO_EXPECTED_WARNS = [ "Final phase max_transition_epoch", # still required for PyTorch/Lightning <=2.4 ] - -# def _parallelize_base_model_parallel_tp(model, device_mesh): -# from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module - -# tp_mesh = device_mesh["tensor_parallel"] -# tp_plan = { -# "w1": ColwiseParallel(), -# "w2": ColwiseParallel(), -# "w3": RowwiseParallel(), -# } -# parallelize_module(model, tp_mesh, tp_plan) -# return model - -# def _parallelize_feed_forward_tp(model, device_mesh): -# from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module - -# tp_mesh = device_mesh["tensor_parallel"] -# tp_plan = { -# "w1": ColwiseParallel(), -# "w2": ColwiseParallel(), -# "w3": RowwiseParallel(), -# } -# parallelize_module(model, tp_mesh, tp_plan) -# return model - - -# def _parallelize_base_model_parallel_fsdp2(model, device_mesh): -# from torch.distributed._composable.fsdp.fully_shard import fully_shard - -# dp_mesh = device_mesh["data_parallel"] -# assert dp_mesh.ndim == 1 # Hybrid-sharding not supported - -# # Fully-shard each layer -# fully_shard(model.w1, mesh=dp_mesh) -# fully_shard(model.w2, mesh=dp_mesh) -# fully_shard(model.w3, mesh=dp_mesh) - -# # TODO: Re-enable activation checkpointing -# # Currently, state dict keys get prefixed with '_checkpoint_wrapper' in the keys -# # which leads to mismatches when loading weights into a checkpoint-wrapped module. -# # PyTorch should handle this automatically. - -# # model = checkpoint_wrapper(model) - -# return model - - -# def _parallelize_base_model_parallel_fsdp2_tp(model, device_mesh): -# model = _parallelize_base_model_parallel_tp(model, device_mesh) -# model = _parallelize_base_model_parallel_fsdp2(model, device_mesh) -# return model - -# class DeviceMeshSummary(NamedTuple): -# tensor_ndim: int -# mesh_ndim: int -# mesh_shape: Tuple -# mesh_dim_names: Tuple -# placement_summary: List[Optional[str | int]] - ################################################################################ # Model Parallel Test Models ################################################################################ -class FeedForward(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - self.w1 = nn.Linear(32, 64) - self.w2 = nn.Linear(32, 64) - self.w3 = nn.Linear(64, 2) - - def forward(self, x): - return self.w3(F.silu(self.w1(x)) * self.w2(x)) - class FTSBaseModelParallel(FinetuningSchedulerBoringModel): def __init__(self, fsdp_plan: Dict, tp_plan: Dict | Callable, module_cls: nn.Module = FTSToyTransformer, loss_parallel: bool = True, @@ -189,14 +104,8 @@ def backward(self, *args, **kwargs): def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: inputs, target = batch - #output = self(inputs, target) - with CommDebugMode() as comm_mode: - output = self(inputs) - #print(comm_mode.advanced_module_tracker.sharding_dict) - #with CommDebugMode() as comm_mode: - loss = self.loss_fn(output, target) - #comm_mode.advanced_module_tracker.sharding_dict - #loss = F.cross_entropy(output.reshape(-1, output.size(-1)), target.reshape(-1)) + output = self(inputs) + loss = self.loss_fn(output, target) self.training_step_outputs.append(loss) return {"loss": loss} @@ -208,12 +117,9 @@ def on_val_epoch_end(self) -> None: def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: inputs, target = batch - #output = self(inputs, target) output = self(inputs) - #loss = self.val_loss(batch, output) # TODO: for now, not using diverge_on_epoch for simplicity loss = self.loss_fn(output, target) - #loss = F.cross_entropy(output.reshape(-1, output.size(-1)), target.reshape(-1)) self.validation_step_outputs.append(loss) # we would normally use sync_dist for epoch-only logging in a distributed context but leaving it `False` here # to test FTS transition behavior when the test model is used in a distributed context @@ -379,69 +285,39 @@ def model_parallel_ft_schedule(tmpdir_factory) -> Tuple[Path, Dict]: if rank == 0: with pytest.raises(SystemExit): trainer.fit(model) - # TODO: NEXT: continue setting up ftstransformer schedule fixture! mp_tp_sched_dict = get_fts(trainer).load_yaml_schedule(unmod_schedule_file) - mp_tp_sched_dict[0]["params"] = [r"model.output.weight", r"model.norm.*", - #r"model.tok_embeddings.weight", - #r"model.(pos_embeddings|tok_embeddings).weight", - ] + mp_tp_sched_dict[0]["params"] = [r"model.output.weight", r"model.norm.*"] mp_tp_sched_dict[0]["max_transition_epoch"] = 1 - mp_tp_sched_dict[1]["params"] = [#r"model.pos_embeddings.weight", - r"model.layers.[0-1].(feed_forward|ffn_norm|attention|attention_norm).*"] + mp_tp_sched_dict[1]["params"] = [r"model.layers.1.(feed_forward|ffn_norm|attention.w.*|attention_norm).*"] mp_tp_sched_dict[1]["max_transition_epoch"] = 2 - mp_tp_sched_dict[2]["params"] = [#r"model.layers.0.(feed_forward|ffn_norm|attention|attention_norm).*", + mp_tp_sched_dict[2]["params"] = [r"model.layers.0.(feed_forward|ffn_norm|attention.w.*|attention_norm).*", r"model.(pos_embeddings|tok_embeddings).weight"] mp_tp_sched_dict.pop(3) - # mp_tp_all_req_grad = deepcopy(mp_tp_sched_dict) - # mp_tp_all_req_grad[0]["params"] = [r"model.output.weight", r"model.norm.*", - # r"model.layers.[0-1].(feed_forward|ffn_norm|attention|attention_norm).*"] - # mp_tp_all_req_grad[1]["params"] = [r"model.(pos_embeddings|tok_embeddings).weight"] - # mp_tp_all_req_grad.pop(2) - mp_tp_all_req_grad = deepcopy(mp_tp_sched_dict) - mp_tp_all_req_grad[0]["params"] = [r"model.output.weight", r"model.norm.*", - r"model.layers.[0-1].(feed_forward|ffn_norm|attention|attention_norm).*", - r"model.(pos_embeddings|tok_embeddings).weight"] - #mp_tp_all_req_grad[1]["params"] = [r"model.(pos_embeddings|tok_embeddings).weight"] - mp_tp_all_req_grad.pop(1) - mp_tp_all_req_grad.pop(2) - mp_tp_ln_no_grad = deepcopy(mp_tp_sched_dict) - mp_tp_ln_no_grad[0]["params"] = [r"model.output.weight", r"model.norm.*"] - mp_tp_ln_no_grad[1]["params"] = [r"model.layers.[0-1].(feed_forward|ffn_norm|attention|attention_norm).*"] - mp_tp_ln_no_grad[2]["params"] = [r"model.(pos_embeddings|tok_embeddings).weight"] - # mp_tp_sched_dict = {0: {'params': ['model.w3.bias', 'model.w3.weight']}, - # 1: {'params': ['model.w2.bias', 'model.w2.weight']}, - # 2: {'params': ['model.w1.bias', 'model.w1.weight']}} - # mp_tp_sched_dict[0]["max_transition_epoch"] = 1 - # mp_tp_sched_dict[1]["max_transition_epoch"] = 2 + mp_tp_dbg_req_grad = deepcopy(mp_tp_sched_dict) + mp_tp_dbg_req_grad[0]["params"] = [r"model.layers.1.attention_norm.*", r"model.layers.1.feed_forward.w2.*", + ] + mp_tp_dbg_req_grad[1] = {"params": [ + r"model.output.weight", + r"model.norm.*", + r"model.layers.[0-1].(ffn_norm|attention.w.*).*", + r"model.layers.0.attention_norm.*", + r"model.layers.[0-1].feed_forward.w1.*", + r"model.layers.0.feed_forward.w2.*", + ]} + mp_tp_dbg_req_grad[2] = {"params": [r"model.pos_embeddings.weight",]} + return ( unmod_schedule_file, mp_tp_sched_dict, - mp_tp_all_req_grad, - mp_tp_ln_no_grad, - #mp_tp_sched_dict + mp_tp_dbg_req_grad, ) -# class FSDP2Model(FTSBaseModelParallel): - -# def configure_model(self): -# _parallelize_base_model_parallel_fsdp2(self.model, device_mesh=self.device_mesh) - -# class TensorParallelModel(FTSBaseModelParallel): -# def configure_model(self): -# _parallelize_base_model_parallel_tp(self.model, device_mesh=self.device_mesh) - - -# class FSDP2TensorParallelModel(FTSBaseModelParallel): -# def configure_model(self): -# _parallelize_base_model_parallel_fsdp2_tp(self.model, device_mesh=self.device_mesh) - # modified version of https://bit.ly/torchtitan_transformer_tp_plan def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, loss_parallel: bool) -> nn.Module: """Apply tensor parallelism.""" # we're only applying tensor parallelism, composable fsdp is applied subsequently elsewhere if requested tp_mesh = device_mesh["tensor_parallel"] - #loss_parallel = enable_loss_parallel # 1. Parallelize the embedding and shard its outputs # 2. Parallelize the root norm layer over the sequence dim @@ -451,12 +327,6 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los ), "pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)), "norm": SequenceParallel(), - # TODO: prob not necessary, inspect opportunities for refactoring/optimization - # "layers.0": PrepareModuleInput( - # input_layouts=(Replicate(), None), - # desired_input_layouts=(Shard(1), None), - # use_local_output=True, - # ), } model = parallelize_module(model, tp_mesh, non_transformerblock_tp_plan) @@ -474,12 +344,12 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los #for layer_id, transformer_block in model.layers.items(): # support # we currently support `ModuleList` and `ModuleDict` transformer_block containers - if isinstance(model.layers, nn.ModuleList): - module_iterable = model.layers - elif isinstance(model.layers, nn.ModuleDict): - module_iterable = model.layers.values() - else: - raise "Unsupported transformer_block container, expected `ModuleList` or `ModuleDict` model.layers" + # if isinstance(model.layers, nn.ModuleList): + # module_iterable = model.layers + # elif isinstance(model.layers, nn.ModuleDict): + # module_iterable = model.layers.values() + # else: + # raise "Unsupported transformer_block container, expected `ModuleList` or `ModuleDict` model.layers" for transformer_block in model.layers: layer_plan = { "attention": PrepareModuleInput( @@ -487,21 +357,12 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los desired_input_layouts=Replicate(), ), "attention_norm": SequenceParallel(), - # "attention": PrepareModuleInput( - # input_layouts=(Shard(1), None), - # desired_input_layouts=(Replicate(), None), - # ), - - "attention.wq": ColwiseParallel(use_local_output=False), # try use_local_output=False ? - "attention.wk": ColwiseParallel(use_local_output=False), # try use_local_output=False ? - "attention.wv": ColwiseParallel(use_local_output=False), # try use_local_output=False ? + + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), - # "feed_forward": PrepareModuleInput( - # input_layouts=(Shard(1),), - # desired_input_layouts=(Replicate(),), - # ), - #"feed_forward.w1": ColwiseParallel(), "feed_forward.w1": ColwiseParallel(input_layouts=Shard(1)), "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), } @@ -526,12 +387,6 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los ) parallelize_module(model.output, tp_mesh, output_parallelize_plan) - # "output": ColwiseParallel( - # input_layouts=Shard(1), - # output_layouts=Shard(-1) if loss_parallel else Replicate(), - # use_local_output=not loss_parallel, - # ), - # Manually set output.weight so that parameters and gradients are shared. if model.init_args.weight_tying: model.output.weight = model.tok_embeddings.weight @@ -544,42 +399,24 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los ################################################################################ ## Model Aliases -ff_mod_parallel = FTSBaseModelParallel - tt_mod_parallel = FTSBaseModelParallel ## DTensor Placement Plan Aliases - -basic_tp_plan = { - "w1": ColwiseParallel(), - "w2": ColwiseParallel(), - "w3": RowwiseParallel(), - } - # TODO: set tp_plan and model loss_parallel from same config tt_tp_plan = gen_apply_transformer_tp_plan -#tt_tp_plan = partial(gen_apply_transformer_tp_plan, enable_loss_parallel=True) -# tt_tp_plan_no_loss_parallel = partial(gen_apply_transformer_tp_plan, enable_loss_parallel=False) ## FSDP2 Model Configuration Aliases - -shard_all = {"sharded_mods": ['model.w1', 'model.w2', 'model.w3'], "unsharded_mods": []} -shard_tt_basic = {"sharded_mods": ['model.layers.1', 'model.norm', 'model.output'], "unsharded_mods": []} +shard_tt_basic = {"sharded_mods": ['model.layers.1', 'model.norm', 'model.output']} ## toy transformer cfgs - basic_tt = TestModelArgs() ## Model Parallel Model Configuration Aliases -tp_no_fsdp = {"fsdp_plan": None, "tp_plan": basic_tp_plan} -fsdp_no_tp = {"fsdp_plan": shard_all, "tp_plan": None} -fsdp_tp = {"fsdp_plan": shard_all, "tp_plan": basic_tp_plan} - +tt_fsdp_tp = {"fsdp_plan": shard_tt_basic, "tp_plan": tt_tp_plan, "module_cls": FTSToyTransformer, "tt_cfg": basic_tt} tt_fsdp_no_tp = {"fsdp_plan": shard_tt_basic, "tp_plan": None, "module_cls": FTSToyTransformer, "tt_cfg": basic_tt} -tt_tp_no_fsdp_lp = {"fsdp_plan": None, "tp_plan": tt_tp_plan, "module_cls": FTSToyTransformer, "tt_cfg": basic_tt, - "loss_parallel": True} -tt_tp_no_fsdp_no_lp = {"fsdp_plan": None, "tp_plan": tt_tp_plan, "module_cls": FTSToyTransformer, "tt_cfg": basic_tt, - "loss_parallel": False} +tt_tp_no_fsdp = {"fsdp_plan": None, "tp_plan": tt_tp_plan, "module_cls": FTSToyTransformer, "tt_cfg": basic_tt} +tt_tp_no_fsdp_lp = {**tt_tp_no_fsdp, "loss_parallel": True} +tt_tp_no_fsdp_no_lp = {**tt_tp_no_fsdp, "loss_parallel": False} ## Model Parallel Strategy Aliases dp1_tp2 = {"data_parallel_size": 1, "tensor_parallel_size": 2} @@ -598,7 +435,6 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los max_depth_0 = {"max_depth": 0} no_restore_best = {"restore_best": False} -# with mock.patch.object(ModelCheckpoint, "_save_checkpoint"): ## Model Parallel Test Configuration Dataclass @@ -631,54 +467,32 @@ def __post_init__(self): self.es_cfg = {**self.es_cfg, **default_dep_cfg} self.ckpt_cfg = {**self.ckpt_cfg, **default_dep_cfg} +@mock.patch("finetuning_scheduler.strategy_adapters.model_parallel._TORCH_GREATER_EQUAL_2_5", False) +def test_torch_greater_equal_2_5(): + with pytest.raises(MisconfigurationException, match="requires PyTorch 2.5 or higher"): + ModelParallelStrategyAdapter() ## Model Parallel Test Definitions FTS_MODEL_PARALLEL_TESTS = ( - ModelParallelTestConfig(model_cfg_key="ff_tp_no_fsdp", model_cls=ff_mod_parallel, model_cfg=tp_no_fsdp, - strategy_cfg=dp1_tp2, runif_alias="min2_4", - expected_results=ExpectedResults(expected_state=path_ff_tp_no_fsdp)), - ModelParallelTestConfig(model_cfg_key="ff_fsdp_no_tp", model_cls=ff_mod_parallel, - model_cfg=fsdp_no_tp, - strategy_cfg=dp2_tp1, runif_alias="min2_4", - expected_results=ExpectedResults(expected_state=path_ff_fsdp_no_tp)), - ModelParallelTestConfig(model_cfg_key="ff_fsdp_tp", model_cls=ff_mod_parallel, - model_cfg=fsdp_tp, fts_cfg=no_restore_best, ckpt_cfg=no_ckpt_save, - strategy_cfg=dp2_tp1, runif_alias="min2_4", - expected_results=ExpectedResults(expected_state=path_ff_fsdp_tp)), + ModelParallelTestConfig(model_cfg_key="tt_fsdp_tp", model_cls=tt_mod_parallel, + model_cfg=tt_fsdp_tp, fts_cfg=no_restore_best, ckpt_cfg=no_ckpt_save, + strategy_cfg=dp2_tp1, runif_alias="einsum_exp", + expected_results=ExpectedResults(expected_state=path_tt_fsdp_tp)), ModelParallelTestConfig(model_cfg_key="tt_fsdp_no_tp", model_cls=tt_mod_parallel, - model_cfg=tt_fsdp_no_tp, strategy_cfg=dp2_tp1, runif_alias="min2_4", + model_cfg=tt_fsdp_no_tp, strategy_cfg=dp2_tp1, runif_alias="min2_5", expected_results=ExpectedResults(expected_state=path_tt_fsdp_no_tp)), - ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp", model_cls=tt_mod_parallel, - model_cfg=tt_tp_no_fsdp_no_lp, strategy_cfg=dp1_tp2, runif_alias="min2_4", - expected_results=ExpectedResults(expected_state=path_tt_tp_no_fsdp)), - # ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_no_lp", model_cls=tt_mod_parallel, - # model_cfg=tt_tp_no_fsdp_no_lp, strategy_cfg=dp1_tp2, ft_sched_idx=1, - # runif_alias="min2_4", - # expected_results=ExpectedResults(expected_state=path_tt_tp_no_fsdp)), - ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_no_lp_no_lnreq", model_cls=tt_mod_parallel, - model_cfg=tt_tp_no_fsdp_lp, strategy_cfg=dp1_tp2, ft_sched_idx=3, - runif_alias="min2_4", + ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_lp", model_cls=tt_mod_parallel, + model_cfg=tt_tp_no_fsdp_lp, strategy_cfg=dp1_tp2, runif_alias="min2_5", expected_results=ExpectedResults(expected_state=path_tt_tp_no_fsdp)), - ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_lp_all_req", model_cls=tt_mod_parallel, - ft_sched_idx=2, fts_cfg=max_depth_0, - model_cfg=tt_tp_no_fsdp_lp, strategy_cfg=dp1_tp2, runif_alias="min2_4", + ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_no_lp", model_cls=tt_mod_parallel, + model_cfg=tt_tp_no_fsdp_no_lp, strategy_cfg=dp1_tp2, runif_alias="min2_5", expected_results=ExpectedResults(expected_state=path_tt_tp_no_fsdp)), - ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_no_lp_all_req", model_cls=tt_mod_parallel, - model_cfg=tt_tp_no_fsdp_no_lp, strategy_cfg=dp1_tp2, ft_sched_idx=2, fts_cfg=max_depth_0, - runif_alias="min2_4", - expected_results=ExpectedResults(expected_state=path_tt_tp_no_fsdp)), - ) - - -@RunIf(standalone=False, min_cuda_gpus=2) +@RunIf(min_cuda_gpus=2, standalone=True) @pytest.mark.parametrize("test_cfg", pytest_param_factory(FTS_MODEL_PARALLEL_TESTS)) def test_fts_model_parallel(tmpdir, recwarn, model_parallel_ft_schedule, test_cfg): """Validate :class:`~finetuning_scheduler.FinetuningScheduler` functions properly in a supported 'ddp' distributed context.""" - # some experimental tests may require version/os-env gated dependency patches to be applied, they may be loaded here - if test_cfg.model_cfg_key in ("ff_fsdp_tp"): - pass seed_everything(42) # one can manually set this to True for a local test override state_log_dir = tmpdir if FTS_GLOBAL_STATE_LOG_MODE else None @@ -703,13 +517,8 @@ def test_fts_model_parallel(tmpdir, recwarn, model_parallel_ft_schedule, test_cf expected_warns_dynamo=MODEL_PARALLEL_DYNAMO_EXPECTED_WARNS, use_dynamo=use_dynamo) -def gen_exceptions(trainer, model, model_cfg_key, exception_expected): - if model_cfg_key == "no_fsdp_params_p0": - with mock.patch.object(FSDPStrategyAdapter, "_rank_zero_logger", 42): - with pytest.raises(MisconfigurationException, match=exception_expected): - trainer.fit(model) - else: - with pytest.raises(MisconfigurationException, match=exception_expected): +def gen_exceptions(trainer, model, exception_expected): + with pytest.raises(MisconfigurationException, match=exception_expected): trainer.fit(model) @@ -722,100 +531,3 @@ def callbacks_cfg(ft_sched, state_log_dir, test_cfg): if issubclass(tcls, subc): callbacks.append(tcls(**tcfg)) return callbacks - - - - -# @RunIf(min_torch="2.3", standalone=False, min_cuda_gpus=2) -# def test_fsdp2_trivial_tp(): -# from torch.distributed._tensor import DTensor - -# class Model(FSDP2TensorParallelModel): -# def on_train_start(self): -# optimizer = self.optimizers() -# #assert all(isinstance(weight, DTensor) for weight in self.model.parameters()) -# #assert all(isinstance(tensor, DTensor) for tensor in optimizer.param_groups[0]["params"]) -# assert self.model.w1.weight.device_mesh.ndim == 2 -# assert self.model.w1.weight.device_mesh.size(0) == 2 -# assert self.model.w1.weight.device_mesh.size(1) == 1 -# #assert all(weight.device.type != "meta" for weight in self.model.parameters()) -# #assert all(tensor.device_mesh.ndim == 2 for tensor in optimizer.param_groups[0]["params"]) -# #assert all(tensor.device.type != "meta" for tensor in optimizer.param_groups[0]["params"]) - -# # No data sharding across TP dimension, sharding across data-parallel dimension only -# device_mesh = self.device_mesh -# dp_mesh = device_mesh["data_parallel"] -# dataloader = self.trainer.train_dataloader -# assert len(dataloader) == 8 // dataloader.batch_size // dp_mesh.size() -# assert isinstance(dataloader.sampler, DistributedSampler) - -# def training_step(self, batch): -# batches = self.all_gather(batch) -# dp_mesh = self.device_mesh["data_parallel"] -# tp_mesh = self.device_mesh["tensor_parallel"] - -# # Batches across the TP dimension must be identical -# batches_tp = batches[tp_mesh.mesh] -# assert all(torch.equal(batches_tp[0], batches_tp[i]) for i in range(1, len(batches_tp))) -# # Batches across the DP dimension must be different -# batches_dp = batches[dp_mesh.mesh] -# assert all(not torch.equal(batches_dp[0], batches_dp[i]) for i in range(1, len(batches_dp))) - -# return super().training_step(batch) - -# strategy = ModelParallelStrategy( -# data_parallel_size=2, -# tensor_parallel_size=1, -# ) -# trainer = Trainer( -# accelerator="auto", -# devices=2, -# strategy=strategy, -# max_steps=2, -# enable_checkpointing=False, -# logger=False, -# ) - -# seed_everything(0) -# with trainer.init_module(empty_init=True): -# model = Model() - -# trainer.fit(model) - - -# @RunIf(min_torch="2.3", standalone=False, min_cuda_gpus=2) -# def test_fsdp2_notp(): -# #from torch.distributed._tensor import DTensor - -# class Model(FSDP2Model): -# def on_train_start(self): -# optimizer = self.optimizers() -# #assert all(isinstance(weight, DTensor) for weight in self.model.parameters()) -# #assert all(isinstance(tensor, DTensor) for tensor in optimizer.param_groups[0]["params"]) -# #assert self.model.w1.weight.device_mesh.ndim == 2 -# #assert self.model.w1.weight.device_mesh.size(0) == 2 -# #assert self.model.w1.weight.device_mesh.size(1) == 2 -# #assert all(weight.device.type != "meta" for weight in self.model.parameters()) -# #assert all(tensor.device_mesh.ndim == 2 for tensor in optimizer.param_groups[0]["params"]) -# #assert all(tensor.device.type != "meta" for tensor in optimizer.param_groups[0]["params"]) - -# # No data sharding across TP dimension, sharding across data-parallel dimension only -# device_mesh = self.device_mesh -# dp_mesh = device_mesh["data_parallel"] -# dataloader = self.trainer.train_dataloader -# #assert len(dataloader) == 8 // dataloader.batch_size // dp_mesh.size() -# #assert isinstance(dataloader.sampler, DistributedSampler) - -# def training_step(self, batch): -# batches = self.all_gather(batch) -# dp_mesh = self.device_mesh["data_parallel"] -# #tp_mesh = self.device_mesh["tensor_parallel"] - -# # Batches across the TP dimension must be identical -# #batches_tp = batches[tp_mesh.mesh] -# #assert all(torch.equal(batches_tp[0], batches_tp[i]) for i in range(1, len(batches_tp))) -# # Batches across the DP dimension must be different -# #batches_dp = batches[dp_mesh.mesh] -# #assert all(not torch.equal(batches_dp[0], batches_dp[i]) for i in range(1, len(batches_dp))) - -# return super().training_step(batch)