Skip to content

Commit

Permalink
feat(runners): add interfaces to map and wrap models in strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Jan 27, 2024
1 parent 5fb2ade commit f453075
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 46 deletions.
39 changes: 32 additions & 7 deletions todd/runners/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
'BaseStrategy',
]

from typing import Any, Mapping
from typing import Any, Generic, Mapping, TypeVar, cast

import torch
from torch import nn
Expand All @@ -11,28 +11,53 @@
from ...utils import StateDictMixin, get_rank
from ..utils import RunnerHolderMixin

T = TypeVar('T', bound=nn.Module)


@StrategyRegistry.register_()
class BaseStrategy(RunnerHolderMixin, StateDictMixin):
_model: nn.Module
class BaseStrategy(RunnerHolderMixin, StateDictMixin, Generic[T]):

def __init__(
self,
*args,
model: Config,
map_model: Config | None = None,
wrap_model: Config | None = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._build_model(model)
self._build_model(model, map_model, wrap_model)

def _build_model(
self,
config: Config,
map_config: Config | None,
wrap_config: Config | None,
) -> None:
model = ModelRegistry.build(config)
model = self.map_model(model, map_config)
model = self.wrap_model(model, wrap_config)
self._model = model

def map_model(
self,
model: nn.Module,
config: Config | None = None,
) -> nn.Module:
if config is None:
config = Config()
return model

def _build_model(self, config: Config) -> None:
self._model = ModelRegistry.build(config)
def wrap_model(self, model: nn.Module, config: Config | None = None) -> T:
if config is None:
config = Config()
return cast(T, model)

def build_optimizer(self, config: Config) -> torch.optim.Optimizer:
return OptimizerRegistry.build(config, model=self.module)

@property
def model(self) -> nn.Module:
def model(self) -> T:
return self._model

@property
Expand Down
17 changes: 13 additions & 4 deletions todd/runners/strategies/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
'CUDAStrategy',
]

from typing import TypeVar

import torch
import torch.distributed
from torch import nn

from ...base import Config, Store, StrategyRegistry
from ...utils import get_local_rank
from .base import BaseStrategy

T = TypeVar('T', bound=nn.Module)


@StrategyRegistry.register_()
class CUDAStrategy(BaseStrategy):
class CUDAStrategy(BaseStrategy[T]):

def __init__(
self,
Expand All @@ -34,6 +39,10 @@ def _setup(self, config: Config) -> None:
torch.distributed.init_process_group(**init_process_group)
torch.cuda.set_device(get_local_rank() % torch.cuda.device_count())

def _build_model(self, config: Config) -> None:
super()._build_model(config)
self._model = self._model.cuda()
def map_model(
self,
model: nn.Module,
config: Config | None = None,
) -> nn.Module:
model = super().map_model(model, config)
return model.cuda()
34 changes: 11 additions & 23 deletions todd/runners/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,27 @@
'DDPStrategy',
]

from typing import TYPE_CHECKING
from typing import TypeVar, cast

from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

from ...base import Config, StrategyRegistry
from .cuda import CUDAStrategy

T = TypeVar('T', bound=DDP)


@StrategyRegistry.register_()
class DDPStrategy(CUDAStrategy):
_model: DDP

def __init__(
self,
*args,
wrap_model: Config | None = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
if wrap_model is None:
wrap_model = Config()
self._wrap_model(wrap_model)

def _wrap_model(self, config: Config) -> None:
self._model = DDP(self._model, **config)
class DDPStrategy(CUDAStrategy[T]):

def wrap_model(self, model: nn.Module, config: Config | None = None) -> T:
if config is None:
config = Config()
model = super().wrap_model(model, config)
model = DDP(model, **config)
return cast(T, model)

@property
def module(self) -> nn.Module:
return self._model.module

if TYPE_CHECKING:

@property
def model(self) -> DDP:
...
28 changes: 16 additions & 12 deletions todd/runners/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,33 @@
'FSDPStrategy',
]

from typing import TYPE_CHECKING, Any, Mapping
from typing import Any, Mapping, TypeVar, cast

import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from ...base import Config, OptimizerRegistry, StrategyRegistry
from .ddp import DDPStrategy
from .cuda import CUDAStrategy

# TODO: update when pytorch updates

T = TypeVar('T', bound=FSDP)


@StrategyRegistry.register_()
class FSDPStrategy(DDPStrategy):
_model: FSDP # type: ignore[assignment]
class FSDPStrategy(CUDAStrategy[T]):

def wrap_model(self, model: nn.Module, config: Config | None = None) -> T:
if config is None:
config = Config()
model = super().wrap_model(model, config)
model = FSDP(model, **config)
return cast(T, model)

def _wrap_model(self, config: Config) -> None:
self._model = FSDP(self._model, **config)
@property
def module(self) -> nn.Module:
return self._model.module

def build_optimizer(self, config: Config) -> torch.optim.Optimizer:
return OptimizerRegistry.build(config, model=self._model)
Expand Down Expand Up @@ -53,9 +63,3 @@ def load_optim_state_dict(
self._model,
)
self.trainer.optimizer.load_state_dict(sharded_state_dict)

if TYPE_CHECKING:

@property
def model(self) -> FSDP: # type: ignore[override]
...

0 comments on commit f453075

Please sign in to comment.