From bec6f17b97a8e29c2bae958fe34f01b8866079e0 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 11 Oct 2024 18:21:24 +0000 Subject: [PATCH] [GraphBolt][CUDA] Cooperative Minibatching hetero bug fixes. --- python/dgl/graphbolt/feature_fetcher.py | 6 +++--- python/dgl/graphbolt/impl/cooperative_conv.py | 11 ++++++----- python/dgl/graphbolt/impl/neighbor_sampler.py | 4 ++-- python/dgl/graphbolt/subgraph_sampler.py | 6 ++++-- .../impl/test_cooperative_minibatching_utils.py | 14 +++++++------- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index cf9d5f4104c2..e3db4df38bd2 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -166,13 +166,13 @@ def _cooperative_exchange(self, data): self.node_feature_keys, Dict ) or isinstance(self.edge_feature_keys, Dict) if is_heterogeneous: - node_features = {key: {} for key, _ in data.node_features.keys()} - for (key, ntype), feature in data.node_features.items(): + node_features = {key: {} for _, key in data.node_features.keys()} + for (ntype, key), feature in data.node_features.items(): node_features[key][ntype] = feature for key, feature in node_features.items(): new_feature = CooperativeConvFunction.apply(subgraph, feature) for ntype, tensor in new_feature.items(): - data.node_features[(key, ntype)] = tensor + data.node_features[(ntype, key)] = tensor else: for key in data.node_features: feature = data.node_features[key] diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 22c5ae316c71..0daa106c17a6 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -44,17 +44,18 @@ def forward( outs = {} for ntype, typed_tensor in convert_to_hetero(tensor).items(): out = typed_tensor.new_empty( - (sum(counts_sent[ntype]),) + typed_tensor.shape[1:] + (sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:], + requires_grad=typed_tensor.requires_grad, ) all_to_all( - torch.split(out, counts_sent[ntype]), + torch.split(out, counts_sent.get(ntype, 0)), torch.split( - typed_tensor[seed_inverse_ids[ntype]], - counts_received[ntype], + typed_tensor[seed_inverse_ids.get(ntype, slice(None))], + counts_received.get(ntype, 0), ), ) outs[ntype] = out - return revert_to_homo(out) + return revert_to_homo(outs) @staticmethod def backward( diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 7ddba6d7ccac..059fca7ef4c0 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -561,7 +561,7 @@ def _seeds_cooperative_exchange_1(minibatch): seeds_offsets = {"_N": seeds_offsets} num_ntypes = len(seeds_offsets) counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) - for i, offsets in enumerate(seeds_offsets.values()): + for i, (_, offsets) in enumerate(sorted(seeds_offsets.items())): counts_sent[ torch.arange(i, world_size * num_ntypes, num_ntypes) ] = offsets.diff() @@ -589,7 +589,7 @@ def _seeds_cooperative_exchange_2(minibatch): seeds_received = {} counts_sent = {} counts_received = {} - for i, (ntype, typed_seeds) in enumerate(seeds.items()): + for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())): idx = torch.arange(i, world_size * num_ntypes, num_ntypes) typed_counts_sent = subgraph._counts_sent[idx].tolist() typed_counts_received = subgraph._counts_received[idx].tolist() diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 88fc9c124de5..85029bb627e0 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -236,7 +236,9 @@ def _seeds_cooperative_exchange_1_wait_future(minibatch): else: minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets} counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) - for i, offsets in enumerate(minibatch._seeds_offsets.values()): + for i, (_, offsets) in enumerate( + sorted(minibatch._seeds_offsets.items()) + ): counts_sent[ torch.arange(i, world_size * num_ntypes, num_ntypes) ] = offsets.diff() @@ -261,7 +263,7 @@ def _seeds_cooperative_exchange_2(minibatch): seeds_received = {} counts_sent = {} counts_received = {} - for i, (ntype, typed_seeds) in enumerate(seeds.items()): + for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())): idx = torch.arange(i, world_size * num_ntypes, num_ntypes) typed_counts_sent = minibatch._counts_sent[idx].tolist() typed_counts_received = minibatch._counts_received[idx].tolist() diff --git a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py index f88e011f4385..a2895191dd0e 100644 --- a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py +++ b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py @@ -57,12 +57,12 @@ def test_rank_sort_and_unique_and_compact(dtype, rank): nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]] ) - unique, compacted, offsets = gb.unique_and_compact( - nodes_list1[:1], rank, WORLD_SIZE - ) + nodes = {str(i): [typed_seeds] for i, typed_seeds in enumerate(nodes_list1)} - nodes1, idx1, offsets1 = res1[0] + unique, compacted, offsets = gb.unique_and_compact(nodes, rank, WORLD_SIZE) - assert_equal(unique, nodes1) - assert_equal(compacted[0], idx1) - assert_equal(offsets, offsets1) + for i in nodes.keys(): + nodes1, idx1, offsets1 = res1[int(i)] + assert_equal(unique[i], nodes1) + assert_equal(compacted[i][0], idx1) + assert_equal(offsets[i], offsets1)