Skip to content

Commit

Permalink
change test_distributed_sampling.py
Browse files Browse the repository at this point in the history
  • Loading branch information
CfromBU committed Dec 19, 2024
1 parent 636c61c commit 09f3718
Showing 1 changed file with 4 additions and 26 deletions.
30 changes: 4 additions & 26 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,33 +1862,11 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
def check_hetero_dist_edge_dataloader_gb(
tmpdir, num_server, use_graphbolt=True
):
# Custom function to create a heterogeneous graph, ensuring that edges with missing masks
# can still be used to sample in DistEdgeDataloader. create_random_hetero does not support this case,
# so this function was added to handle the requirement.
def create_hetero_graph():
num_nodes = {"n1": 210, "n2": 200, "n3": 220, "n4": 230}
etypes = [("n1", "r12", "n2"), ("n2", "r23", "n3"), ("n3", "r34", "n4")]
edges = {}
random.seed(42)
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(
num_nodes[src_ntype],
num_nodes[dst_ntype],
density=0.1,
format="coo",
random_state=100,
)
edges[etype] = (arr.row, arr.col)
g = dgl.heterograph(edges, num_nodes)

return g

generate_ip_config("rpc_ip_config.txt", num_server, num_server)

g = create_hetero_graph()
eids = torch.randperm(g.num_edges("r34"))[:10]
mask = torch.zeros(g.num_edges("r34"), dtype=torch.bool)
g = create_random_hetero()
eids = torch.randperm(g.num_edges("r23"))[:10]
mask = torch.zeros(g.num_edges("r23"), dtype=torch.bool)
mask[eids] = True

num_parts = num_server
Expand Down Expand Up @@ -1930,7 +1908,7 @@ def create_hetero_graph():

os.environ["DGL_DIST_DEBUG"] = "1"

edges = {("n3", "r34", "n4"): eids}
edges = {("n2", "r23", "n3"): eids}
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask")
loader = dgl.dataloading.DistEdgeDataLoader(
dist_graph, edges, sampler, batch_size=64
Expand Down

0 comments on commit 09f3718

Please sign in to comment.