From 1c5dac155a163a38b4f650f4b6deb502d911a933 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 19 Sep 2024 20:06:08 +0000 Subject: [PATCH 1/3] [GraphBolt][CUDA] `rank_sort_async` for Cooperative Minibatching. --- .../cuda/cooperative_minibatching_utils.cu | 11 ++++++ .../src/cuda/cooperative_minibatching_utils.h | 7 ++++ graphbolt/src/python_binding.cc | 8 +++++ python/dgl/graphbolt/subgraph_sampler.py | 35 +++++++++++++------ 4 files changed, 51 insertions(+), 10 deletions(-) diff --git a/graphbolt/src/cuda/cooperative_minibatching_utils.cu b/graphbolt/src/cuda/cooperative_minibatching_utils.cu index e192de458650..583e58629449 100644 --- a/graphbolt/src/cuda/cooperative_minibatching_utils.cu +++ b/graphbolt/src/cuda/cooperative_minibatching_utils.cu @@ -25,6 +25,7 @@ #include #include +#include "../utils.h" #include "./common.h" #include "./cooperative_minibatching_utils.cuh" #include "./cooperative_minibatching_utils.h" @@ -144,5 +145,15 @@ std::vector> RankSort( return results; } +c10::intrusive_ptr>>> +RankSortAsync( + const std::vector& nodes_list, const int64_t rank, + const int64_t world_size) { + return async( + [=] { return RankSort(nodes_list, rank, world_size); }, + utils::is_on_gpu(nodes_list.at(0))); +} + } // namespace cuda } // namespace graphbolt diff --git a/graphbolt/src/cuda/cooperative_minibatching_utils.h b/graphbolt/src/cuda/cooperative_minibatching_utils.h index efe2b5b28bf1..c506c18d21c3 100644 --- a/graphbolt/src/cuda/cooperative_minibatching_utils.h +++ b/graphbolt/src/cuda/cooperative_minibatching_utils.h @@ -22,6 +22,7 @@ #define GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_H_ #include +#include #include namespace graphbolt { @@ -83,6 +84,12 @@ std::vector> RankSort( const std::vector& nodes_list, int64_t rank, int64_t world_size); +c10::intrusive_ptr>>> +RankSortAsync( + const std::vector& nodes_list, const int64_t rank, + const int64_t world_size); + } // namespace cuda } // namespace graphbolt diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 35ab345c56f9..ed5e6273e7ac 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -59,6 +59,13 @@ TORCH_LIBRARY(graphbolt, m) { &Future>>:: Wait); + m.class_>>>( + "RankSortFuture") + .def( + "wait", + &Future>>::Wait); m.class_>>( "GpuGraphCacheQueryFuture") .def( @@ -198,6 +205,7 @@ TORCH_LIBRARY(graphbolt, m) { #ifdef GRAPHBOLT_USE_CUDA m.def("set_max_uva_threads", &cuda::set_max_uva_threads); m.def("rank_sort", &cuda::RankSort); + m.def("rank_sort_async", &cuda::RankSortAsync); #endif #ifdef HAS_IMPL_ABSTRACT_PYSTUB m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base"); diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index dd5093ae5f69..9ec0af4a2b7a 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -140,6 +140,9 @@ def __init__( if cooperative: datapipe = datapipe.transform(self._seeds_cooperative_exchange_1) datapipe = datapipe.buffer() + datapipe = datapipe.transform( + self._seeds_cooperative_exchange_1_wait_future + ).buffer() datapipe = datapipe.transform(self._seeds_cooperative_exchange_2) datapipe = datapipe.buffer() datapipe = datapipe.transform(self._seeds_cooperative_exchange_3) @@ -193,19 +196,33 @@ def _wait_preprocess_future(minibatch, cooperative: bool): return minibatch @staticmethod - def _seeds_cooperative_exchange_1(minibatch, group=None): - rank = thd.get_rank(group) - world_size = thd.get_world_size(group) + def _seeds_cooperative_exchange_1(minibatch): + rank = thd.get_rank() + world_size = thd.get_world_size() seeds = minibatch._seed_nodes is_homogeneous = not isinstance(seeds, dict) if is_homogeneous: seeds = {"_N": seeds} if minibatch._seeds_offsets is None: - seeds_list = list(seeds.values()) - result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size) assert minibatch.compacted_seeds is None + seeds_list = list(seeds.values()) + minibatch._rank_sort_future = torch.ops.graphbolt.rank_sort_async( + seeds_list, rank, world_size + ) + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_1_wait_future(minibatch): + world_size = thd.get_world_size() + seeds = minibatch._seed_nodes + is_homogeneous = not isinstance(seeds, dict) + if is_homogeneous: + seeds = {"_N": seeds} + num_ntypes = len(seeds.keys()) + if minibatch._seeds_offsets is None: + result = minibatch._rank_sort_future.wait() + delattr(minibatch, "_rank_sort_future") sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {} - num_ntypes = len(seeds.keys()) for i, ( seed_type, (typed_sorted_seeds, typed_index, typed_offsets), @@ -229,7 +246,6 @@ def _seeds_cooperative_exchange_1(minibatch, group=None): minibatch._counts_future = all_to_all( counts_received.split(num_ntypes), counts_sent.split(num_ntypes), - group=group, async_op=True, ) minibatch._counts_sent = counts_sent @@ -237,8 +253,8 @@ def _seeds_cooperative_exchange_1(minibatch, group=None): return minibatch @staticmethod - def _seeds_cooperative_exchange_2(minibatch, group=None): - world_size = thd.get_world_size(group) + def _seeds_cooperative_exchange_2(minibatch): + world_size = thd.get_world_size() seeds = minibatch._seed_nodes minibatch._counts_future.wait() delattr(minibatch, "_counts_future") @@ -256,7 +272,6 @@ def _seeds_cooperative_exchange_2(minibatch, group=None): all_to_all( typed_seeds_received.split(typed_counts_received), typed_seeds.split(typed_counts_sent), - group, ) seeds_received[ntype] = typed_seeds_received counts_sent[ntype] = typed_counts_sent From 50cdf78b9c8fa0137c56e7f463bc5d14f13ee668 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 19 Sep 2024 20:14:25 +0000 Subject: [PATCH 2/3] fix the test. --- python/dgl/graphbolt/subgraph_sampler.py | 3 +-- tests/python/pytorch/graphbolt/test_dataloader.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 9ec0af4a2b7a..88fc9c124de5 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -205,9 +205,8 @@ def _seeds_cooperative_exchange_1(minibatch): seeds = {"_N": seeds} if minibatch._seeds_offsets is None: assert minibatch.compacted_seeds is None - seeds_list = list(seeds.values()) minibatch._rank_sort_future = torch.ops.graphbolt.rank_sort_async( - seeds_list, rank, world_size + list(seeds.values()), rank, world_size ) return minibatch diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index ee8f2b0cb9f5..e2f2664c6acb 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -159,7 +159,7 @@ def test_gpu_sampling_DataLoader( if asynchronous: bufferer_cnt += 2 * num_layers + 1 # _preprocess stage has 1. if cooperative: - bufferer_cnt += 3 * num_layers + bufferer_cnt += 3 * num_layers + 1 if enable_feature_fetch: bufferer_cnt += 1 # feature fetch has 1. if cooperative: From 38e8dd382d68ad56c7b51ae82411672464181094 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 19 Sep 2024 20:23:41 +0000 Subject: [PATCH 3/3] fix the test. --- tests/python/pytorch/graphbolt/test_dataloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index e2f2664c6acb..5843264516fc 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -159,12 +159,12 @@ def test_gpu_sampling_DataLoader( if asynchronous: bufferer_cnt += 2 * num_layers + 1 # _preprocess stage has 1. if cooperative: - bufferer_cnt += 3 * num_layers + 1 + bufferer_cnt += 3 * num_layers if enable_feature_fetch: bufferer_cnt += 1 # feature fetch has 1. if cooperative: - # _preprocess stage and each sampling layer. - bufferer_cnt += 3 + # _preprocess stage. + bufferer_cnt += 4 datapipe_graph = traverse_dps(dataloader) bufferers = find_dps( datapipe_graph,