Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Cooperative Minibatching initial exchange. #7795

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/dgl/graphbolt/internal/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def wait(self):
if is_homogeneous:
compacted_csc_formats = list(compacted_csc_formats.values())[0]
unique_nodes = list(unique_nodes.values())[0]
offsets = list(offsets.values())[0]

return unique_nodes, compacted_csc_formats, offsets

Expand Down
193 changes: 180 additions & 13 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from typing import Dict

import torch
import torch.distributed as thd
from torch.utils.data import functional_datapipe

from .base import seed_type_str_to_ntypes
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch import MiniBatch
from .minibatch_transformer import MiniBatchTransformer

__all__ = [
Expand All @@ -28,6 +30,25 @@ def wait(self):
return result


def _shift(inputs: list, group=None):
cutoff = len(inputs) - thd.get_rank(group)
return inputs[cutoff:] + inputs[:cutoff]


def all_to_all(outputs, inputs, group=None, async_op=False):
"""Wrapper for thd.all_to_all that permuted outputs and inputs before
calling it. The arguments have the permutation
`rank, ..., world_size - 1, 0, ..., rank - 1` and we make it
`0, world_size - 1` before calling `thd.all_to_all`."""
shift_fn = partial(_shift, group=group)
return thd.all_to_all(shift_fn(outputs), shift_fn(inputs), group, async_op)


def _revert_to_homo(d: dict):
is_homogenous = len(d) == 1 and "_N" in d
return list(d.values())[0] if is_homogenous else d


@functional_datapipe("sample_subgraph")
class SubgraphSampler(MiniBatchTransformer):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
Expand All @@ -49,8 +70,8 @@ class SubgraphSampler(MiniBatchTransformer):
Arguments to be passed into sampling_stages.
kwargs : Keyword Arguments
Arguments to be passed into sampling_stages. Preprocessing stage makes
use of the `asynchronous` parameter before it is passed to
the sampling stages.
use of the `asynchronous` and `cooperative` parameters before they are
passed to the sampling stages.
"""

def __init__(
Expand All @@ -60,10 +81,22 @@ def __init__(
**kwargs,
):
async_op = kwargs.get("asynchronous", False)
preprocess_fn = partial(self._preprocess, async_op=async_op)
cooperative = kwargs.get("cooperative", False)
preprocess_fn = partial(
self._preprocess, cooperative=cooperative, async_op=async_op
)
datapipe = datapipe.transform(preprocess_fn)
if async_op:
datapipe = datapipe.buffer().transform(self._wait_preprocess_future)
fn = partial(self._wait_preprocess_future, cooperative=cooperative)
datapipe = datapipe.buffer().transform(fn)
if cooperative:
datapipe = datapipe.transform(self._seeds_cooperative_exchange_1)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_2)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_3)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_4)
datapipe = self.sampling_stages(datapipe, *args, **kwargs)
datapipe = datapipe.transform(self._postprocess)
super().__init__(datapipe)
Expand All @@ -75,30 +108,142 @@ def _postprocess(minibatch):
return minibatch

@staticmethod
def _preprocess(minibatch, async_op: bool):
def _preprocess(minibatch, cooperative: bool, async_op: bool):
if minibatch.seeds is None:
raise ValueError(
f"Invalid minibatch {minibatch}: `seeds` should have a value."
)
results = SubgraphSampler._seeds_preprocess(minibatch, async_op)
rank = thd.get_rank() if cooperative else 0
world_size = thd.get_world_size() if cooperative else 1
results = SubgraphSampler._seeds_preprocess(
minibatch, rank, world_size, async_op
)
if async_op:
minibatch._preprocess_future = results
else:
(
minibatch._seed_nodes,
minibatch._seeds_timestamp,
minibatch.compacted_seeds,
offsets,
) = results
if cooperative:
minibatch._seeds_offsets = offsets
return minibatch

@staticmethod
def _wait_preprocess_future(minibatch):
def _wait_preprocess_future(minibatch, cooperative: bool):
(
minibatch._seed_nodes,
minibatch._seeds_timestamp,
minibatch.compacted_seeds,
offsets,
) = minibatch._preprocess_future.wait()
delattr(minibatch, "_preprocess_future")
if cooperative:
minibatch._seeds_offsets = offsets
return minibatch

@staticmethod
def _seeds_cooperative_exchange_1(minibatch, group=None):
rank = thd.get_rank(group)
world_size = thd.get_world_size(group)
assert world_size > 1
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
if minibatch._seeds_offsets is None:
seeds_list = list(seeds.values())
(
sorted_seeds_list,
index_list,
offsets_list,
) = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
assert minibatch.compacted_seeds is None
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
num_ntypes = len(seeds.keys())
for i, (
seed_type,
typed_sorted_seeds,
typed_index,
typed_offsets,
) in enumerate(
zip(
seeds.keys(),
sorted_seeds_list,
index_list,
offsets_list,
)
):
sorted_seeds[seed_type] = typed_sorted_seeds
sorted_compacted[seed_type] = typed_index
sorted_offsets[seed_type] = typed_offsets.tolist()

minibatch._seed_nodes = sorted_seeds
minibatch.compacted_seeds = sorted_compacted
minibatch._seeds_offsets = sorted_offsets
else:
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(minibatch._seeds_offsets[0].values()):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
delattr(minibatch, "_seeds_offsets")
counts_received = torch.empty_like(counts_sent)
minibatch._counts_future = all_to_all(
counts_received.split(num_ntypes),
counts_sent.split(num_ntypes),
group=group,
async_op=True,
)
minibatch._counts_sent = counts_sent
minibatch._counts_received = counts_received
return minibatch

@staticmethod
def _seeds_cooperative_exchange_2(minibatch, group=None):
world_size = thd.get_world_size(group)
seeds = minibatch._seed_nodes
minibatch._counts_future.wait()
delattr(minibatch, "_counts_future")
counts_received = minibatch._counts_received
num_ntypes = len(seeds.keys())
seeds_received = {}
counts_sent = {}
counts_received = {}
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
typed_counts_sent = minibatch._counts_sent[idx].tolist()
typed_counts_received = minibatch._counts_received[idx].tolist()
typed_seeds_received = typed_seeds.new_empty(
sum(typed_counts_received)
)
all_to_all(
typed_seeds_received.split(typed_counts_received),
typed_seeds.split(typed_counts_sent),
group,
)
seeds_received[ntype] = typed_seeds_received
minibatch._seed_nodes = _revert_to_homo(seeds_received)
minibatch._counts_sent = _revert_to_homo(counts_sent)
minibatch._counts_received = _revert_to_homo(counts_received)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
minibatch._unique_future = unique_and_compact(
minibatch._seed_nodes, 0, 1, async_op=True
)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_4(minibatch):
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
delattr(minibatch, "_unique_future")
minibatch._seed_nodes = _revert_to_homo(unique_seeds)
minibatch._seed_inverse_ids = _revert_to_homo(inverse_seeds)
return minibatch

def _sample(self, minibatch):
Expand All @@ -119,7 +264,12 @@ def sampling_stages(self, datapipe):
return datapipe.transform(self._sample)

@staticmethod
def _seeds_preprocess(minibatch, async_op):
def _seeds_preprocess(
minibatch: MiniBatch,
rank: int = 0,
world_size: int = 1,
async_op: bool = False,
):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
Expand All @@ -130,6 +280,11 @@ def _seeds_preprocess(minibatch, async_op):
----------
minibatch: MiniBatch
The minibatch.
rank : int
The rank of the current process among cooperating processes.
world_size : int
The number of cooperating
(`arXiv:2210.13339<https://arxiv.org/abs/2310.12403>`__) processes.
async_op: bool
Boolean indicating whether the call is asynchronous. If so, the
result can be obtained by calling wait on the returned future.
Expand All @@ -145,8 +300,16 @@ def _seeds_preprocess(minibatch, async_op):
compacted_seeds: torch.tensor or a Dict[str, torch.Tensor]
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
offsets: None or torch.Tensor or Dict[src, torch.Tensor]
The unique nodes offsets tensor partitions the unique_nodes tensor.
Has size `world_size + 1` and
`unique_nodes[offsets[i]: offsets[i + 1]]` belongs to the rank
`(rank + i) % world_size`.
"""
use_timestamp = hasattr(minibatch, "timestamp")
assert (
not use_timestamp or world_size == 1
), "Temporal code path does not currently support Cooperative Minibatching"
seeds = minibatch.seeds
is_heterogeneous = isinstance(seeds, Dict)
if is_heterogeneous:
Expand All @@ -164,7 +327,7 @@ def _seeds_preprocess(minibatch, async_op):
if hasattr(minibatch, "timestamp")
else None
)
result = _NoOpWaiter((seeds, nodes_timestamp, None))
result = _NoOpWaiter((seeds, nodes_timestamp, None, None))
break
result = None
assert typed_seeds.ndim == 2, (
Expand Down Expand Up @@ -200,16 +363,17 @@ def __init__(self, nodes, nodes_timestamp, seeds):
)
else:
self.future = unique_and_compact(
nodes, async_op=async_op
nodes, rank, world_size, async_op
)
self.seeds = seeds

