diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 411a9546ac8f..6f2df8b84f0f 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -1862,33 +1862,11 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask): def check_hetero_dist_edge_dataloader_gb( tmpdir, num_server, use_graphbolt=True ): - # Custom function to create a heterogeneous graph, ensuring that edges with missing masks - # can still be used to sample in DistEdgeDataloader. create_random_hetero does not support this case, - # so this function was added to handle the requirement. - def create_hetero_graph(): - num_nodes = {"n1": 210, "n2": 200, "n3": 220, "n4": 230} - etypes = [("n1", "r12", "n2"), ("n2", "r23", "n3"), ("n3", "r34", "n4")] - edges = {} - random.seed(42) - for etype in etypes: - src_ntype, _, dst_ntype = etype - arr = spsp.random( - num_nodes[src_ntype], - num_nodes[dst_ntype], - density=0.1, - format="coo", - random_state=100, - ) - edges[etype] = (arr.row, arr.col) - g = dgl.heterograph(edges, num_nodes) - - return g - generate_ip_config("rpc_ip_config.txt", num_server, num_server) - g = create_hetero_graph() - eids = torch.randperm(g.num_edges("r34"))[:10] - mask = torch.zeros(g.num_edges("r34"), dtype=torch.bool) + g = create_random_hetero() + eids = torch.randperm(g.num_edges("r23"))[:10] + mask = torch.zeros(g.num_edges("r23"), dtype=torch.bool) mask[eids] = True num_parts = num_server @@ -1930,7 +1908,7 @@ def create_hetero_graph(): os.environ["DGL_DIST_DEBUG"] = "1" - edges = {("n3", "r34", "n4"): eids} + edges = {("n2", "r23", "n3"): eids} sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask") loader = dgl.dataloading.DistEdgeDataLoader( dist_graph, edges, sampler, batch_size=64