Skip to content

Commit

Permalink
Syntactic sugar for FSDPStrategy (#723)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #723

In addition to using torch.distributed enum classes, use standard
classes and strings, then convert them to allow simpler instantiation of
FSDPStrategy.

Reviewed By: galrotem

Differential Revision: D54568599

fbshipit-source-id: 57c27e6255320cab70501b380e030729d3f42e92
  • Loading branch information
schwarzmx authored and facebook-github-bot committed Mar 12, 2024
1 parent 97d4bbc commit 67159a4
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 8 deletions.
126 changes: 125 additions & 1 deletion tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,131 @@ def _test_fdsp_precision() -> None:
fsdp_module.mixed_precision.param_dtype, mixed_precision.param_dtype
)

# test strategy options
@skip_if_not_distributed
@skip_if_not_gpu
def test_fdsp_str_types(self) -> None:
spawn_multi_process(
2,
"nccl",
self._test_fdsp_precision_str_types,
)
spawn_multi_process(
2,
"nccl",
self._test_fdsp_backward_prefetch_str_types,
)
spawn_multi_process(
2,
"nccl",
self._test_fdsp_sharding_strategy_str_types,
)
spawn_multi_process(
2,
"nccl",
self._test_fdsp_state_dict_str_types,
)

@staticmethod
def _test_fdsp_precision_str_types() -> None:
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision

module = torch.nn.Linear(1, 1)
device = init_from_env()
mixed_precision = _MixedPrecision(
param_dtype="fp16",
reduce_dtype="bf16",
buffer_dtype="fp32",
)

fsdp_module = prepare_fsdp(
module, device, FSDPStrategy(mixed_precision=mixed_precision)
)
tc = unittest.TestCase()
tc.assertTrue(isinstance(fsdp_module, FSDP))

@staticmethod
def _test_fdsp_backward_prefetch_str_types() -> None:
module = torch.nn.Linear(1, 1)
device = init_from_env()

tc = unittest.TestCase()
for value in ["BACKWARD_PRE", "BACKWARD_POST"]:
fsdp_module = prepare_fsdp(
module, device, FSDPStrategy(backward_prefetch=value)
)
tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}")

@staticmethod
def _test_fdsp_sharding_strategy_str_types() -> None:
module = torch.nn.Linear(1, 1)
device = init_from_env()

tc = unittest.TestCase()
for value in [
"FULL_SHARD",
"SHARD_GRAD_OP",
"NO_SHARD",
# skip hybrid strategy; tricky to configure in-test
]:

fsdp_module = prepare_fsdp(
module,
device,
FSDPStrategy(sharding_strategy=value),
)
tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}")

@staticmethod
def _test_fdsp_state_dict_str_types() -> None:
module = torch.nn.Linear(1, 1)
device = init_from_env()

tc = unittest.TestCase()
for value in [
"FULL_STATE_DICT",
"LOCAL_STATE_DICT",
"SHARDED_STATE_DICT",
]:
fsdp_module = prepare_fsdp(
module, device, FSDPStrategy(state_dict_type=value)
)
tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}")

def test_invalid_fsdp_strategy_str_values(self) -> None:
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision

with self.assertRaisesRegex(ValueError, "Invalid BackwardPrefetch 'foo'"):
FSDPStrategy(backward_prefetch="foo")

with self.assertRaisesRegex(ValueError, "Invalid ShardingStrategy 'FOO'"):
FSDPStrategy(sharding_strategy="FOO")

with self.assertRaisesRegex(ValueError, "Invalid StateDictType 'FOO'"):
FSDPStrategy(state_dict_type="FOO")

with self.assertRaisesRegex(
ValueError,
"Invalid module class 'torch.nn.modules._BatchNorm': module 'torch.nn.modules' has no attribute '_BatchNorm'",
):
FSDPStrategy(
mixed_precision=_MixedPrecision(
_module_classes_to_ignore=[
# correct type is torch.nn.modules.batchnorm._BatchNorm
"torch.nn.modules._BatchNorm"
]
)
)
with self.assertRaisesRegex(
ValueError,
"Invalid module class 'foo.bar.Baz': No module named 'foo'",
):
FSDPStrategy(
mixed_precision=_MixedPrecision(
_module_classes_to_ignore=["foo.bar.Baz"]
)
)

