Skip to content

Commit

Permalink
More lightning features (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Jun 26, 2024
1 parent a5201ce commit 0ee9562
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 52 deletions.
1 change: 1 addition & 0 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
)
if self.world_rank in ranks_in_ith_jth_data_parallel_group:
self.coll_nccl_comm = ith_jth_data_parallel_group
self.data_parallel_group = ith_jth_data_parallel_group

# create communicators for intra-layer parallelism
for i_ in range(G_data):
Expand Down
65 changes: 51 additions & 14 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def gather(
OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
ALL_GATHER_ITERATOR = None
NO_GRADIENT_SYNC = False
handles = []
pending_grad_accumulations = []
weights_cache = {}
Expand Down Expand Up @@ -188,15 +189,30 @@ def optimize_communication(
ALL_GATHER_ITERATOR = None


@contextmanager
def no_grad_sync():
global NO_GRADIENT_SYNC
old_val = NO_GRADIENT_SYNC
try:
NO_GRADIENT_SYNC = True
finally:
NO_GRADIENT_SYNC = old_val


@torch.no_grad()
def sync_gradients(
model, gradient_attr_name="grad", mean=False, vectorize=False, mean_weight=None
def sync_gradients_depth_parallel(
model, gradient_attr_name="grad", mean=False, vectorize=False
):
if NO_GRADIENT_SYNC:
return
grads_to_sync = []
world_size = dist.get_world_size(ax.comm_handle.depth_intra_layer_parallel_group)
for param in model.parameters():
if param.requires_grad:
grad = getattr(param, gradient_attr_name)
if grad is not None:
if mean:
grad.div_(world_size)
if hasattr(param, "is_tensor_parallel") and param.is_tensor_parallel:
if (
hasattr(param, "needs_gradient_sync")
Expand All @@ -209,30 +225,51 @@ def sync_gradients(
if not grads_to_sync:
return

world_size = dist.get_world_size(ax.comm_handle.depth_intra_layer_parallel_group)
if vectorize:
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

global_grad = _flatten_dense_tensors(grads_to_sync)
dist.all_reduce(
global_grad, group=ax.comm_handle.depth_intra_layer_parallel_group
)
if mean:
global_grad.div_(world_size)

for old_tensor, new_tensor in zip(
grads_to_sync, _unflatten_dense_tensors(global_grad, grads_to_sync)
):
old_tensor.data = new_tensor
else:
for grad in grads_to_sync:
if mean:
if mean_weight is None:
grad.div_(world_size)
else:
mean_weight_pt = torch.tensor(
[mean_weight], device="cuda", dtype=torch.float32
)
dist.all_reduce(mean_weight_pt)
grad.mul_(mean_weight).div_(mean_weight_pt)
dist.all_reduce(grad, group=ax.comm_handle.depth_intra_layer_parallel_group)


@torch.no_grad()
def sync_gradients_data_parallel(
model, gradient_attr_name="grad", mean=False, vectorize=False
):
if NO_GRADIENT_SYNC:
return
grads_to_sync = []
world_size = dist.get_world_size(ax.comm_handle.data_parallel_group)
for param in model.parameters():
if param.requires_grad:
grad = getattr(param, gradient_attr_name)
if grad is not None:
if mean:
grad.div_(world_size)
grads_to_sync.append(grad)

if not grads_to_sync:
return

if vectorize:
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

global_grad = _flatten_dense_tensors(grads_to_sync)
dist.all_reduce(global_grad, group=ax.comm_handle.data_parallel_group)
for old_tensor, new_tensor in zip(
grads_to_sync, _unflatten_dense_tensors(global_grad, grads_to_sync)
):
old_tensor.data = new_tensor
else:
for grad in grads_to_sync:
dist.all_reduce(grad, group=ax.comm_handle.data_parallel_group)
14 changes: 7 additions & 7 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
bias=True,
skip_bias_add=False,
init_method=None,
use_easy_api=True,
expert_mode=False,
**kwargs
):
super(Linear, self).__init__()
Expand All @@ -183,7 +183,7 @@ def __init__(

self.in_features = in_features
self.out_features = out_features
self.use_easy_api = use_easy_api
self.expert_mode = expert_mode

if init_method is None:
init_method = default_init_method
Expand Down Expand Up @@ -266,7 +266,7 @@ def forward(

weight = self.weight
if not self.transpose:
if self.use_easy_api:
if not self.expert_mode:
x = Drop.apply(x, self.inner_group)
x = AsyncLinear.apply(
x,
Expand All @@ -279,10 +279,10 @@ def forward(
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
if self.use_easy_api:
if not self.expert_mode:
x = Gather.apply(x, self.outer_group)
else:
if self.use_easy_api:
if not self.expert_mode:
x = Drop.apply(x, self.outer_group)

x = AsyncLinear.apply(
Expand All @@ -296,14 +296,14 @@ def forward(
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
if self.use_easy_api:
if not self.expert_mode:
x = Gather.apply(x, self.inner_group)

if self.bias is None:
return x
else:
bias = self.bias
if self.use_easy_api:
if not self.expert_mode:
bias = Gather.apply(
bias,
self.outer_group if not self.transpose else self.inner_group,
Expand Down
63 changes: 48 additions & 15 deletions axonn/lightning/axonn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, ContextManager
from contextlib import nullcontext

import torch
import torch.distributed
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
Expand All @@ -23,7 +25,7 @@
)
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import TBroadcast
from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
Expand All @@ -33,12 +35,17 @@
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.rank_zero import rank_zero_only

from axonn import axonn as ax
from axonn.intra_layer import sync_gradients
from axonn.intra_layer import (
sync_gradients_data_parallel,
sync_gradients_depth_parallel,
clip_grad_norm_,
no_grad_sync,
)


class AxonnStrategy(ParallelStrategy):

def __init__(
self,
accelerator: Optional[Accelerator] = None,
Expand All @@ -48,8 +55,6 @@ def __init__(
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
G_data: int = 1,
G_inter: int = 1,
G_intra_r: int = 1,
G_intra_c: int = 1,
G_intra_d: int = 1,
Expand All @@ -63,19 +68,14 @@ def __init__(
precision=precision,
)

assert G_data == 1, "Data Parallelism not Supported in AxoNNStrategy"
assert (
G_inter == 1
), "Inter-layer (or pipeline) Parallellism not Supported in AxoNNStrategy"
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self.G_data = G_data
self.G_inter = G_inter
self.G_intra_r = G_intra_r
self.G_intra_c = G_intra_c
self.G_intra_d = G_intra_d
self._axonn_kwargs = kwargs
self._backward_sync_control = _AxoNNBackwardSyncControl()

@property
@override
Expand Down Expand Up @@ -169,10 +169,13 @@ def _setup_distributed(self) -> None:
_init_dist_connection(
self.cluster_environment, self._process_group_backend, timeout=self._timeout
)
tensor_parallel_world_size = self.G_intra_c * self.G_intra_r * self.G_intra_d
assert torch.distributed.get_world_size() % tensor_parallel_world_size == 0
self.G_data = torch.distributed.get_world_size() // tensor_parallel_world_size

ax.init(
G_data=self.G_data,
G_inter=self.G_inter,
G_inter=1,
G_intra_r=self.G_intra_r,
G_intra_c=self.G_intra_c,
G_intra_d=self.G_intra_d,
Expand All @@ -199,13 +202,15 @@ def _determine_device_ids(self) -> Optional[List[int]]:
def backward(
self, tensor: Tensor, module: Optional[Module] = None, *args: Any, **kwargs: Any
) -> None:
super().backward(tensor / self.G_intra_d, module, *args, **kwargs)
super().backward(tensor, module, *args, **kwargs)
if self.G_intra_d > 1:
assert module is not None, (
"When using G_intra_d > 1 with AxoNN,"
" you need to pass the model in fabric.backward(model=..)"
)
sync_gradients(module)
sync_gradients_depth_parallel(module, mean=True)
if self.G_data > 1:
sync_gradients_data_parallel(module, mean=True)

def save_checkpoint(
self,
Expand All @@ -226,3 +231,31 @@ def load_checkpoint(
"Current fabric.load(..) is not supported with the"
" AxoNN strategy. Use axonn.load instead."
)

def clip_gradients_norm(
self,
module: Module,
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> Tensor:
self.precision.unscale_gradients(optimizer)
parameters = self.precision.main_params(optimizer)
grad_norm = clip_grad_norm_(
parameters=parameters,
max_norm=max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
)
return grad_norm


class _AxoNNBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
"""Blocks gradient synchronization inside AxoNN"""
if not enabled:
return nullcontext()

return no_grad_sync()
4 changes: 2 additions & 2 deletions axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Conv2d,
optimize_communication,
clear_weights_cache,
sync_gradients,
sync_gradients_depth_parallel,
)
import math
import torch.distributed as dist
Expand Down Expand Up @@ -185,7 +185,7 @@ def test_bw_pass(
Y_local.backward(Y_local_grad)

if not easy_tp:
sync_gradients(layer)
sync_gradients_depth_parallel(layer)
if comm_opt_level >= 3:
clear_weights_cache()

Expand Down
Loading

0 comments on commit 0ee9562

Please sign in to comment.