diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index a0e9bc2..5d08cbc 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from .fully_connected import Linear # noqa: F401 -from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401 +from .conv import Conv2d # noqa: F401 from .communication import Drop, Gather from .gradient_normalization import clip_grad_norm_ # noqa: F401 @@ -86,7 +86,7 @@ def clear_weights_cache(): def trigger_async_all_gathers(model): global weights_cache for module in model.modules(): - if isinstance(module, Linear): + if isinstance(module, Linear) or isinstance(module, Conv2d): weight = module.weight if weight not in weights_cache: # only trigger all gathers if not in cache diff --git a/axonn/intra_layer/conv.py b/axonn/intra_layer/conv.py index 682d666..77f5710 100644 --- a/axonn/intra_layer/conv.py +++ b/axonn/intra_layer/conv.py @@ -1,21 +1,45 @@ from axonn import axonn as ax +import axonn import torch.distributed as dist import torch -from .communication import ForwardAllReduce, BackwardAllReduce, Drop +import math +from .communication import ( + ForwardAllReduce, + BackwardAllReduce, + Drop, + Gather, + ForwardGather_BackwardReduceScatter, +) from .utils import divide @torch.no_grad() def initialize_params( - out_channels, in_channels, kernel_size, outer_group, inner_group, init_method + out_channels, + in_channels, + kernel_size, + outer_group, + inner_group, + depth_group, + init_method, + init_device="cuda", ): - params = torch.empty((out_channels, in_channels, kernel_size, kernel_size)) + params = torch.empty( + (out_channels, in_channels, kernel_size, kernel_size), device=init_device + ) init_method(params) params = Drop.apply(params, outer_group, 0) params = Drop.apply(params, inner_group, 1) + params = Drop.apply(params.reshape(-1), depth_group) + params = params.cpu() return params +@torch.no_grad() +def default_init_method(weight): + return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + class Conv2d(torch.nn.Module): def __init__( self, @@ -24,54 +48,130 @@ def __init__( kernel_size, *args, transpose=False, + bias=True, skip_bias_add=False, init_method=None, - **kwargs + stride=1, + padding=0, + dilation=1, + groups=1, ): super(Conv2d, self).__init__() + # For transpose, inner and outer groups are swapped if not transpose: self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group + self.depth_group = ax.comm_handle.depth_intra_layer_parallel_group else: self.outer_group = ax.comm_handle.inner_intra_layer_parallel_group self.inner_group = ax.comm_handle.outer_intra_layer_parallel_group + self.depth_group = ax.comm_handle.depth_intra_layer_parallel_group self.inner_group_size = dist.get_world_size(self.inner_group) self.outer_group_size = dist.get_world_size(self.outer_group) + self.depth_group_size = dist.get_world_size(self.depth_group) - self.in_channels = divide(in_channels, self.inner_group_size) - self.out_channels = divide(out_channels, self.outer_group_size) + if init_method is None: + init_method = default_init_method - self.conv = torch.nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=kernel_size, - bias=False, - **kwargs + self.local_in_channels = divide(in_channels, self.inner_group_size) + self.local_out_channels = divide(out_channels, self.outer_group_size) + + initial_params = initialize_params( + out_channels, + in_channels, + kernel_size, + self.outer_group, + self.inner_group, + self.depth_group, + init_method, ) - if init_method: - initial_params = initialize_params( - out_channels, - in_channels, - kernel_size, + self.weight = torch.nn.Parameter(initial_params, requires_grad=True) + setattr(self.weight, "is_tensor_parallel", True) + setattr(self.weight, "needs_gradient_sync", False) + setattr( + self.weight, + "process_group_for_norm_reduction", + ax.comm_handle.intra_layer_group, # What is intra_layer_group? + ) + + if bias: + self.bias = torch.nn.Parameter( + torch.zeros(self.local_out_channels), requires_grad=True + ) + setattr(self.bias, "is_tensor_parallel", True) + setattr(self.bias, "needs_gradient_sync", True) + setattr( + self.bias, + "process_group_for_norm_reduction", self.outer_group, - self.inner_group, - init_method, ) - self.conv.weight.data.copy_(initial_params) - - self.bias = torch.nn.Parameter(torch.zeros(self.out_channels)) + else: + self.bias = None + self.kernel_size = kernel_size self.skip_bias_add = skip_bias_add + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward( + self, + x, + scatter_input=True, + gather_output=True, + cache_weights_in_all_gather=False, + ): + # Gather weights from depth parallel group + # TODO: We should make the OVERLAP_REDUCE_SCATTER flag part of axonn.axonn + weight = ForwardGather_BackwardReduceScatter.apply( + self.weight, + self.depth_group, + 0, + axonn.intra_layer.OVERLAP_REDUCE_SCATTER, + cache_weights_in_all_gather, + ).reshape( + self.local_out_channels, + self.local_in_channels, + self.kernel_size, + self.kernel_size, + ) + + if scatter_input: + # Drop input across the in_channels dimension on the inner_group + x = Drop.apply(x, self.inner_group, 1) + # Drop input across the batch dimension on the depth_group + x = Drop.apply(x, self.depth_group, 0) - def forward(self, x): x = BackwardAllReduce.apply(x, self.outer_group) - h = self.conv(x) + h = torch.nn.functional.conv2d( + x, + weight, + bias=None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) h = ForwardAllReduce.apply(h, self.inner_group) - if self.skip_bias_add: - return h, self.bias + + if gather_output: + # Gather input across the in_channels dimension on the inner_group + h = Gather.apply(h, self.outer_group, 1) + # Gather input across the batch dimension on the depth_group + h = Gather.apply(h, self.depth_group, 0) + + if self.bias is None: + return h else: - return h + self.bias.view(1, -1, 1, 1) - return h + bias = self.bias + if gather_output: + bias = Gather.apply(bias, self.outer_group) + + if self.skip_bias_add: + return h, bias + else: + return h + bias.view(1, -1, 1, 1) diff --git a/axonn/tests/test_intra_layer_conv.py b/axonn/tests/test_intra_layer_conv.py index 99423f8..c05fb12 100644 --- a/axonn/tests/test_intra_layer_conv.py +++ b/axonn/tests/test_intra_layer_conv.py @@ -2,15 +2,49 @@ import pytest from axonn import axonn as ax from axonn.intra_layer.communication import _drop, _gather -from axonn.intra_layer import Tensor_Parallel_Conv2d +from axonn.intra_layer import ( + Conv2d, + optimize_communication, + clear_weights_cache, + sync_gradients, +) +import math +import torch.distributed as dist + + +def log_dist(msg, ranks=[]): + assert dist.is_initialized() + if dist.get_rank() in ranks: + print(f"Rank {dist.get_rank()} : {msg}") + + +def norm_allclose(X, Y): + epsilon = 1e-6 + squared_diff = torch.square(X - Y) + mse = torch.mean(squared_diff).item() + rmse = math.sqrt(mse) + + log_dist(f"RMSE:{rmse}", [0]) + log_dist(f"L2Norm:{torch.norm(X - Y, 2)}", [0]) + + if rmse < epsilon: + return True + else: + return False @pytest.mark.mpi @pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)]) -@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) -def test_fw_pass(G_intra_r, G_intra_c, H, W, C): +@pytest.mark.parametrize("B", [2, 4, 16]) +@pytest.mark.parametrize( + "G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)] +) +@pytest.mark.parametrize("easy_tp", [True, False]) +@pytest.mark.parametrize("bias", [True, False]) +def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias): # These tests are in fp-32 torch.manual_seed(42) + torch.cuda.manual_seed(42) # Need to remove all non-determinism from convolutions torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False @@ -23,36 +57,60 @@ def test_fw_pass(G_intra_r, G_intra_c, H, W, C): G_inter=1, G_intra_r=G_intra_r, G_intra_c=G_intra_c, + G_intra_d=G_intra_d, ) - X = torch.randn(1, C, H, W).cuda() * 0.01 + X = torch.randn(B, C, H, W).cuda() * 0.01 inner_group = ax.comm_handle.inner_intra_layer_parallel_group outer_group = ax.comm_handle.outer_intra_layer_parallel_group + depth_group = ax.comm_handle.depth_intra_layer_parallel_group - X_local = _drop( - X, 1, inner_group - ) # divide channels of X along the inner tensor group - layer = Tensor_Parallel_Conv2d( - in_channels=C, out_channels=2 * C, kernel_size=5, skip_bias_add=True - ).cuda() + if not easy_tp: + X_local = _drop( + X, 1, inner_group + ) # divide channels of X along the inner tensor group + X_local = _drop( + X_local, 0, depth_group + ) # divide input channels of X along the depth tensor group + else: + X_local = X + + layer = Conv2d(in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias).cuda() with torch.no_grad(): # parallel FW pass - Y_local, _ = layer(X_local) - Y_parallel = _gather(Y_local.clone(), 1, outer_group) + Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + if not easy_tp: + Y_parallel = _gather(Y_local.clone(), 1, outer_group) + Y_parallel = _gather(Y_parallel.clone(), 0, depth_group) + else: + Y_parallel = Y_local # sequential FW pass layer_sequential = torch.nn.Conv2d( in_channels=C, out_channels=C * 2, kernel_size=5, - bias=False, + bias=bias, ).cuda() weight_sequential = _gather( - _gather(layer.conv.weight, 1, inner_group), 0, outer_group + _gather( + _gather(layer.weight, 0, depth_group).reshape( + layer.local_out_channels, + layer.local_in_channels, + layer.kernel_size, + layer.kernel_size, + ), + 1, + inner_group, + ), + 0, + outer_group, ) layer_sequential.weight.copy_(weight_sequential) + if bias: + layer_sequential.bias.zero_() Y_sequential = layer_sequential(X) assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" @@ -60,11 +118,20 @@ def test_fw_pass(G_intra_r, G_intra_c, H, W, C): @pytest.mark.mpi @pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)]) -@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) -def test_bw_pass(G_intra_r, G_intra_c, H, W, C): +@pytest.mark.parametrize("B", [2, 4, 16]) +@pytest.mark.parametrize( + "G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)] +) +@pytest.mark.parametrize("easy_tp", [True, False]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("comm_opt_level", [0, 3]) +def test_bw_pass( + G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias, comm_opt_level +): # These tests are in fp-32 # Need to remove all non-determinism from convolutions torch.manual_seed(42) + torch.cuda.manual_seed(42) torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True @@ -76,50 +143,109 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C): G_inter=1, G_intra_r=G_intra_r, G_intra_c=G_intra_c, + G_intra_d=G_intra_d, ) - X = torch.randn(1, C, H, W).cuda() * 0.01 - Y_grad = torch.randn(1, 2 * C, H - 4, W - 4).cuda() * 0.01 + X = torch.randn(B, C, H, W).cuda() * 0.01 + Y_grad = torch.randn(B, 2 * C, H - 4, W - 4).cuda() * 0.01 inner_group = ax.comm_handle.inner_intra_layer_parallel_group outer_group = ax.comm_handle.outer_intra_layer_parallel_group + depth_group = ax.comm_handle.depth_intra_layer_parallel_group # parallel backward pass - layer = Tensor_Parallel_Conv2d( - in_channels=C, out_channels=2 * C, kernel_size=5, skip_bias_add=True - ).cuda() - X_local = ( - _drop(X, 1, inner_group).detach().clone() - ) # divide input channels of X along the inner tensor group + layer = Conv2d(in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias).cuda() + + if not easy_tp: + X_local = ( + _drop(X, 1, inner_group).detach().clone() + ) # divide input channels of X along the inner tensor group + X_local = ( + _drop(X_local, 0, depth_group).detach().clone() + ) # divide input channels of X along the depth tensor group + else: + X_local = X + X_local.requires_grad = True - Y_local, _ = layer(X_local) - Y_local_grad = _drop(Y_grad, 1, outer_group) - Y_local.backward(Y_local_grad) + if not easy_tp: + Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone() + Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone() + else: + Y_local_grad = Y_grad + + with optimize_communication( + overlap_reduce_scatter=comm_opt_level >= 1, + cache_weights=comm_opt_level >= 2, + overlap_all_gather=comm_opt_level == 3, + model_object_for_overlapping_allgathers=layer, + ): + Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + Y_local.backward(Y_local_grad) + + if not easy_tp: + sync_gradients(layer) + if comm_opt_level >= 3: + clear_weights_cache() # sequential backward pass layer_sequential = torch.nn.Conv2d( in_channels=C, out_channels=C * 2, kernel_size=5, - bias=False, + bias=bias, ).cuda() with torch.no_grad(): weight_sequential = _gather( - _gather(layer.conv.weight, 1, inner_group), 0, outer_group + _gather( + _gather(layer.weight, 0, depth_group).reshape( + layer.local_out_channels, + layer.local_in_channels, + layer.kernel_size, + layer.kernel_size, + ), + 1, + inner_group, + ), + 0, + outer_group, ) layer_sequential.weight.copy_(weight_sequential) + if bias: + layer_sequential.bias.zero_() X.requires_grad = True Y_sequential = layer_sequential(X) Y_sequential.backward(Y_grad) - X_grad_parallel = _gather(X_local.grad, 1, inner_group) + if not easy_tp: + X_grad_parallel = _gather(X_local.grad, 0, depth_group) + X_grad_parallel = _gather(X_grad_parallel, 1, inner_group) + else: + X_grad_parallel = X_local.grad - assert torch.allclose( + assert norm_allclose( X_grad_parallel, X.grad ), "BW Pass - gradients of input do not match" weight_grad_parallel = _gather( - _gather(layer.conv.weight.grad, 1, inner_group), 0, outer_group + _gather( + _gather(layer.weight.grad, 0, depth_group).reshape( + layer.local_out_channels, + layer.local_in_channels, + layer.kernel_size, + layer.kernel_size, + ), + 1, + inner_group, + ), + 0, + outer_group, ) - assert torch.allclose( + + assert norm_allclose( weight_grad_parallel, layer_sequential.weight.grad ), "BW Pass - gradients of weight do not match" + + if bias: + bias_grad_parallel = _gather(layer.bias.grad, 0, outer_group) + assert norm_allclose( + bias_grad_parallel, layer_sequential.bias.grad + ), "BW Pass - gradients of bias do not match"