Skip to content

Commit

Permalink
version guard mp imports, encapsulating the import logic
Browse files Browse the repository at this point in the history
  • Loading branch information
speediedan committed Oct 9, 2024
1 parent 3593ae9 commit 60a9942
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 213 deletions.
28 changes: 28 additions & 0 deletions src/finetuning_scheduler/strategy_adapters/_mp_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# TODO: replace local version once Lightning version available
# from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5
from lightning_utilities.core.imports import compare_version
import operator
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0", use_base_version=True)

# ruff: noqa: F401
# we require torch 2.5 or higher for composable distributed API support so until torch 2.5.0 is the minimum version,
# supported, we conditionally import indirectly to avoid duplicating import logic in several different modules
if _TORCH_GREATER_EQUAL_2_5:
from torch.distributed._composable import checkpoint
from torch.distributed._composable.fsdp._fsdp_api import CPUOffloadPolicy
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.tensor.experimental import implicit_replication
from torch.distributed._composable.fsdp import FSDPModule, fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, offload_wrapper,
ActivationWrapper)
from torch.distributed.tensor.parallel import (ColwiseParallel, PrepareModuleInput, RowwiseParallel,
SequenceParallel, parallelize_module, loss_parallel)
else:
for mp_obj in ["SDPBackend", "DeviceMesh", "DTensor", "Replicate", "Shard", "ColwiseParallel", "PrepareModuleInput",
"RowwiseParallel", "SequenceParallel", "implicit_replication", "parallelize_module", "loss_parallel",
"FSDPModule", "fully_shard", "checkpoint", "checkpoint_wrapper", "offload_wrapper", "ActivationWrapper",
"CPUOffloadPolicy", "sdpa_kernel", "FSDPMemTracker"]:
globals()[mp_obj] = None
17 changes: 4 additions & 13 deletions src/finetuning_scheduler/strategy_adapters/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,21 @@
import re
import os
from pprint import pformat
# TODO: replace local version once Lightning version available
# from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5
import operator
from dataclasses import dataclass, field

import torch
from torch.distributed.tensor import DTensor
from torch.distributed._composable import checkpoint
from torch.distributed._composable.fsdp.fully_shard import fully_shard, FSDPModule
from torch.distributed._composable.fsdp._fsdp_api import CPUOffloadPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, offload_wrapper,
ActivationWrapper)
from lightning.fabric.utilities.enums import LightningEnum
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning_utilities.core.imports import compare_version
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.fabric.utilities import rank_zero_warn, rank_zero_info
from lightning.pytorch.utilities.rank_zero import rank_zero_debug

from finetuning_scheduler.strategy_adapters.base import StrategyAdapter
from finetuning_scheduler.strategy_adapters._wrap_utils import _compose_ncac

_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0", use_base_version=True)

# conditionally import indirectly to avoid duplicating import logic in several different modules
from finetuning_scheduler.strategy_adapters._mp_imports import (_TORCH_GREATER_EQUAL_2_5, DTensor, FSDPModule,
fully_shard, CPUOffloadPolicy, checkpoint,
checkpoint_wrapper, ActivationWrapper, offload_wrapper)

class ActCkptEnum(LightningEnum):
COMPOSABLE = "composable"
Expand Down
2 changes: 1 addition & 1 deletion src/fts_examples/model_parallel/mp_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from fts_examples.cli_experiment_utils import ExpHarness, FTSExperimentCLI, ExperimentCfg
from fts_examples.model_parallel.torchtitan_llama import ModelCfg, Transformer
from finetuning_scheduler.strategy_adapters.model_parallel import _TORCH_GREATER_EQUAL_2_5
from finetuning_scheduler.strategy_adapters._mp_imports import _TORCH_GREATER_EQUAL_2_5

# Lightning ModelParallel still uses `torch.load` with `weights_only=False`
warnings.filterwarnings("ignore", ".*uses the default pickle.*")
Expand Down
3 changes: 2 additions & 1 deletion src/fts_examples/patching/dep_patch_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def _patch_triton():


einsum_strategies_patch = DependencyPatch(
condition=(lwt_compare_version("torch", operator.le, "2.5.1"),),
condition=(lwt_compare_version("torch", operator.le, "2.5.1"),
lwt_compare_version("torch", operator.ge, "2.5.0"),),
env_flag=OSEnvToggle("ENABLE_FTS_EINSUM_STRATEGY_PATCH", default="0"),
function=_patch_einsum_strategies, patched_package='torch',
description='Address trivial tp submesh limitation until PyTorch provides upstream fix')
Expand Down
186 changes: 94 additions & 92 deletions src/fts_examples/patching/patched_einsum_strategies.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,105 @@
from fts_examples.patching._patch_utils import _prepare_module_ctx

globals().update(_prepare_module_ctx('torch.distributed.tensor._ops._einsum_strategy', globals()))
import operator
from fts_examples.patching._patch_utils import _prepare_module_ctx, lwt_compare_version

