Skip to content

Commit

Permalink
[GraphBolt][CUDA] Add CooperativeConv.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 12, 2024
1 parent 189b83c commit 6a02b52
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 3 deletions.
35 changes: 35 additions & 0 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from ..sampled_subgraph import SampledSubgraph

Check warning on line 1 in python/dgl/graphbolt/impl/cooperative_conv.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
from ..subgraph_sampler import all_to_all

import torch

class CooperativeConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, subgraph: SampledSubgraph, h):
counts_sent = subgraph._counts_sent
counts_received = subgraph._counts_received
seed_inverse_ids = subgraph._seed_inverse_ids
seed_sizes = subgraph._seed_sizes
ctx.save_for_backward(counts_sent, counts_received, seed_inverse_ids, seed_sizes)
out = h.new_empty((sum(counts_sent),) + h.shape[1:])
all_to_all(torch.split(out, counts_sent), torch.split(h[seed_inverse_ids], counts_received))
return out

@staticmethod
def backward(ctx, grad_output):
counts_sent, counts_received, seed_inverse_ids, seed_sizes = ctx.saved_tensors
out = grad_output.new_empty((sum(counts_received),) + grad_output.shape[1:])
all_to_all(torch.split(out, counts_received), torch.split(grad_output, counts_sent))
i = out.new_empty(2, out.shape[0], dtype=torch.int64)
i[0] = torch.arange(out.shape[0], device=grad_output.device) # src
i[1] = seed_inverse_ids # dst
coo = torch.sparse_coo_tensor(i, 1, size=(seed_sizes, i.shape[1]))
rout = torch.sparse.mm(coo, out)
return None, rout

class CooperativeConv(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, subgraph: SampledSubgraph, x):
return CooperativeConvFunction.apply(subgraph, x)
30 changes: 27 additions & 3 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,17 +601,16 @@ def _seeds_cooperative_exchange_2(minibatch):
typed_seeds.split(typed_counts_sent),
)
seeds_received[ntype] = typed_seeds_received
subgraph._seeds_received = seeds_received
minibatch._seed_nodes = seeds_received
subgraph._counts_sent = revert_to_homo(counts_sent)
subgraph._counts_received = revert_to_homo(counts_received)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
subgraph = minibatch.sampled_subgraphs[0]
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in subgraph._seeds_received.items()
for ntype, typed_seeds in minibatch._seed_nodes.items()
}
minibatch._unique_future = unique_and_compact(
nodes, 0, 1, async_op=True
Expand All @@ -627,6 +626,11 @@ def _seeds_cooperative_exchange_4(minibatch):
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
subgraph = minibatch.sampled_subgraphs[0]
sizes = {
ntype: typed_seeds.size(0)
for ntype, typed_seeds in unique_seeds.items()
}
subgraph._seed_sizes = revert_to_homo(sizes)
subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch

Expand Down Expand Up @@ -831,6 +835,16 @@ class NeighborSampler(NeighborSamplerImpl):
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
cooperative: bool, optional
Boolean indicating whether Cooperative Minibatching, which was initially
proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__
and was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs
eliminates duplicate work performed across the GPUs due to the
overlapping sampled k-hop neighborhoods of seed nodes when performing
GNN minibatching.
asynchronous: bool
Boolean indicating whether sampling and compaction stages should run
in background threads to hide the latency of CPU GPU synchronization.
Expand Down Expand Up @@ -986,6 +1000,16 @@ class LayerNeighborSampler(NeighborSamplerImpl):
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
cooperative: bool, optional
Boolean indicating whether Cooperative Minibatching, which was initially
proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__
and was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs
eliminates duplicate work performed across the GPUs due to the
overlapping sampled k-hop neighborhoods of seed nodes when performing
GNN minibatching.
asynchronous: bool
Boolean indicating whether sampling and compaction stages should run
in background threads to hide the latency of CPU GPU synchronization.
Expand Down
5 changes: 5 additions & 0 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ def _seeds_cooperative_exchange_4(minibatch):
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
sizes = {
ntype: typed_seeds.size(0)
for ntype, typed_seeds in unique_seeds.items()
}
minibatch._seed_sizes = revert_to_homo(sizes)
minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch

Expand Down

0 comments on commit 6a02b52

Please sign in to comment.