Skip to content

Commit

Permalink
[GraphBolt][CUDA] Eliminate synchronization from exclude edges. (#7757)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Aug 29, 2024
1 parent 03e83ac commit d6cf415
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 16 deletions.
5 changes: 3 additions & 2 deletions examples/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# the negative samples.
############################################################################
if is_train and args.exclude_edges:
datapipe = datapipe.transform(
partial(gb.exclude_seed_edges, include_reverse_edges=True)
datapipe = datapipe.exclude_seed_edges(
include_reverse_edges=True,
asynchronous=args.storage_device != "cpu",
)

############################################################################
Expand Down
5 changes: 3 additions & 2 deletions examples/graphbolt/pyg/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ def create_dataloader(
asynchronous=args.graph_device != "cpu",
)
if job == "train" and args.exclude_edges:
datapipe = datapipe.transform(
partial(gb.exclude_seed_edges, include_reverse_edges=True)
datapipe = datapipe.exclude_seed_edges(
include_reverse_edges=True,
asynchronous=args.graph_device != "cpu",
)
# Copy the data to the specified device.
if args.feature_device != "cpu" and need_copy:
Expand Down
14 changes: 13 additions & 1 deletion graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,22 @@ Sort(torch::Tensor input, int num_bits = 0);
* @return
* A boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise.
*
*/
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);

/**
* @brief Returns the indexes of the nonzero elements in the given boolean mask
* if logical_not is false. Otherwise, returns the indexes of the zero elements
* instead.
*
* @param mask Input boolean mask.
* @param logical_not Whether mask should be treated as ~mask.
*
* @return An int64_t tensor of the same shape as mask containing the indexes
* of the selected elements.
*/
torch::Tensor Nonzero(torch::Tensor mask, bool logical_not);

/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
Expand Down
19 changes: 18 additions & 1 deletion graphbolt/include/graphbolt/isin.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef GRAPHBOLT_ISIN_H_
#define GRAPHBOLT_ISIN_H_

#include <graphbolt/async.h>
#include <torch/torch.h>

namespace graphbolt {
Expand All @@ -25,11 +26,27 @@ namespace sampling {
* @return
* A boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise.
*
*/
torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements);

/**
* @brief Tests if each element of elements is not in test_elements. Returns an
* int64_t tensor of the same shape as elements containing the indexes of the
* elements not found in test_elements.
*
* @param elements Input elements
* @param test_elements Values against which to test for each input element.
*
* @return An int64_t tensor of the same shape as elements containing indexes of
* elements not found in test_elements.
*/
torch::Tensor IsNotInIndex(
const torch::Tensor& elements, const torch::Tensor& test_elements);

c10::intrusive_ptr<Future<torch::Tensor>> IsNotInIndexAsync(
const torch::Tensor& elements, const torch::Tensor& test_elements);

} // namespace sampling
} // namespace graphbolt

Expand Down
22 changes: 22 additions & 0 deletions graphbolt/src/cuda/isin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>

#include <cub/cub.cuh>

#include "./common.h"

namespace graphbolt {
Expand All @@ -42,5 +44,25 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
return result;
}

torch::Tensor Nonzero(torch::Tensor mask, bool logical_not) {
thrust::counting_iterator<int64_t> iota(0);
auto result = torch::empty_like(mask, torch::kInt64);
auto mask_ptr = mask.data_ptr<bool>();
auto result_ptr = result.data_ptr<int64_t>();
auto allocator = cuda::GetAllocator();
auto num_copied = allocator.AllocateStorage<int64_t>(1);
if (logical_not) {
CUB_CALL(
DeviceSelect::FlaggedIf, iota, mask_ptr, result_ptr, num_copied.get(),
mask.numel(), thrust::logical_not<bool>{});
} else {
CUB_CALL(
DeviceSelect::Flagged, iota, mask_ptr, result_ptr, num_copied.get(),
mask.numel());
}
cuda::CopyScalar num_copied_cpu(num_copied.get());
return result.slice(0, 0, static_cast<int64_t>(num_copied_cpu));
}

} // namespace ops
} // namespace graphbolt
17 changes: 17 additions & 0 deletions graphbolt/src/isin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,22 @@ torch::Tensor IsIn(
return IsInCPU(elements, test_elements);
}
}

