diff --git a/examples/graphbolt/link_prediction.py b/examples/graphbolt/link_prediction.py index 60d48b57fc2d..cdd7440901f1 100644 --- a/examples/graphbolt/link_prediction.py +++ b/examples/graphbolt/link_prediction.py @@ -202,8 +202,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True): # the negative samples. ############################################################################ if is_train and args.exclude_edges: - datapipe = datapipe.transform( - partial(gb.exclude_seed_edges, include_reverse_edges=True) + datapipe = datapipe.exclude_seed_edges( + include_reverse_edges=True, + asynchronous=args.storage_device != "cpu", ) ############################################################################ diff --git a/examples/graphbolt/pyg/link_prediction.py b/examples/graphbolt/pyg/link_prediction.py index 4c2b05fd410e..5dc782d9ff4d 100644 --- a/examples/graphbolt/pyg/link_prediction.py +++ b/examples/graphbolt/pyg/link_prediction.py @@ -163,8 +163,9 @@ def create_dataloader( asynchronous=args.graph_device != "cpu", ) if job == "train" and args.exclude_edges: - datapipe = datapipe.transform( - partial(gb.exclude_seed_edges, include_reverse_edges=True) + datapipe = datapipe.exclude_seed_edges( + include_reverse_edges=True, + asynchronous=args.graph_device != "cpu", ) # Copy the data to the specified device. if args.feature_device != "cpu" and need_copy: diff --git a/graphbolt/include/graphbolt/cuda_ops.h b/graphbolt/include/graphbolt/cuda_ops.h index e7f2f60721b4..91cd1a10c652 100644 --- a/graphbolt/include/graphbolt/cuda_ops.h +++ b/graphbolt/include/graphbolt/cuda_ops.h @@ -79,10 +79,22 @@ Sort(torch::Tensor input, int num_bits = 0); * @return * A boolean tensor of the same shape as elements that is True for elements * in test_elements and False otherwise. - * */ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements); +/** + * @brief Returns the indexes of the nonzero elements in the given boolean mask + * if logical_not is false. Otherwise, returns the indexes of the zero elements + * instead. + * + * @param mask Input boolean mask. + * @param logical_not Whether mask should be treated as ~mask. + * + * @return An int64_t tensor of the same shape as mask containing the indexes + * of the selected elements. + */ +torch::Tensor Nonzero(torch::Tensor mask, bool logical_not); + /** * @brief Select columns for a sparse matrix in a CSC format according to nodes * tensor. diff --git a/graphbolt/include/graphbolt/isin.h b/graphbolt/include/graphbolt/isin.h index 0b472858ecf6..4e52b429988f 100644 --- a/graphbolt/include/graphbolt/isin.h +++ b/graphbolt/include/graphbolt/isin.h @@ -7,6 +7,7 @@ #ifndef GRAPHBOLT_ISIN_H_ #define GRAPHBOLT_ISIN_H_ +#include #include namespace graphbolt { @@ -25,11 +26,27 @@ namespace sampling { * @return * A boolean tensor of the same shape as elements that is True for elements * in test_elements and False otherwise. - * */ torch::Tensor IsIn( const torch::Tensor& elements, const torch::Tensor& test_elements); +/** + * @brief Tests if each element of elements is not in test_elements. Returns an + * int64_t tensor of the same shape as elements containing the indexes of the + * elements not found in test_elements. + * + * @param elements Input elements + * @param test_elements Values against which to test for each input element. + * + * @return An int64_t tensor of the same shape as elements containing indexes of + * elements not found in test_elements. + */ +torch::Tensor IsNotInIndex( + const torch::Tensor& elements, const torch::Tensor& test_elements); + +c10::intrusive_ptr> IsNotInIndexAsync( + const torch::Tensor& elements, const torch::Tensor& test_elements); + } // namespace sampling } // namespace graphbolt diff --git a/graphbolt/src/cuda/isin.cu b/graphbolt/src/cuda/isin.cu index af773934415e..aa2c724d2535 100644 --- a/graphbolt/src/cuda/isin.cu +++ b/graphbolt/src/cuda/isin.cu @@ -20,6 +20,8 @@ #include #include +#include + #include "./common.h" namespace graphbolt { @@ -42,5 +44,25 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) { return result; } +torch::Tensor Nonzero(torch::Tensor mask, bool logical_not) { + thrust::counting_iterator iota(0); + auto result = torch::empty_like(mask, torch::kInt64); + auto mask_ptr = mask.data_ptr(); + auto result_ptr = result.data_ptr(); + auto allocator = cuda::GetAllocator(); + auto num_copied = allocator.AllocateStorage(1); + if (logical_not) { + CUB_CALL( + DeviceSelect::FlaggedIf, iota, mask_ptr, result_ptr, num_copied.get(), + mask.numel(), thrust::logical_not{}); + } else { + CUB_CALL( + DeviceSelect::Flagged, iota, mask_ptr, result_ptr, num_copied.get(), + mask.numel()); + } + cuda::CopyScalar num_copied_cpu(num_copied.get()); + return result.slice(0, 0, static_cast(num_copied_cpu)); +} + } // namespace ops } // namespace graphbolt diff --git a/graphbolt/src/isin.cc b/graphbolt/src/isin.cc index c41b839b1651..76cbf1f8d0f1 100644 --- a/graphbolt/src/isin.cc +++ b/graphbolt/src/isin.cc @@ -56,5 +56,22 @@ torch::Tensor IsIn( return IsInCPU(elements, test_elements); } } + +torch::Tensor IsNotInIndex( + const torch::Tensor& elements, const torch::Tensor& test_elements) { + auto mask = IsIn(elements, test_elements); + if (utils::is_on_gpu(mask)) { + GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( + c10::DeviceType::CUDA, "NonzeroOperation", + { return ops::Nonzero(mask, true); }); + } + return torch::nonzero(torch::logical_not(mask)).squeeze(1); +} + +c10::intrusive_ptr> IsNotInIndexAsync( + const torch::Tensor& elements, const torch::Tensor& test_elements) { + return async([=] { return IsNotInIndex(elements, test_elements); }); +} + } // namespace sampling } // namespace graphbolt diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 4df395b0f904..ea2b543761cf 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -181,6 +181,8 @@ TORCH_LIBRARY(graphbolt, m) { m.def("unique_and_compact_batched", &UniqueAndCompactBatched); m.def("unique_and_compact_batched_async", &UniqueAndCompactBatchedAsync); m.def("isin", &IsIn); + m.def("is_not_in_index", &IsNotInIndex); + m.def("is_not_in_index_async", &IsNotInIndexAsync); m.def("index_select", &ops::IndexSelect); m.def("index_select_async", &ops::IndexSelectAsync); m.def("scatter_async", &ops::ScatterAsync); diff --git a/python/dgl/graphbolt/external_utils.py b/python/dgl/graphbolt/external_utils.py index 98ddc310d213..89737fbe6dc3 100644 --- a/python/dgl/graphbolt/external_utils.py +++ b/python/dgl/graphbolt/external_utils.py @@ -1,10 +1,60 @@ """Utility functions for external use.""" - +from functools import partial from typing import Dict, Union import torch +from torch.utils.data import functional_datapipe + from .minibatch import MiniBatch +from .minibatch_transformer import MiniBatchTransformer + + +@functional_datapipe("exclude_seed_edges") +class SeedEdgesExcluder(MiniBatchTransformer): + """A mini-batch transformer used to manipulate mini-batch. + + Functional name: :obj:`transform`. + + Parameters + ---------- + datapipe : DataPipe + The datapipe. + include_reverse_edges : bool + Whether reverse edges should be excluded as well. Default is False. + reverse_etypes_mapping : Dict[str, str] = None + The mapping from the original edge types to their reverse edge types. + asynchronous: bool + Boolean indicating whether edge exclusion stages should run on + background threads to hide the latency of CPU GPU synchronization. + Should be enabled only when sampling on the GPU. + """ + + def __init__( + self, + datapipe, + include_reverse_edges: bool = False, + reverse_etypes_mapping: Dict[str, str] = None, + asynchronous=False, + ): + exclude_seed_edges_fn = partial( + exclude_seed_edges, + include_reverse_edges=include_reverse_edges, + reverse_etypes_mapping=reverse_etypes_mapping, + async_op=asynchronous, + ) + datapipe = datapipe.transform(exclude_seed_edges_fn) + if asynchronous: + datapipe = datapipe.buffer() + datapipe = datapipe.transform(self._wait_for_sampled_subgraphs) + super().__init__(datapipe) + + @staticmethod + def _wait_for_sampled_subgraphs(minibatch): + minibatch.sampled_subgraphs = [ + subgraph.wait() for subgraph in minibatch.sampled_subgraphs + ] + return minibatch def add_reverse_edges( @@ -79,6 +129,7 @@ def exclude_seed_edges( minibatch: MiniBatch, include_reverse_edges: bool = False, reverse_etypes_mapping: Dict[str, str] = None, + async_op: bool = False, ): """ Exclude seed edges with or without their reverse edges from the sampled @@ -88,8 +139,13 @@ def exclude_seed_edges( ---------- minibatch : MiniBatch The minibatch. + include_reverse_edges : bool + Whether reverse edges should be excluded as well. Default is False. reverse_etypes_mapping : Dict[str, str] = None The mapping from the original edge types to their reverse edge types. + async_op: bool + Boolean indicating whether the call is asynchronous. If so, the result + can be obtained by calling wait on the modified sampled_subgraphs. """ edges_to_exclude = minibatch.seeds if include_reverse_edges: @@ -97,7 +153,7 @@ def exclude_seed_edges( edges_to_exclude, reverse_etypes_mapping ) minibatch.sampled_subgraphs = [ - subgraph.exclude_edges(edges_to_exclude) + subgraph.exclude_edges(edges_to_exclude, async_op=async_op) for subgraph in minibatch.sampled_subgraphs ] return minibatch diff --git a/python/dgl/graphbolt/sampled_subgraph.py b/python/dgl/graphbolt/sampled_subgraph.py index 8bff77de90b3..bcbd8a2004a1 100644 --- a/python/dgl/graphbolt/sampled_subgraph.py +++ b/python/dgl/graphbolt/sampled_subgraph.py @@ -20,6 +20,27 @@ __all__ = ["SampledSubgraph"] +class _ExcludeEdgesWaiter: + def __init__(self, sampled_subgraph, index): + self.sampled_subgraph = sampled_subgraph + self.index = index + + def wait(self): + """Returns the stored value when invoked.""" + sampled_subgraph = self.sampled_subgraph + index = self.index + # Ensure there is no memory leak. + self.sampled_subgraph = self.index = None + + if isinstance(index, dict): + for k in list(index.keys()): + index[k] = index[k].wait() + else: + index = index.wait() + + return type(sampled_subgraph)(*_slice_subgraph(sampled_subgraph, index)) + + class PyGLayerData(NamedTuple): """A named tuple class to represent homogenous inputs to a PyG model layer. The fields are x (input features), edge_index and size @@ -142,6 +163,7 @@ def exclude_edges( torch.Tensor, ], assume_num_node_within_int32: bool = True, + async_op: bool = False, ): r"""Exclude edges from the sampled subgraph. @@ -163,6 +185,9 @@ def exclude_edges( If True, assumes the value of node IDs in the provided `edges` fall within the int32 range, which can significantly enhance computation speed. Default: True + async_op: bool + Boolean indicating whether the call is asynchronous. If so, the + result can be obtained by calling wait on the returned future. Returns ------- @@ -222,9 +247,8 @@ def exclude_edges( self.original_column_node_ids, ) index = _exclude_homo_edges( - reverse_edges, edges, assume_num_node_within_int32 + reverse_edges, edges, assume_num_node_within_int32, async_op ) - return calling_class(*_slice_subgraph(self, index)) else: index = {} for etype, pair in self.sampled_csc.items(): @@ -252,7 +276,11 @@ def exclude_edges( reverse_edges, edges[etype], assume_num_node_within_int32, + async_op, ) + if async_op: + return _ExcludeEdgesWaiter(self, index) + else: return calling_class(*_slice_subgraph(self, index)) def to_pyg( @@ -367,6 +395,7 @@ def _exclude_homo_edges( edges: Tuple[torch.Tensor, torch.Tensor], edges_to_exclude: torch.Tensor, assume_num_node_within_int32: bool, + async_op: bool, ): """Return the indices of edges to be included.""" if assume_num_node_within_int32: @@ -381,8 +410,11 @@ def _exclude_homo_edges( raise NotImplementedError( "Values out of range int32 are not supported yet" ) - mask = ~isin(val, val_to_exclude) - return torch.nonzero(mask, as_tuple=True)[0] + if async_op: + return torch.ops.graphbolt.is_not_in_index_async(val, val_to_exclude) + else: + mask = ~isin(val, val_to_exclude) + return torch.nonzero(mask, as_tuple=True)[0] def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): diff --git a/tests/python/pytorch/graphbolt/test_utils.py b/tests/python/pytorch/graphbolt/test_utils.py index 3795c791ebde..149942f98ea9 100644 --- a/tests/python/pytorch/graphbolt/test_utils.py +++ b/tests/python/pytorch/graphbolt/test_utils.py @@ -72,7 +72,8 @@ def test_add_reverse_edges_hetero(): F._default_context_str == "gpu", reason="Fails due to different result on the GPU.", ) -def test_exclude_seed_edges_homo_cpu(): +@pytest.mark.parametrize("use_datapipe", [False, True]) +def test_exclude_seed_edges_homo_cpu(use_datapipe): graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4])) graph = gb.from_dglgraph(graph, True).to(F.ctx()) items = torch.LongTensor([[0, 3], [4, 4]]) @@ -83,7 +84,10 @@ def test_exclude_seed_edges_homo_cpu(): fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] sampler = gb.NeighborSampler datapipe = sampler(datapipe, graph, fanouts) - datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) + if use_datapipe: + datapipe = datapipe.exclude_seed_edges() + else: + datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) original_row_node_ids = [ torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()), torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), @@ -121,7 +125,9 @@ def test_exclude_seed_edges_homo_cpu(): F._default_context_str == "cpu", reason="Fails due to different result on the CPU.", ) -def test_exclude_seed_edges_gpu(): +@pytest.mark.parametrize("use_datapipe", [False, True]) +@pytest.mark.parametrize("async_op", [False, True]) +def test_exclude_seed_edges_gpu(use_datapipe, async_op): graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4])) graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx()) items = torch.LongTensor([[0, 3], [4, 4]]) @@ -137,7 +143,12 @@ def test_exclude_seed_edges_gpu(): fanouts, deduplicate=True, ) - datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) + if use_datapipe: + datapipe = datapipe.exclude_seed_edges(asynchronous=async_op) + else: + datapipe = datapipe.transform( + partial(gb.exclude_seed_edges, async_op=async_op) + ) if torch.cuda.get_device_capability()[0] < 7: original_row_node_ids = [ torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()), @@ -174,6 +185,8 @@ def test_exclude_seed_edges_gpu(): ] for data in datapipe: for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + if async_op and not use_datapipe: + sampled_subgraph = sampled_subgraph.wait() assert torch.equal( sampled_subgraph.original_row_node_ids, original_row_node_ids[step],