diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index c2d40d94d186..133f8d0c9835 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -47,11 +47,12 @@ def forward( (sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:], requires_grad=typed_tensor.requires_grad, ) + default_splits = [0] * torch.distributed.get_world_size() all_to_all( - torch.split(out, counts_sent.get(ntype, 0)), + torch.split(out, counts_sent.get(ntype, default_splits)), torch.split( typed_tensor[seed_inverse_ids.get(ntype, slice(None))], - counts_received.get(ntype, 0), + counts_received.get(ntype, default_splits), ), ) outs[ntype] = out