Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Oct 23, 2024
1 parent 5d1e531 commit 27740e8
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
3 changes: 2 additions & 1 deletion axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def init(
G_intra_c: int = 1,
G_intra_d: int = 1,
gpus_per_node: Optional[int] = None,
enable_internal_timers: bool = False
enable_internal_timers: bool = False,
) -> None:
"""
Initialize AxoNN's 2D parallelism with G_inter-way inter-layer
Expand Down Expand Up @@ -120,6 +120,7 @@ def create_dataloader(
**kwargs,
) # not working with drop_last=False


def get_timers():
global timers
return timers
2 changes: 2 additions & 0 deletions axonn/intra_layer/asym_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist
from axonn import axonn as ax


def print_rank(msg):
if dist.get_rank() == 0:
print(f"{dist.get_rank()} | {msg}")
Expand All @@ -28,6 +29,7 @@ def gather_batch_sizes(local_batch_size, process_group=None):
ax.get_timers().stop("gather-batch-sizes")
return global_batch_tensor


@torch.no_grad()
def _allgatherv(tensor, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("allgatherv")
Expand Down
5 changes: 2 additions & 3 deletions axonn/lightning/axonn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
G_intra_c: int = 1,
G_intra_d: int = 1,
overlap_communication=False,
enable_timers = False,
enable_timers=False,
activation_checkpointing: Optional[
Union[Type[Module], List[Type[Module]]]
] = None,
Expand Down Expand Up @@ -218,7 +218,7 @@ def _setup_distributed(self) -> None:
G_intra_r=self.G_intra_r,
G_intra_c=self.G_intra_c,
G_intra_d=self.G_intra_d,
enable_internal_timers=self.enable_timers
enable_internal_timers=self.enable_timers,
)

def _get_process_group_backend(self) -> str:
Expand Down Expand Up @@ -319,7 +319,6 @@ def module_init_context(self, empty_init: Optional[bool] = None):
def module_sharded_context(self) -> ContextManager:
return auto_parallelize()


def get_timers(self):
assert self.enable_timers, "you should set enable_timers=True in AxoNNStrategy"
return ax.get_timers()
Expand Down
9 changes: 6 additions & 3 deletions axonn/timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from collections import deque
import axonn

class Timers():

class Timers:
def __init__(self):
self.timers = defaultdict(list)
self.curr_index = defaultdict(int)
self.stack = deque()

def start(self, key):
def start(self, key):
if not axonn.axonn.enable_timers:
return
self.stack.append(key)
Expand All @@ -18,7 +19,9 @@ def start(self, key):
timers = self.timers[key]
assert index == len(timers) or index < len(timers)
if index == len(timers):
self.timers[key].append([torch.cuda.Event(enable_timing=True) for _ in range(2)])
self.timers[key].append(
[torch.cuda.Event(enable_timing=True) for _ in range(2)]
)
self.timers[key][index][0].record()

def stop(self, key):
Expand Down

0 comments on commit 27740e8

Please sign in to comment.