torch::Tensor IsNotInIndex(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
auto mask = IsIn(elements, test_elements);
if (utils::is_on_gpu(mask)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "NonzeroOperation",
{ return ops::Nonzero(mask, true); });
}
return torch::nonzero(torch::logical_not(mask)).squeeze(1);
}

c10::intrusive_ptr<Future<torch::Tensor>> IsNotInIndexAsync(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
return async([=] { return IsNotInIndex(elements, test_elements); });
}

} // namespace sampling
} // namespace graphbolt
2 changes: 2 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("unique_and_compact_batched", &UniqueAndCompactBatched);
m.def("unique_and_compact_batched_async", &UniqueAndCompactBatchedAsync);
m.def("isin", &IsIn);
m.def("is_not_in_index", &IsNotInIndex);
m.def("is_not_in_index_async", &IsNotInIndexAsync);
m.def("index_select", &ops::IndexSelect);
m.def("index_select_async", &ops::IndexSelectAsync);
m.def("scatter_async", &ops::ScatterAsync);
Expand Down
60 changes: 58 additions & 2 deletions python/dgl/graphbolt/external_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,60 @@
"""Utility functions for external use."""

from functools import partial
from typing import Dict, Union

import torch

from torch.utils.data import functional_datapipe

from .minibatch import MiniBatch
from .minibatch_transformer import MiniBatchTransformer


@functional_datapipe("exclude_seed_edges")
class SeedEdgesExcluder(MiniBatchTransformer):
"""A mini-batch transformer used to manipulate mini-batch.
Functional name: :obj:`transform`.
Parameters
----------
datapipe : DataPipe
The datapipe.
include_reverse_edges : bool
Whether reverse edges should be excluded as well. Default is False.
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
asynchronous: bool
Boolean indicating whether edge exclusion stages should run on
background threads to hide the latency of CPU GPU synchronization.
Should be enabled only when sampling on the GPU.
"""

def __init__(
self,
datapipe,
include_reverse_edges: bool = False,
reverse_etypes_mapping: Dict[str, str] = None,
asynchronous=False,
):
exclude_seed_edges_fn = partial(
exclude_seed_edges,
include_reverse_edges=include_reverse_edges,
reverse_etypes_mapping=reverse_etypes_mapping,
async_op=asynchronous,
)
datapipe = datapipe.transform(exclude_seed_edges_fn)
if asynchronous:
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._wait_for_sampled_subgraphs)
super().__init__(datapipe)

@staticmethod
def _wait_for_sampled_subgraphs(minibatch):
minibatch.sampled_subgraphs = [
subgraph.wait() for subgraph in minibatch.sampled_subgraphs
]
return minibatch


def add_reverse_edges(
Expand Down Expand Up @@ -79,6 +129,7 @@ def exclude_seed_edges(
minibatch: MiniBatch,
include_reverse_edges: bool = False,
reverse_etypes_mapping: Dict[str, str] = None,
async_op: bool = False,
):
"""
Exclude seed edges with or without their reverse edges from the sampled
Expand All @@ -88,16 +139,21 @@ def exclude_seed_edges(
----------
minibatch : MiniBatch
The minibatch.
include_reverse_edges : bool
Whether reverse edges should be excluded as well. Default is False.
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
async_op: bool
Boolean indicating whether the call is asynchronous. If so, the result
can be obtained by calling wait on the modified sampled_subgraphs.
"""
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
edges_to_exclude, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude)
subgraph.exclude_edges(edges_to_exclude, async_op=async_op)
for subgraph in minibatch.sampled_subgraphs
]
return minibatch
40 changes: 36 additions & 4 deletions python/dgl/graphbolt/sampled_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,27 @@
__all__ = ["SampledSubgraph"]


class _ExcludeEdgesWaiter:
def __init__(self, sampled_subgraph, index):
self.sampled_subgraph = sampled_subgraph
self.index = index

def wait(self):
"""Returns the stored value when invoked."""
sampled_subgraph = self.sampled_subgraph
index = self.index
# Ensure there is no memory leak.
self.sampled_subgraph = self.index = None

