Skip to content

Commit

Permalink
feat: define holder mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Mar 14, 2024
1 parent 0a7955a commit 3be34ba
Show file tree
Hide file tree
Showing 18 changed files with 833 additions and 568 deletions.
18 changes: 9 additions & 9 deletions todd/runners/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,27 @@ def __init__(

def init(self, *args, **kwargs) -> None:
super().init(*args, **kwargs)
self._checkpoint_dir = self._runner.work_dir / 'checkpoints'
self._checkpoint_dir = self.runner.work_dir / 'checkpoints'
self._latest_checkpoint_dir = self._checkpoint_dir / 'latest'

self._checkpoint_dir.mkdir(parents=True, exist_ok=True)

if self._runner._auto_resume and self._latest_checkpoint_dir.exists():
if self.runner._auto_resume and self._latest_checkpoint_dir.exists():
load_from = self._latest_checkpoint_dir
elif self._runner.load_from is not None:
load_from = pathlib.Path(self._runner.load_from)
elif self.runner.load_from is not None:
load_from = pathlib.Path(self.runner.load_from)
assert load_from.exists()
else:
load_from = None

if load_from is not None:
if get_rank() == 0:
self._runner.logger.info("Loading from %s", load_from)
self.runner.logger.info("Loading from %s", load_from)
state_dict = {
f.stem: torch.load(f, 'cpu')
for f in load_from.glob('*.pth')
}
self._runner.load_state_dict(state_dict, **self._load_state_dict)
self.runner.load_state_dict(state_dict, **self._load_state_dict)

@property
def checkpoint_dir(self) -> pathlib.Path:
Expand All @@ -71,13 +71,13 @@ def _work_dir(self, name: str) -> pathlib.Path:

def _save(self, name: str) -> None:
# for FSDP, all ranks should call state dict
state_dict = self._runner.state_dict(**self._state_dict)
state_dict = self.runner.state_dict(**self._state_dict)

if get_rank() != 0:
return
work_dir = self._work_dir(name)
work_dir.mkdir(parents=True, exist_ok=True)
self._runner.logger.info("Saving state dict to %s", work_dir)
self.runner.logger.info("Saving state dict to %s", work_dir)
for k, v in state_dict.items():
torch.save(v, work_dir / f'{k}.pth')

Expand All @@ -88,7 +88,7 @@ def _save(self, name: str) -> None:
def after_run_iter(self, batch, memo: Memo) -> None:
super().after_run_iter(batch, memo)
if self._should_run_iter():
self._save(f'iter_{self._runner.iter_}')
self._save(f'iter_{self.runner.iter_}')

def after_run_epoch(self, epoch_memo: Memo, memo: Memo) -> None:
super().after_run_epoch(epoch_memo, memo)
Expand Down
2 changes: 1 addition & 1 deletion todd/runners/callbacks/composed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, *args, callbacks: Iterable[Config], **kwargs) -> None:
super().__init__(*args, **kwargs)
priorities = [c.pop('priority', dict()) for c in callbacks]
queue = [
CallbackRegistry.build(c, runner=self._runner) for c in callbacks
CallbackRegistry.build(c, runner=self.runner) for c in callbacks
]
self._priority_queue: PriorityQueue[KT, BaseCallback] = \
PriorityQueue(priorities, queue)
Expand Down
6 changes: 3 additions & 3 deletions todd/runners/callbacks/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def init(self, *args, **kwargs) -> None:
diff = subprocess_run(args_)
except subprocess.CalledProcessError as e:
diff = str(e)
self._runner.logger.error(e)
self.runner.logger.error(e)
else:
file = (
self._runner.work_dir / f'git_diff_{get_timestamp()}.log'
self.runner.work_dir / f'git_diff_{get_timestamp()}.log'
)
self._runner.logger.info('Saving git diff to %s', file)
self.runner.logger.info('Saving git diff to %s', file)
file.write_text(diff)
2 changes: 1 addition & 1 deletion todd/runners/callbacks/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __should_run(self, step: int) -> bool:
return self._interval > 0 and step % self._interval == 0

def _should_run_iter(self) -> bool:
return not self._by_epoch and self.__should_run(self._runner.iter_)
return not self._by_epoch and self.__should_run(self.runner.iter_)

