Skip to content

Commit

Permalink
feat: EMA and all_sync
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Feb 10, 2024
1 parent 419d007 commit b7b3bf5
Show file tree
Hide file tree
Showing 8 changed files with 555 additions and 557 deletions.
23 changes: 12 additions & 11 deletions todd/base/eta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from abc import ABC, abstractmethod
from typing import NamedTuple

from ..utils import ExponentialMovingAverage as EMA
from .configs import Config
from .registries import ETARegistry


Expand All @@ -28,19 +30,19 @@ def _datum(self, x: int) -> Datum:
return Datum(x, t)

@abstractmethod
def _pace(self, datum: Datum) -> float:
def pace(self, datum: Datum) -> float:
pass

def __call__(self, x: int) -> float:
datum = self._datum(x)
pace = self._pace(datum)
pace = self.pace(datum)
return pace * (self._end - x) / 1000


@ETARegistry.register_()
class AverageETA(BaseETA):

def _pace(self, datum: Datum) -> float:
def pace(self, datum: Datum) -> float:
t = datum.t - self._start.t
x = datum.x - self._start.x
return t.total_seconds() * 1000 / x
Expand All @@ -49,14 +51,13 @@ def _pace(self, datum: Datum) -> float:
@ETARegistry.register_()
class EMA_ETA(AverageETA): # noqa: N801 pylint: disable=invalid-name

def __init__(self, *args, decay: float, **kwargs) -> None:
assert 0 <= decay <= 1
def __init__(self, *args, ema: Config, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._decay = decay
self._ema_pace = 0.
self._ema = EMA(**ema)
self._pace: float | None = None

def _pace(self, datum: Datum) -> float:
pace = super()._pace(datum)
pace = self._decay * self._ema_pace + (1 - self._decay) * pace
self._ema_pace = pace
def pace(self, datum: Datum) -> float:
pace = super().pace(datum)
pace = self._ema(self._pace, pace)
self._pace = pace
return pace
1 change: 0 additions & 1 deletion todd/base/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os

import torch
import torch.distributed
from packaging.version import parse

from ..utils import NonInstantiableMeta
Expand Down
1 change: 0 additions & 1 deletion todd/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import TYPE_CHECKING, Any, Mapping

import torch
import torch.distributed
import torch.utils.data

from ..base import (
Expand Down
6 changes: 3 additions & 3 deletions todd/runners/strategies/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TypeVar

import torch
import torch.distributed
import torch.distributed as dist
from torch import nn

from ...base import Config, Store, StrategyRegistry
Expand All @@ -31,12 +31,12 @@ def __init__(
super().__init__(*args, **kwargs)

def _setup(self, config: Config) -> None:
if not torch.distributed.is_initialized():
if not dist.is_initialized():
init_process_group = config.get(
'init_process_group',
Config(backend='nccl'),
)
torch.distributed.init_process_group(**init_process_group)
dist.init_process_group(**init_process_group)
torch.cuda.set_device(get_local_rank() % torch.cuda.device_count())

def map_model(
Expand Down
1 change: 0 additions & 1 deletion todd/runners/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, Mapping

import torch
import torch.distributed
import torch.utils.data

from ..base import Config, RunnerRegistry
Expand Down
1 change: 0 additions & 1 deletion todd/runners/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
]

import torch
import torch.distributed
import torch.utils.data

from .base import BaseRunner, RunnerRegistry
Expand Down
56 changes: 53 additions & 3 deletions todd/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
'get_world_size',
'all_gather',
'all_gather_',
'all_sync',
'Shape',
'ModuleList',
'ModuleDict',
'ExponentialMovingAverage',
]

import functools
import itertools
import operator
import os
from typing import Any
from typing import TYPE_CHECKING

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -90,6 +92,15 @@ def all_gather_(
return tensors


def all_sync(x: torch.Tensor) -> bool:
if get_world_size() <= 1:
return True
x_prime = x.clone()
dist.all_reduce(x)
x /= get_world_size()
return torch.allclose(x, x_prime)


class Shape:

@classmethod
Expand Down Expand Up @@ -134,11 +145,50 @@ def conv(

class ModuleList(nn.ModuleList):

def forward(self, *args, **kwargs) -> list:
def forward(self, *args, **kwargs) -> list[nn.Module]:
return [m(*args, **kwargs) for m in self]


class ModuleDict(nn.ModuleDict):

def forward(self, *args, **kwargs) -> dict[str, Any]:
def forward(self, *args, **kwargs) -> dict[str, nn.Module]:
return {k: m(*args, **kwargs) for k, m in self.items()}


class ExponentialMovingAverage(nn.Module):

def __init__(
self,
*args,
decay=0.99,
**kwargs,
) -> None:
self.check_decay(decay)
super().__init__(*args, **kwargs)
self._decay = decay

@staticmethod
def check_decay(decay) -> None:
if isinstance(decay, torch.Tensor):
assert decay.ge(0).all() and decay.le(1).all()
else:
assert 0 <= decay <= 1

@property
def decay(self):
return self._decay

def forward(self, x, y, decay=None):
assert x is not None or y is not None
if x is None:
return y
if y is None:
return x
if decay is None:
decay = self._decay
else:
self.check_decay(decay)
return x * decay + y * (1 - decay)

if TYPE_CHECKING:
__call__ = forward
Loading

0 comments on commit b7b3bf5

Please sign in to comment.