Skip to content

Commit

Permalink
[GraphBolt][CUDA] Cooperative Minibatching hetero bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Oct 11, 2024
1 parent 4a6bfa4 commit bec6f17
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 19 deletions.
6 changes: 3 additions & 3 deletions python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ def _cooperative_exchange(self, data):
self.node_feature_keys, Dict
) or isinstance(self.edge_feature_keys, Dict)
if is_heterogeneous:
node_features = {key: {} for key, _ in data.node_features.keys()}
for (key, ntype), feature in data.node_features.items():
node_features = {key: {} for _, key in data.node_features.keys()}
for (ntype, key), feature in data.node_features.items():
node_features[key][ntype] = feature
for key, feature in node_features.items():
new_feature = CooperativeConvFunction.apply(subgraph, feature)
for ntype, tensor in new_feature.items():
data.node_features[(key, ntype)] = tensor
data.node_features[(ntype, key)] = tensor
else:
for key in data.node_features:
feature = data.node_features[key]
Expand Down
11 changes: 6 additions & 5 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ def forward(
outs = {}
for ntype, typed_tensor in convert_to_hetero(tensor).items():
out = typed_tensor.new_empty(
(sum(counts_sent[ntype]),) + typed_tensor.shape[1:]
(sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:],
requires_grad=typed_tensor.requires_grad,
)
all_to_all(
torch.split(out, counts_sent[ntype]),
torch.split(out, counts_sent.get(ntype, 0)),
torch.split(
typed_tensor[seed_inverse_ids[ntype]],
counts_received[ntype],
typed_tensor[seed_inverse_ids.get(ntype, slice(None))],
counts_received.get(ntype, 0),
),
)
outs[ntype] = out
return revert_to_homo(out)
return revert_to_homo(outs)

@staticmethod
def backward(
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def _seeds_cooperative_exchange_1(minibatch):
seeds_offsets = {"_N": seeds_offsets}
num_ntypes = len(seeds_offsets)
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(seeds_offsets.values()):
for i, (_, offsets) in enumerate(sorted(seeds_offsets.items())):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
Expand Down Expand Up @@ -589,7 +589,7 @@ def _seeds_cooperative_exchange_2(minibatch):
seeds_received = {}
counts_sent = {}
counts_received = {}
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())):
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
typed_counts_sent = subgraph._counts_sent[idx].tolist()
typed_counts_received = subgraph._counts_received[idx].tolist()
Expand Down
6 changes: 4 additions & 2 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def _seeds_cooperative_exchange_1_wait_future(minibatch):
else:
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(minibatch._seeds_offsets.values()):
for i, (_, offsets) in enumerate(
sorted(minibatch._seeds_offsets.items())
):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
Expand All @@ -261,7 +263,7 @@ def _seeds_cooperative_exchange_2(minibatch):
seeds_received = {}
counts_sent = {}
counts_received = {}
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())):
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
typed_counts_sent = minibatch._counts_sent[idx].tolist()
typed_counts_received = minibatch._counts_received[idx].tolist()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def test_rank_sort_and_unique_and_compact(dtype, rank):
nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]]
)

unique, compacted, offsets = gb.unique_and_compact(
nodes_list1[:1], rank, WORLD_SIZE
)
nodes = {str(i): [typed_seeds] for i, typed_seeds in enumerate(nodes_list1)}

nodes1, idx1, offsets1 = res1[0]
unique, compacted, offsets = gb.unique_and_compact(nodes, rank, WORLD_SIZE)

assert_equal(unique, nodes1)
assert_equal(compacted[0], idx1)
assert_equal(offsets, offsets1)
for i in nodes.keys():
nodes1, idx1, offsets1 = res1[int(i)]
assert_equal(unique[i], nodes1)
assert_equal(compacted[i][0], idx1)
assert_equal(offsets[i], offsets1)

0 comments on commit bec6f17

Please sign in to comment.