From bec6f17b97a8e29c2bae958fe34f01b8866079e0 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 11 Oct 2024 18:21:24 +0000 Subject: [PATCH 1/5] [GraphBolt][CUDA] Cooperative Minibatching hetero bug fixes. --- python/dgl/graphbolt/feature_fetcher.py | 6 +++--- python/dgl/graphbolt/impl/cooperative_conv.py | 11 ++++++----- python/dgl/graphbolt/impl/neighbor_sampler.py | 4 ++-- python/dgl/graphbolt/subgraph_sampler.py | 6 ++++-- .../impl/test_cooperative_minibatching_utils.py | 14 +++++++------- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index cf9d5f4104c2..e3db4df38bd2 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -166,13 +166,13 @@ def _cooperative_exchange(self, data): self.node_feature_keys, Dict ) or isinstance(self.edge_feature_keys, Dict) if is_heterogeneous: - node_features = {key: {} for key, _ in data.node_features.keys()} - for (key, ntype), feature in data.node_features.items(): + node_features = {key: {} for _, key in data.node_features.keys()} + for (ntype, key), feature in data.node_features.items(): node_features[key][ntype] = feature for key, feature in node_features.items(): new_feature = CooperativeConvFunction.apply(subgraph, feature) for ntype, tensor in new_feature.items(): - data.node_features[(key, ntype)] = tensor + data.node_features[(ntype, key)] = tensor else: for key in data.node_features: feature = data.node_features[key] diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 22c5ae316c71..0daa106c17a6 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -44,17 +44,18 @@ def forward( outs = {} for ntype, typed_tensor in convert_to_hetero(tensor).items(): out = typed_tensor.new_empty( - (sum(counts_sent[ntype]),) + typed_tensor.shape[1:] + (sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:], + requires_grad=typed_tensor.requires_grad, ) all_to_all( - torch.split(out, counts_sent[ntype]), + torch.split(out, counts_sent.get(ntype, 0)), torch.split( - typed_tensor[seed_inverse_ids[ntype]], - counts_received[ntype], + typed_tensor[seed_inverse_ids.get(ntype, slice(None))], + counts_received.get(ntype, 0), ), ) outs[ntype] = out - return revert_to_homo(out) + return revert_to_homo(outs) @staticmethod def backward( diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 7ddba6d7ccac..059fca7ef4c0 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -561,7 +561,7 @@ def _seeds_cooperative_exchange_1(minibatch): seeds_offsets = {"_N": seeds_offsets} num_ntypes = len(seeds_offsets) counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) - for i, offsets in enumerate(seeds_offsets.values()): + for i, (_, offsets) in enumerate(sorted(seeds_offsets.items())): counts_sent[ torch.arange(i, world_size * num_ntypes, num_ntypes) ] = offsets.diff() @@ -589,7 +589,7 @@ def _seeds_cooperative_exchange_2(minibatch): seeds_received = {} counts_sent = {} counts_received = {} - for i, (ntype, typed_seeds) in enumerate(seeds.items()): + for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())): idx = torch.arange(i, world_size * num_ntypes, num_ntypes) typed_counts_sent = subgraph._counts_sent[idx].tolist() typed_counts_received = subgraph._counts_received[idx].tolist() diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 88fc9c124de5..85029bb627e0 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -236,7 +236,9 @@ def _seeds_cooperative_exchange_1_wait_future(minibatch): else: minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets} counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) - for i, offsets in enumerate(minibatch._seeds_offsets.values()): + for i, (_, offsets) in enumerate( + sorted(minibatch._seeds_offsets.items()) + ): counts_sent[ torch.arange(i, world_size * num_ntypes, num_ntypes) ] = offsets.diff() @@ -261,7 +263,7 @@ def _seeds_cooperative_exchange_2(minibatch): seeds_received = {} counts_sent = {} counts_received = {} - for i, (ntype, typed_seeds) in enumerate(seeds.items()): + for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())): idx = torch.arange(i, world_size * num_ntypes, num_ntypes) typed_counts_sent = minibatch._counts_sent[idx].tolist() typed_counts_received = minibatch._counts_received[idx].tolist() diff --git a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py index f88e011f4385..a2895191dd0e 100644 --- a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py +++ b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py @@ -57,12 +57,12 @@ def test_rank_sort_and_unique_and_compact(dtype, rank): nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]] ) - unique, compacted, offsets = gb.unique_and_compact( - nodes_list1[:1], rank, WORLD_SIZE - ) + nodes = {str(i): [typed_seeds] for i, typed_seeds in enumerate(nodes_list1)} - nodes1, idx1, offsets1 = res1[0] + unique, compacted, offsets = gb.unique_and_compact(nodes, rank, WORLD_SIZE) - assert_equal(unique, nodes1) - assert_equal(compacted[0], idx1) - assert_equal(offsets, offsets1) + for i in nodes.keys(): + nodes1, idx1, offsets1 = res1[int(i)] + assert_equal(unique[i], nodes1) + assert_equal(compacted[i][0], idx1) + assert_equal(offsets[i], offsets1) From 01b77dcb6ad30a818f4c4645200cd600e1e591c2 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 11 Oct 2024 19:00:19 +0000 Subject: [PATCH 2/5] more fixes. --- python/dgl/graphbolt/feature_fetcher.py | 4 ++-- python/dgl/graphbolt/impl/cooperative_conv.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index e3db4df38bd2..832d00181588 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -169,12 +169,12 @@ def _cooperative_exchange(self, data): node_features = {key: {} for _, key in data.node_features.keys()} for (ntype, key), feature in data.node_features.items(): node_features[key][ntype] = feature - for key, feature in node_features.items(): + for key, feature in sorted(node_features.items()): new_feature = CooperativeConvFunction.apply(subgraph, feature) for ntype, tensor in new_feature.items(): data.node_features[(ntype, key)] = tensor else: - for key in data.node_features: + for key in sorted(data.node_features): feature = data.node_features[key] new_feature = CooperativeConvFunction.apply(subgraph, feature) data.node_features[key] = new_feature diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 0daa106c17a6..c2d40d94d186 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -42,7 +42,7 @@ def forward( seed_sizes, ) outs = {} - for ntype, typed_tensor in convert_to_hetero(tensor).items(): + for ntype, typed_tensor in sorted(convert_to_hetero(tensor).items()): out = typed_tensor.new_empty( (sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:], requires_grad=typed_tensor.requires_grad, @@ -70,7 +70,9 @@ def backward( ) = ctx.communication_variables delattr(ctx, "communication_variables") outs = {} - for ntype, typed_grad_output in convert_to_hetero(grad_output).items(): + for ntype, typed_grad_output in sorted( + convert_to_hetero(grad_output).items() + ): out = typed_grad_output.new_empty( (sum(counts_received[ntype]),) + typed_grad_output.shape[1:] ) From 20a35a7f451ec9d63a3f120352d8c17f94c7cf5e Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 11 Oct 2024 19:31:53 +0000 Subject: [PATCH 3/5] fix the last bug hopefully. --- python/dgl/graphbolt/impl/cooperative_conv.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index c2d40d94d186..133f8d0c9835 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -47,11 +47,12 @@ def forward( (sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:], requires_grad=typed_tensor.requires_grad, ) + default_splits = [0] * torch.distributed.get_world_size() all_to_all( - torch.split(out, counts_sent.get(ntype, 0)), + torch.split(out, counts_sent.get(ntype, default_splits)), torch.split( typed_tensor[seed_inverse_ids.get(ntype, slice(None))], - counts_received.get(ntype, 0), + counts_received.get(ntype, default_splits), ), ) outs[ntype] = out From c3e8e89efdbe491837694995876693ab14621b97 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 11 Oct 2024 20:30:28 +0000 Subject: [PATCH 4/5] extend test coverage. --- .../graphbolt/impl/test_cooperative_minibatching_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py index a2895191dd0e..4f2d3d6d04b9 100644 --- a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py +++ b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py @@ -37,13 +37,14 @@ def test_rank_sort_and_unique_and_compact(dtype, rank): assert_equal(offsets1, offsets2) assert offsets1.is_pinned() and offsets2.is_pinned() - res3 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE) + # Test with the reverse order of ntypes. See if results are equivalent. + res3 = torch.ops.graphbolt.rank_sort(nodes_list1[::-1], rank, WORLD_SIZE) # This function is deterministic. Call with identical arguments and check. - for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, res3): + for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, reversed(res3)): assert_equal(nodes1, nodes3) assert_equal(idx1, idx3) - assert_equal(offsets1, offsets3) + assert_equal(offsets1.diff(), offsets3.diff()) # The dependency on the rank argument is simply a permutation. res4 = torch.ops.graphbolt.rank_sort(nodes_list1, 0, WORLD_SIZE) From 61ca2797848f5d146fc1d1320f9b25e96bb46551 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 11 Oct 2024 22:29:04 +0000 Subject: [PATCH 5/5] linting --- .../graphbolt/impl/test_cooperative_minibatching_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py index 4f2d3d6d04b9..3eb5bd591752 100644 --- a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py +++ b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py @@ -41,7 +41,9 @@ def test_rank_sort_and_unique_and_compact(dtype, rank): res3 = torch.ops.graphbolt.rank_sort(nodes_list1[::-1], rank, WORLD_SIZE) # This function is deterministic. Call with identical arguments and check. - for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, reversed(res3)): + for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip( + res1, reversed(res3) + ): assert_equal(nodes1, nodes3) assert_equal(idx1, idx3) assert_equal(offsets1.diff(), offsets3.diff())