From 17292b643f6980d4e03612705306e495f63d19d0 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Wed, 11 Sep 2024 17:46:50 -0400 Subject: [PATCH] [GraphBolt][CUDA] Cooperative Minibatching initial exchange. (#7795) --- python/dgl/graphbolt/internal/sample_utils.py | 1 + python/dgl/graphbolt/subgraph_sampler.py | 193 ++++++++++++++++-- .../graphbolt/impl/test_neighbor_sampler.py | 4 +- 3 files changed, 184 insertions(+), 14 deletions(-) diff --git a/python/dgl/graphbolt/internal/sample_utils.py b/python/dgl/graphbolt/internal/sample_utils.py index e88ea0193a55..f499694d6d72 100644 --- a/python/dgl/graphbolt/internal/sample_utils.py +++ b/python/dgl/graphbolt/internal/sample_utils.py @@ -349,6 +349,7 @@ def wait(self): if is_homogeneous: compacted_csc_formats = list(compacted_csc_formats.values())[0] unique_nodes = list(unique_nodes.values())[0] + offsets = list(offsets.values())[0] return unique_nodes, compacted_csc_formats, offsets diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index e26ec3336be6..88fdd3808714 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -5,10 +5,12 @@ from typing import Dict import torch +import torch.distributed as thd from torch.utils.data import functional_datapipe from .base import seed_type_str_to_ntypes from .internal import compact_temporal_nodes, unique_and_compact +from .minibatch import MiniBatch from .minibatch_transformer import MiniBatchTransformer __all__ = [ @@ -28,6 +30,25 @@ def wait(self): return result +def _shift(inputs: list, group=None): + cutoff = len(inputs) - thd.get_rank(group) + return inputs[cutoff:] + inputs[:cutoff] + + +def all_to_all(outputs, inputs, group=None, async_op=False): + """Wrapper for thd.all_to_all that permuted outputs and inputs before + calling it. The arguments have the permutation + `rank, ..., world_size - 1, 0, ..., rank - 1` and we make it + `0, world_size - 1` before calling `thd.all_to_all`.""" + shift_fn = partial(_shift, group=group) + return thd.all_to_all(shift_fn(outputs), shift_fn(inputs), group, async_op) + + +def _revert_to_homo(d: dict): + is_homogenous = len(d) == 1 and "_N" in d + return list(d.values())[0] if is_homogenous else d + + @functional_datapipe("sample_subgraph") class SubgraphSampler(MiniBatchTransformer): """A subgraph sampler used to sample a subgraph from a given set of nodes @@ -49,8 +70,8 @@ class SubgraphSampler(MiniBatchTransformer): Arguments to be passed into sampling_stages. kwargs : Keyword Arguments Arguments to be passed into sampling_stages. Preprocessing stage makes - use of the `asynchronous` parameter before it is passed to - the sampling stages. + use of the `asynchronous` and `cooperative` parameters before they are + passed to the sampling stages. """ def __init__( @@ -60,10 +81,22 @@ def __init__( **kwargs, ): async_op = kwargs.get("asynchronous", False) - preprocess_fn = partial(self._preprocess, async_op=async_op) + cooperative = kwargs.get("cooperative", False) + preprocess_fn = partial( + self._preprocess, cooperative=cooperative, async_op=async_op + ) datapipe = datapipe.transform(preprocess_fn) if async_op: - datapipe = datapipe.buffer().transform(self._wait_preprocess_future) + fn = partial(self._wait_preprocess_future, cooperative=cooperative) + datapipe = datapipe.buffer().transform(fn) + if cooperative: + datapipe = datapipe.transform(self._seeds_cooperative_exchange_1) + datapipe = datapipe.buffer() + datapipe = datapipe.transform(self._seeds_cooperative_exchange_2) + datapipe = datapipe.buffer() + datapipe = datapipe.transform(self._seeds_cooperative_exchange_3) + datapipe = datapipe.buffer() + datapipe = datapipe.transform(self._seeds_cooperative_exchange_4) datapipe = self.sampling_stages(datapipe, *args, **kwargs) datapipe = datapipe.transform(self._postprocess) super().__init__(datapipe) @@ -75,12 +108,16 @@ def _postprocess(minibatch): return minibatch @staticmethod - def _preprocess(minibatch, async_op: bool): + def _preprocess(minibatch, cooperative: bool, async_op: bool): if minibatch.seeds is None: raise ValueError( f"Invalid minibatch {minibatch}: `seeds` should have a value." ) - results = SubgraphSampler._seeds_preprocess(minibatch, async_op) + rank = thd.get_rank() if cooperative else 0 + world_size = thd.get_world_size() if cooperative else 1 + results = SubgraphSampler._seeds_preprocess( + minibatch, rank, world_size, async_op + ) if async_op: minibatch._preprocess_future = results else: @@ -88,17 +125,125 @@ def _preprocess(minibatch, async_op: bool): minibatch._seed_nodes, minibatch._seeds_timestamp, minibatch.compacted_seeds, + offsets, ) = results + if cooperative: + minibatch._seeds_offsets = offsets return minibatch @staticmethod - def _wait_preprocess_future(minibatch): + def _wait_preprocess_future(minibatch, cooperative: bool): ( minibatch._seed_nodes, minibatch._seeds_timestamp, minibatch.compacted_seeds, + offsets, ) = minibatch._preprocess_future.wait() delattr(minibatch, "_preprocess_future") + if cooperative: + minibatch._seeds_offsets = offsets + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_1(minibatch, group=None): + rank = thd.get_rank(group) + world_size = thd.get_world_size(group) + assert world_size > 1 + seeds = minibatch._seed_nodes + is_homogeneous = not isinstance(seeds, dict) + if is_homogeneous: + seeds = {"_N": seeds} + if minibatch._seeds_offsets is None: + seeds_list = list(seeds.values()) + ( + sorted_seeds_list, + index_list, + offsets_list, + ) = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size) + assert minibatch.compacted_seeds is None + sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {} + num_ntypes = len(seeds.keys()) + for i, ( + seed_type, + typed_sorted_seeds, + typed_index, + typed_offsets, + ) in enumerate( + zip( + seeds.keys(), + sorted_seeds_list, + index_list, + offsets_list, + ) + ): + sorted_seeds[seed_type] = typed_sorted_seeds + sorted_compacted[seed_type] = typed_index + sorted_offsets[seed_type] = typed_offsets.tolist() + + minibatch._seed_nodes = sorted_seeds + minibatch.compacted_seeds = sorted_compacted + minibatch._seeds_offsets = sorted_offsets + else: + minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets} + counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) + for i, offsets in enumerate(minibatch._seeds_offsets[0].values()): + counts_sent[ + torch.arange(i, world_size * num_ntypes, num_ntypes) + ] = offsets.diff() + delattr(minibatch, "_seeds_offsets") + counts_received = torch.empty_like(counts_sent) + minibatch._counts_future = all_to_all( + counts_received.split(num_ntypes), + counts_sent.split(num_ntypes), + group=group, + async_op=True, + ) + minibatch._counts_sent = counts_sent + minibatch._counts_received = counts_received + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_2(minibatch, group=None): + world_size = thd.get_world_size(group) + seeds = minibatch._seed_nodes + minibatch._counts_future.wait() + delattr(minibatch, "_counts_future") + counts_received = minibatch._counts_received + num_ntypes = len(seeds.keys()) + seeds_received = {} + counts_sent = {} + counts_received = {} + for i, (ntype, typed_seeds) in enumerate(seeds.items()): + idx = torch.arange(i, world_size * num_ntypes, num_ntypes) + typed_counts_sent = minibatch._counts_sent[idx].tolist() + typed_counts_received = minibatch._counts_received[idx].tolist() + typed_seeds_received = typed_seeds.new_empty( + sum(typed_counts_received) + ) + all_to_all( + typed_seeds_received.split(typed_counts_received), + typed_seeds.split(typed_counts_sent), + group, + ) + seeds_received[ntype] = typed_seeds_received + minibatch._seed_nodes = _revert_to_homo(seeds_received) + minibatch._counts_sent = _revert_to_homo(counts_sent) + minibatch._counts_received = _revert_to_homo(counts_received) + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_3(minibatch): + minibatch._unique_future = unique_and_compact( + minibatch._seed_nodes, 0, 1, async_op=True + ) + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_4(minibatch): + unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait() + delattr(minibatch, "_unique_future") + minibatch._seed_nodes = _revert_to_homo(unique_seeds) + minibatch._seed_inverse_ids = _revert_to_homo(inverse_seeds) return minibatch def _sample(self, minibatch): @@ -119,7 +264,12 @@ def sampling_stages(self, datapipe): return datapipe.transform(self._sample) @staticmethod - def _seeds_preprocess(minibatch, async_op): + def _seeds_preprocess( + minibatch: MiniBatch, + rank: int = 0, + world_size: int = 1, + async_op: bool = False, + ): """Preprocess `seeds` in a minibatch to construct `unique_seeds`, `node_timestamp` and `compacted_seeds` for further sampling. It optionally incorporates timestamps for temporal graphs, organizing and @@ -130,6 +280,11 @@ def _seeds_preprocess(minibatch, async_op): ---------- minibatch: MiniBatch The minibatch. + rank : int + The rank of the current process among cooperating processes. + world_size : int + The number of cooperating + (`arXiv:2210.13339`__) processes. async_op: bool Boolean indicating whether the call is asynchronous. If so, the result can be obtained by calling wait on the returned future. @@ -145,8 +300,16 @@ def _seeds_preprocess(minibatch, async_op): compacted_seeds: torch.tensor or a Dict[str, torch.Tensor] Representation of compacted seeds corresponding to 'seeds', where all node ids inside are compacted. + offsets: None or torch.Tensor or Dict[src, torch.Tensor] + The unique nodes offsets tensor partitions the unique_nodes tensor. + Has size `world_size + 1` and + `unique_nodes[offsets[i]: offsets[i + 1]]` belongs to the rank + `(rank + i) % world_size`. """ use_timestamp = hasattr(minibatch, "timestamp") + assert ( + not use_timestamp or world_size == 1 + ), "Temporal code path does not currently support Cooperative Minibatching" seeds = minibatch.seeds is_heterogeneous = isinstance(seeds, Dict) if is_heterogeneous: @@ -164,7 +327,7 @@ def _seeds_preprocess(minibatch, async_op): if hasattr(minibatch, "timestamp") else None ) - result = _NoOpWaiter((seeds, nodes_timestamp, None)) + result = _NoOpWaiter((seeds, nodes_timestamp, None, None)) break result = None assert typed_seeds.ndim == 2, ( @@ -200,7 +363,7 @@ def __init__(self, nodes, nodes_timestamp, seeds): ) else: self.future = unique_and_compact( - nodes, async_op=async_op + nodes, rank, world_size, async_op ) self.seeds = seeds @@ -208,8 +371,9 @@ def wait(self): """Returns the stored value when invoked.""" if use_timestamp: unique_seeds, nodes_timestamp, compacted = self.future + offsets = None else: - unique_seeds, compacted, _ = ( + unique_seeds, compacted, offsets = ( self.future.wait() if async_op else self.future ) nodes_timestamp = None @@ -234,6 +398,7 @@ def wait(self): unique_seeds, nodes_timestamp, compacted_seeds, + offsets, ) # When typed_seeds is not a one-dimensional tensor @@ -248,7 +413,7 @@ def wait(self): if hasattr(minibatch, "timestamp") else None ) - result = _NoOpWaiter((seeds, nodes_timestamp, None)) + result = _NoOpWaiter((seeds, nodes_timestamp, None, None)) else: # Collect nodes from all types of input. nodes = [seeds.view(-1)] @@ -289,8 +454,9 @@ def wait(self): nodes_timestamp, compacted, ) = self.future + offsets = None else: - unique_seeds, compacted, _ = ( + unique_seeds, compacted, offsets = ( self.future.wait() if async_op else self.future ) nodes_timestamp = None @@ -305,6 +471,7 @@ def wait(self): unique_seeds, nodes_timestamp, compacted_seeds, + offsets, ) result = _Waiter(nodes, nodes_timestamp, seeds) diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index 4a095ac1b10d..5326e620b0c3 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -65,7 +65,9 @@ def test_NeighborSampler_GraphFetch( graph.type_per_edge = None item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) fanout = torch.LongTensor([2]) - preprocess_fn = partial(gb.SubgraphSampler._preprocess, async_op=False) + preprocess_fn = partial( + gb.SubgraphSampler._preprocess, cooperative=False, async_op=False + ) datapipe = item_sampler.map(preprocess_fn) datapipe = datapipe.map( partial(gb.NeighborSampler._prepare, graph.node_type_to_id)