Skip to content

Commit

Permalink
fix the last bug hopefully.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Oct 11, 2024
1 parent 01b77dc commit 20a35a7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def forward(
(sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:],
requires_grad=typed_tensor.requires_grad,
)
default_splits = [0] * torch.distributed.get_world_size()
all_to_all(
torch.split(out, counts_sent.get(ntype, 0)),
torch.split(out, counts_sent.get(ntype, default_splits)),
torch.split(
typed_tensor[seed_inverse_ids.get(ntype, slice(None))],
counts_received.get(ntype, 0),
counts_received.get(ntype, default_splits),
),
)
outs[ntype] = out
Expand Down

0 comments on commit 20a35a7

Please sign in to comment.