diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 2ecce809ff18..29e9f488c114 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -1859,8 +1859,10 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask): ) -def check_mask_hetero_sampling_gb(tmpdir, num_server, use_graphbolt=True): - def create_hetero_graph(dense=False, empty=False): +def check_hetero_dist_edge_dataloader_gb( + tmpdir, num_server, use_graphbolt=True +): + def create_hetero_graph(): num_nodes = {"n1": 210, "n2": 200, "n3": 220, "n4": 230} etypes = [("n1", "r12", "n2"), ("n2", "r23", "n3"), ("n3", "r34", "n4")] edges = {} @@ -1868,8 +1870,8 @@ def create_hetero_graph(dense=False, empty=False): for etype in etypes: src_ntype, _, dst_ntype = etype arr = spsp.random( - num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype], - num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype], + num_nodes[src_ntype], + num_nodes[dst_ntype], density=0.1, format="coo", random_state=100, @@ -1930,19 +1932,22 @@ def create_hetero_graph(dense=False, empty=False): loader = dgl.dataloading.DistEdgeDataLoader( dist_graph, edges, sampler, batch_size=64 ) + dgl.distributed.exit_client() + for p in pserver_list: + p.join() + assert p.exitcode == 0 block = next(iter(loader))[2][0] assert block.num_src_nodes("n1") > 0 -@pytest.mark.parametrize("num_parts", [1]) -def test_local_masked_sampling_heterograph_gb( - num_server, +def test_hetero_dist_edge_dataloader_gb( + num_server=1, ): reset_envs() os.environ["DGL_DIST_MODE"] = "distributed" with tempfile.TemporaryDirectory() as tmpdirname: - check_mask_hetero_sampling_gb(Path(tmpdirname), num_server) + check_hetero_dist_edge_dataloader_gb(Path(tmpdirname), num_server) if __name__ == "__main__":