# # test strategy options
def test_prepare_module_strategy_invalid_str(self) -> None:
"""
Test that an exception is raised with an invalid strategy string
Expand Down
126 changes: 126 additions & 0 deletions torchtnt/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib

from dataclasses import dataclass
from typing import List, Optional, Sequence, Type

import torch

from torch.distributed.fsdp import StateDictType as _StateDictType

from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch as _BackwardPrefetch,
MixedPrecision as _MixedPrecision,
ShardingStrategy as _ShardingStrategy,
)
from torchtnt.utils.precision import convert_precision_str_to_dtype


class ShardingStrategy:
"""Supported values for `ShardingStrategy <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy>`_"""

FULL_SHARD = "FULL_SHARD"
SHARD_GRAD_OP = "SHARD_GRAD_OP"
NO_SHARD = "NO_SHARD"
HYBRID_SHARD = "HYBRID_SHARD"
_HYBRID_SHARD_ZERO2 = "_HYBRID_SHARD_ZERO2"

@staticmethod
def to_native_sharding_strategy(value: str) -> _ShardingStrategy:
"""Convert a string to its PyTorch native ShardingStrategy."""
if value not in [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
]:
raise ValueError(f"Invalid ShardingStrategy '{value}'")

return _ShardingStrategy[value]


class BackwardPrefetch:
"""Supported values for `BackwardPrefetch <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch>`_"""

BACKWARD_PRE = "BACKWARD_PRE"
BACKWARD_POST = "BACKWARD_POST"

@staticmethod
def to_native_backward_prefetch(value: str) -> _BackwardPrefetch:
"""Convert a string to its PyTorch native BackwardPrefetch."""
if value not in [
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
]:
raise ValueError(f"Invalid BackwardPrefetch '{value}'")

return _BackwardPrefetch[value]


class StateDictType:
"""Supported values for `StateDictType <https://pytorch.org/docs/stable/fsdp.html>`_"""

FULL_STATE_DICT = "FULL_STATE_DICT"
LOCAL_STATE_DICT = "LOCAL_STATE_DICT"
SHARDED_STATE_DICT = "SHARDED_STATE_DICT"

@staticmethod
def to_native_state_dict_type(value: str) -> _StateDictType:
"""Convert a string to its PyTorch native StateDictType."""
if value not in [
StateDictType.FULL_STATE_DICT,
StateDictType.LOCAL_STATE_DICT,
StateDictType.SHARDED_STATE_DICT,
]:
raise ValueError(f"Invalid StateDictType '{value}'")

return _StateDictType[value]


def _to_dtype_or_none(x: Optional[str]) -> Optional[torch.dtype]:
return convert_precision_str_to_dtype(x) if x else None


@dataclass
class MixedPrecision:
"""Supported values for `MixedPrecision <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision>`_"""

param_dtype: Optional[str] = None
reduce_dtype: Optional[str] = None
buffer_dtype: Optional[str] = None
keep_low_precision_grads: bool = False
cast_forward_inputs: bool = False
cast_root_forward_inputs: bool = True
_module_classes_to_ignore: Sequence[str] = (
"torch.nn.modules.batchnorm._BatchNorm",
)

def to_native_mixed_precision(self) -> _MixedPrecision:
"""Convert this instance to its PyTorch native MixedPrecision."""

# Convert string module classes to their corresponding types
# e.g. "torch.nn.modules.batchnorm._BatchNorm" -> torch.nn.modules.batchnorm._BatchNorm
target_types: List[Type[torch.nn.Module]] = []
for type_str in self._module_classes_to_ignore:
path, _, attr = type_str.rpartition(".")
try:
target_types.append(getattr(importlib.import_module(path), attr))
except (AttributeError, ModuleNotFoundError) as e:
raise ValueError(f"Invalid module class '{type_str}': {e}")
module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = target_types

return _MixedPrecision(
param_dtype=_to_dtype_or_none(self.param_dtype),
reduce_dtype=_to_dtype_or_none(self.reduce_dtype),
buffer_dtype=_to_dtype_or_none(self.buffer_dtype),
keep_low_precision_grads=self.keep_low_precision_grads,
cast_forward_inputs=self.cast_forward_inputs,
cast_root_forward_inputs=self.cast_root_forward_inputs,
_module_classes_to_ignore=module_classes_to_ignore,
)
45 changes: 38 additions & 7 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,25 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType as _StateDictType,
)
from torch.distributed.fsdp.api import OptimStateDictConfig, StateDictConfig
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
BackwardPrefetch as _BackwardPrefetch,
CPUOffload,
MixedPrecision as _MixedPrecision,
ShardingStrategy as _ShardingStrategy,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtnt.utils.fsdp_utils import (
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.nn.parallel import DistributedDataParallel as DDP

from torchtnt.utils.rank_zero_log import rank_zero_warn
from torchtnt.utils.version import (
is_torch_version_geq_1_12,
Expand Down Expand Up @@ -91,11 +101,13 @@ class FSDPStrategy(Strategy):
"""Dataclass representing the `FullyShardedDataParallel <https://pytorch.org/docs/stable/fsdp.html>`_ strategy"""

process_group: Optional[ProcessGroup] = None
sharding_strategy: Optional[ShardingStrategy] = None
sharding_strategy: Optional[Union[str, _ShardingStrategy]] = None
cpu_offload: Optional[CPUOffload] = None
auto_wrap_policy: Optional[Callable[[torch.nn.Module, bool, int], bool]] = None
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE
mixed_precision: Optional[MixedPrecision] = None
backward_prefetch: Optional[Union[str, _BackwardPrefetch]] = (
_BackwardPrefetch.BACKWARD_PRE
)
mixed_precision: Optional[Union[_MixedPrecision, MixedPrecision]] = None
ignored_modules: Optional[Iterable[torch.nn.Module]] = None
param_init_fn: Optional[Callable[[torch.nn.Module], None]] = None
sync_module_states: bool = False
Expand All @@ -105,10 +117,29 @@ class FSDPStrategy(Strategy):

# FSDP set_state_dict_type params: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
# for setting type of state dict for checkpointing
state_dict_type: Optional[StateDictType] = None
state_dict_type: Optional[Union[str, _StateDictType]] = None
state_dict_config: Optional[StateDictConfig] = None
optim_state_dict_config: Optional[OptimStateDictConfig] = None

def __post_init__(self) -> None:
if isinstance(self.sharding_strategy, str):
self.sharding_strategy = ShardingStrategy.to_native_sharding_strategy(
self.sharding_strategy
)

if isinstance(self.backward_prefetch, str):
self.backward_prefetch = BackwardPrefetch.to_native_backward_prefetch(
self.backward_prefetch
)

if isinstance(self.state_dict_type, str):
self.state_dict_type = StateDictType.to_native_state_dict_type(
self.state_dict_type
)

if isinstance(self.mixed_precision, MixedPrecision):
self.mixed_precision = self.mixed_precision.to_native_mixed_precision()


@dataclass
class TorchCompileParams:
Expand Down

0 comments on commit 67159a4

Please sign in to comment.