Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] rank_sort_async for Cooperative Minibatching. #7805

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[GraphBolt][CUDA] rank_sort_async for Cooperative Minibatching.
  • Loading branch information
mfbalin committed Sep 19, 2024

Verified

This commit was signed with the committer’s verified signature.
m8rmclaren Hayden
commit 1c5dac155a163a38b4f650f4b6deb502d911a933
11 changes: 11 additions & 0 deletions graphbolt/src/cuda/cooperative_minibatching_utils.cu
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
#include <cub/cub.cuh>
#include <cuda/functional>

#include "../utils.h"
#include "./common.h"
#include "./cooperative_minibatching_utils.cuh"
#include "./cooperative_minibatching_utils.h"
@@ -144,5 +145,15 @@ std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
return results;
}

c10::intrusive_ptr<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
RankSortAsync(
const std::vector<torch::Tensor>& nodes_list, const int64_t rank,
const int64_t world_size) {
return async(
[=] { return RankSort(nodes_list, rank, world_size); },
utils::is_on_gpu(nodes_list.at(0)));
}

} // namespace cuda
} // namespace graphbolt
7 changes: 7 additions & 0 deletions graphbolt/src/cuda/cooperative_minibatching_utils.h
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
#define GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_H_

#include <ATen/cuda/CUDAEvent.h>
#include <graphbolt/async.h>
#include <torch/script.h>

namespace graphbolt {
@@ -83,6 +84,12 @@ std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
const std::vector<torch::Tensor>& nodes_list, int64_t rank,
int64_t world_size);

c10::intrusive_ptr<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
RankSortAsync(
const std::vector<torch::Tensor>& nodes_list, const int64_t rank,
const int64_t world_size);

} // namespace cuda
} // namespace graphbolt

8 changes: 8 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
@@ -59,6 +59,13 @@ TORCH_LIBRARY(graphbolt, m) {
&Future<std::vector<std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>::
Wait);
m.class_<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>(
"RankSortFuture")
.def(
"wait",
&Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);
m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(
"GpuGraphCacheQueryFuture")
.def(
@@ -198,6 +205,7 @@ TORCH_LIBRARY(graphbolt, m) {
#ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
m.def("rank_sort", &cuda::RankSort);
m.def("rank_sort_async", &cuda::RankSortAsync);
#endif
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base");
35 changes: 25 additions & 10 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
@@ -140,6 +140,9 @@ def __init__(
if cooperative:
datapipe = datapipe.transform(self._seeds_cooperative_exchange_1)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_1_wait_future
).buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_2)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_3)
@@ -193,19 +196,33 @@ def _wait_preprocess_future(minibatch, cooperative: bool):
return minibatch

@staticmethod
def _seeds_cooperative_exchange_1(minibatch, group=None):
rank = thd.get_rank(group)
world_size = thd.get_world_size(group)
def _seeds_cooperative_exchange_1(minibatch):
rank = thd.get_rank()
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
if minibatch._seeds_offsets is None:
seeds_list = list(seeds.values())
result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
assert minibatch.compacted_seeds is None
seeds_list = list(seeds.values())
minibatch._rank_sort_future = torch.ops.graphbolt.rank_sort_async(
seeds_list, rank, world_size
)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_1_wait_future(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
num_ntypes = len(seeds.keys())
if minibatch._seeds_offsets is None:
result = minibatch._rank_sort_future.wait()
delattr(minibatch, "_rank_sort_future")
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
num_ntypes = len(seeds.keys())
for i, (
seed_type,
(typed_sorted_seeds, typed_index, typed_offsets),
@@ -229,16 +246,15 @@ def _seeds_cooperative_exchange_1(minibatch, group=None):
minibatch._counts_future = all_to_all(
counts_received.split(num_ntypes),
counts_sent.split(num_ntypes),
group=group,
async_op=True,
)
minibatch._counts_sent = counts_sent
minibatch._counts_received = counts_received
return minibatch

@staticmethod
def _seeds_cooperative_exchange_2(minibatch, group=None):
world_size = thd.get_world_size(group)
def _seeds_cooperative_exchange_2(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
minibatch._counts_future.wait()
delattr(minibatch, "_counts_future")
@@ -256,7 +272,6 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
all_to_all(
typed_seeds_received.split(typed_counts_received),
typed_seeds.split(typed_counts_sent),
group,
)
seeds_received[ntype] = typed_seeds_received
counts_sent[ntype] = typed_counts_sent
Loading