Skip to content

Commit

Permalink
[GraphBolt][CUDA] Add CooperativeConv and minor fixes. (#7797)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Sep 13, 2024
1 parent 189b83c commit 864b023
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .gpu_graph_cache import *
from .cpu_feature_cache import *
from .cpu_cached_feature import *
from .cooperative_conv import *
109 changes: 109 additions & 0 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Graphbolt cooperative convolution."""
from typing import Dict, Union

import torch

from ..sampled_subgraph import SampledSubgraph
from ..subgraph_sampler import all_to_all, convert_to_hetero, revert_to_homo

__all__ = ["CooperativeConvFunction", "CooperativeConv"]


class CooperativeConvFunction(torch.autograd.Function):
"""Cooperative convolution operation from Cooperative Minibatching.
Implements the `all-to-all` message passing algorithm
in Cooperative Minibatching, which was initially proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__ and
was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__.
Cooperation between the GPUs eliminates duplicate work performed across the
GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when
performing GNN minibatching. This reduces the redundant computations across
GPUs at the expense of communication.
"""

@staticmethod
def forward(
ctx,
subgraph: SampledSubgraph,
tensor: Union[torch.Tensor, Dict[str, torch.Tensor]],
):
"""Implements the forward pass."""
counts_sent = convert_to_hetero(subgraph._counts_sent)
counts_received = convert_to_hetero(subgraph._counts_received)
seed_inverse_ids = convert_to_hetero(subgraph._seed_inverse_ids)
seed_sizes = convert_to_hetero(subgraph._seed_sizes)
ctx.save_for_backward(
counts_sent, counts_received, seed_inverse_ids, seed_sizes
)
outs = {}
for ntype, typed_tensor in convert_to_hetero(tensor).items():
out = typed_tensor.new_empty(
(sum(counts_sent[ntype]),) + typed_tensor.shape[1:]
)
all_to_all(
torch.split(out, counts_sent[ntype]),
torch.split(
typed_tensor[seed_inverse_ids[ntype]],
counts_received[ntype],
),
)
outs[ntype] = out
return revert_to_homo(out)

@staticmethod
def backward(
ctx, grad_output: Union[torch.Tensor, Dict[str, torch.Tensor]]
):
"""Implements the forward pass."""
(
counts_sent,
counts_received,
seed_inverse_ids,
seed_sizes,
) = ctx.saved_tensors
outs = {}
for ntype, typed_grad_output in convert_to_hetero(grad_output).items():
out = typed_grad_output.new_empty(
(sum(counts_received[ntype]),) + typed_grad_output.shape[1:]
)
all_to_all(
torch.split(out, counts_received[ntype]),
torch.split(typed_grad_output, counts_sent[ntype]),
)
i = out.new_empty(2, out.shape[0], dtype=torch.int64)
i[0] = torch.arange(
out.shape[0], device=typed_grad_output.device
) # src
i[1] = seed_inverse_ids[ntype] # dst
coo = torch.sparse_coo_tensor(
i, 1, size=(seed_sizes[ntype], i.shape[1])
)
outs[ntype] = torch.sparse.mm(coo, out)
return None, revert_to_homo(outs)


class CooperativeConv(torch.nn.Module):
"""Cooperative convolution operation from Cooperative Minibatching.
Implements the `all-to-all` message passing algorithm
in Cooperative Minibatching, which was initially proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__ and
was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__.
Cooperation between the GPUs eliminates duplicate work performed across the
GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when
performing GNN minibatching. This reduces the redundant computations across
GPUs at the expense of communication.
"""

def forward(
self,
subgraph: SampledSubgraph,
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
):
"""Implements the forward pass."""
return CooperativeConvFunction.apply(subgraph, x)
32 changes: 29 additions & 3 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,17 +601,18 @@ def _seeds_cooperative_exchange_2(minibatch):
typed_seeds.split(typed_counts_sent),
)
seeds_received[ntype] = typed_seeds_received
subgraph._seeds_received = seeds_received
counts_sent[ntype] = typed_counts_sent
counts_received[ntype] = typed_counts_received
minibatch._seed_nodes = seeds_received
subgraph._counts_sent = revert_to_homo(counts_sent)
subgraph._counts_received = revert_to_homo(counts_received)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
subgraph = minibatch.sampled_subgraphs[0]
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in subgraph._seeds_received.items()
for ntype, typed_seeds in minibatch._seed_nodes.items()
}
minibatch._unique_future = unique_and_compact(
nodes, 0, 1, async_op=True
Expand All @@ -627,6 +628,11 @@ def _seeds_cooperative_exchange_4(minibatch):
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
subgraph = minibatch.sampled_subgraphs[0]
sizes = {
ntype: typed_seeds.size(0)
for ntype, typed_seeds in unique_seeds.items()
}
subgraph._seed_sizes = revert_to_homo(sizes)
subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch

Expand Down Expand Up @@ -831,6 +837,16 @@ class NeighborSampler(NeighborSamplerImpl):
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
cooperative: bool, optional
Boolean indicating whether Cooperative Minibatching, which was initially
proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__
and was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs
eliminates duplicate work performed across the GPUs due to the
overlapping sampled k-hop neighborhoods of seed nodes when performing
GNN minibatching.
asynchronous: bool
Boolean indicating whether sampling and compaction stages should run
in background threads to hide the latency of CPU GPU synchronization.
Expand Down Expand Up @@ -986,6 +1002,16 @@ class LayerNeighborSampler(NeighborSamplerImpl):
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
cooperative: bool, optional
Boolean indicating whether Cooperative Minibatching, which was initially
proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__
and was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs
eliminates duplicate work performed across the GPUs due to the
overlapping sampled k-hop neighborhoods of seed nodes when performing
GNN minibatching.
asynchronous: bool
Boolean indicating whether sampling and compaction stages should run
in background threads to hide the latency of CPU GPU synchronization.
Expand Down
15 changes: 15 additions & 0 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
__all__ = [
"SubgraphSampler",
"all_to_all",
"convert_to_hetero",
"revert_to_homo",
]

Expand Down Expand Up @@ -89,6 +90,13 @@ def revert_to_homo(d: dict):
return list(d.values())[0] if is_homogenous else d


def convert_to_hetero(item):
"""Utility function to convert homogenous data to heterogenous with a single
node type."""
is_heterogenous = isinstance(item, dict)
return item if is_heterogenous else {"_N": item}


@functional_datapipe("sample_subgraph")
class SubgraphSampler(MiniBatchTransformer):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
Expand Down Expand Up @@ -251,6 +259,8 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
group,
)
seeds_received[ntype] = typed_seeds_received
counts_sent[ntype] = typed_counts_sent
counts_received[ntype] = typed_counts_received
minibatch._seed_nodes = seeds_received
minibatch._counts_sent = revert_to_homo(counts_sent)
minibatch._counts_received = revert_to_homo(counts_received)
Expand All @@ -275,6 +285,11 @@ def _seeds_cooperative_exchange_4(minibatch):
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
sizes = {
ntype: typed_seeds.size(0)
for ntype, typed_seeds in unique_seeds.items()
}
minibatch._seed_sizes = revert_to_homo(sizes)
minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch

Expand Down
19 changes: 19 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dgl
import dgl.graphbolt
import dgl.graphbolt as gb
import pytest
import torch
import torch.distributed as thd
Expand Down Expand Up @@ -194,5 +195,23 @@ def test_gpu_sampling_DataLoader(
if sampler_name == "LayerNeighborSampler":
assert torch.equal(edge_feature, edge_feature_ref)
assert len(list(dataloader)) == N // B

if asynchronous and cooperative:
for minibatch in minibatches:
x = torch.ones((minibatch.node_ids().size(0), 1), device=F.ctx())
for subgraph in minibatch.sampled_subgraphs:
x = gb.CooperativeConvFunction.apply(subgraph, x)
x, edge_index, size = subgraph.to_pyg(x)
x = x[0]
one = torch.ones(
edge_index.shape[1], dtype=x.dtype, device=x.device
)
coo = torch.sparse_coo_tensor(
edge_index.flipud(), one, size=(size[1], size[0])
)
x = torch.sparse.mm(coo, x)
assert x.shape[0] == minibatch.seeds.shape[0]
assert x.shape[1] == 1

if thd.is_initialized():
thd.destroy_process_group()

0 comments on commit 864b023

Please sign in to comment.