def wait(self):
"""Returns the stored value when invoked."""
if use_timestamp:
unique_seeds, nodes_timestamp, compacted = self.future
offsets = None
else:
unique_seeds, compacted, _ = (
unique_seeds, compacted, offsets = (
self.future.wait() if async_op else self.future
)
nodes_timestamp = None
Expand All @@ -234,6 +398,7 @@ def wait(self):
unique_seeds,
nodes_timestamp,
compacted_seeds,
offsets,
)

# When typed_seeds is not a one-dimensional tensor
Expand All @@ -248,7 +413,7 @@ def wait(self):
if hasattr(minibatch, "timestamp")
else None
)
result = _NoOpWaiter((seeds, nodes_timestamp, None))
result = _NoOpWaiter((seeds, nodes_timestamp, None, None))
else:
# Collect nodes from all types of input.
nodes = [seeds.view(-1)]
Expand Down Expand Up @@ -289,8 +454,9 @@ def wait(self):
nodes_timestamp,
compacted,
) = self.future
offsets = None
else:
unique_seeds, compacted, _ = (
unique_seeds, compacted, offsets = (
self.future.wait() if async_op else self.future
)
nodes_timestamp = None
Expand All @@ -305,6 +471,7 @@ def wait(self):
unique_seeds,
nodes_timestamp,
compacted_seeds,
offsets,
)

result = _Waiter(nodes, nodes_timestamp, seeds)
Expand Down
4 changes: 3 additions & 1 deletion tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def test_NeighborSampler_GraphFetch(
graph.type_per_edge = None
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
fanout = torch.LongTensor([2])
preprocess_fn = partial(gb.SubgraphSampler._preprocess, async_op=False)
preprocess_fn = partial(
gb.SubgraphSampler._preprocess, cooperative=False, async_op=False
)
datapipe = item_sampler.map(preprocess_fn)
datapipe = datapipe.map(
partial(gb.NeighborSampler._prepare, graph.node_type_to_id)
Expand Down
Loading