def _should_run_epoch(self) -> bool:
return (
Expand Down
16 changes: 8 additions & 8 deletions todd/runners/callbacks/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,24 @@ def init(self, *args, **kwargs) -> None:
if get_rank() > 0:
return
if self._with_file_handler:
file = self._runner.work_dir / f'{get_timestamp()}.log'
file = self.runner.work_dir / f'{get_timestamp()}.log'
handler = logging.FileHandler(file)
handler.setFormatter(Formatter())
self._runner.logger.addHandler(handler)
self.runner.logger.addHandler(handler)
if self._collect_env is not None:
from ...base import ( # noqa: E501 pylint: disable=import-outside-toplevel
collect_env,
)
env = collect_env(**self._collect_env)
self._runner.logger.info(env)
self.runner.logger.info(env)

def before_run(self, memo: Memo) -> None:
super().before_run(memo)
self._eta: BaseETA | None = (
None if self._eta_config is None else ETARegistry.build(
self._eta_config,
start=self._runner.iter_ - 1,
end=self._runner.iters,
start=self.runner.iter_ - 1,
end=self.runner.iters,
)
)

Expand All @@ -73,10 +73,10 @@ def after_run_iter(self, batch, memo: Memo) -> None:
super().after_run_iter(batch, memo)
if 'log' not in memo:
return
prefix = f"Iter [{self._runner.iter_}/{self._runner.iters}] "
prefix = f"Iter [{self.runner.iter_}/{self.runner.iters}] "

if self._eta is not None:
eta = self._eta(self._runner.iter_)
eta = self._eta(self.runner.iter_)
eta = round(eta)
prefix += f"ETA {str(datetime.timedelta(seconds=eta))} "

Expand All @@ -90,7 +90,7 @@ def after_run_iter(self, batch, memo: Memo) -> None:

log: dict[str, Any] = memo.pop('log')
message = ' '.join(f'{k}={v}' for k, v in log.items() if v is not None)
self._runner.logger.info(prefix + message)
self.runner.logger.info(prefix + message)

def before_run_epoch(self, epoch_memo: Memo, memo: Memo) -> None:
super().before_run_epoch(epoch_memo, memo)
Expand Down
8 changes: 4 additions & 4 deletions todd/runners/callbacks/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def __init__(
**kwargs,
) -> None:
super().__init__(*args, interval=interval, **kwargs)
assert isinstance(self._runner, Trainer)
assert isinstance(self.runner, Trainer)
self._lr_scheduler_config = lr_scheduler

def init(self, *args, **kwargs) -> None:
super().init(*args, **kwargs)
self._build_lr_scheduler()

def _build_lr_scheduler(self) -> None:
runner = cast(Trainer, self._runner)
runner = cast(Trainer, self.runner)
self._lr_scheduler: torch.optim.lr_scheduler.LRScheduler = \
LRSchedulerRegistry.build(
self._lr_scheduler_config,
Expand Down Expand Up @@ -75,11 +75,11 @@ class LRScaleCallback(BaseCallback):

def __init__(self, *args, lr_scaler: Config, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert isinstance(self._runner, Trainer)
assert isinstance(self.runner, Trainer)
self._lr_scaler_config = lr_scaler

def _scale_lr(self, config: Config) -> None:
runner = cast(Trainer, self._runner)
runner = cast(Trainer, self.runner)
assert runner.dataloader.batch_size is not None
base_batch_size = config.base_batch_size
batch_size = get_world_size() * runner.dataloader.batch_size
Expand Down
4 changes: 2 additions & 2 deletions todd/runners/callbacks/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def run_iter_context(
) -> None:
super().run_iter_context(exit_stack, batch, memo)
context = Context(
self._runner.logger,
iter_=self._runner.iter_,
self.runner.logger,
iter_=self.runner.iter_,
batch=batch,
memo=memo,
)
Expand Down
2 changes: 1 addition & 1 deletion todd/runners/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def init(self, *args, **kwargs) -> None:
super().init(*args, **kwargs)
if get_rank() > 0:
return
log_dir = self._runner.work_dir / 'tensorboard'
log_dir = self.runner.work_dir / 'tensorboard'
self._summary_writer = SummaryWriter(
log_dir,
**self._summary_writer_config,
Expand Down
2 changes: 1 addition & 1 deletion todd/runners/epoch_based_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@RunnerRegistry.register_()
class EpochBasedTrainer(Trainer):
class EpochBasedTrainer(Trainer[T]):

def __init__(self, *args, epochs: int, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion todd/runners/iter_based_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@RunnerRegistry.register_()
class IterBasedTrainer(Trainer):
class IterBasedTrainer(Trainer[T]):

def __init__(self, *args, iters: int, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down
10 changes: 5 additions & 5 deletions todd/runners/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
'BaseStrategy',
]

from typing import Any, Generic, Mapping, TypeVar, cast
from typing import Any, Mapping, TypeVar, cast

import torch
from torch import nn
Expand All @@ -15,7 +15,7 @@


@StrategyRegistry.register_()
class BaseStrategy(RunnerHolderMixin, StateDictMixin, Generic[T]):
class BaseStrategy(RunnerHolderMixin[T], StateDictMixin):

def __init__(
self,
Expand Down Expand Up @@ -45,7 +45,7 @@ def build_optimizer(self, config: Config) -> torch.optim.Optimizer:

@property
def module(self) -> nn.Module:
return self._runner.model
return self.runner.model

def model_state_dict(self, *args, **kwargs) -> dict[str, Any]:
return self.module.state_dict(*args, **kwargs)
Expand All @@ -62,7 +62,7 @@ def load_model_state_dict(
**kwargs,
)
if get_rank() == 0:
self._runner.logger.info(incompatible_keys)
self.runner.logger.info(incompatible_keys)

def load_model_from(
self,
Expand All @@ -77,7 +77,7 @@ def load_model_from(
model_state_dict = dict()
for f_ in f_list:
if get_rank() == 0:
self._runner.logger.info("Loading model from %s", f_)
self.runner.logger.info("Loading model from %s", f_)
model_state_dict.update(torch.load(f_, 'cpu'))
self.load_model_state_dict(model_state_dict, *args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion todd/runners/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def wrap_model(self, model: nn.Module, config: Config) -> T:

@property
def module(self) -> nn.Module:
return self._runner.model.module
return self.runner.model.module
8 changes: 4 additions & 4 deletions todd/runners/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def wrap_model(self, model: nn.Module, config: Config) -> T:

@property
def module(self) -> nn.Module:
return self._runner.model.module
return self.runner.model.module

def build_optimizer(self, config: Config) -> torch.optim.Optimizer:
return OptimizerRegistry.build(config, model=self._runner.model)
return OptimizerRegistry.build(config, model=self.runner.model)

def model_state_dict(self, *args, **kwargs) -> dict[str, Any]:
return self._runner.model.state_dict(*args, **kwargs)
return self.runner.model.state_dict(*args, **kwargs)

def load_model_state_dict(
self,
state_dict: Mapping[str, Any],
*args,
**kwargs,
) -> None:
self._runner.model.load_state_dict(state_dict, *args, **kwargs)
self.runner.model.load_state_dict(state_dict, *args, **kwargs)

def optim_state_dict(
self,
Expand Down
47 changes: 25 additions & 22 deletions todd/runners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,45 @@
'RunnerHolderMixin',
]

import weakref
from typing import cast
from typing import TypeVar

from torch import nn

from ..utils import HolderMixin
from .base import BaseRunner
from .epoch_based_trainer import EpochBasedTrainer
from .iter_based_trainer import IterBasedTrainer
from .trainer import Trainer
from .validator import Validator

T = TypeVar('T', bound=nn.Module)


class RunnerHolderMixin:
class RunnerHolderMixin(HolderMixin[BaseRunner[T]]):

def __init__(self, *args, runner: BaseRunner, **kwargs) -> None:
super().__init__(*args, **kwargs)
runner_proxy = (
runner if isinstance(runner, weakref.ProxyTypes) else
weakref.proxy(runner)
)
self._runner = cast(BaseRunner, runner_proxy)
def __init__(self, *args, runner: BaseRunner[T], **kwargs) -> None:
super().__init__(*args, instance=runner, **kwargs)

@property
def runner(self) -> BaseRunner[T]:
return self._instance

@property
def trainer(self) -> Trainer:
assert isinstance(self._runner, Trainer)
return self._runner
def trainer(self) -> Trainer[T]:
assert isinstance(self._instance, Trainer)
return self._instance

@property
def validator(self) -> Validator:
assert isinstance(self._runner, Validator)
return self._runner
def validator(self) -> Validator[T]:
assert isinstance(self._instance, Validator)
return self._instance

@property
def iter_based_trainer(self) -> IterBasedTrainer:
assert isinstance(self._runner, IterBasedTrainer)
return self._runner
def iter_based_trainer(self) -> IterBasedTrainer[T]:
assert isinstance(self._instance, IterBasedTrainer)
return self._instance

@property
def epoch_based_trainer(self) -> EpochBasedTrainer:
assert isinstance(self._runner, EpochBasedTrainer)
return self._runner
def epoch_based_trainer(self) -> EpochBasedTrainer[T]:
assert isinstance(self._instance, EpochBasedTrainer)
return self._instance
1 change: 1 addition & 0 deletions todd/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .constants import *
from .enums import *
from .generic_tensors import *
from .metas import *
Expand Down
11 changes: 11 additions & 0 deletions todd/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
__all__ = [
'IMAGENET_MEAN',
'IMAGENET_STD',
'IMAGENET_MEAN_255',
'IMAGENET_STD_255',
]

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMAGENET_MEAN_255 = tuple(x * 255 for x in IMAGENET_MEAN)
IMAGENET_STD_255 = tuple(x * 255 for x in IMAGENET_STD)
Loading

0 comments on commit 3be34ba

Please sign in to comment.