Skip to content

Commit

Permalink
change load partition
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Oct 11, 2024
1 parent 512c616 commit fc61534
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
17 changes: 16 additions & 1 deletion python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _verify_graphbolt_partition(graph, part_id, gpb, ntypes, etypes):
print(f"Partition {part_id} looks good!")


def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False):
def load_partition(part_config, part_id, load_feats=True, use_graphbolt=None):
"""Load data of a partition from the data path.
A partition data includes a graph structure of the partition, a dict of node tensors,
Expand Down Expand Up @@ -334,6 +334,21 @@ def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False):
"part-{}".format(part_id) in part_metadata
), "part-{} does not exist".format(part_id)
part_files = part_metadata["part-{}".format(part_id)]

if use_graphbolt is None:
if os.path.exists(
os.path.join(config_path, f"part{part_id}", "graph.dgl")
):
use_graphbolt = False
elif os.path.exists(
os.path.join(
config_path, f"part{part_id}", "fused_csc_sampling_graph.pt"
)
):
use_graphbolt = True
else:
raise ValueError("The graph object doesn't exist.")

if use_graphbolt:
part_graph_field = "part_graph_graphbolt"
else:
Expand Down
22 changes: 6 additions & 16 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def start_dist_dataloader(
gpb = None
disable_shared_mem = num_server > 1
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(
part_config, rank, use_graphbolt=use_graphbolt
)
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_nodes_to_sample = 202
batch_size = 32
train_nid = th.arange(num_nodes_to_sample)
Expand Down Expand Up @@ -465,9 +463,7 @@ def start_node_dataloader(
gpb = None
disable_shared_mem = num_server > 1
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(
part_config, rank, use_graphbolt=use_graphbolt
)
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_nodes_to_sample = 202
batch_size = 32
graph_name = os.path.splitext(os.path.basename(part_config))[0]
Expand All @@ -486,9 +482,7 @@ def start_node_dataloader(
}

for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(
part_config, i, use_graphbolt=use_graphbolt
)
part, _, _, _, _, _, _ = load_partition(part_config, i)

# Create sampler
_prob = None
Expand Down Expand Up @@ -595,9 +589,7 @@ def start_edge_dataloader(
gpb = None
disable_shared_mem = num_server > 1
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(
part_config, rank, use_graphbolt=use_graphbolt
)
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_edges_to_sample = 202
batch_size = 32
graph_name = os.path.splitext(os.path.basename(part_config))[0]
Expand All @@ -612,9 +604,7 @@ def start_edge_dataloader(
}

for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(
part_config, i, use_graphbolt=use_graphbolt
)
part, _, _, _, _, _, _ = load_partition(part_config, i)

# Create sampler
_prob = None
Expand Down Expand Up @@ -1187,4 +1177,4 @@ def test_deprecated_dataloader(dataloader_type):
0,
dataloader_type,
use_deprecated_dataloader=True,

Check warning on line 1179 in tests/distributed/test_mp_dataloader.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
)
)

0 comments on commit fc61534

Please sign in to comment.