# we ignore these for the entire file since we're using our global namespace trickeration to patch
# ruff: noqa: F821
# pyright: reportUndefinedVariable=false

if lwt_compare_version("torch", operator.ge, "2.5.0"):
globals().update(_prepare_module_ctx('torch.distributed.tensor._ops._einsum_strategy', globals()))

def gen_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> OpStrategy:
"""Generate a strategy list for the ops that follow einsum style notation."""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)

all_mesh_dim_strategies = []

# generate strategies for each mesh dim
for mesh_dim in range(mesh.ndim):
mesh_dim_strategies = []

# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
mesh_dim_strategies.append(placement_list)

# if mesh.size(mesh_dim) <= 1:
# # only replicate strategy for mesh dim with size 1
# # TODO: see if this is valid for the submesh case
# continue

# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))

mesh_dim_strategies.append(placement_list)
def gen_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> OpStrategy:
"""Generate a strategy list for the ops that follow einsum style notation."""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)

all_mesh_dim_strategies = []

# split contracting dim
for contracting_dim in edims.contracting_dims:
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
# generate strategies for each mesh dim
for mesh_dim in range(mesh.ndim):
mesh_dim_strategies = []

# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
mesh_dim_strategies.append(placement_list)

# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim = output_dim.index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: List[Placement] = [
Shard(lhs_free_dim),
Shard(lhs_free_dim),
Replicate(),
]
mesh_dim_strategies.append(lhs_placement_list)

# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim = output_dim.index(rhs_dim)
rhs_placement_list: List[Placement] = [
Shard(rhs_free_dim),
Replicate(),
Shard(rhs_free_dim),
]
mesh_dim_strategies.append(rhs_placement_list)

# linearity strategy
if linearity:
linearity_placement_list: List[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
mesh_dim_strategies.append(linearity_placement_list)

all_mesh_dim_strategies.append(mesh_dim_strategies)

# generate strategies for entire mesh
strategy_combs = itertools.product(*all_mesh_dim_strategies)

# TODO: filter out invalid strategies, at this point we generate
# all possible strategies without considering the whether the tensor
# dim could be sharded or not, we would need to filter out invalid
# strategies base on the actual tensor shape
# (i.e. for Shard, tensor dim size must > mesh size)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = []
for specs in zip(*strategy_comb):
spec_list.append(DTensorSpec(mesh, tuple(specs)))
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
all_strategies.append(strat)

return OpStrategy(all_strategies)
# if mesh.size(mesh_dim) <= 1:
# # only replicate strategy for mesh dim with size 1
# # TODO: see if this is valid for the submesh case
# continue

# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))

mesh_dim_strategies.append(placement_list)

# split contracting dim
for contracting_dim in edims.contracting_dims:
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))

mesh_dim_strategies.append(placement_list)

# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim = output_dim.index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: List[Placement] = [
Shard(lhs_free_dim),
Shard(lhs_free_dim),
Replicate(),
]
mesh_dim_strategies.append(lhs_placement_list)

# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim = output_dim.index(rhs_dim)
rhs_placement_list: List[Placement] = [
Shard(rhs_free_dim),
Replicate(),
Shard(rhs_free_dim),
]
mesh_dim_strategies.append(rhs_placement_list)

# linearity strategy
if linearity:
linearity_placement_list: List[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
mesh_dim_strategies.append(linearity_placement_list)

all_mesh_dim_strategies.append(mesh_dim_strategies)

# generate strategies for entire mesh
strategy_combs = itertools.product(*all_mesh_dim_strategies)

# TODO: filter out invalid strategies, at this point we generate
# all possible strategies without considering the whether the tensor
# dim could be sharded or not, we would need to filter out invalid
# strategies base on the actual tensor shape
# (i.e. for Shard, tensor dim size must > mesh size)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = []
for specs in zip(*strategy_comb):
spec_list.append(DTensorSpec(mesh, tuple(specs)))
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
all_strategies.append(strat)

return OpStrategy(all_strategies)
10 changes: 2 additions & 8 deletions src/fts_examples/profiling/memprofiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,11 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.utilities.exceptions import MisconfigurationException

from finetuning_scheduler.strategy_adapters.model_parallel import _TORCH_GREATER_EQUAL_2_5
# conditionally import indirectly to avoid duplicating import logic in several different modules
from finetuning_scheduler.strategy_adapters._mp_imports import _TORCH_GREATER_EQUAL_2_5, FSDPModule, FSDPMemTracker
from fts_examples.cfg_utils import resolve_funcs
from finetuning_scheduler.types import AutoStrEnum

if _TORCH_GREATER_EQUAL_2_5:
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed._composable.fsdp import FSDPModule
else:
FSDPMemTracker = None
FSDPModule = None


class DefaultMemHooks(AutoStrEnum):
pre_forward = 'fts_examples.profiling.npp_hooks._hook_npp_pre_forward'
Expand Down
Loading

0 comments on commit 60a9942

Please sign in to comment.