Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distGB] graphbolt graph edge's mask will be filled with 0 if these edges have no mask initial #7846

Merged
merged 15 commits into from
Jan 9, 2025
Merged
25 changes: 18 additions & 7 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def collate(self, items):
raise NotImplementedError

@staticmethod
def add_edge_attribute_to_graph(g, data_name):
def add_edge_attribute_to_graph(g, data_name, gb_padding):
"""Add data into the graph as an edge attribute.

For some cases such as prob/mask-based sampling on GraphBolt partitions,
Expand All @@ -327,9 +327,11 @@ def add_edge_attribute_to_graph(g, data_name):
The graph.
data_name : str
The name of data that's stored in DistGraph.ndata/edata.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes.
classicsong marked this conversation as resolved.
Show resolved Hide resolved
"""
if g._use_graphbolt and data_name:
g.add_edge_attribute(data_name)
g.add_edge_attribute(data_name, gb_padding)


class NodeCollator(Collator):
Expand All @@ -344,6 +346,11 @@ class NodeCollator(Collator):
The node set to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
e.g. prob/mask-based sampling.
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.

Examples
--------
Expand All @@ -366,7 +373,7 @@ class NodeCollator(Collator):
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""

def __init__(self, g, nids, graph_sampler):
def __init__(self, g, nids, graph_sampler, gb_padding=1):
self.g = g
if not isinstance(nids, Mapping):
assert (
Expand All @@ -380,7 +387,7 @@ def __init__(self, g, nids, graph_sampler):
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down Expand Up @@ -508,8 +515,11 @@ class EdgeCollator(Collator):

A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.

Examples
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
e.g. prob/mask-based sampling.
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.
--------
The following example shows how to train a 3-layer GNN for edge classification on a
set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes
Expand Down Expand Up @@ -612,6 +622,7 @@ def __init__(
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
gb_padding=1,
):
self.g = g
if not isinstance(eids, Mapping):
Expand Down Expand Up @@ -642,7 +653,7 @@ def __init__(
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down
21 changes: 15 additions & 6 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,16 @@ def _copy_data_from_shared_mem(name, shape):
class AddEdgeAttributeFromKVRequest(rpc.Request):
"""Add edge attribute from kvstore to local GraphBolt partition."""

def __init__(self, name, kv_names):
def __init__(self, name, kv_names, padding):
self._name = name
self._kv_names = kv_names
self._padding = padding

def __getstate__(self):
return self._name, self._kv_names
return self._name, self._kv_names, self._padding

def __setstate__(self, state):
self._name, self._kv_names = state
self._name, self._kv_names, self._padding = state

def process_request(self, server_state):
# For now, this is only used to add prob/mask data to the graph.
Expand All @@ -169,7 +170,13 @@ def process_request(self, server_state):
gpb = server_state.partition_book
# Initialize the edge attribute.
num_edges = g.total_num_edges
attr_data = torch.zeros(num_edges, dtype=data_type)

# Padding is used to fill missing edge attributes (e.g., 'prob' or 'mask') for certain edge types.
# In DGLGraph, some edges may lack these attributes or have them set to None, but DGL will still sample these edges.
# In contrast, GraphBolt samples edges based on specific attributes (e.g., 'mask' == 1) and will skip edges with missing attributes.
# To ensure consistent sampling behavior in GraphBolt, we pad missing attributes with default values (e.g., 'mask' = 1),
# allowing all edges to be sampled, even if their attributes were missing or None in DGLGraph.
attr_data = torch.full((num_edges,), self._padding, dtype=data_type)
classicsong marked this conversation as resolved.
Show resolved Hide resolved
Rhett-Ying marked this conversation as resolved.
Show resolved Hide resolved
# Map data from kvstore to the local partition for inner edges only.
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
homo_eids = g.edge_attributes[EID][:num_inner_edges]
Expand Down Expand Up @@ -1620,13 +1627,15 @@ def _get_edata_names(self, etype=None):
edata_names.append(name)
return edata_names

def add_edge_attribute(self, name):
def add_edge_attribute(self, name, padding):
"""Add an edge attribute into GraphBolt partition from edge data.

Parameters
----------
name : str
The name of the edge attribute.
padding : int, optional
The padding value for the new edge attribute.
"""
# Sanity checks.
if not self._use_graphbolt:
Expand All @@ -1643,7 +1652,7 @@ def add_edge_attribute(self, name):
]
rpc.send_request(
self._client._main_server_id,
AddEdgeAttributeFromKVRequest(name, kv_names),
AddEdgeAttributeFromKVRequest(name, kv_names, padding),
)
# Wait for the response.
assert rpc.recv_response()._name == name
Expand Down
78 changes: 77 additions & 1 deletion tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import unittest
from pathlib import Path

import backend as F
import dgl

import dgl.backend as F
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -1858,6 +1859,81 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
)


def check_hetero_dist_edge_dataloader_gb(
tmpdir, num_server, use_graphbolt=True
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)

g = create_random_hetero()
eids = torch.randperm(g.num_edges("r23"))[:10]
mask = torch.zeros(g.num_edges("r23"), dtype=torch.bool)
mask[eids] = True

num_parts = num_server

orig_nid_map, orig_eid_map = partition_graph(
g,
"test_sampling",
num_parts,
tmpdir,
num_hops=1,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=True,
)

part_config = tmpdir / "test_sampling.json"

pserver_list = []
ctx = mp.get_context("spawn")
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
True,
),
)
p.start()
time.sleep(1)
pserver_list.append(p)

dgl.distributed.initialize("rpc_ip_config.txt", use_graphbolt=True)
dist_graph = DistGraph("test_sampling", part_config=part_config)

os.environ["DGL_DIST_DEBUG"] = "1"

edges = {("n2", "r23", "n3"): eids}
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask")
loader = dgl.dataloading.DistEdgeDataLoader(
dist_graph, edges, sampler, batch_size=64
)
dgl.distributed.exit_client()
for p in pserver_list:
p.join()
assert p.exitcode == 0

block = next(iter(loader))[2][0]
assert block.num_src_nodes("n1") > 0
CfromBU marked this conversation as resolved.
Show resolved Hide resolved
assert block.num_edges("r12") > 0
assert block.num_edges("r13") > 0
assert block.num_edges("r23") > 0


def test_hetero_dist_edge_dataloader_gb(
num_server=1,
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_hetero_dist_edge_dataloader_gb(Path(tmpdirname), num_server)


if __name__ == "__main__":
import tempfile

Expand Down
Loading