Skip to content

Commit

Permalink
Add more timers (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
anishbh authored Nov 4, 2024
1 parent 4dae813 commit 5a664ea
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
2 changes: 1 addition & 1 deletion axonn/intra_layer/asym_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group
)

torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group)
ax.get_timers().stop("gather-channels-scatter-batch")
ax.get_timers().stop("alltoallv")
ax.get_timers().stop("gather-channels-scatter-batch")
return torch.cat(recv_tensors, dim=-1)


Expand Down
16 changes: 11 additions & 5 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,40 @@
import torch.distributed as dist
import torch
import axonn.intra_layer.overlap_communication as overlap_communication
from axonn import axonn as ax


def _all_reduce(input_, process_group=None, overlap_comm=False):
ax.get_timers().start("all-reduce")
input_ = input_.contiguous()
if dist.get_world_size(process_group) > 1:
handle = dist.all_reduce(
input_.contiguous(), group=process_group, async_op=overlap_comm
)
if overlap_comm:
overlap_communication.register_handle(handle)
ax.get_timers().stop("all-reduce")
return input_


def _drop(input_, dim, process_group=None):
"""Divide a tensor among the tensor parallel ranks"""
if dist.get_world_size(process_group) == 1:
return input_

ax.get_timers().start("drop")
total_chunks = dist.get_world_size(process_group)
this_chunk = dist.get_rank(process_group)
assert input_.shape[dim] % total_chunks == 0
chunk_size = input_.shape[dim] // total_chunks

ax.get_timers().stop("drop")
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)


def _gather(input_, dim, process_group=None, cache=False):
"""Gather tensors and concatenate them along a dimension"""
if dist.get_world_size(process_group) == 1:
return input_

ax.get_timers().start("gather")
if input_ in overlap_communication.weights_cache:
output, handle = overlap_communication.retrieve_all_gathered_weight(
input_, delete=not cache
Expand All @@ -61,7 +64,7 @@ def _gather(input_, dim, process_group=None, cache=False):

if cache:
overlap_communication.weights_cache[input_] = output, None

ax.get_timers().stop("gather")
return output


Expand All @@ -70,7 +73,7 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):

if dist.get_world_size(process_group) == 1:
return input_

ax.get_timers().start("reduce-scatter")
total_chunks = dist.get_world_size(process_group)
assert input_.shape[dim] % total_chunks == 0
tensor_shape = list(input_.shape)
Expand All @@ -79,6 +82,7 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
)

ax.get_timers().start("reduce-scatter-dist")
if hasattr(torch.distributed, "reduce_scatter_tensor"):
handle = torch.distributed.reduce_scatter_tensor(
output, input_, group=process_group, async_op=overlap_comm
Expand All @@ -88,8 +92,10 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
output, input_, group=process_group, async_op=overlap_comm
)

ax.get_timers().stop("reduce-scatter-dist")
if overlap_comm:
overlap_communication.register_handle(handle)
ax.get_timers().stop("reduce-scatter")
return output


Expand Down
19 changes: 18 additions & 1 deletion axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward(
local_weight_shape,
cache_weights,
):
ax.get_timers().start("forward-async")
original_weight = weight
weight = _gather(
weight, dim=0, process_group=depth_parallel_group, cache=cache_weights
Expand All @@ -115,14 +116,17 @@ def forward(
ctx.backward_all_reduce_group = backward_all_reduce_group
ctx.depth_parallel_group = depth_parallel_group
ctx.shape = local_weight_shape
ax.get_timers().start("compute")
output = input_.matmul(weight.t())
ax.get_timers().stop("compute")
dist.all_reduce(output, group=forward_all_reduce_group, async_op=False)

ax.get_timers().stop("forward-async")
return output

@staticmethod
@version_aware_custom_bwd
def backward(ctx, grad_output):
ax.get_timers().start("backward-async")
input_, original_weight = ctx.saved_tensors
weight = _gather(
original_weight, dim=0, process_group=ctx.depth_parallel_group, cache=False
Expand All @@ -137,18 +141,22 @@ def backward(ctx, grad_output):
grad_input, grad_weight = None, None

if ctx.needs_input_grad[0]:
ax.get_timers().start("compute")
grad_input = grad_output.matmul(weight)
ax.get_timers().stop("compute")
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=overlap_all_reduce,
)
if ctx.needs_input_grad[1]:
ax.get_timers().start("compute")
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)
ax.get_timers().stop("compute")

grad_weight = grad_weight.reshape(-1)
grad_weight = _reduce_scatter(
Expand All @@ -163,16 +171,19 @@ def backward(ctx, grad_output):
if overlap_reduce_scatter and ctx.needs_input_grad[1]:
overlap_communication.accumulate_later(original_weight, grad_weight)
grad_weight = None # weight gradients are not ready yet
ax.get_timers().stop("backward-async")
return grad_input, grad_weight, None, None, None, None, None, None, None
else:
grad_input, grad_weight = None, None

if ctx.needs_input_grad[1]:
ax.get_timers().start("compute")
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
).reshape(-1)
ax.get_timers().stop("compute")
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
Expand All @@ -183,7 +194,10 @@ def backward(ctx, grad_output):
grad_weight = None # weight gradients are not ready yet

if ctx.needs_input_grad[0]:
ax.get_timers().start("compute")
grad_input = grad_output.matmul(weight)
ax.get_timers().stop("compute")
ax.get_timers().stop("backward-async")
return grad_input, grad_weight, None, None, None, None, None, None, None


Expand Down Expand Up @@ -305,6 +319,7 @@ def forward(
x,
cache_weights_in_all_gather=False,
):
ax.get_timers().start("forward-linear")
original_shape_x = x.shape
x = x.reshape(-1, x.shape[-1])
weight = self.weight
Expand Down Expand Up @@ -345,11 +360,13 @@ def forward(
x = x.reshape(*original_shape_x[:-1], x.shape[-1])

if self.bias is None:
ax.get_timers().stop("forward-linear")
return x
else:
bias = self.bias
if not self.expert_mode:
bias = Gather.apply(bias, self.outer_group)
ax.get_timers().stop("forward-linear")
if self.skip_bias_add:
return x, bias
else:
Expand Down

0 comments on commit 5a664ea

Please sign in to comment.