if isinstance(index, dict):
for k in list(index.keys()):
index[k] = index[k].wait()
else:
index = index.wait()

return type(sampled_subgraph)(*_slice_subgraph(sampled_subgraph, index))


class PyGLayerData(NamedTuple):
"""A named tuple class to represent homogenous inputs to a PyG model layer.
The fields are x (input features), edge_index and size
Expand Down Expand Up @@ -142,6 +163,7 @@ def exclude_edges(
torch.Tensor,
],
assume_num_node_within_int32: bool = True,
async_op: bool = False,
):
r"""Exclude edges from the sampled subgraph.
Expand All @@ -163,6 +185,9 @@ def exclude_edges(
If True, assumes the value of node IDs in the provided `edges` fall
within the int32 range, which can significantly enhance computation
speed. Default: True
async_op: bool
Boolean indicating whether the call is asynchronous. If so, the
result can be obtained by calling wait on the returned future.
Returns
-------
Expand Down Expand Up @@ -222,9 +247,8 @@ def exclude_edges(
self.original_column_node_ids,
)
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
reverse_edges, edges, assume_num_node_within_int32, async_op
)
return calling_class(*_slice_subgraph(self, index))
else:
index = {}
for etype, pair in self.sampled_csc.items():
Expand Down Expand Up @@ -252,7 +276,11 @@ def exclude_edges(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
async_op,
)
if async_op:
return _ExcludeEdgesWaiter(self, index)
else:
return calling_class(*_slice_subgraph(self, index))

def to_pyg(
Expand Down Expand Up @@ -367,6 +395,7 @@ def _exclude_homo_edges(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: torch.Tensor,
assume_num_node_within_int32: bool,
async_op: bool,
):
"""Return the indices of edges to be included."""
if assume_num_node_within_int32:
Expand All @@ -381,8 +410,11 @@ def _exclude_homo_edges(
raise NotImplementedError(
"Values out of range int32 are not supported yet"
)
mask = ~isin(val, val_to_exclude)
return torch.nonzero(mask, as_tuple=True)[0]
if async_op:
return torch.ops.graphbolt.is_not_in_index_async(val, val_to_exclude)
else:
mask = ~isin(val, val_to_exclude)
return torch.nonzero(mask, as_tuple=True)[0]


def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
Expand Down
21 changes: 17 additions & 4 deletions tests/python/pytorch/graphbolt/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def test_add_reverse_edges_hetero():
F._default_context_str == "gpu",
reason="Fails due to different result on the GPU.",
)
def test_exclude_seed_edges_homo_cpu():
@pytest.mark.parametrize("use_datapipe", [False, True])
def test_exclude_seed_edges_homo_cpu(use_datapipe):
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx())
items = torch.LongTensor([[0, 3], [4, 4]])
Expand All @@ -83,7 +84,10 @@ def test_exclude_seed_edges_homo_cpu():
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = gb.NeighborSampler
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
if use_datapipe:
datapipe = datapipe.exclude_seed_edges()
else:
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
Expand Down Expand Up @@ -121,7 +125,9 @@ def test_exclude_seed_edges_homo_cpu():
F._default_context_str == "cpu",
reason="Fails due to different result on the CPU.",
)
def test_exclude_seed_edges_gpu():
@pytest.mark.parametrize("use_datapipe", [False, True])
@pytest.mark.parametrize("async_op", [False, True])
def test_exclude_seed_edges_gpu(use_datapipe, async_op):
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
items = torch.LongTensor([[0, 3], [4, 4]])
Expand All @@ -137,7 +143,12 @@ def test_exclude_seed_edges_gpu():
fanouts,
deduplicate=True,
)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
if use_datapipe:
datapipe = datapipe.exclude_seed_edges(asynchronous=async_op)
else:
datapipe = datapipe.transform(
partial(gb.exclude_seed_edges, async_op=async_op)
)
if torch.cuda.get_device_capability()[0] < 7:
original_row_node_ids = [
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
Expand Down Expand Up @@ -174,6 +185,8 @@ def test_exclude_seed_edges_gpu():
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
if async_op and not use_datapipe:
sampled_subgraph = sampled_subgraph.wait()
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
Expand Down

0 comments on commit d6cf415

Please sign in to comment.