Skip to content

Commit

Permalink
Change overlap for depth tp and do not initialize MPI unless absolute…
Browse files Browse the repository at this point in the history
…ly needed (#62)

* optimize communication in the backward pass and do not initialize MPI unless absolutely needed
  • Loading branch information
siddharth9820 authored Jan 27, 2024
1 parent 45647ea commit a9d38c2
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 31 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
export G_data=$(( 2 / G_inter ))
export memopt=${{ matrix.memopt }}
echo "training with G_inter = ${G_inter}, G_data = $(( 2 / G_inter )) ${{ matrix.memopt }}"
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
Expand All @@ -47,10 +47,10 @@ jobs:
pip install -r requirements.txt
- name: Run intra-layer FC unit tests
run: |
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
- name: Run intra-layer Conv unit tests
run: |
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_conv.py
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_conv.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
14 changes: 13 additions & 1 deletion axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@
from .communication import communication_handle
from .optim import CPUAdam
import torch
from mpi4py import MPI
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from enum import Enum
import numpy as np
import types

try:
# from mpi4py import MPI
import mpi4py

MPI4PY = True
mpi4py.rc.initialize = False # do not initialize MPI automatically
from mpi4py import MPI
except ImportError:
MPI4PY = False

# True when init has been called
is_initialized = False
# Communication handle for point-to-point (MPI) and collective (NCCL) communication
Expand Down Expand Up @@ -577,6 +586,7 @@ def _recv(post_fw_recv=True, post_bw_recv=True, eval_mode=False) -> int:
Returns:
tag(int): the tag of the received message which is the microbatch number
"""
assert MPI4PY, "attempting to use inter-layer parallelism without mpi4py installed"
status = MPI.Status()
if (requests["bw"] is None) and (requests["fw"] is not None):
requests["fw"][1].Wait(status)
Expand Down Expand Up @@ -655,6 +665,8 @@ def _backward_pass(output_gradients, microbatch_no):


def _sync_scale(local_overflow):
assert MPI4PY, "attempting to use inter-layer parallelism without mpi4py installed"

global loss_scale, no_overflow_iters, max_scale
assert computation_dtype == torch.float16
overflow_np = np.array(int(local_overflow), "i")
Expand Down
13 changes: 11 additions & 2 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import os

try:
from mpi4py import MPI
# from mpi4py import MPI
import mpi4py

MPI4PY = True
mpi4py.rc.initialize = False # do not initialize MPI automatically
from mpi4py import MPI
except ImportError:
MPI4PY = False
import torch
Expand Down Expand Up @@ -44,6 +47,8 @@ def __init__(
if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
if not MPI.Is_initialized():
MPI.Init()
self.world_rank = MPI.COMM_WORLD.Get_rank()
self.world_size = MPI.COMM_WORLD.Get_size()
else:
Expand Down Expand Up @@ -88,6 +93,8 @@ def __init__(
if G_inter > 1:
# this needs to be checked
if MPI4PY:
if not MPI.Is_initialized():
MPI.Init()
self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour)
assert self.p2p_mpi_comm.Get_size() == G_inter
else:
Expand Down Expand Up @@ -244,7 +251,7 @@ def recv(
self,
tensor: torch.Tensor,
send_rank: int,
tag: int = MPI.ANY_TAG,
tag: int = None,
async_op: bool = True,
):
"""Receive a PyTorch tensor from a particular rank using MPI
Expand All @@ -260,6 +267,8 @@ def recv(
mpi4py future object if async is true, else None - this object
can be queried to check for completion of communication
"""
if tag is None:
tag = MPI.ANY_TAG
mpi4py_compatible_array = self._torch_to_mpi(tensor)
if async_op:
mpi_future_object = self.p2p_mpi_comm.Irecv(
Expand Down
99 changes: 74 additions & 25 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from axonn import axonn as ax
import axonn
from .communication import Drop, Gather, ForwardGather_BackwardReduceScatter
from .communication import (
Drop,
Gather,
_gather,
_reduce_scatter,
)


def divide(a, b):
Expand Down Expand Up @@ -57,11 +62,20 @@ def forward(
weight,
forward_all_reduce_group,
backward_all_reduce_group,
depth_parallel_group,
local_weight_shape,
cache_weights,
backward_comm_async,
forward_comm_async,
):
ctx.save_for_backward(input_, weight)
original_weight = weight
weight = _gather(
weight, dim=0, process_group=depth_parallel_group, cache=cache_weights
)
weight = weight.reshape(local_weight_shape)
ctx.save_for_backward(input_, weight, original_weight)
ctx.backward_all_reduce_group = backward_all_reduce_group
ctx.depth_parallel_group = depth_parallel_group
ctx.backward_comm_async = backward_comm_async
if not forward_comm_async:
output = input_.matmul(weight.t())
Expand All @@ -86,24 +100,59 @@ def forward(
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, original_weight = ctx.saved_tensors
handle = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=ctx.backward_comm_async,
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
overlap_reduce_scatter = axonn.intra_layer.OVERLAP_REDUCE_SCATTER
if dist.get_world_size(ctx.backward_all_reduce_group) > 1 or (
not overlap_reduce_scatter
):
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=ctx.backward_comm_async,
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)

grad_weight = grad_weight.reshape(-1)
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
process_group=ctx.depth_parallel_group,
overlap_comm=overlap_reduce_scatter,
)
if handle and ctx.backward_comm_async:
handle.wait()
return grad_input, grad_weight, None, None, None, None

if handle and ctx.backward_comm_async:
handle.wait()
if overlap_reduce_scatter:
axonn.intra_layer.accumulate_later(original_weight, grad_weight)
grad_weight = None # weight gradients are not ready yet
return grad_input, grad_weight, None, None, None, None, None, None, None
else:
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
).reshape(-1)
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
process_group=ctx.depth_parallel_group,
overlap_comm=True,
)
axonn.intra_layer.accumulate_later(original_weight, grad_weight)
grad_weight = None # weight gradients are not ready yet

if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
return grad_input, grad_weight, None, None, None, None, None, None, None


class Linear(torch.nn.Module):
Expand Down Expand Up @@ -210,14 +259,8 @@ def forward(
):
# gather weights from depth parallel group
# reduce scatter in the backward pass
weight = ForwardGather_BackwardReduceScatter.apply(
self.weight,
self.depth_group,
0,
axonn.intra_layer.OVERLAP_REDUCE_SCATTER,
cache_weights_in_all_gather,
).reshape(self.local_out_features, self.local_in_features)

weight = self.weight
if not self.transpose:
if scatter_input:
x = Drop.apply(x, self.inner_group)
Expand All @@ -227,6 +270,9 @@ def forward(
weight,
self.inner_group,
self.outer_group,
self.depth_group,
(self.local_out_features, self.local_in_features),
cache_weights_in_all_gather,
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
Expand All @@ -243,6 +289,9 @@ def forward(
weight,
self.outer_group,
self.inner_group,
self.depth_group,
(self.local_out_features, self.local_in_features),
cache_weights_in_all_gather,
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
Expand Down

0 comments on commit a9d38c2

Please sign in to comment.