Skip to content

Commit

Permalink
initial modelparallelstrategy support with PyTorch 20240827 nightly c…
Browse files Browse the repository at this point in the history
…ontaining required PR commits
  • Loading branch information
speediedan committed Aug 29, 2024
1 parent 970942a commit ea3afc4
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 1,797 deletions.
2 changes: 1 addition & 1 deletion src/finetuning_scheduler/fts_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[s
should_stop = True
reason = (
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
f" Best score: {self.best_score.item():.3f}. Signaling Trainer to stop."
)
else:
self._transition_es_phase()
Expand Down
3 changes: 2 additions & 1 deletion src/finetuning_scheduler/strategy_adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"""
from finetuning_scheduler.strategy_adapters.base import StrategyAdapter
from finetuning_scheduler.strategy_adapters.fsdp import FSDPStrategyAdapter
from finetuning_scheduler.strategy_adapters.model_parallel import ModelParallelStrategyAdapter

__all__ = ["StrategyAdapter", "FSDPStrategyAdapter"]
__all__ = ["StrategyAdapter", "FSDPStrategyAdapter", "ModelParallelStrategyAdapter"]
755 changes: 59 additions & 696 deletions src/finetuning_scheduler/strategy_adapters/model_parallel.py

Large diffs are not rendered by default.

19 changes: 13 additions & 6 deletions src/fts_examples/stable/patching/dep_patch_shim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import operator
import sys
import os
from enum import Enum
from typing import NamedTuple, Tuple, Callable
from fts_examples.stable.patching._patch_utils import lwt_compare_version

Expand Down Expand Up @@ -35,7 +36,7 @@ def _patch_einsum_strategies():

# In this case fortunately, we only import/call `gen_einsum_strategies` from
# `torch.distributed._tensor.ops.matrix_ops`, so only need to patch there.
target_mod = 'torch.distributed._tensor.ops.matrix_ops'
target_mod = 'torch.distributed._tensor.ops._matrix_ops'
sys.modules.get(target_mod).__dict__['gen_einsum_strategies'] = gen_einsum_strategies

def _patch_unsupported_numpy_arrow_extractor():
Expand All @@ -54,7 +55,7 @@ def _patch_triton():


einsum_strategies_patch = DependencyPatch(
condition=(lwt_compare_version("torch", operator.le, "2.4.1"),),
condition=(lwt_compare_version("torch", operator.le, "2.5.1"),),
env_flag=OSEnvToggle("ENABLE_FTS_EINSUM_STRATEGY_PATCH", default="0"),
function=_patch_einsum_strategies, patched_package='torch',
description='Address trivial tp submesh limitation until PyTorch provides upstream fix')
Expand All @@ -73,11 +74,17 @@ def _patch_triton():
function=_patch_triton, patched_package='pytorch-triton',
description='Address `triton` #3564 until PyTorch pins the upstream fix')

_DEFINED_PATCHES = {einsum_strategies_patch, datasets_numpy_extractor_patch, triton_codgen_patch}
class ExpPatch(Enum):
EINSUM_STRATEGIES = einsum_strategies_patch
NUMPY_EXTRACTOR = datasets_numpy_extractor_patch
TRITON_CODEGEN = triton_codgen_patch

#_DEFINED_PATCHES = {einsum_strategies_patch, datasets_numpy_extractor_patch, triton_codgen_patch}
_DEFINED_PATCHES = set(ExpPatch)
_ACTIVE_PATCHES = set()

for defined_patch in _DEFINED_PATCHES:
if all(defined_patch.condition) and os.environ.get(defined_patch.env_flag.env_var_name,
defined_patch.env_flag.default) == "1":
defined_patch.function()
if all(defined_patch.value.condition) and os.environ.get(defined_patch.value.env_flag.env_var_name,
defined_patch.value.env_flag.default) == "1":
defined_patch.value.function()
_ACTIVE_PATCHES.add(defined_patch)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fts_examples.stable.patching._patch_utils import _prepare_module_ctx

globals().update(_prepare_module_ctx('torch.distributed._tensor.ops.basic_strategy', globals()))
globals().update(_prepare_module_ctx('torch.distributed._tensor.ops._einsum_strategy', globals()))

# we ignore these for the entire file since we're using our global namespace trickeration to patch
# ruff: noqa: F821
Expand Down
83 changes: 1 addition & 82 deletions tests/helpers/boring_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,36 +137,17 @@ def step(self, batch: Tensor) -> Tensor:
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT:
return {"loss": self.step(batch)}

# def training_step(self, batch, batch_idx):
# output = self(batch)
# loss = self.loss(batch, output)
# return {"loss": loss}

def training_step_end(self, training_step_output: STEP_OUTPUT) -> STEP_OUTPUT:
return training_step_output

def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
return {"x": self.step(batch)}

# def validation_step(self, batch, batch_idx):
# output = self(batch)
# loss = self.loss(batch, output)
# return {"x": loss}

# def on_validation_epoch_end(self, outputs) -> None:
# torch.stack([x["x"] for x in outputs]).mean()

# def test_step(self, batch, batch_idx):
# output = self(batch)
# loss = self.loss(batch, output)
# return {"y": loss}

def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
return {"y": self.step(batch)}

# def test_epoch_end(self, outputs) -> None:
# torch.stack([x["y"] for x in outputs]).mean()

def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]:
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
Expand Down Expand Up @@ -224,15 +205,6 @@ def __init__(self):
super().__init__()
self.automatic_optimization = False

# def training_step(self, batch, batch_idx):
# opt = self.optimizers()
# output = self(batch)
# loss = self.loss(batch, output)
# opt.zero_grad()
# self.manual_backward(loss)
# opt.step()
# return loss

def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT:
opt = self.optimizers()
assert isinstance(opt, (Optimizer, LightningOptimizer))
Expand All @@ -249,43 +221,6 @@ class FTSWikiText2(WikiText2):
def __init__(self, data_dir: Path = Path(_PATH_DATASETS), block_size: int = 32, *args, **kwargs) -> None:
super().__init__(data_dir=data_dir, block_size=block_size, *args, **kwargs)

# @property
# def vocab_size(self) -> int:
# return len(self.dictionary)

# def __len__(self) -> int:
# return len(self.data) // self.block_size - 1

# def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
# start = index * self.block_size
# end = start + self.block_size
# inputs = self.data[start:end]
# target = self.data[(start + 1) : (end + 1)]
# return inputs, target

# @staticmethod
# def download(destination: Path) -> None:
# if not _REQUESTS_AVAILABLE:
# raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE))

# import requests

# os.makedirs(destination.parent, exist_ok=True)
# url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt"
# if os.path.exists(destination):
# return
# with open(destination, "w") as f:
# f.write(requests.get(url).text)


# class SampledOutput(NamedTuple):
# """Sampled Output Named Tuple.

# Named tuple object for if we want to output both logits and tokens.
# """

# tokens: Union[torch.Tensor, str]
# logits: torch.Tensor

################################################################################
# Toy Configurable Transformer (non-TransformerLens)
Expand All @@ -295,33 +230,17 @@ def __init__(self, data_dir: Path = Path(_PATH_DATASETS), block_size: int = 32,
################################################################################





@dataclass
class TestModelArgs:
n_layers: int = 2 # 2
vocab_size: int = 33278 # 33
max_seq_len: int = 200 # 10
dim: int = 200 # 10
n_heads: int = 2
dropout_p: float = 0.2 # 0.1
dropout_p: float = 0.0 # 0.2 # 0.1
use_attn_mask: bool = True
weight_tying: bool = False # True
checkpoint_activations: bool = False
#tokenizer: Optional[Callable] = None
#device: Optional[torch.device] = None
#dtype: Optional[torch.dtype] = None
# handle below can be used at runtime to allow this model's `generate` to adapt to various configuration contexts
#ctx_handle: Optional[torch.nn.Module] = None

# def __post_init__(self):
# if self.ctx_handle:
# # snag potentially useful context references and then delete the handle
# #self.tokenizer = self.tokenizer or self.ctx_handle.it_cfg.tokenizer
# self.device = self.device or self.ctx_handle.device
# self.dtype = self.dtype or self.ctx_handle.torch_dtype
# del self.ctx_handle


class Attention(torch.nn.Module):
Expand Down
14 changes: 12 additions & 2 deletions tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,25 @@
import os
import re
import sys
from typing import Optional
from typing import Optional, Set

import pytest
import torch
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from packaging.version import Version
from pkg_resources import get_distribution
from fts_examples.stable.patching.dep_patch_shim import ExpPatch, _ACTIVE_PATCHES

EXTENDED_VER_PAT = re.compile(r"([0-9]+\.){2}[0-9]+")

# RunIf aliases
RUNIF_MAP = {
"min2_4": {"min_torch": "2.4.0"},
"min2_5": {"min_torch": "2.5.0"},
"min2_2": {"min_torch": "2.2.0"},
"max3_12_min2_3": {"max_python": "3.12", "min_torch": "2.3.0"},
"max3_12_min2_2": {"max_python": "3.12", "min_torch": "2.2.0"},
"einsum_exp": {"exp_patch": {ExpPatch.EINSUM_STRATEGIES}, "min_torch": "2.5.0"},
}


Expand Down Expand Up @@ -58,6 +60,7 @@ def __new__(
standalone: bool = False,
deepspeed: bool = False,
slow: bool = False,
exp_patch: Optional[ExpPatch|Set[ExpPatch]] = None,
**kwargs,
):
"""
Expand All @@ -74,6 +77,7 @@ def __new__(
standalone: Mark the test as standalone, our CI will run it in a separate process.
This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set.
deepspeed: Require that microsoft/DeepSpeed is installed.
exp_patch: Require that a given experimental patch is installed.
slow: Mark the test as slow, our CI will run it in a separate job.
This requires that the ``PL_RUN_SLOW_TESTS=1`` environment variable is set.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
Expand Down Expand Up @@ -143,6 +147,12 @@ def __new__(
conditions.append(not _DEEPSPEED_AVAILABLE)
reasons.append("Deepspeed")

if exp_patch:
if not isinstance(exp_patch, Set):
exp_patch = {exp_patch}
conditions.append(not exp_patch.issubset(_ACTIVE_PATCHES))
reasons.append(f"Required experimental patch configuration {exp_patch} is not active.")

if slow:
env_flag = os.getenv("PL_RUN_SLOW_TESTS", "0")
conditions.append(env_flag != "1")
Expand Down
Loading

0 comments on commit ea3afc4

Please sign in to comment.