From 8f9e6393b3f4bfcb8e2c5c18b78c65ddeaa17ef2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 14 Aug 2024 02:35:23 +0000 Subject: [PATCH 01/37] change a variable --- python/dgl/distributed/partition.py | 93 ++++++++++++++++++----------- 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 73ea48959597..fd0fcae9d9c2 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1105,7 +1105,7 @@ def get_homogeneous(g, balance_ntypes): inner_node_mask = _get_inner_node_mask(parts[i], ntype_id) val.append( F.as_scalar(F.sum(F.astype(inner_node_mask, F.int64), 0)) - ) + )#note inner_node_mask(tensor[n,bool])->tensor[n,int64]->sum->scalar, compute the num of one partition inner_nids = F.boolean_mask( parts[i].ndata[NID], inner_node_mask ) @@ -1115,7 +1115,7 @@ def get_homogeneous(g, balance_ntypes): int(F.as_scalar(inner_nids[-1])) + 1, ] ) - val = np.cumsum(val).tolist() + val = np.cumsum(val).tolist()# note computing the cumulative sum of array elements. assert val[-1] == g.num_nodes(ntype) for etype in g.canonical_etypes: etype_id = g.get_etype_id(etype) @@ -1135,7 +1135,7 @@ def get_homogeneous(g, balance_ntypes): [int(inner_eids[0]), int(inner_eids[-1]) + 1] ) val = np.cumsum(val).tolist() - assert val[-1] == g.num_edges(etype) + assert val[-1] == g.num_edges(etype)# note assure the tot graph can be used else: node_map_val = {} edge_map_val = {} @@ -1305,32 +1305,52 @@ def get_homogeneous(g, balance_ntypes): part_dir = os.path.join(out_path, "part" + str(part_id)) node_feat_file = os.path.join(part_dir, "node_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") - part_graph_file = os.path.join(part_dir, "graph.dgl") - part_metadata["part-{}".format(part_id)] = { - "node_feats": os.path.relpath(node_feat_file, out_path), - "edge_feats": os.path.relpath(edge_feat_file, out_path), - "part_graph": os.path.relpath(part_graph_file, out_path), - } + os.makedirs(part_dir, mode=0o775, exist_ok=True) save_tensors(node_feat_file, node_feats) save_tensors(edge_feat_file, edge_feats) - sort_etypes = len(g.etypes) > 1 - _save_graphs( - part_graph_file, - [part], - formats=graph_formats, - sort_etypes=sort_etypes, + #save + if use_graphbolt: + part_metadata["part-{}".format(part_id)] = { + "node_feats": os.path.relpath(node_feat_file, out_path), + "edge_feats": os.path.relpath(edge_feat_file, out_path), + } + else: + part_graph_file = os.path.join(part_dir, "graph.dgl") + + part_metadata["part-{}".format(part_id)] = { + "node_feats": os.path.relpath(node_feat_file, out_path), + "edge_feats": os.path.relpath(edge_feat_file, out_path), + "part_graph": os.path.relpath(part_graph_file, out_path), + } + sort_etypes = len(g.etypes) > 1 + _save_graphs( + part_graph_file, + [part], + formats=graph_formats, + sort_etypes=sort_etypes, + ) + + + part_config = os.path.join(out_path, graph_name + ".json") + if use_graphbolt: + kwargs["graph_formats"] = graph_formats + dgl_partition_to_graphbolt( + part_config, + parts=parts, + part_meta=part_metadata, + **kwargs, ) + else: + _dump_part_config(part_config, part_metadata) + print( "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( time.time() - start, get_peak_mem() ) ) - part_config = os.path.join(out_path, graph_name + ".json") - _dump_part_config(part_config, part_metadata) - num_cuts = sim_g.num_edges() - tot_num_inner_edges if num_parts == 1: num_cuts = 0 @@ -1340,13 +1360,6 @@ def get_homogeneous(g, balance_ntypes): ) ) - if use_graphbolt: - kwargs["graph_formats"] = graph_formats - dgl_partition_to_graphbolt( - part_config, - **kwargs, - ) - if return_mapping: return orig_nids, orig_eids @@ -1392,9 +1405,9 @@ def init_type_per_edge(graph, gpb): etype_ids = gpb.map_to_per_etype(graph.edata[EID])[0] return etype_ids - -def gb_convert_single_dgl_partition( +def gb_convert_single_dgl_partition(# TODO change this part_id, + parts, graph_formats, part_config, store_eids, @@ -1427,14 +1440,18 @@ def gb_convert_single_dgl_partition( "Running in debug mode which means all attributes of DGL partitions" " will be saved to the new format." ) - + part_meta = _load_part_config(part_config) num_parts = part_meta["num_parts"] - graph, _, _, gpb, _, _, _ = load_partition( - part_config, part_id, load_feats=False - ) - _, _, ntypes, etypes = load_partition_book(part_config, part_id) + if parts!=None: + assert len(parts)==num_parts + graph=parts[part_id] + else: + graph, _, _, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=False + ) + gpb, _, ntypes, etypes = load_partition_book(part_config, part_id) is_homo = is_homogeneous(ntypes, etypes) node_type_to_id = ( None if is_homo else {ntype: ntid for ntid, ntype in enumerate(ntypes)} @@ -1503,7 +1520,7 @@ def gb_convert_single_dgl_partition( indptr, dtype=indices.dtype ) - # Cast various data to minimum dtype. + # Cast various data to minimum dtype.#note convert to minimun dtype # Cast 1: indptr. indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr) # Cast 2: indices. @@ -1552,7 +1569,6 @@ def gb_convert_single_dgl_partition( return os.path.relpath(csc_graph_path, os.path.dirname(part_config)) # Update graph path. - def dgl_partition_to_graphbolt( part_config, *, @@ -1561,7 +1577,10 @@ def dgl_partition_to_graphbolt( store_inner_edge=False, graph_formats=None, n_jobs=1, -): + parts=None, + part_meta=None +):# note + """Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt. This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is @@ -1598,7 +1617,8 @@ def dgl_partition_to_graphbolt( "Running in debug mode which means all attributes of DGL partitions" " will be saved to the new format." ) - part_meta = _load_part_config(part_config) + if part_meta==None: + part_meta = _load_part_config(part_config) new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] @@ -1615,6 +1635,7 @@ def dgl_partition_to_graphbolt( convert_with_format = partial( gb_convert_single_dgl_partition, graph_formats=graph_formats, + parts=parts, part_config=part_config, store_eids=store_eids, store_inner_node=store_inner_node, From bfeb3b454e4dc928ed6e69d2e5f38976c9d83968 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Aug 2024 09:35:36 +0000 Subject: [PATCH 02/37] modify partition test case --- tests/distributed/test_partition.py | 641 +++++++++++++++++++++------- 1 file changed, 485 insertions(+), 156 deletions(-) diff --git a/tests/distributed/test_partition.py b/tests/distributed/test_partition.py index 5fb121750e01..0f2425cb054d 100644 --- a/tests/distributed/test_partition.py +++ b/tests/distributed/test_partition.py @@ -5,10 +5,12 @@ import dgl import dgl.backend as F +import dgl.sparse as dglsp import numpy as np import pytest import torch as th from dgl import function as fn +from dgl.base import NTYPE from dgl.distributed import ( dgl_partition_to_graphbolt, load_partition, @@ -35,12 +37,19 @@ from utils import reset_envs -def _verify_partition_data_types(part_g): - for k, dtype in RESERVED_FIELD_DTYPE.items(): - if k in part_g.ndata: - assert part_g.ndata[k].dtype == dtype - if k in part_g.edata: - assert part_g.edata[k].dtype == dtype +def _verify_partition_data_types(part_g, use_graphbolt=False): + if not use_graphbolt: + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in part_g.ndata: + assert part_g.ndata[k].dtype == dtype + if k in part_g.edata: + assert part_g.edata[k].dtype == dtype + else: + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in part_g.node_attributes: + assert part_g.node_attributes[k].dtype == dtype + if k in part_g.edge_attributes: + assert part_g.edge_attributes[k].dtype == dtype def _verify_partition_formats(part_g, formats): @@ -81,11 +90,58 @@ def create_random_hetero(): return dgl.heterograph(edges, num_nodes) -def verify_hetero_graph(g, parts): +def verify_hetero_graph(g, parts, use_graphbolt=False): + if use_graphbolt: + num_nodes = {ntype: 0 for ntype in g.ntypes} + num_edges = {etype: 0 for etype in g.canonical_etypes} + for part in parts: + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask( + part, etype_id, use_graphbolt + ) + num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0) + num_edges[etype] += num_inner_edges + + # Verify the number of edges are correct. + for etype in g.canonical_etypes: + print( + "edge {}: {}, {}".format( + etype, g.num_edges(etype), num_edges[etype] + ) + ) + assert g.num_edges(etype) == num_edges[etype] + + nids = {ntype: [] for ntype in g.ntypes} + eids = {etype: [] for etype in g.canonical_etypes} + for part in parts: + eid = th.arange(len(part.edge_attributes[dgl.EID])) + etype_arr = F.gather_row(part.type_per_edge, eid) + eid_type = F.gather_row(part.edge_attributes[dgl.EID], eid) + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + eids[etype].append( + F.boolean_mask(eid_type, etype_arr == etype_id) + ) + # Make sure edge Ids fall into a range. + inner_edge_mask = _get_inner_edge_mask( + part, etype_id, use_graphbolt + ) + inner_eids = np.sort( + F.asnumpy( + F.boolean_mask( + part.edge_attributes[dgl.EID], inner_edge_mask + ) + ) + ) + assert np.all( + inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1) + ) + return + num_nodes = {ntype: 0 for ntype in g.ntypes} num_edges = {etype: 0 for etype in g.canonical_etypes} for part in parts: - assert len(g.ntypes) == len(F.unique(part.ndata[dgl.NTYPE])) assert len(g.canonical_etypes) == len(F.unique(part.edata[dgl.ETYPE])) for ntype in g.ntypes: ntype_id = g.get_ntype_id(ntype) @@ -161,47 +217,107 @@ def verify_hetero_graph(g, parts): def verify_graph_feats( - g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids + g, + gpb, + part, + node_feats, + edge_feats, + orig_nids, + orig_eids, + use_graphbolt=False, ): - for ntype in g.ntypes: - ntype_id = g.get_ntype_id(ntype) - inner_node_mask = _get_inner_node_mask(part, ntype_id) - inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask) - ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) - partid = gpb.nid2partid(inner_type_nids, ntype) - assert np.all(F.asnumpy(ntype_ids) == ntype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) - - orig_id = orig_nids[ntype][inner_type_nids] - local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) - - for name in g.nodes[ntype].data: - if name in [dgl.NID, "inner_node"]: - continue - true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) - ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) - assert np.all(F.asnumpy(ndata == true_feats)) + if use_graphbolt: + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask( + part, ntype_id, use_graphbolt + ) + inner_nids = F.boolean_mask( + part.node_attributes[dgl.NID], inner_node_mask + ) + ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) + partid = gpb.nid2partid(inner_type_nids, ntype) + assert np.all(F.asnumpy(ntype_ids) == ntype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) - for etype in g.canonical_etypes: - etype_id = g.get_etype_id(etype) - inner_edge_mask = _get_inner_edge_mask(part, etype_id) - inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask) - etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) - partid = gpb.eid2partid(inner_type_eids, etype) - assert np.all(F.asnumpy(etype_ids) == etype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) - - orig_id = orig_eids[etype][inner_type_eids] - local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) - - for name in g.edges[etype].data: - if name in [dgl.EID, "inner_edge"]: - continue - true_feats = F.gather_row(g.edges[etype].data[name], orig_id) - edata = F.gather_row( - edge_feats[_etype_tuple_to_str(etype) + "/" + name], local_eids + orig_id = orig_nids[ntype][inner_type_nids] + local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + + for name in g.nodes[ntype].data: + if name in [dgl.NID, "inner_node"]: + continue + true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) + ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) + assert np.all(F.asnumpy(ndata == true_feats)) + + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask( + part, etype_id, use_graphbolt ) - assert np.all(F.asnumpy(edata == true_feats)) + inner_eids = F.boolean_mask( + part.edge_attributes[dgl.EID], inner_edge_mask + ) + etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) + partid = gpb.eid2partid(inner_type_eids, etype) + assert np.all(F.asnumpy(etype_ids) == etype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + orig_id = orig_eids[etype][inner_type_eids] + local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) + + for name in g.edges[etype].data: + if name in [dgl.EID, "inner_edge"]: + continue + true_feats = F.gather_row(g.edges[etype].data[name], orig_id) + edata = F.gather_row( + edge_feats[_etype_tuple_to_str(etype) + "/" + name], + local_eids, + ) + assert np.all(F.asnumpy(edata == true_feats)) + else: + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask( + part, ntype_id, use_graphbolt + ) + inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask) + ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) + partid = gpb.nid2partid(inner_type_nids, ntype) + assert np.all(F.asnumpy(ntype_ids) == ntype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + orig_id = orig_nids[ntype][inner_type_nids] + local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + + for name in g.nodes[ntype].data: + if name in [dgl.NID, "inner_node"]: + continue + true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) + ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) + assert np.all(F.asnumpy(ndata == true_feats)) + + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask) + etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) + partid = gpb.eid2partid(inner_type_eids, etype) + assert np.all(F.asnumpy(etype_ids) == etype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + orig_id = orig_eids[etype][inner_type_eids] + local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) + + for name in g.edges[etype].data: + if name in [dgl.EID, "inner_edge"]: + continue + true_feats = F.gather_row(g.edges[etype].data[name], orig_id) + edata = F.gather_row( + edge_feats[_etype_tuple_to_str(etype) + "/" + name], + local_eids, + ) + assert np.all(F.asnumpy(edata == true_feats)) def check_hetero_partition( @@ -245,7 +361,7 @@ def check_hetero_partition( shuffled_labels = [] shuffled_elabels = [] for i in range(num_parts): - part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition( + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( "/tmp/partition/test.json", i, load_feats=load_feats ) _verify_partition_data_types(part_g) @@ -1075,17 +1191,12 @@ def test_not_sorted_node_edge_map(): @pytest.mark.parametrize("part_method", ["metis", "random"]) @pytest.mark.parametrize("num_parts", [1, 4]) -@pytest.mark.parametrize("store_eids", [True, False]) -@pytest.mark.parametrize("store_inner_node", [True, False]) -@pytest.mark.parametrize("store_inner_edge", [True, False]) @pytest.mark.parametrize("debug_mode", [True, False]) def test_partition_graph_graphbolt_homo( part_method, num_parts, - store_eids, - store_inner_node, - store_inner_edge, debug_mode, + num_trainers_per_machine=1, ): reset_envs() if debug_mode: @@ -1093,148 +1204,369 @@ def test_partition_graph_graphbolt_homo( with tempfile.TemporaryDirectory() as test_dir: g = create_random_graph(1000) graph_name = "test" - partition_graph( + g.ndata["labels"] = F.arange(0, g.num_nodes()) + g.ndata["feats"] = F.tensor( + np.random.randn(g.num_nodes(), 10), F.float32 + ) + g.edata["feats"] = F.tensor( + np.random.randn(g.num_edges(), 10), F.float32 + ) + g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h")) + g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh")) + + orig_nids, orig_eids = partition_graph( g, graph_name, num_parts, test_dir, part_method=part_method, use_graphbolt=True, - store_eids=store_eids, - store_inner_node=store_inner_node, - store_inner_edge=store_inner_edge, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, + return_mapping=True, ) + part_sizes = [] + shuffled_labels = [] + shuffled_edata = [] part_config = os.path.join(test_dir, f"{graph_name}.json") - for part_id in range(num_parts): - orig_g = dgl.load_graphs( - os.path.join(test_dir, f"part{part_id}/graph.dgl") - )[0][0] - new_g = load_partition( - part_config, part_id, load_feats=False, use_graphbolt=True - )[0] - orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() - assert th.equal(orig_indptr, new_g.csc_indptr) - assert th.equal(orig_indices, new_g.indices) - assert new_g.node_type_offset is None - assert th.equal( - orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] + for i in range(num_parts): + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + part_config, i, load_feats=True, use_graphbolt=True ) - if store_inner_node or debug_mode: - assert th.equal( - orig_g.ndata["inner_node"], - new_g.node_attributes["inner_node"], - ) - else: - assert "inner_node" not in new_g.node_attributes - if store_eids or debug_mode: - assert th.equal( - orig_g.edata[dgl.EID][orig_eids], - new_g.edge_attributes[dgl.EID], - ) - else: - assert dgl.EID not in new_g.edge_attributes - if store_inner_edge or debug_mode: - assert th.equal( - orig_g.edata["inner_edge"][orig_eids], - new_g.edge_attributes["inner_edge"], - ) - else: - assert "inner_edge" not in new_g.edge_attributes - assert new_g.type_per_edge is None - assert new_g.node_type_to_id is None - assert new_g.edge_type_to_id is None + if num_trainers_per_machine > 1: + for ntype in g.ntypes: + name = ntype + "/trainer_id" + assert name in node_feats + part_ids = F.floor_div( + node_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == i) + + for etype in g.canonical_etypes: + name = _etype_tuple_to_str(etype) + "/trainer_id" + assert name in edge_feats + part_ids = F.floor_div( + edge_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == i) + + # Check the metadata + assert gpb._num_nodes() == g.num_nodes() + assert gpb._num_edges() == g.num_edges() + + assert gpb.num_partitions() == num_parts + gpb_meta = gpb.metadata() + assert len(gpb_meta) == num_parts + assert len(gpb.partid2nids(i)) == gpb_meta[i]["num_nodes"] + assert len(gpb.partid2eids(i)) == gpb_meta[i]["num_edges"] + part_sizes.append( + (gpb_meta[i]["num_nodes"], gpb_meta[i]["num_edges"]) + ) + + nid = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + local_nid = gpb.nid2localnid(nid, i) + assert F.dtype(local_nid) in (F.int64, F.int32) + assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid))) + eid = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + local_eid = gpb.eid2localeid(eid, i) + assert F.dtype(local_eid) in (F.int64, F.int32) + assert np.all( + np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid)) + ) + + # Check the node map. + local_nodes = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + llocal_nodes = F.nonzero_1d(part_g.node_attributes["inner_node"]) + local_nodes1 = gpb.partid2nids(i) + assert F.dtype(local_nodes1) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_nodes)) + == np.sort(F.asnumpy(local_nodes1)) + ) + assert np.all( + F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)) + ) + + # Check the edge map. + local_edges = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + llocal_edges = F.nonzero_1d(part_g.edge_attributes["inner_edge"]) + local_edges1 = gpb.partid2eids(i) + assert F.dtype(local_edges1) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_edges)) + == np.sort(F.asnumpy(local_edges1)) + ) + assert np.all( + F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)) + ) + + # Verify the mapping between the reshuffled IDs and the original IDs. + indices, indptr = part_g.indices, part_g.csc_indptr + adj_matrix = dglsp.from_csc(indptr, indices) + part_src_ids, part_dst_ids = adj_matrix.coo() + part_src_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_src_ids + ) + part_dst_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_dst_ids + ) + part_eids = part_g.edge_attributes[dgl.EID] + orig_src_ids = F.gather_row(orig_nids, part_src_ids) + orig_dst_ids = F.gather_row(orig_nids, part_dst_ids) + orig_eids1 = F.gather_row(orig_eids, part_eids) + orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids) + assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0] + assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) + + local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]] + local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]] + part_g.node_attributes["feats"] = F.gather_row( + g.ndata["feats"], local_orig_nids + ) + part_g.edge_attributes["feats"] = F.gather_row( + g.edata["feats"], local_orig_eids + ) + local_nodes = orig_nids[local_nodes] + local_edges = orig_eids[local_edges] + + # part_g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h")) + # part_g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh")) + # part_g.node_attributes["h"] = adj_matrix@part_g.node_attributes["h"] + + # assert F.allclose( + # F.gather_row(g.ndata["h"], local_nodes), + # F.gather_row(part_g.node_attributes["h"], llocal_nodes), + # ) + # assert F.allclose( + # F.gather_row(g.ndata["eh"], local_nodes), + # F.gather_row(part_g.node_attributes["eh"], llocal_nodes), + # ) + + for name in ["labels", "feats"]: + assert "_N/" + name in node_feats + assert node_feats["_N/" + name].shape[0] == len(local_nodes) + true_feats = F.gather_row(g.ndata[name], local_nodes) + ndata = F.gather_row(node_feats["_N/" + name], local_nid) + assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata)) + for name in ["feats"]: + efeat_name = _etype_tuple_to_str(DEFAULT_ETYPE) + "/" + name + assert efeat_name in edge_feats + assert edge_feats[efeat_name].shape[0] == len(local_edges) + true_feats = F.gather_row(g.edata[name], local_edges) + edata = F.gather_row(edge_feats[efeat_name], local_eid) + assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata)) + + # This only works if node/edge IDs are shuffled. + shuffled_labels.append(node_feats["_N/labels"]) + shuffled_edata.append(edge_feats["_N:_E:_N/feats"]) + + # Verify that we can reconstruct node/edge data for original IDs. + shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) + shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0)) + orig_labels = np.zeros( + shuffled_labels.shape, dtype=shuffled_labels.dtype + ) + orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype) + orig_labels[F.asnumpy(orig_nids)] = shuffled_labels + orig_edata[F.asnumpy(orig_eids)] = shuffled_edata + assert np.all(orig_labels == F.asnumpy(g.ndata["labels"])) + assert np.all(orig_edata == F.asnumpy(g.edata["feats"])) + + node_map = [] + edge_map = [] + for i, (num_nodes, num_edges) in enumerate(part_sizes): + node_map.append(np.ones(num_nodes) * i) + edge_map.append(np.ones(num_edges) * i) + node_map = np.concatenate(node_map) + edge_map = np.concatenate(edge_map) + nid2pid = gpb.nid2partid(F.arange(0, len(node_map))) + assert F.dtype(nid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(nid2pid) == node_map) + eid2pid = gpb.eid2partid(F.arange(0, len(edge_map))) + assert F.dtype(eid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(eid2pid) == edge_map) @pytest.mark.parametrize("part_method", ["metis", "random"]) @pytest.mark.parametrize("num_parts", [1, 4]) -@pytest.mark.parametrize("store_eids", [True, False]) -@pytest.mark.parametrize("store_inner_node", [True, False]) -@pytest.mark.parametrize("store_inner_edge", [True, False]) @pytest.mark.parametrize("debug_mode", [True, False]) def test_partition_graph_graphbolt_hetero( part_method, num_parts, - store_eids, - store_inner_node, - store_inner_edge, debug_mode, n_jobs=1, + num_trainers_per_machine=1, ): + test_ntype = "n1" + test_etype = ("n1", "r1", "n2") reset_envs() if debug_mode: os.environ["DGL_DIST_DEBUG"] = "1" with tempfile.TemporaryDirectory() as test_dir: - g = create_random_hetero() + hg = create_random_hetero() graph_name = "test" - partition_graph( - g, + hg.nodes[test_ntype].data["labels"] = F.arange( + 0, hg.num_nodes(test_ntype) + ) + hg.nodes[test_ntype].data["feats"] = F.tensor( + np.random.randn(hg.num_nodes(test_ntype), 10), F.float32 + ) + hg.edges[test_etype].data["feats"] = F.tensor( + np.random.randn(hg.num_edges(test_etype), 10), F.float32 + ) + hg.edges[test_etype].data["labels"] = F.arange( + 0, hg.num_edges(test_etype) + ) + num_hops = 1 + orig_nids, orig_eids = partition_graph( + hg, graph_name, num_parts, test_dir, part_method=part_method, + return_mapping=True, + num_trainers_per_machine=1, use_graphbolt=True, - store_eids=store_eids, - store_inner_node=store_inner_node, - store_inner_edge=store_inner_edge, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, n_jobs=n_jobs, ) + assert len(orig_nids) == len(hg.ntypes) + assert len(orig_eids) == len(hg.canonical_etypes) + for ntype in hg.ntypes: + assert len(orig_nids[ntype]) == hg.num_nodes(ntype) + for etype in hg.canonical_etypes: + assert len(orig_eids[etype]) == hg.num_edges(etype) + parts = [] + shuffled_labels = [] + shuffled_elabels = [] part_config = os.path.join(test_dir, f"{graph_name}.json") for part_id in range(num_parts): - orig_g = dgl.load_graphs( - os.path.join(test_dir, f"part{part_id}/graph.dgl") - )[0][0] - new_g = load_partition( - part_config, part_id, load_feats=False, use_graphbolt=True - )[0] - orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() - assert th.equal(orig_indptr, new_g.csc_indptr) - assert th.equal(orig_indices, new_g.indices) - assert th.equal( - orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=True, use_graphbolt=True ) - if store_inner_node or debug_mode: - assert th.equal( - orig_g.ndata["inner_node"], - new_g.node_attributes["inner_node"], + if num_trainers_per_machine > 1: + for ntype in hg.ntypes: + name = ntype + "/trainer_id" + assert name in node_feats + part_ids = F.floor_div( + node_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == part_id) + + for etype in hg.canonical_etypes: + name = _etype_tuple_to_str(etype) + "/trainer_id" + assert name in edge_feats + part_ids = F.floor_div( + edge_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == part_id) + + # Verify the mapping between the reshuffled IDs and the original IDs. + # These are partition-local IDs. + indices, indptr = part_g.indices, part_g.csc_indptr + csc_matrix = dglsp.from_csc(indptr, indices) + part_src_ids, part_dst_ids = csc_matrix.coo() + # These are reshuffled global homogeneous IDs. + part_src_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_src_ids + ) + part_dst_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_dst_ids + ) + part_eids = part_g.edge_attributes[dgl.EID] + # These are reshuffled per-type IDs. + src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids) + dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids) + etype_ids, part_eids = gpb.map_to_per_etype(part_eids) + # `IdMap` is in int64 by default. + assert src_ntype_ids.dtype == F.int64 + assert dst_ntype_ids.dtype == F.int64 + assert etype_ids.dtype == F.int64 + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_ntype(F.tensor([0], F.int32)) + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_etype(F.tensor([0], F.int32)) + # These are original per-type IDs. + for etype_id, etype in enumerate(hg.canonical_etypes): + part_src_ids1 = F.boolean_mask( + part_src_ids, etype_ids == etype_id ) - else: - assert "inner_node" not in new_g.node_attributes - if debug_mode: - assert th.equal( - orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE] + src_ntype_ids1 = F.boolean_mask( + src_ntype_ids, etype_ids == etype_id ) - else: - assert dgl.NTYPE not in new_g.node_attributes - if store_eids or debug_mode: - assert th.equal( - orig_g.edata[dgl.EID][orig_eids], - new_g.edge_attributes[dgl.EID], + part_dst_ids1 = F.boolean_mask( + part_dst_ids, etype_ids == etype_id ) - else: - assert dgl.EID not in new_g.edge_attributes - if store_inner_edge or debug_mode: - assert th.equal( - orig_g.edata["inner_edge"], - new_g.edge_attributes["inner_edge"], + dst_ntype_ids1 = F.boolean_mask( + dst_ntype_ids, etype_ids == etype_id ) - else: - assert "inner_edge" not in new_g.edge_attributes - if debug_mode: - assert th.equal( - orig_g.edata[dgl.ETYPE][orig_eids], - new_g.edge_attributes[dgl.ETYPE], + part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id) + assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0])) + assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0])) + src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])] + dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])] + orig_src_ids1 = F.gather_row( + orig_nids[src_ntype], part_src_ids1 ) - else: - assert dgl.ETYPE not in new_g.edge_attributes - assert th.equal( - orig_g.edata[dgl.ETYPE][orig_eids], new_g.type_per_edge + orig_dst_ids1 = F.gather_row( + orig_nids[dst_ntype], part_dst_ids1 + ) + orig_eids1 = F.gather_row(orig_eids[etype], part_eids1) + orig_eids2 = hg.edge_ids( + orig_src_ids1, orig_dst_ids1, etype=etype + ) + assert len(orig_eids1) == len(orig_eids2) + assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) + parts.append(part_g) + if NTYPE in part_g.node_attributes: + verify_graph_feats( + hg, + gpb, + part_g, + node_feats, + edge_feats, + orig_nids, + orig_eids, + use_graphbolt=True, + ) + + shuffled_labels.append(node_feats[test_ntype + "/labels"]) + shuffled_elabels.append( + edge_feats[_etype_tuple_to_str(test_etype) + "/labels"] ) + verify_hetero_graph(hg, parts, True) - for node_type, type_id in new_g.node_type_to_id.items(): - assert g.get_ntype_id(node_type) == type_id - for edge_type, type_id in new_g.edge_type_to_id.items(): - assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id - assert new_g.node_type_offset is None + shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) + shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0)) + orig_labels = np.zeros( + shuffled_labels.shape, dtype=shuffled_labels.dtype + ) + orig_elabels = np.zeros( + shuffled_elabels.shape, dtype=shuffled_elabels.dtype + ) + orig_labels[F.asnumpy(orig_nids[test_ntype])] = shuffled_labels + orig_elabels[F.asnumpy(orig_eids[test_etype])] = shuffled_elabels + assert np.all( + orig_labels == F.asnumpy(hg.nodes[test_ntype].data["labels"]) + ) + assert np.all( + orig_elabels == F.asnumpy(hg.edges[test_etype].data["labels"]) + ) @pytest.mark.parametrize("part_method", ["metis", "random"]) @@ -1461,9 +1793,6 @@ def test_partition_graph_graphbolt_hetero_multi( part_method="random", num_parts=num_parts, n_jobs=4, - store_eids=True, - store_inner_node=True, - store_inner_edge=True, debug_mode=False, ) From bec2af3f84d468a9ffb2603710adf3f4405f393e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Aug 2024 09:45:50 +0000 Subject: [PATCH 03/37] change pr --- python/dgl/distributed/partition.py | 95 ++--- tests/distributed/test_partition.py | 641 +++++++++++++++++++++------- 2 files changed, 522 insertions(+), 214 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index fd0fcae9d9c2..ab5cf670d743 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1105,7 +1105,7 @@ def get_homogeneous(g, balance_ntypes): inner_node_mask = _get_inner_node_mask(parts[i], ntype_id) val.append( F.as_scalar(F.sum(F.astype(inner_node_mask, F.int64), 0)) - )#note inner_node_mask(tensor[n,bool])->tensor[n,int64]->sum->scalar, compute the num of one partition + ) inner_nids = F.boolean_mask( parts[i].ndata[NID], inner_node_mask ) @@ -1115,7 +1115,7 @@ def get_homogeneous(g, balance_ntypes): int(F.as_scalar(inner_nids[-1])) + 1, ] ) - val = np.cumsum(val).tolist()# note computing the cumulative sum of array elements. + val = np.cumsum(val).tolist() assert val[-1] == g.num_nodes(ntype) for etype in g.canonical_etypes: etype_id = g.get_etype_id(etype) @@ -1135,7 +1135,7 @@ def get_homogeneous(g, balance_ntypes): [int(inner_eids[0]), int(inner_eids[-1]) + 1] ) val = np.cumsum(val).tolist() - assert val[-1] == g.num_edges(etype)# note assure the tot graph can be used + assert val[-1] == g.num_edges(etype) else: node_map_val = {} edge_map_val = {} @@ -1305,52 +1305,32 @@ def get_homogeneous(g, balance_ntypes): part_dir = os.path.join(out_path, "part" + str(part_id)) node_feat_file = os.path.join(part_dir, "node_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") - - os.makedirs(part_dir, mode=0o775, exist_ok=True) - save_tensors(node_feat_file, node_feats) - save_tensors(edge_feat_file, edge_feats) - - #save - if use_graphbolt: - part_metadata["part-{}".format(part_id)] = { - "node_feats": os.path.relpath(node_feat_file, out_path), - "edge_feats": os.path.relpath(edge_feat_file, out_path), - } - else: - part_graph_file = os.path.join(part_dir, "graph.dgl") - - part_metadata["part-{}".format(part_id)] = { + part_graph_file = os.path.join(part_dir, "graph.dgl") + part_metadata["part-{}".format(part_id)] = { "node_feats": os.path.relpath(node_feat_file, out_path), "edge_feats": os.path.relpath(edge_feat_file, out_path), "part_graph": os.path.relpath(part_graph_file, out_path), } - sort_etypes = len(g.etypes) > 1 - _save_graphs( - part_graph_file, - [part], - formats=graph_formats, - sort_etypes=sort_etypes, - ) - - - part_config = os.path.join(out_path, graph_name + ".json") - if use_graphbolt: - kwargs["graph_formats"] = graph_formats - dgl_partition_to_graphbolt( - part_config, - parts=parts, - part_meta=part_metadata, - **kwargs, + os.makedirs(part_dir, mode=0o775, exist_ok=True) + save_tensors(node_feat_file, node_feats) + save_tensors(edge_feat_file, edge_feats) + + sort_etypes = len(g.etypes) > 1 + _save_graphs( + part_graph_file, + [part], + formats=graph_formats, + sort_etypes=sort_etypes, ) - else: - _dump_part_config(part_config, part_metadata) - print( "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( time.time() - start, get_peak_mem() ) ) + part_config = os.path.join(out_path, graph_name + ".json") + _dump_part_config(part_config, part_metadata) + num_cuts = sim_g.num_edges() - tot_num_inner_edges if num_parts == 1: num_cuts = 0 @@ -1360,6 +1340,13 @@ def get_homogeneous(g, balance_ntypes): ) ) + if use_graphbolt: + kwargs["graph_formats"] = graph_formats + dgl_partition_to_graphbolt( + part_config, + **kwargs, + ) + if return_mapping: return orig_nids, orig_eids @@ -1405,9 +1392,9 @@ def init_type_per_edge(graph, gpb): etype_ids = gpb.map_to_per_etype(graph.edata[EID])[0] return etype_ids -def gb_convert_single_dgl_partition(# TODO change this + +def gb_convert_single_dgl_partition( part_id, - parts, graph_formats, part_config, store_eids, @@ -1440,18 +1427,14 @@ def gb_convert_single_dgl_partition(# TODO change this "Running in debug mode which means all attributes of DGL partitions" " will be saved to the new format." ) - + part_meta = _load_part_config(part_config) num_parts = part_meta["num_parts"] - if parts!=None: - assert len(parts)==num_parts - graph=parts[part_id] - else: - graph, _, _, gpb, _, _, _ = load_partition( - part_config, part_id, load_feats=False - ) - gpb, _, ntypes, etypes = load_partition_book(part_config, part_id) + graph, _, _, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=False + ) + _, _, ntypes, etypes = load_partition_book(part_config, part_id) is_homo = is_homogeneous(ntypes, etypes) node_type_to_id = ( None if is_homo else {ntype: ntid for ntid, ntype in enumerate(ntypes)} @@ -1520,7 +1503,7 @@ def gb_convert_single_dgl_partition(# TODO change this indptr, dtype=indices.dtype ) - # Cast various data to minimum dtype.#note convert to minimun dtype + # Cast various data to minimum dtype. # Cast 1: indptr. indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr) # Cast 2: indices. @@ -1569,6 +1552,7 @@ def gb_convert_single_dgl_partition(# TODO change this return os.path.relpath(csc_graph_path, os.path.dirname(part_config)) # Update graph path. + def dgl_partition_to_graphbolt( part_config, *, @@ -1577,10 +1561,7 @@ def dgl_partition_to_graphbolt( store_inner_edge=False, graph_formats=None, n_jobs=1, - parts=None, - part_meta=None -):# note - +): """Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt. This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is @@ -1617,8 +1598,7 @@ def dgl_partition_to_graphbolt( "Running in debug mode which means all attributes of DGL partitions" " will be saved to the new format." ) - if part_meta==None: - part_meta = _load_part_config(part_config) + part_meta = _load_part_config(part_config) new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] @@ -1635,7 +1615,6 @@ def dgl_partition_to_graphbolt( convert_with_format = partial( gb_convert_single_dgl_partition, graph_formats=graph_formats, - parts=parts, part_config=part_config, store_eids=store_eids, store_inner_node=store_inner_node, @@ -1675,4 +1654,4 @@ def dgl_partition_to_graphbolt( new_part_meta["edge_map_dtype"] = "int64" _dump_part_config(part_config, new_part_meta) - print(f"Converted partitions to GraphBolt format into {part_config}") + print(f"Converted partitions to GraphBolt format into {part_config}") \ No newline at end of file diff --git a/tests/distributed/test_partition.py b/tests/distributed/test_partition.py index 5fb121750e01..0f2425cb054d 100644 --- a/tests/distributed/test_partition.py +++ b/tests/distributed/test_partition.py @@ -5,10 +5,12 @@ import dgl import dgl.backend as F +import dgl.sparse as dglsp import numpy as np import pytest import torch as th from dgl import function as fn +from dgl.base import NTYPE from dgl.distributed import ( dgl_partition_to_graphbolt, load_partition, @@ -35,12 +37,19 @@ from utils import reset_envs -def _verify_partition_data_types(part_g): - for k, dtype in RESERVED_FIELD_DTYPE.items(): - if k in part_g.ndata: - assert part_g.ndata[k].dtype == dtype - if k in part_g.edata: - assert part_g.edata[k].dtype == dtype +def _verify_partition_data_types(part_g, use_graphbolt=False): + if not use_graphbolt: + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in part_g.ndata: + assert part_g.ndata[k].dtype == dtype + if k in part_g.edata: + assert part_g.edata[k].dtype == dtype + else: + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in part_g.node_attributes: + assert part_g.node_attributes[k].dtype == dtype + if k in part_g.edge_attributes: + assert part_g.edge_attributes[k].dtype == dtype def _verify_partition_formats(part_g, formats): @@ -81,11 +90,58 @@ def create_random_hetero(): return dgl.heterograph(edges, num_nodes) -def verify_hetero_graph(g, parts): +def verify_hetero_graph(g, parts, use_graphbolt=False): + if use_graphbolt: + num_nodes = {ntype: 0 for ntype in g.ntypes} + num_edges = {etype: 0 for etype in g.canonical_etypes} + for part in parts: + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask( + part, etype_id, use_graphbolt + ) + num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0) + num_edges[etype] += num_inner_edges + + # Verify the number of edges are correct. + for etype in g.canonical_etypes: + print( + "edge {}: {}, {}".format( + etype, g.num_edges(etype), num_edges[etype] + ) + ) + assert g.num_edges(etype) == num_edges[etype] + + nids = {ntype: [] for ntype in g.ntypes} + eids = {etype: [] for etype in g.canonical_etypes} + for part in parts: + eid = th.arange(len(part.edge_attributes[dgl.EID])) + etype_arr = F.gather_row(part.type_per_edge, eid) + eid_type = F.gather_row(part.edge_attributes[dgl.EID], eid) + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + eids[etype].append( + F.boolean_mask(eid_type, etype_arr == etype_id) + ) + # Make sure edge Ids fall into a range. + inner_edge_mask = _get_inner_edge_mask( + part, etype_id, use_graphbolt + ) + inner_eids = np.sort( + F.asnumpy( + F.boolean_mask( + part.edge_attributes[dgl.EID], inner_edge_mask + ) + ) + ) + assert np.all( + inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1) + ) + return + num_nodes = {ntype: 0 for ntype in g.ntypes} num_edges = {etype: 0 for etype in g.canonical_etypes} for part in parts: - assert len(g.ntypes) == len(F.unique(part.ndata[dgl.NTYPE])) assert len(g.canonical_etypes) == len(F.unique(part.edata[dgl.ETYPE])) for ntype in g.ntypes: ntype_id = g.get_ntype_id(ntype) @@ -161,47 +217,107 @@ def verify_hetero_graph(g, parts): def verify_graph_feats( - g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids + g, + gpb, + part, + node_feats, + edge_feats, + orig_nids, + orig_eids, + use_graphbolt=False, ): - for ntype in g.ntypes: - ntype_id = g.get_ntype_id(ntype) - inner_node_mask = _get_inner_node_mask(part, ntype_id) - inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask) - ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) - partid = gpb.nid2partid(inner_type_nids, ntype) - assert np.all(F.asnumpy(ntype_ids) == ntype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) - - orig_id = orig_nids[ntype][inner_type_nids] - local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) - - for name in g.nodes[ntype].data: - if name in [dgl.NID, "inner_node"]: - continue - true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) - ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) - assert np.all(F.asnumpy(ndata == true_feats)) + if use_graphbolt: + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask( + part, ntype_id, use_graphbolt + ) + inner_nids = F.boolean_mask( + part.node_attributes[dgl.NID], inner_node_mask + ) + ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) + partid = gpb.nid2partid(inner_type_nids, ntype) + assert np.all(F.asnumpy(ntype_ids) == ntype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) - for etype in g.canonical_etypes: - etype_id = g.get_etype_id(etype) - inner_edge_mask = _get_inner_edge_mask(part, etype_id) - inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask) - etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) - partid = gpb.eid2partid(inner_type_eids, etype) - assert np.all(F.asnumpy(etype_ids) == etype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) - - orig_id = orig_eids[etype][inner_type_eids] - local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) - - for name in g.edges[etype].data: - if name in [dgl.EID, "inner_edge"]: - continue - true_feats = F.gather_row(g.edges[etype].data[name], orig_id) - edata = F.gather_row( - edge_feats[_etype_tuple_to_str(etype) + "/" + name], local_eids + orig_id = orig_nids[ntype][inner_type_nids] + local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + + for name in g.nodes[ntype].data: + if name in [dgl.NID, "inner_node"]: + continue + true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) + ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) + assert np.all(F.asnumpy(ndata == true_feats)) + + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask( + part, etype_id, use_graphbolt ) - assert np.all(F.asnumpy(edata == true_feats)) + inner_eids = F.boolean_mask( + part.edge_attributes[dgl.EID], inner_edge_mask + ) + etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) + partid = gpb.eid2partid(inner_type_eids, etype) + assert np.all(F.asnumpy(etype_ids) == etype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + orig_id = orig_eids[etype][inner_type_eids] + local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) + + for name in g.edges[etype].data: + if name in [dgl.EID, "inner_edge"]: + continue + true_feats = F.gather_row(g.edges[etype].data[name], orig_id) + edata = F.gather_row( + edge_feats[_etype_tuple_to_str(etype) + "/" + name], + local_eids, + ) + assert np.all(F.asnumpy(edata == true_feats)) + else: + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask( + part, ntype_id, use_graphbolt + ) + inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask) + ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) + partid = gpb.nid2partid(inner_type_nids, ntype) + assert np.all(F.asnumpy(ntype_ids) == ntype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + orig_id = orig_nids[ntype][inner_type_nids] + local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + + for name in g.nodes[ntype].data: + if name in [dgl.NID, "inner_node"]: + continue + true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) + ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) + assert np.all(F.asnumpy(ndata == true_feats)) + + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask) + etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) + partid = gpb.eid2partid(inner_type_eids, etype) + assert np.all(F.asnumpy(etype_ids) == etype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + orig_id = orig_eids[etype][inner_type_eids] + local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) + + for name in g.edges[etype].data: + if name in [dgl.EID, "inner_edge"]: + continue + true_feats = F.gather_row(g.edges[etype].data[name], orig_id) + edata = F.gather_row( + edge_feats[_etype_tuple_to_str(etype) + "/" + name], + local_eids, + ) + assert np.all(F.asnumpy(edata == true_feats)) def check_hetero_partition( @@ -245,7 +361,7 @@ def check_hetero_partition( shuffled_labels = [] shuffled_elabels = [] for i in range(num_parts): - part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition( + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( "/tmp/partition/test.json", i, load_feats=load_feats ) _verify_partition_data_types(part_g) @@ -1075,17 +1191,12 @@ def test_not_sorted_node_edge_map(): @pytest.mark.parametrize("part_method", ["metis", "random"]) @pytest.mark.parametrize("num_parts", [1, 4]) -@pytest.mark.parametrize("store_eids", [True, False]) -@pytest.mark.parametrize("store_inner_node", [True, False]) -@pytest.mark.parametrize("store_inner_edge", [True, False]) @pytest.mark.parametrize("debug_mode", [True, False]) def test_partition_graph_graphbolt_homo( part_method, num_parts, - store_eids, - store_inner_node, - store_inner_edge, debug_mode, + num_trainers_per_machine=1, ): reset_envs() if debug_mode: @@ -1093,148 +1204,369 @@ def test_partition_graph_graphbolt_homo( with tempfile.TemporaryDirectory() as test_dir: g = create_random_graph(1000) graph_name = "test" - partition_graph( + g.ndata["labels"] = F.arange(0, g.num_nodes()) + g.ndata["feats"] = F.tensor( + np.random.randn(g.num_nodes(), 10), F.float32 + ) + g.edata["feats"] = F.tensor( + np.random.randn(g.num_edges(), 10), F.float32 + ) + g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h")) + g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh")) + + orig_nids, orig_eids = partition_graph( g, graph_name, num_parts, test_dir, part_method=part_method, use_graphbolt=True, - store_eids=store_eids, - store_inner_node=store_inner_node, - store_inner_edge=store_inner_edge, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, + return_mapping=True, ) + part_sizes = [] + shuffled_labels = [] + shuffled_edata = [] part_config = os.path.join(test_dir, f"{graph_name}.json") - for part_id in range(num_parts): - orig_g = dgl.load_graphs( - os.path.join(test_dir, f"part{part_id}/graph.dgl") - )[0][0] - new_g = load_partition( - part_config, part_id, load_feats=False, use_graphbolt=True - )[0] - orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() - assert th.equal(orig_indptr, new_g.csc_indptr) - assert th.equal(orig_indices, new_g.indices) - assert new_g.node_type_offset is None - assert th.equal( - orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] + for i in range(num_parts): + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + part_config, i, load_feats=True, use_graphbolt=True ) - if store_inner_node or debug_mode: - assert th.equal( - orig_g.ndata["inner_node"], - new_g.node_attributes["inner_node"], - ) - else: - assert "inner_node" not in new_g.node_attributes - if store_eids or debug_mode: - assert th.equal( - orig_g.edata[dgl.EID][orig_eids], - new_g.edge_attributes[dgl.EID], - ) - else: - assert dgl.EID not in new_g.edge_attributes - if store_inner_edge or debug_mode: - assert th.equal( - orig_g.edata["inner_edge"][orig_eids], - new_g.edge_attributes["inner_edge"], - ) - else: - assert "inner_edge" not in new_g.edge_attributes - assert new_g.type_per_edge is None - assert new_g.node_type_to_id is None - assert new_g.edge_type_to_id is None + if num_trainers_per_machine > 1: + for ntype in g.ntypes: + name = ntype + "/trainer_id" + assert name in node_feats + part_ids = F.floor_div( + node_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == i) + + for etype in g.canonical_etypes: + name = _etype_tuple_to_str(etype) + "/trainer_id" + assert name in edge_feats + part_ids = F.floor_div( + edge_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == i) + + # Check the metadata + assert gpb._num_nodes() == g.num_nodes() + assert gpb._num_edges() == g.num_edges() + + assert gpb.num_partitions() == num_parts + gpb_meta = gpb.metadata() + assert len(gpb_meta) == num_parts + assert len(gpb.partid2nids(i)) == gpb_meta[i]["num_nodes"] + assert len(gpb.partid2eids(i)) == gpb_meta[i]["num_edges"] + part_sizes.append( + (gpb_meta[i]["num_nodes"], gpb_meta[i]["num_edges"]) + ) + + nid = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + local_nid = gpb.nid2localnid(nid, i) + assert F.dtype(local_nid) in (F.int64, F.int32) + assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid))) + eid = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + local_eid = gpb.eid2localeid(eid, i) + assert F.dtype(local_eid) in (F.int64, F.int32) + assert np.all( + np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid)) + ) + + # Check the node map. + local_nodes = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + llocal_nodes = F.nonzero_1d(part_g.node_attributes["inner_node"]) + local_nodes1 = gpb.partid2nids(i) + assert F.dtype(local_nodes1) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_nodes)) + == np.sort(F.asnumpy(local_nodes1)) + ) + assert np.all( + F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)) + ) + + # Check the edge map. + local_edges = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + llocal_edges = F.nonzero_1d(part_g.edge_attributes["inner_edge"]) + local_edges1 = gpb.partid2eids(i) + assert F.dtype(local_edges1) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_edges)) + == np.sort(F.asnumpy(local_edges1)) + ) + assert np.all( + F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)) + ) + + # Verify the mapping between the reshuffled IDs and the original IDs. + indices, indptr = part_g.indices, part_g.csc_indptr + adj_matrix = dglsp.from_csc(indptr, indices) + part_src_ids, part_dst_ids = adj_matrix.coo() + part_src_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_src_ids + ) + part_dst_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_dst_ids + ) + part_eids = part_g.edge_attributes[dgl.EID] + orig_src_ids = F.gather_row(orig_nids, part_src_ids) + orig_dst_ids = F.gather_row(orig_nids, part_dst_ids) + orig_eids1 = F.gather_row(orig_eids, part_eids) + orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids) + assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0] + assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) + + local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]] + local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]] + part_g.node_attributes["feats"] = F.gather_row( + g.ndata["feats"], local_orig_nids + ) + part_g.edge_attributes["feats"] = F.gather_row( + g.edata["feats"], local_orig_eids + ) + local_nodes = orig_nids[local_nodes] + local_edges = orig_eids[local_edges] + + # part_g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h")) + # part_g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh")) + # part_g.node_attributes["h"] = adj_matrix@part_g.node_attributes["h"] + + # assert F.allclose( + # F.gather_row(g.ndata["h"], local_nodes), + # F.gather_row(part_g.node_attributes["h"], llocal_nodes), + # ) + # assert F.allclose( + # F.gather_row(g.ndata["eh"], local_nodes), + # F.gather_row(part_g.node_attributes["eh"], llocal_nodes), + # ) + + for name in ["labels", "feats"]: + assert "_N/" + name in node_feats + assert node_feats["_N/" + name].shape[0] == len(local_nodes) + true_feats = F.gather_row(g.ndata[name], local_nodes) + ndata = F.gather_row(node_feats["_N/" + name], local_nid) + assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata)) + for name in ["feats"]: + efeat_name = _etype_tuple_to_str(DEFAULT_ETYPE) + "/" + name + assert efeat_name in edge_feats + assert edge_feats[efeat_name].shape[0] == len(local_edges) + true_feats = F.gather_row(g.edata[name], local_edges) + edata = F.gather_row(edge_feats[efeat_name], local_eid) + assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata)) + + # This only works if node/edge IDs are shuffled. + shuffled_labels.append(node_feats["_N/labels"]) + shuffled_edata.append(edge_feats["_N:_E:_N/feats"]) + + # Verify that we can reconstruct node/edge data for original IDs. + shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) + shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0)) + orig_labels = np.zeros( + shuffled_labels.shape, dtype=shuffled_labels.dtype + ) + orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype) + orig_labels[F.asnumpy(orig_nids)] = shuffled_labels + orig_edata[F.asnumpy(orig_eids)] = shuffled_edata + assert np.all(orig_labels == F.asnumpy(g.ndata["labels"])) + assert np.all(orig_edata == F.asnumpy(g.edata["feats"])) + + node_map = [] + edge_map = [] + for i, (num_nodes, num_edges) in enumerate(part_sizes): + node_map.append(np.ones(num_nodes) * i) + edge_map.append(np.ones(num_edges) * i) + node_map = np.concatenate(node_map) + edge_map = np.concatenate(edge_map) + nid2pid = gpb.nid2partid(F.arange(0, len(node_map))) + assert F.dtype(nid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(nid2pid) == node_map) + eid2pid = gpb.eid2partid(F.arange(0, len(edge_map))) + assert F.dtype(eid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(eid2pid) == edge_map) @pytest.mark.parametrize("part_method", ["metis", "random"]) @pytest.mark.parametrize("num_parts", [1, 4]) -@pytest.mark.parametrize("store_eids", [True, False]) -@pytest.mark.parametrize("store_inner_node", [True, False]) -@pytest.mark.parametrize("store_inner_edge", [True, False]) @pytest.mark.parametrize("debug_mode", [True, False]) def test_partition_graph_graphbolt_hetero( part_method, num_parts, - store_eids, - store_inner_node, - store_inner_edge, debug_mode, n_jobs=1, + num_trainers_per_machine=1, ): + test_ntype = "n1" + test_etype = ("n1", "r1", "n2") reset_envs() if debug_mode: os.environ["DGL_DIST_DEBUG"] = "1" with tempfile.TemporaryDirectory() as test_dir: - g = create_random_hetero() + hg = create_random_hetero() graph_name = "test" - partition_graph( - g, + hg.nodes[test_ntype].data["labels"] = F.arange( + 0, hg.num_nodes(test_ntype) + ) + hg.nodes[test_ntype].data["feats"] = F.tensor( + np.random.randn(hg.num_nodes(test_ntype), 10), F.float32 + ) + hg.edges[test_etype].data["feats"] = F.tensor( + np.random.randn(hg.num_edges(test_etype), 10), F.float32 + ) + hg.edges[test_etype].data["labels"] = F.arange( + 0, hg.num_edges(test_etype) + ) + num_hops = 1 + orig_nids, orig_eids = partition_graph( + hg, graph_name, num_parts, test_dir, part_method=part_method, + return_mapping=True, + num_trainers_per_machine=1, use_graphbolt=True, - store_eids=store_eids, - store_inner_node=store_inner_node, - store_inner_edge=store_inner_edge, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, n_jobs=n_jobs, ) + assert len(orig_nids) == len(hg.ntypes) + assert len(orig_eids) == len(hg.canonical_etypes) + for ntype in hg.ntypes: + assert len(orig_nids[ntype]) == hg.num_nodes(ntype) + for etype in hg.canonical_etypes: + assert len(orig_eids[etype]) == hg.num_edges(etype) + parts = [] + shuffled_labels = [] + shuffled_elabels = [] part_config = os.path.join(test_dir, f"{graph_name}.json") for part_id in range(num_parts): - orig_g = dgl.load_graphs( - os.path.join(test_dir, f"part{part_id}/graph.dgl") - )[0][0] - new_g = load_partition( - part_config, part_id, load_feats=False, use_graphbolt=True - )[0] - orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() - assert th.equal(orig_indptr, new_g.csc_indptr) - assert th.equal(orig_indices, new_g.indices) - assert th.equal( - orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=True, use_graphbolt=True ) - if store_inner_node or debug_mode: - assert th.equal( - orig_g.ndata["inner_node"], - new_g.node_attributes["inner_node"], + if num_trainers_per_machine > 1: + for ntype in hg.ntypes: + name = ntype + "/trainer_id" + assert name in node_feats + part_ids = F.floor_div( + node_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == part_id) + + for etype in hg.canonical_etypes: + name = _etype_tuple_to_str(etype) + "/trainer_id" + assert name in edge_feats + part_ids = F.floor_div( + edge_feats[name], num_trainers_per_machine + ) + assert np.all(F.asnumpy(part_ids) == part_id) + + # Verify the mapping between the reshuffled IDs and the original IDs. + # These are partition-local IDs. + indices, indptr = part_g.indices, part_g.csc_indptr + csc_matrix = dglsp.from_csc(indptr, indices) + part_src_ids, part_dst_ids = csc_matrix.coo() + # These are reshuffled global homogeneous IDs. + part_src_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_src_ids + ) + part_dst_ids = F.gather_row( + part_g.node_attributes[dgl.NID], part_dst_ids + ) + part_eids = part_g.edge_attributes[dgl.EID] + # These are reshuffled per-type IDs. + src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids) + dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids) + etype_ids, part_eids = gpb.map_to_per_etype(part_eids) + # `IdMap` is in int64 by default. + assert src_ntype_ids.dtype == F.int64 + assert dst_ntype_ids.dtype == F.int64 + assert etype_ids.dtype == F.int64 + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_ntype(F.tensor([0], F.int32)) + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_etype(F.tensor([0], F.int32)) + # These are original per-type IDs. + for etype_id, etype in enumerate(hg.canonical_etypes): + part_src_ids1 = F.boolean_mask( + part_src_ids, etype_ids == etype_id ) - else: - assert "inner_node" not in new_g.node_attributes - if debug_mode: - assert th.equal( - orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE] + src_ntype_ids1 = F.boolean_mask( + src_ntype_ids, etype_ids == etype_id ) - else: - assert dgl.NTYPE not in new_g.node_attributes - if store_eids or debug_mode: - assert th.equal( - orig_g.edata[dgl.EID][orig_eids], - new_g.edge_attributes[dgl.EID], + part_dst_ids1 = F.boolean_mask( + part_dst_ids, etype_ids == etype_id ) - else: - assert dgl.EID not in new_g.edge_attributes - if store_inner_edge or debug_mode: - assert th.equal( - orig_g.edata["inner_edge"], - new_g.edge_attributes["inner_edge"], + dst_ntype_ids1 = F.boolean_mask( + dst_ntype_ids, etype_ids == etype_id ) - else: - assert "inner_edge" not in new_g.edge_attributes - if debug_mode: - assert th.equal( - orig_g.edata[dgl.ETYPE][orig_eids], - new_g.edge_attributes[dgl.ETYPE], + part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id) + assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0])) + assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0])) + src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])] + dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])] + orig_src_ids1 = F.gather_row( + orig_nids[src_ntype], part_src_ids1 ) - else: - assert dgl.ETYPE not in new_g.edge_attributes - assert th.equal( - orig_g.edata[dgl.ETYPE][orig_eids], new_g.type_per_edge + orig_dst_ids1 = F.gather_row( + orig_nids[dst_ntype], part_dst_ids1 + ) + orig_eids1 = F.gather_row(orig_eids[etype], part_eids1) + orig_eids2 = hg.edge_ids( + orig_src_ids1, orig_dst_ids1, etype=etype + ) + assert len(orig_eids1) == len(orig_eids2) + assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) + parts.append(part_g) + if NTYPE in part_g.node_attributes: + verify_graph_feats( + hg, + gpb, + part_g, + node_feats, + edge_feats, + orig_nids, + orig_eids, + use_graphbolt=True, + ) + + shuffled_labels.append(node_feats[test_ntype + "/labels"]) + shuffled_elabels.append( + edge_feats[_etype_tuple_to_str(test_etype) + "/labels"] ) + verify_hetero_graph(hg, parts, True) - for node_type, type_id in new_g.node_type_to_id.items(): - assert g.get_ntype_id(node_type) == type_id - for edge_type, type_id in new_g.edge_type_to_id.items(): - assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id - assert new_g.node_type_offset is None + shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) + shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0)) + orig_labels = np.zeros( + shuffled_labels.shape, dtype=shuffled_labels.dtype + ) + orig_elabels = np.zeros( + shuffled_elabels.shape, dtype=shuffled_elabels.dtype + ) + orig_labels[F.asnumpy(orig_nids[test_ntype])] = shuffled_labels + orig_elabels[F.asnumpy(orig_eids[test_etype])] = shuffled_elabels + assert np.all( + orig_labels == F.asnumpy(hg.nodes[test_ntype].data["labels"]) + ) + assert np.all( + orig_elabels == F.asnumpy(hg.edges[test_etype].data["labels"]) + ) @pytest.mark.parametrize("part_method", ["metis", "random"]) @@ -1461,9 +1793,6 @@ def test_partition_graph_graphbolt_hetero_multi( part_method="random", num_parts=num_parts, n_jobs=4, - store_eids=True, - store_inner_node=True, - store_inner_edge=True, debug_mode=False, ) From 21b592da7f21677b4d9e8efd042a3f64b0e1619f Mon Sep 17 00:00:00 2001 From: Ubuntu <2649624957@qq.com> Date: Wed, 21 Aug 2024 03:55:43 +0000 Subject: [PATCH 04/37] change test_partition.py and partiton.py --- python/dgl/distributed/partition.py | 41 ++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 73ea48959597..7d81a15c8cfd 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -109,19 +109,40 @@ def _save_graphs(filename, g_list, formats=None, sort_etypes=False): save_graphs(filename, g_list, formats=formats) -def _get_inner_node_mask(graph, ntype_id): - if NTYPE in graph.ndata: - dtype = F.dtype(graph.ndata["inner_node"]) - return ( - graph.ndata["inner_node"] - * F.astype(graph.ndata[NTYPE] == ntype_id, dtype) - == 1 - ) +def _get_inner_node_mask(graph, ntype_id, use_graphbolt=False): + if use_graphbolt: + if NTYPE in graph.node_attributes: + dtype = F.dtype(graph.node_attributes["inner_node"]) + return ( + graph.node_attributes["inner_node"] + * F.astype(graph.node_attributes[NTYPE] == ntype_id, dtype) + == 1 + ) + else: + return graph.node_attributes["inner_node"] == 1 else: - return graph.ndata["inner_node"] == 1 + if NTYPE in graph.ndata: + dtype = F.dtype(graph.ndata["inner_node"]) + return ( + graph.ndata["inner_node"] + * F.astype(graph.ndata[NTYPE] == ntype_id, dtype) + == 1 + ) + else: + return graph.ndata["inner_node"] == 1 -def _get_inner_edge_mask(graph, etype_id): +def _get_inner_edge_mask(graph, etype_id, use_graphbolt=False): + if use_graphbolt: + if graph.type_per_edge is not None: + dtype = F.dtype(graph.edge_attributes["inner_edge"]) + return ( + graph.edge_attributes["inner_edge"] + * F.astype(graph.type_per_edge == etype_id, dtype) + == 1 + ) + else: + return graph.edge_attributes["inner_edge"] == 1 if ETYPE in graph.edata: dtype = F.dtype(graph.edata["inner_edge"]) return ( From 4ef95d5ae932c81a7d94fa1afc015969a39b2801 Mon Sep 17 00:00:00 2001 From: Ubuntu <2649624957@qq.com> Date: Wed, 21 Aug 2024 04:08:24 +0000 Subject: [PATCH 05/37] partition --- python/dgl/distributed/partition.py | 354 +++++++++++++++++++++------- 1 file changed, 270 insertions(+), 84 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 7d81a15c8cfd..2559f6ec943e 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -88,7 +88,7 @@ def _dump_part_config(part_config, part_metadata): json.dump(part_metadata, outfile, sort_keys=False, indent=4) -def _save_graphs(filename, g_list, formats=None, sort_etypes=False): +def _process_partitions(g_list, formats=None, sort_etypes=False): """Preprocess partitions before saving: 1. format data types. 2. sort csc/csr by tag. @@ -106,6 +106,13 @@ def _save_graphs(filename, g_list, formats=None, sort_etypes=False): g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") if "csc" in formats: g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") + return g_list + + +def _save_graphs(filename, g_list, formats=None, sort_etypes=False): + g_list = _process_partitions( + g_list, formats=formats, sort_etypes=sort_etypes + ) save_graphs(filename, g_list, formats=formats) @@ -336,9 +343,10 @@ def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False): "part-{}".format(part_id) in part_metadata ), "part-{} does not exist".format(part_id) part_files = part_metadata["part-{}".format(part_id)] - part_graph_field = "part_graph" if use_graphbolt: part_graph_field = "part_graph_graphbolt" + else: + part_graph_field = "part_graph" assert ( part_graph_field in part_files ), f"the partition does not contain graph structure: {part_graph_field}" @@ -465,6 +473,105 @@ def load_partition_feats( return node_feats, edge_feats +def _load_partition_book_from_metadata(part_metadata, part_id): + assert "num_parts" in part_metadata, "num_parts does not exist." + assert ( + part_metadata["num_parts"] > part_id + ), "part {} is out of range (#parts: {})".format( + part_id, part_metadata["num_parts"] + ) + num_parts = part_metadata["num_parts"] + assert ( + "num_nodes" in part_metadata + ), "cannot get the number of nodes of the global graph." + assert ( + "num_edges" in part_metadata + ), "cannot get the number of edges of the global graph." + assert "node_map" in part_metadata, "cannot get the node map." + assert "edge_map" in part_metadata, "cannot get the edge map." + assert "graph_name" in part_metadata, "cannot get the graph name" + + # If this is a range partitioning, node_map actually stores a list, whose elements + # indicate the boundary of range partitioning. Otherwise, node_map stores a filename + # that contains node map in a NumPy array. + node_map = part_metadata["node_map"] + edge_map = part_metadata["edge_map"] + if isinstance(node_map, dict): + for key in node_map: + is_range_part = isinstance(node_map[key], list) + break + elif isinstance(node_map, list): + is_range_part = True + node_map = {DEFAULT_NTYPE: node_map} + else: + is_range_part = False + if isinstance(edge_map, list): + edge_map = {DEFAULT_ETYPE: edge_map} + + ntypes = {DEFAULT_NTYPE: 0} + etypes = {DEFAULT_ETYPE: 0} + if "ntypes" in part_metadata: + ntypes = part_metadata["ntypes"] + if "etypes" in part_metadata: + etypes = part_metadata["etypes"] + + if isinstance(node_map, dict): + for key in node_map: + assert key in ntypes, "The node type {} is invalid".format(key) + if isinstance(edge_map, dict): + for key in edge_map: + assert key in etypes, "The edge type {} is invalid".format(key) + + if not is_range_part: + raise TypeError("Only RangePartitionBook is supported currently.") + + node_map = _get_part_ranges(node_map) + edge_map = _get_part_ranges(edge_map) + + # Format dtype of node/edge map if dtype is specified. + def _format_node_edge_map(part_metadata, map_type, data): + key = f"{map_type}_map_dtype" + if key not in part_metadata: + return data + dtype = part_metadata[key] + assert dtype in ["int32", "int64"], ( + f"The {map_type} map dtype should be either int32 or int64, " + f"but got {dtype}." + ) + for key in data: + data[key] = data[key].astype(dtype) + return data + + node_map = _format_node_edge_map(part_metadata, "node", node_map) + edge_map = _format_node_edge_map(part_metadata, "edge", edge_map) + + # Sort the node/edge maps by the node/edge type ID. + node_map = dict(sorted(node_map.items(), key=lambda x: ntypes[x[0]])) + edge_map = dict(sorted(edge_map.items(), key=lambda x: etypes[x[0]])) + + def _assert_is_sorted(id_map): + id_ranges = np.array(list(id_map.values())) + ids = [] + for i in range(num_parts): + ids.append(id_ranges[:, i, :]) + ids = np.array(ids).flatten() + assert np.all( + ids[:-1] <= ids[1:] + ), f"The node/edge map is not sorted: {ids}" + + _assert_is_sorted(node_map) + _assert_is_sorted(edge_map) + + return ( + RangePartitionBook( + part_id, num_parts, node_map, edge_map, ntypes, etypes + ), + part_metadata["graph_name"], + ntypes, + etypes, + ) + + def load_partition_book(part_config, part_id): """Load a graph partition book from the partition config file. @@ -1326,31 +1433,41 @@ def get_homogeneous(g, balance_ntypes): part_dir = os.path.join(out_path, "part" + str(part_id)) node_feat_file = os.path.join(part_dir, "node_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") - part_graph_file = os.path.join(part_dir, "graph.dgl") - part_metadata["part-{}".format(part_id)] = { - "node_feats": os.path.relpath(node_feat_file, out_path), - "edge_feats": os.path.relpath(edge_feat_file, out_path), - "part_graph": os.path.relpath(part_graph_file, out_path), - } + os.makedirs(part_dir, mode=0o775, exist_ok=True) save_tensors(node_feat_file, node_feats) save_tensors(edge_feat_file, edge_feats) + part_metadata["part-{}".format(part_id)] = { + "node_feats": os.path.relpath(node_feat_file, out_path), + "edge_feats": os.path.relpath(edge_feat_file, out_path), + } sort_etypes = len(g.etypes) > 1 - _save_graphs( - part_graph_file, - [part], - formats=graph_formats, - sort_etypes=sort_etypes, - ) - print( - "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( - time.time() - start, get_peak_mem() - ) - ) + if not use_graphbolt: + part_graph_file = os.path.join(part_dir, "graph.dgl") + part_metadata["part-{}".format(part_id)][ + "part_graph" + ] = os.path.relpath(part_graph_file, out_path) + _save_graphs( + part_graph_file, + [part], + formats=graph_formats, + sort_etypes=sort_etypes, + ) + else: + part = _process_partitions([part], graph_formats, sort_etypes)[0] part_config = os.path.join(out_path, graph_name + ".json") - _dump_part_config(part_config, part_metadata) + if use_graphbolt: + kwargs["graph_formats"] = graph_formats + _dgl_partition_to_graphbolt( + part_config, + parts=parts, + part_meta=part_metadata, + **kwargs, + ) + else: + _dump_part_config(part_config, part_metadata) num_cuts = sim_g.num_edges() - tot_num_inner_edges if num_parts == 1: @@ -1361,12 +1478,11 @@ def get_homogeneous(g, balance_ntypes): ) ) - if use_graphbolt: - kwargs["graph_formats"] = graph_formats - dgl_partition_to_graphbolt( - part_config, - **kwargs, + print( + "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( + time.time() - start, get_peak_mem() ) + ) if return_mapping: return orig_nids, orig_eids @@ -1414,8 +1530,21 @@ def init_type_per_edge(graph, gpb): return etype_ids +def _load_parts(part_config, part_id, parts): + """load parts from variable or dist.""" + if parts is None: + graph, _, _, _, _, _, _ = load_partition( + part_config, part_id, load_feats=False + ) + else: + graph = parts[part_id] + return graph + + def gb_convert_single_dgl_partition( part_id, + parts, + part_meta, graph_formats, part_config, store_eids, @@ -1448,14 +1577,18 @@ def gb_convert_single_dgl_partition( "Running in debug mode which means all attributes of DGL partitions" " will be saved to the new format." ) - - part_meta = _load_part_config(part_config) + if part_meta is None: + part_meta = _load_part_config(part_config) num_parts = part_meta["num_parts"] - graph, _, _, gpb, _, _, _ = load_partition( - part_config, part_id, load_feats=False + graph = _load_parts(part_config, part_id, parts) + + gpb, _, ntypes, etypes = ( + load_partition_book(part_config, part_id) + if part_meta is None + else _load_partition_book_from_metadata(part_meta, part_id) ) - _, _, ntypes, etypes = load_partition_book(part_config, part_id) + is_homo = is_homogeneous(ntypes, etypes) node_type_to_id = ( None if is_homo else {ntype: ntid for ntid, ntype in enumerate(ntypes)} @@ -1561,12 +1694,12 @@ def gb_convert_single_dgl_partition( node_type_to_id=node_type_to_id, edge_type_to_id=edge_type_to_id, ) - orig_graph_path = os.path.join( + orig_feats_path = os.path.join( os.path.dirname(part_config), - part_meta[f"part-{part_id}"]["part_graph"], + part_meta[f"part-{part_id}"]["node_feats"], ) csc_graph_path = os.path.join( - os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.pt" + os.path.dirname(orig_feats_path), "fused_csc_sampling_graph.pt" ) torch.save(csc_graph, csc_graph_path) @@ -1574,55 +1707,17 @@ def gb_convert_single_dgl_partition( # Update graph path. -def dgl_partition_to_graphbolt( +def convert_partition_to_graphbolt( + part_meta, + graph_formats, part_config, - *, - store_eids=True, - store_inner_node=False, - store_inner_edge=False, - graph_formats=None, - n_jobs=1, + store_eids, + store_inner_node, + store_inner_edge, + n_jobs, + num_parts, + parts=None, ): - """Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt. - - This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is - dedicated for sampling in `GraphBolt`. New graphs will be stored alongside - original graph as `fused_csc_sampling_graph.pt`. - - In the near future, partitions are supposed to be saved as - `FusedCSCSamplingGraph` directly. At that time, this API should be deprecated. - - Parameters - ---------- - part_config : str - The partition configuration JSON file. - store_eids : bool, optional - Whether to store edge IDs in the new graph. Default: True. - store_inner_node : bool, optional - Whether to store inner node mask in the new graph. Default: False. - store_inner_edge : bool, optional - Whether to store inner edge mask in the new graph. Default: False. - graph_formats : str or list[str], optional - Save partitions in specified formats. It could be any combination of - `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`, - it is not necessary to specify this argument. It's mainly for - specifying `coo` format to save edge ID mapping and destination node - IDs. If not specified, whether to save `coo` format is determined by - the availability of the format in DGL partitions. Default: None. - n_jobs: int - Number of parallel jobs to run during partition conversion. Max parallelism - is determined by the partition count. - """ - debug_mode = "DGL_DIST_DEBUG" in os.environ - if debug_mode: - dgl_warning( - "Running in debug mode which means all attributes of DGL partitions" - " will be saved to the new format." - ) - part_meta = _load_part_config(part_config) - new_part_meta = copy.deepcopy(part_meta) - num_parts = part_meta["num_parts"] - # [Rui] DGL partitions are always saved as homogeneous graphs even though # the original graph is heterogeneous. But heterogeneous information like # node/edge types are saved as node/edge data alongside with partitions. @@ -1635,6 +1730,8 @@ def dgl_partition_to_graphbolt( # Iterate over partitions. convert_with_format = partial( gb_convert_single_dgl_partition, + parts=parts, + part_meta=part_meta, graph_formats=graph_formats, part_config=part_config, store_eids=store_eids, @@ -1664,15 +1761,104 @@ def dgl_partition_to_graphbolt( for part_id in range(num_parts): # Update graph path. - new_part_meta[f"part-{part_id}"][ + part_meta[f"part-{part_id}"][ "part_graph_graphbolt" ] = rel_path_results[part_id] # Save dtype info into partition config. # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more # details in #7175. - new_part_meta["node_map_dtype"] = "int64" - new_part_meta["edge_map_dtype"] = "int64" + part_meta["node_map_dtype"] = "int64" + part_meta["edge_map_dtype"] = "int64" - _dump_part_config(part_config, new_part_meta) + _dump_part_config(part_config, part_meta) print(f"Converted partitions to GraphBolt format into {part_config}") + +def _dgl_partition_to_graphbolt( + part_config, + part_meta, + parts, + *, + store_eids=True, + store_inner_node=False, + store_inner_edge=False, + graph_formats=None, + n_jobs=1, +): + debug_mode = "DGL_DIST_DEBUG" in os.environ + if debug_mode: + dgl_warning( + "Running in debug mode which means all attributes of DGL partitions" + " will be saved to the new format." + ) + new_part_meta = copy.deepcopy(part_meta) + num_parts = part_meta["num_parts"] + convert_partition_to_graphbolt(new_part_meta, + graph_formats, + part_config, + store_eids, + store_inner_node, + store_inner_edge, + n_jobs, + num_parts, + parts=parts, + ) + + +def dgl_partition_to_graphbolt( + part_config, + *, + store_eids=True, + store_inner_node=False, + store_inner_edge=False, + graph_formats=None, + n_jobs=1, +): + """Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt. + + This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is + dedicated for sampling in `GraphBolt`. New graphs will be stored alongside + original graph as `fused_csc_sampling_graph.pt`. + + In the near future, partitions are supposed to be saved as + `FusedCSCSamplingGraph` directly. At that time, this API should be deprecated. + + Parameters + ---------- + part_config : str + The partition configuration JSON file. + store_eids : bool, optional + Whether to store edge IDs in the new graph. Default: True. + store_inner_node : bool, optional + Whether to store inner node mask in the new graph. Default: False. + store_inner_edge : bool, optional + Whether to store inner edge mask in the new graph. Default: False. + graph_formats : str or list[str], optional + Save partitions in specified formats. It could be any combination of + `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`, + it is not necessary to specify this argument. It's mainly for + specifying `coo` format to save edge ID mapping and destination node + IDs. If not specified, whether to save `coo` format is determined by + the availability of the format in DGL partitions. Default: None. + n_jobs: int + Number of parallel jobs to run during partition conversion. Max parallelism + is determined by the partition count. + """ + debug_mode = "DGL_DIST_DEBUG" in os.environ + if debug_mode: + dgl_warning( + "Running in debug mode which means all attributes of DGL partitions" + " will be saved to the new format." + ) + part_meta = _load_part_config(part_config) + new_part_meta = copy.deepcopy(part_meta) + num_parts = part_meta["num_parts"] + convert_partition_to_graphbolt(new_part_meta, + graph_formats, + part_config, + store_eids, + store_inner_node, + store_inner_edge, + n_jobs, + num_parts, + ) \ No newline at end of file From 1074f85a8ef230e16a73497ce1251036bd603c5b Mon Sep 17 00:00:00 2001 From: Ubuntu <2649624957@qq.com> Date: Wed, 21 Aug 2024 04:22:27 +0000 Subject: [PATCH 06/37] change partition --- python/dgl/distributed/partition.py | 49 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 2559f6ec943e..c70e406691f1 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1761,9 +1761,9 @@ def convert_partition_to_graphbolt( for part_id in range(num_parts): # Update graph path. - part_meta[f"part-{part_id}"][ - "part_graph_graphbolt" - ] = rel_path_results[part_id] + part_meta[f"part-{part_id}"]["part_graph_graphbolt"] = rel_path_results[ + part_id + ] # Save dtype info into partition config. # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more @@ -1774,6 +1774,7 @@ def convert_partition_to_graphbolt( _dump_part_config(part_config, part_meta) print(f"Converted partitions to GraphBolt format into {part_config}") + def _dgl_partition_to_graphbolt( part_config, part_meta, @@ -1793,17 +1794,18 @@ def _dgl_partition_to_graphbolt( ) new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] - convert_partition_to_graphbolt(new_part_meta, - graph_formats, - part_config, - store_eids, - store_inner_node, - store_inner_edge, - n_jobs, - num_parts, - parts=parts, - ) - + convert_partition_to_graphbolt( + new_part_meta, + graph_formats, + part_config, + store_eids, + store_inner_node, + store_inner_edge, + n_jobs, + num_parts, + parts=parts, + ) + def dgl_partition_to_graphbolt( part_config, @@ -1853,12 +1855,13 @@ def dgl_partition_to_graphbolt( part_meta = _load_part_config(part_config) new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] - convert_partition_to_graphbolt(new_part_meta, - graph_formats, - part_config, - store_eids, - store_inner_node, - store_inner_edge, - n_jobs, - num_parts, - ) \ No newline at end of file + convert_partition_to_graphbolt( + new_part_meta, + graph_formats, + part_config, + store_eids, + store_inner_node, + store_inner_edge, + n_jobs, + num_parts, + ) From e03376dbd637dcce6b5179e3f49d991d17e0da5e Mon Sep 17 00:00:00 2001 From: Ubuntu <2649624957@qq.com> Date: Wed, 21 Aug 2024 04:53:31 +0000 Subject: [PATCH 07/37] change partition internal function --- python/dgl/distributed/partition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index c70e406691f1..0f5e317b1627 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1707,7 +1707,7 @@ def gb_convert_single_dgl_partition( # Update graph path. -def convert_partition_to_graphbolt( +def _convert_partition_to_graphbolt( part_meta, graph_formats, part_config, @@ -1794,7 +1794,7 @@ def _dgl_partition_to_graphbolt( ) new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] - convert_partition_to_graphbolt( + _convert_partition_to_graphbolt( new_part_meta, graph_formats, part_config, @@ -1855,7 +1855,7 @@ def dgl_partition_to_graphbolt( part_meta = _load_part_config(part_config) new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] - convert_partition_to_graphbolt( + _convert_partition_to_graphbolt( new_part_meta, graph_formats, part_config, From 16097f8002b3b0943681bb951a2b9abfc51ea8e0 Mon Sep 17 00:00:00 2001 From: Ubuntu <2649624957@qq.com> Date: Wed, 21 Aug 2024 06:57:43 +0000 Subject: [PATCH 08/37] dist partition --- tools/distpartitioning/convert_partition.py | 214 ++++++++++++++----- tools/distpartitioning/data_proc_pipeline.py | 5 + tools/distpartitioning/data_shuffle.py | 10 +- tools/distpartitioning/utils.py | 28 ++- 4 files changed, 203 insertions(+), 54 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index a169589a3f97..c0605e28dc8e 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -8,6 +8,8 @@ import constants import dgl +import dgl.graphbolt as gb +import dgl.sparse as dglsp import numpy as np import pandas as pd import pyarrow @@ -164,6 +166,12 @@ def _get_unique_invidx(srcids, dstids, nids, low_mem=True): return uniques, idxes, srcids, dstids +# Utility functions. +def is_homogeneous(ntypes, etypes): + """Checks if the provided ntypes and etypes form a homogeneous graph.""" + return len(ntypes) == 1 and len(etypes) == 1 + + def create_dgl_object( schema, part_id, @@ -174,6 +182,7 @@ def create_dgl_object( edge_typecounts, return_orig_nids=False, return_orig_eids=False, + use_graphbolt=False, ): """ This function creates dgl objects for a given graph partition, as in function @@ -450,56 +459,156 @@ def create_dgl_object( ) # create the graph here now. - part_graph = dgl.graph( - data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) - ) - part_graph.edata[dgl.EID] = th.arange( - edgeid_offset, - edgeid_offset + part_graph.num_edges(), - dtype=th.int64, - ) - part_graph.edata[dgl.ETYPE] = th.as_tensor( - etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE] - ) - part_graph.edata["inner_edge"] = th.ones( - part_graph.num_edges(), dtype=RESERVED_FIELD_DTYPE["inner_edge"] - ) + # create the graph here now. + if use_graphbolt: + edge_attr = {} + num_edges = len(part_local_dst_id) + edge_attr[dgl.EID] = th.arange( + edgeid_offset, + edgeid_offset + num_edges, + dtype=th.int64, + ) + type_per_edge = th.as_tensor( + etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE] + ) + edge_attr["inner_edge"] = th.ones( + num_edges, dtype=RESERVED_FIELD_DTYPE["inner_edge"] + ) - # compute per_type_ids and ntype for all the nodes in the graph. - global_ids = np.concatenate([global_src_id, global_dst_id, global_homo_nid]) - part_global_ids = global_ids[idx] - part_global_ids = part_global_ids[reshuffle_nodes] - ntype, per_type_ids = id_map(part_global_ids) + # compute per_type_ids and ntype for all the nodes in the graph. + global_ids = np.concatenate( + [global_src_id, global_dst_id, global_homo_nid] + ) + part_global_ids = global_ids[idx] + part_global_ids = part_global_ids[reshuffle_nodes] + ntype, per_type_ids = id_map(part_global_ids) + + # continue with the graph creation + node_attr = {} + node_attr[dgl.NTYPE] = th.as_tensor( + ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE] + ) + node_attr[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes]) + node_attr["inner_node"] = th.as_tensor( + inner_nodes[reshuffle_nodes], + dtype=RESERVED_FIELD_DTYPE["inner_node"], + ) - # continue with the graph creation - part_graph.ndata[dgl.NTYPE] = th.as_tensor( - ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE] - ) - part_graph.ndata[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes]) - part_graph.ndata["inner_node"] = th.as_tensor( - inner_nodes[reshuffle_nodes], dtype=RESERVED_FIELD_DTYPE["inner_node"] - ) + is_homo = is_homogeneous(ntypes, etypes) + orig_nids = None + orig_eids = None + if return_orig_nids: + orig_nids = {} + for ntype, ntype_id in ntypes_map.items(): + mask = th.logical_and( + node_attr[dgl.NTYPE] == ntype_id, + node_attr["inner_node"], + ) + orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) + if return_orig_eids: + orig_eids = {} + for etype, etype_id in etypes_map.items(): + mask = th.logical_and( + type_per_edge == etype_id, + edge_attr["inner_edge"], + ) + orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( + global_edge_id[mask] + ) + edge_type_to_id = ( + None + if is_homo + else { + gb.etype_tuple_to_str(etype): etid + for etype, etid in etypes_map.items() + } + ) - orig_nids = None - orig_eids = None - if return_orig_nids: - orig_nids = {} - for ntype, ntype_id in ntypes_map.items(): - mask = th.logical_and( - part_graph.ndata[dgl.NTYPE] == ntype_id, - part_graph.ndata["inner_node"], - ) - orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) - if return_orig_eids: - orig_eids = {} - for etype, etype_id in etypes_map.items(): - mask = th.logical_and( - part_graph.edata[dgl.ETYPE] == etype_id, - part_graph.edata["inner_edge"], - ) - orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( - global_edge_id[mask] - ) + part_local_src_id, part_local_dst_id = th.tensor( + part_local_src_id, dtype=th.int64 + ), th.tensor(part_local_dst_id, dtype=th.int64) + size = max(part_local_src_id.max(), part_local_dst_id.max()) + 1 + adj_matrix = dglsp.from_coo( + part_local_src_id, part_local_dst_id, shape=(size, size) + ) + print(adj_matrix) + indptr, indices, _ = adj_matrix.csc() + del adj_matrix + + part_graph = gb.fused_csc_sampling_graph( + indptr, + indices, + node_type_offset=None, + type_per_edge=type_per_edge, + node_attributes=node_attr, + edge_attributes=edge_attr, + node_type_to_id=ntypes_map, + edge_type_to_id=edge_type_to_id, + ) + return ( + part_graph, + node_map_val, + edge_map_val, + ntypes_map, + etypes_map, + orig_nids, + orig_eids, + ) + + else: + part_graph = dgl.graph( + data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) + ) + part_graph.edata[dgl.EID] = th.arange( + edgeid_offset, + edgeid_offset + part_graph.num_edges(), + dtype=th.int64, + ) + part_graph.edata[dgl.ETYPE] = th.as_tensor( + etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE] + ) + part_graph.edata["inner_edge"] = th.ones( + part_graph.num_edges(), dtype=RESERVED_FIELD_DTYPE["inner_edge"] + ) + + # compute per_type_ids and ntype for all the nodes in the graph. + global_ids = np.concatenate( + [global_src_id, global_dst_id, global_homo_nid] + ) + part_global_ids = global_ids[idx] + part_global_ids = part_global_ids[reshuffle_nodes] + ntype, per_type_ids = id_map(part_global_ids) + + # continue with the graph creation + part_graph.ndata[dgl.NTYPE] = th.as_tensor( + ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE] + ) + part_graph.ndata[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes]) + part_graph.ndata["inner_node"] = th.as_tensor( + inner_nodes[reshuffle_nodes], + dtype=RESERVED_FIELD_DTYPE["inner_node"], + ) + + orig_nids = None + orig_eids = None + if return_orig_nids: + orig_nids = {} + for ntype, ntype_id in ntypes_map.items(): + mask = th.logical_and( + part_graph.ndata[dgl.NTYPE] == ntype_id, + part_graph.ndata["inner_node"], + ) + orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) + if return_orig_eids: + orig_eids = {} + for etype, etype_id in etypes_map.items(): + mask = th.logical_and( + part_graph.edata[dgl.ETYPE] == etype_id, + part_graph.edata["inner_edge"], + ) + orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( + global_edge_id[mask] + ) return ( part_graph, @@ -523,6 +632,7 @@ def create_metadata_json( ntypes_map, etypes_map, output_dir, + use_graphbolt, ): """ Auxiliary function to create json file for the graph partition metadata @@ -549,6 +659,8 @@ def create_metadata_json( map between edge type(string) and edge_type_id(int) output_dir : string directory where the output files are to be stored + use_graphbolt : bool + whether to use graphbolt or not Returns: -------- @@ -572,10 +684,14 @@ def create_metadata_json( part_dir = "part" + str(part_id) node_feat_file = os.path.join(part_dir, "node_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") - part_graph_file = os.path.join(part_dir, "graph.dgl") + if use_graphbolt: + part_graph_file = os.path.join(part_dir, "fused_csc_sampling_graph.pt") + else: + part_graph_file = os.path.join(part_dir, "graph.dgl") + part_graph_type = "part_graph_graphbolt" if use_graphbolt else "part_graph" part_metadata["part-{}".format(part_id)] = { "node_feats": node_feat_file, "edge_feats": edge_feat_file, - "part_graph": part_graph_file, + part_graph_type: part_graph_file, } return part_metadata diff --git a/tools/distpartitioning/data_proc_pipeline.py b/tools/distpartitioning/data_proc_pipeline.py index 4c249a34b6b2..b26760eec5fb 100644 --- a/tools/distpartitioning/data_proc_pipeline.py +++ b/tools/distpartitioning/data_proc_pipeline.py @@ -94,6 +94,11 @@ def log_params(params): action="store_true", help="Save original edge IDs into files", ) + parser.add_argument( + "--use-graphbolt", + action="store_true", + help="Use GraphBolt for distributed partition.", + ) parser.add_argument( "--graph-formats", default=None, diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index e85bff5ecc4c..3b140da1b7a3 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -472,8 +472,8 @@ def exchange_feature( ) # exchange actual data here. - logging.debug(f"Rank: {rank} {featdata_key.shape=}") if featdata_key is not None: + logging.debug(f"Rank: {rank} {featdata_key.shape=}") feat_dims_dtype = list(featdata_key.shape) assert ( len(featdata_key.shape) == 2 or len(featdata_key.shape) == 1 @@ -1336,6 +1336,7 @@ def prepare_local_data(src_data, local_part_id): edge_typecounts, params.save_orig_nids, params.save_orig_eids, + params.use_graphbolt, ) sort_etypes = len(etypes_map) > 1 local_node_features = prepare_local_data( @@ -1354,8 +1355,12 @@ def prepare_local_data(src_data, local_part_id): orig_eids, graph_formats, sort_etypes, + params.use_graphbolt, ) - memory_snapshot("DiskWriteDGLObjectsComplete: ", rank) + if params.use_graphbolt: + memory_snapshot("DiskWriteGrapgboltObjectsComplete: ", rank) + else: + memory_snapshot("DiskWriteDGLObjectsComplete: ", rank) # get the meta-data json_metadata = create_metadata_json( @@ -1369,6 +1374,7 @@ def prepare_local_data(src_data, local_part_id): ntypes_map, etypes_map, params.output, + params.use_graphbolt, ) output_meta_json[ "local-part-id-" + str(local_part_id * world_size + rank) diff --git a/tools/distpartitioning/utils.py b/tools/distpartitioning/utils.py index cdb984be3796..fbf4ae8c0fed 100644 --- a/tools/distpartitioning/utils.py +++ b/tools/distpartitioning/utils.py @@ -504,6 +504,20 @@ def write_edge_features(edge_features, edge_file): dgl.data.utils.save_tensors(edge_file, edge_features) +def write_graph_graghbolt(graph_file, graph_obj): + """ + Utility function to serialize FusedCSCSamplingGraph + + Parameters: + ----------- + graph_obj : FusedCSCSamplingGraph + FusedCSCSamplingGraph, as created in convert_partition.py, which is to be serialized + graph_file : string + File name in which graph object is serialized + """ + torch.save(graph_obj, graph_file) + + def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes): """ Utility function to serialize graph dgl objects @@ -534,6 +548,7 @@ def write_dgl_objects( orig_eids, formats, sort_etypes, + use_graphbolt, ): """ Wrapper function to write graph, node/edge feature, original node/edge IDs. @@ -558,12 +573,19 @@ def write_dgl_objects( Save graph in formats. sort_etypes : bool Whether to sort etypes in csc/csr. + use_graphbolt : bool + Whether to use graphbolt or not. """ part_dir = output_dir + "/part" + str(part_id) os.makedirs(part_dir, exist_ok=True) - write_graph_dgl( - os.path.join(part_dir, "graph.dgl"), graph_obj, formats, sort_etypes - ) + if use_graphbolt: + write_graph_graghbolt( + os.path.join(part_dir, "fused_csc_sampling_graph.pt"), graph_obj + ) + else: + write_graph_dgl( + os.path.join(part_dir, "graph.dgl"), graph_obj, formats, sort_etypes + ) if node_features != None: write_node_features( From 098742b08105d27178019316d53aeb596d5fdcb5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 5 Sep 2024 12:07:17 +0000 Subject: [PATCH 09/37] modify convert_partition.py --- tools/distpartitioning/convert_partition.py | 436 ++++++++++++++------ tools/distpartitioning/data_shuffle.py | 34 +- 2 files changed, 316 insertions(+), 154 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index c0605e28dc8e..976d09e3a090 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -9,10 +9,10 @@ import dgl import dgl.graphbolt as gb -import dgl.sparse as dglsp import numpy as np import pandas as pd import pyarrow +import scipy.sparse as spsp import torch as th from dgl.distributed.partition import ( _etype_str_to_tuple, @@ -167,12 +167,258 @@ def _get_unique_invidx(srcids, dstids, nids, low_mem=True): # Utility functions. -def is_homogeneous(ntypes, etypes): +def _is_homogeneous(ntypes, etypes): """Checks if the provided ntypes and etypes form a homogeneous graph.""" return len(ntypes) == 1 and len(etypes) == 1 -def create_dgl_object( +def _create_csc_data(part_local_src_id, part_local_dst_id): + part_local_src_id, part_local_dst_id = th.tensor( + part_local_src_id, dtype=th.int64 + ), th.tensor(part_local_dst_id, dtype=th.int64) + indptr = th.zeros(len(part_local_dst_id) + 1, dtype=th.int64) + col_counts = th.bincount(part_local_src_id, minlength=part_local_dst_id) + indptr[1:] = th.cumsum(col_counts, 0) + indices = part_local_dst_id + return indptr, indices + + +def _create_edge_data(edgeid_offset, etype_ids, num_edges): + eid = th.arange( + edgeid_offset, + edgeid_offset + num_edges, + dtype=th.int64, + ) + etype = th.as_tensor(etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE]) + inner_edge = th.ones(num_edges, dtype=RESERVED_FIELD_DTYPE["inner_edge"]) + return eid, etype, inner_edge + + +def _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes): + node_type = th.as_tensor(ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE]) + node_id = th.as_tensor(uniq_ids[reshuffle_nodes]) + inner_node = th.as_tensor( + inner_nodes[reshuffle_nodes], + dtype=RESERVED_FIELD_DTYPE["inner_node"], + ) + return node_type, node_id, inner_node + + +def _compute_node_ntype( + global_src_id, global_dst_id, global_homo_nid, idx, reshuffle_nodes, id_map +): + global_ids = np.concatenate([global_src_id, global_dst_id, global_homo_nid]) + part_global_ids = global_ids[idx] + part_global_ids = part_global_ids[reshuffle_nodes] + ntype, per_type_ids = id_map(part_global_ids) + return ntype, per_type_ids + + +def _graph_orig_ids( + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + node_attr, + edge_attr, + per_type_ids, + type_per_edge, + global_edge_id, +): + orig_nids = None + orig_eids = None + if return_orig_nids: + orig_nids = {} + for ntype, ntype_id in ntypes_map.items(): + mask = th.logical_and( + node_attr[dgl.NTYPE] == ntype_id, + node_attr["inner_node"], + ) + orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) + if return_orig_eids: + orig_eids = {} + for etype, etype_id in etypes_map.items(): + mask = th.logical_and( + type_per_edge == etype_id, + edge_attr["inner_edge"], + ) + orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( + global_edge_id[mask] + ) + return orig_nids, orig_eids + + +def _partition_DGLGraph( + part_local_src_id, + part_local_dst_id, + global_src_id, + global_dst_id, + global_homo_nid, + idx, + reshuffle_nodes, + id_map, + edgeid_offset, + etype_ids, + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + global_edge_id, + uniq_ids, + inner_nodes, +): + num_edges = len(part_local_dst_id) + part_graph = dgl.graph( + data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) + ) + # create edge data in graph. + ( + part_graph.edata[dgl.EID], + part_graph.edata[dgl.ETYPE], + part_graph.edata["inner_edge"], + ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) + + # compute per_type_ids and ntype for all the nodes in the graph. + ntype, per_type_ids = _compute_node_ntype( + global_src_id, + global_dst_id, + global_homo_nid, + idx, + reshuffle_nodes, + id_map, + ) + + # create node data in graph. + ( + part_graph.ndata[dgl.NTYPE], + part_graph.ndata[dgl.NID], + part_graph.ndata["inner_node"], + ) = _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes) + + # get the original node ids and edge ids from original graph. + orig_nids, orig_eids = _graph_orig_ids( + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + part_graph.ndata, + part_graph.edata, + per_type_ids, + part_graph.edata[dgl.ETYPE], + global_edge_id, + ) + return ( + part_graph, + ntypes_map, + etypes_map, + orig_nids, + orig_eids, + ) + + +def _partition_graphbolt( + part_local_src_id, + part_local_dst_id, + global_src_id, + global_dst_id, + global_homo_nid, + idx, + reshuffle_nodes, + id_map, + edgeid_offset, + etype_ids, + ntypes, + etypes, + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + global_edge_id, + uniq_ids, + inner_nodes, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, +): + edge_attr = {} + # create edge data in graph. + num_edges = len(part_local_dst_id) + ( + edge_attr[dgl.EID], + type_per_edge, + edge_attr["inner_edge"], + ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) + + # compute per_type_ids and ntype for all the nodes in the graph. + ntype, per_type_ids = _compute_node_ntype( + global_src_id, + global_dst_id, + global_homo_nid, + idx, + reshuffle_nodes, + id_map, + ) + + # create node data in graph. + node_attr = {} + ( + node_attr[dgl.NTYPE], + node_attr[dgl.NID], + node_attr["inner_node"], + ) = _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes) + + is_homo = _is_homogeneous(ntypes, etypes) + # get the original node ids and edge ids from original graph. + orig_nids, orig_eids = _graph_orig_ids( + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + node_attr, + edge_attr, + per_type_ids, + type_per_edge, + global_edge_id, + ) + if not store_inner_edge: + edge_attr.pop("inner_edge") + + if not store_eids: + edge_attr.pop(dgl.EID) + + if not store_inner_node: + node_attr.pop("inner_node") + + edge_type_to_id = ( + None + if is_homo + else { + gb.etype_tuple_to_str(etype): etid + for etype, etid in etypes_map.items() + } + ) + + indptr, indices = _create_csc_data(part_local_src_id, part_local_dst_id) + part_graph = gb.fused_csc_sampling_graph( + csc_indptr=indptr, + indices=indices, + node_type_offset=None, + type_per_edge=type_per_edge, + node_attributes=node_attr, + edge_attributes=edge_attr, + node_type_to_id=ntypes_map, + edge_type_to_id=edge_type_to_id, + ) + return ( + part_graph, + ntypes_map, + etypes_map, + orig_nids, + orig_eids, + ) + + +def create_graph_object( schema, part_id, node_data, @@ -183,6 +429,7 @@ def create_dgl_object( return_orig_nids=False, return_orig_eids=False, use_graphbolt=False, + **kwargs, ): """ This function creates dgl objects for a given graph partition, as in function @@ -458,92 +705,35 @@ def create_dgl_object( nid_map[part_local_dst_id], ) - # create the graph here now. # create the graph here now. if use_graphbolt: - edge_attr = {} - num_edges = len(part_local_dst_id) - edge_attr[dgl.EID] = th.arange( + ( + part_graph, + ntypes_map, + etypes_map, + orig_nids, + orig_eids, + ) = _partition_graphbolt( + part_local_src_id, + part_local_dst_id, + global_src_id, + global_dst_id, + global_homo_nid, + idx, + reshuffle_nodes, + id_map, edgeid_offset, - edgeid_offset + num_edges, - dtype=th.int64, - ) - type_per_edge = th.as_tensor( - etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE] - ) - edge_attr["inner_edge"] = th.ones( - num_edges, dtype=RESERVED_FIELD_DTYPE["inner_edge"] - ) - - # compute per_type_ids and ntype for all the nodes in the graph. - global_ids = np.concatenate( - [global_src_id, global_dst_id, global_homo_nid] - ) - part_global_ids = global_ids[idx] - part_global_ids = part_global_ids[reshuffle_nodes] - ntype, per_type_ids = id_map(part_global_ids) - - # continue with the graph creation - node_attr = {} - node_attr[dgl.NTYPE] = th.as_tensor( - ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE] - ) - node_attr[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes]) - node_attr["inner_node"] = th.as_tensor( - inner_nodes[reshuffle_nodes], - dtype=RESERVED_FIELD_DTYPE["inner_node"], - ) - - is_homo = is_homogeneous(ntypes, etypes) - orig_nids = None - orig_eids = None - if return_orig_nids: - orig_nids = {} - for ntype, ntype_id in ntypes_map.items(): - mask = th.logical_and( - node_attr[dgl.NTYPE] == ntype_id, - node_attr["inner_node"], - ) - orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) - if return_orig_eids: - orig_eids = {} - for etype, etype_id in etypes_map.items(): - mask = th.logical_and( - type_per_edge == etype_id, - edge_attr["inner_edge"], - ) - orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( - global_edge_id[mask] - ) - edge_type_to_id = ( - None - if is_homo - else { - gb.etype_tuple_to_str(etype): etid - for etype, etid in etypes_map.items() - } - ) - - part_local_src_id, part_local_dst_id = th.tensor( - part_local_src_id, dtype=th.int64 - ), th.tensor(part_local_dst_id, dtype=th.int64) - size = max(part_local_src_id.max(), part_local_dst_id.max()) + 1 - adj_matrix = dglsp.from_coo( - part_local_src_id, part_local_dst_id, shape=(size, size) - ) - print(adj_matrix) - indptr, indices, _ = adj_matrix.csc() - del adj_matrix - - part_graph = gb.fused_csc_sampling_graph( - indptr, - indices, - node_type_offset=None, - type_per_edge=type_per_edge, - node_attributes=node_attr, - edge_attributes=edge_attr, - node_type_to_id=ntypes_map, - edge_type_to_id=edge_type_to_id, + etype_ids, + ntypes, + etypes, + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + global_edge_id, + uniq_ids, + inner_nodes, + **kwargs, ) return ( part_graph, @@ -556,60 +746,32 @@ def create_dgl_object( ) else: - part_graph = dgl.graph( - data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) - ) - part_graph.edata[dgl.EID] = th.arange( + ( + part_graph, + ntypes_map, + etypes_map, + orig_nids, + orig_eids, + ) = _partition_DGLGraph( + part_local_src_id, + part_local_dst_id, + global_src_id, + global_dst_id, + global_homo_nid, + idx, + reshuffle_nodes, + id_map, edgeid_offset, - edgeid_offset + part_graph.num_edges(), - dtype=th.int64, - ) - part_graph.edata[dgl.ETYPE] = th.as_tensor( - etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE] - ) - part_graph.edata["inner_edge"] = th.ones( - part_graph.num_edges(), dtype=RESERVED_FIELD_DTYPE["inner_edge"] - ) - - # compute per_type_ids and ntype for all the nodes in the graph. - global_ids = np.concatenate( - [global_src_id, global_dst_id, global_homo_nid] - ) - part_global_ids = global_ids[idx] - part_global_ids = part_global_ids[reshuffle_nodes] - ntype, per_type_ids = id_map(part_global_ids) - - # continue with the graph creation - part_graph.ndata[dgl.NTYPE] = th.as_tensor( - ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE] - ) - part_graph.ndata[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes]) - part_graph.ndata["inner_node"] = th.as_tensor( - inner_nodes[reshuffle_nodes], - dtype=RESERVED_FIELD_DTYPE["inner_node"], + etype_ids, + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + global_edge_id, + uniq_ids, + inner_nodes, ) - orig_nids = None - orig_eids = None - if return_orig_nids: - orig_nids = {} - for ntype, ntype_id in ntypes_map.items(): - mask = th.logical_and( - part_graph.ndata[dgl.NTYPE] == ntype_id, - part_graph.ndata["inner_node"], - ) - orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) - if return_orig_eids: - orig_eids = {} - for etype, etype_id in etypes_map.items(): - mask = th.logical_and( - part_graph.edata[dgl.ETYPE] == etype_id, - part_graph.edata["inner_edge"], - ) - orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( - global_edge_id[mask] - ) - return ( part_graph, node_map_val, diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index 3b140da1b7a3..f8837abd398b 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -13,7 +13,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from convert_partition import create_dgl_object, create_metadata_json +from convert_partition import create_graph_object, create_metadata_json from dataset_utils import get_dataset from dist_lookup import DistLookupService from globalids import ( @@ -285,21 +285,21 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup): local_etype_ids.append(rcvd_edge_data[:, 3]) local_eids.append(rcvd_edge_data[:, 4]) - edge_data[ - constants.GLOBAL_SRC_ID + "/" + str(local_part_id) - ] = np.concatenate(local_src_ids) - edge_data[ - constants.GLOBAL_DST_ID + "/" + str(local_part_id) - ] = np.concatenate(local_dst_ids) - edge_data[ - constants.GLOBAL_TYPE_EID + "/" + str(local_part_id) - ] = np.concatenate(local_type_eids) - edge_data[ - constants.ETYPE_ID + "/" + str(local_part_id) - ] = np.concatenate(local_etype_ids) - edge_data[ - constants.GLOBAL_EID + "/" + str(local_part_id) - ] = np.concatenate(local_eids) + edge_data[constants.GLOBAL_SRC_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_src_ids) + ) + edge_data[constants.GLOBAL_DST_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_dst_ids) + ) + edge_data[constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)] = ( + np.concatenate(local_type_eids) + ) + edge_data[constants.ETYPE_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_etype_ids) + ) + edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)] = ( + np.concatenate(local_eids) + ) # Check if the data was exchanged correctly local_edge_count = 0 @@ -1323,7 +1323,7 @@ def prepare_local_data(src_data, local_part_id): etypes_map, orig_nids, orig_eids, - ) = create_dgl_object( + ) = create_graph_object( schema_map, rank + local_part_id * world_size, local_node_data, From cde6a648aa5ae8253971eef8080fe779cc1ec199 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 5 Sep 2024 12:24:42 +0000 Subject: [PATCH 10/37] renew test_partition --- python/dgl/distributed/partition.py | 435 +++------ tests/distributed/test_partition.py | 1375 +++++++++++++++++---------- 2 files changed, 979 insertions(+), 831 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 0f5e317b1627..07601fd5d2ca 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -88,7 +88,7 @@ def _dump_part_config(part_config, part_metadata): json.dump(part_metadata, outfile, sort_keys=False, indent=4) -def _process_partitions(g_list, formats=None, sort_etypes=False): +def _save_graphs(filename, g_list, formats=None, sort_etypes=False): """Preprocess partitions before saving: 1. format data types. 2. sort csc/csr by tag. @@ -106,59 +106,48 @@ def _process_partitions(g_list, formats=None, sort_etypes=False): g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") if "csc" in formats: g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") - return g_list - - -def _save_graphs(filename, g_list, formats=None, sort_etypes=False): - g_list = _process_partitions( - g_list, formats=formats, sort_etypes=sort_etypes - ) save_graphs(filename, g_list, formats=formats) -def _get_inner_node_mask(graph, ntype_id, use_graphbolt=False): - if use_graphbolt: - if NTYPE in graph.node_attributes: - dtype = F.dtype(graph.node_attributes["inner_node"]) - return ( - graph.node_attributes["inner_node"] - * F.astype(graph.node_attributes[NTYPE] == ntype_id, dtype) - == 1 - ) - else: - return graph.node_attributes["inner_node"] == 1 +def _get_inner_node_mask(graph, ntype_id, gpb=None): + ndata = ( + graph.node_attributes + if isinstance(graph, gb.FusedCSCSamplingGraph) + else graph.ndata + ) + assert "inner_node" in ndata, "'inner_node' is not in nodes' data" + if NTYPE in ndata or gpb is not None: + ntype = ( + gpb.map_to_per_ntype(ndata[NID])[0] + if gpb is not None + else ndata[NTYPE] + ) + dtype = F.dtype(ndata["inner_node"]) + return ndata["inner_node"] * F.astype(ntype == ntype_id, dtype) == 1 else: - if NTYPE in graph.ndata: - dtype = F.dtype(graph.ndata["inner_node"]) - return ( - graph.ndata["inner_node"] - * F.astype(graph.ndata[NTYPE] == ntype_id, dtype) - == 1 - ) - else: - return graph.ndata["inner_node"] == 1 + return ndata["inner_node"] == 1 -def _get_inner_edge_mask(graph, etype_id, use_graphbolt=False): - if use_graphbolt: - if graph.type_per_edge is not None: - dtype = F.dtype(graph.edge_attributes["inner_edge"]) - return ( - graph.edge_attributes["inner_edge"] - * F.astype(graph.type_per_edge == etype_id, dtype) - == 1 - ) - else: - return graph.edge_attributes["inner_edge"] == 1 - if ETYPE in graph.edata: - dtype = F.dtype(graph.edata["inner_edge"]) - return ( - graph.edata["inner_edge"] - * F.astype(graph.edata[ETYPE] == etype_id, dtype) - == 1 - ) +def _get_inner_edge_mask( + graph, + etype_id, +): + edata = ( + graph.edge_attributes + if isinstance(graph, gb.FusedCSCSamplingGraph) + else graph.edata + ) + assert "inner_edge" in edata, "'inner_edge' is not in edges' data" + etype = ( + graph.type_per_edge + if isinstance(graph, gb.FusedCSCSamplingGraph) + else (graph.edata[ETYPE] if ETYPE in graph.edata else None) + ) + if etype is not None: + dtype = F.dtype(edata["inner_edge"]) + return edata["inner_edge"] * F.astype(etype == etype_id, dtype) == 1 else: - return graph.edata["inner_edge"] == 1 + return edata["inner_edge"] == 1 def _get_part_ranges(id_ranges): @@ -343,10 +332,9 @@ def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False): "part-{}".format(part_id) in part_metadata ), "part-{} does not exist".format(part_id) part_files = part_metadata["part-{}".format(part_id)] + part_graph_field = "part_graph" if use_graphbolt: part_graph_field = "part_graph_graphbolt" - else: - part_graph_field = "part_graph" assert ( part_graph_field in part_files ), f"the partition does not contain graph structure: {part_graph_field}" @@ -473,105 +461,6 @@ def load_partition_feats( return node_feats, edge_feats -def _load_partition_book_from_metadata(part_metadata, part_id): - assert "num_parts" in part_metadata, "num_parts does not exist." - assert ( - part_metadata["num_parts"] > part_id - ), "part {} is out of range (#parts: {})".format( - part_id, part_metadata["num_parts"] - ) - num_parts = part_metadata["num_parts"] - assert ( - "num_nodes" in part_metadata - ), "cannot get the number of nodes of the global graph." - assert ( - "num_edges" in part_metadata - ), "cannot get the number of edges of the global graph." - assert "node_map" in part_metadata, "cannot get the node map." - assert "edge_map" in part_metadata, "cannot get the edge map." - assert "graph_name" in part_metadata, "cannot get the graph name" - - # If this is a range partitioning, node_map actually stores a list, whose elements - # indicate the boundary of range partitioning. Otherwise, node_map stores a filename - # that contains node map in a NumPy array. - node_map = part_metadata["node_map"] - edge_map = part_metadata["edge_map"] - if isinstance(node_map, dict): - for key in node_map: - is_range_part = isinstance(node_map[key], list) - break - elif isinstance(node_map, list): - is_range_part = True - node_map = {DEFAULT_NTYPE: node_map} - else: - is_range_part = False - if isinstance(edge_map, list): - edge_map = {DEFAULT_ETYPE: edge_map} - - ntypes = {DEFAULT_NTYPE: 0} - etypes = {DEFAULT_ETYPE: 0} - if "ntypes" in part_metadata: - ntypes = part_metadata["ntypes"] - if "etypes" in part_metadata: - etypes = part_metadata["etypes"] - - if isinstance(node_map, dict): - for key in node_map: - assert key in ntypes, "The node type {} is invalid".format(key) - if isinstance(edge_map, dict): - for key in edge_map: - assert key in etypes, "The edge type {} is invalid".format(key) - - if not is_range_part: - raise TypeError("Only RangePartitionBook is supported currently.") - - node_map = _get_part_ranges(node_map) - edge_map = _get_part_ranges(edge_map) - - # Format dtype of node/edge map if dtype is specified. - def _format_node_edge_map(part_metadata, map_type, data): - key = f"{map_type}_map_dtype" - if key not in part_metadata: - return data - dtype = part_metadata[key] - assert dtype in ["int32", "int64"], ( - f"The {map_type} map dtype should be either int32 or int64, " - f"but got {dtype}." - ) - for key in data: - data[key] = data[key].astype(dtype) - return data - - node_map = _format_node_edge_map(part_metadata, "node", node_map) - edge_map = _format_node_edge_map(part_metadata, "edge", edge_map) - - # Sort the node/edge maps by the node/edge type ID. - node_map = dict(sorted(node_map.items(), key=lambda x: ntypes[x[0]])) - edge_map = dict(sorted(edge_map.items(), key=lambda x: etypes[x[0]])) - - def _assert_is_sorted(id_map): - id_ranges = np.array(list(id_map.values())) - ids = [] - for i in range(num_parts): - ids.append(id_ranges[:, i, :]) - ids = np.array(ids).flatten() - assert np.all( - ids[:-1] <= ids[1:] - ), f"The node/edge map is not sorted: {ids}" - - _assert_is_sorted(node_map) - _assert_is_sorted(edge_map) - - return ( - RangePartitionBook( - part_id, num_parts, node_map, edge_map, ntypes, etypes - ), - part_metadata["graph_name"], - ntypes, - etypes, - ) - - def load_partition_book(part_config, part_id): """Load a graph partition book from the partition config file. @@ -1433,41 +1322,31 @@ def get_homogeneous(g, balance_ntypes): part_dir = os.path.join(out_path, "part" + str(part_id)) node_feat_file = os.path.join(part_dir, "node_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") - - os.makedirs(part_dir, mode=0o775, exist_ok=True) - save_tensors(node_feat_file, node_feats) - save_tensors(edge_feat_file, edge_feats) - + part_graph_file = os.path.join(part_dir, "graph.dgl") part_metadata["part-{}".format(part_id)] = { "node_feats": os.path.relpath(node_feat_file, out_path), "edge_feats": os.path.relpath(edge_feat_file, out_path), + "part_graph": os.path.relpath(part_graph_file, out_path), } + os.makedirs(part_dir, mode=0o775, exist_ok=True) + save_tensors(node_feat_file, node_feats) + save_tensors(edge_feat_file, edge_feats) + sort_etypes = len(g.etypes) > 1 - if not use_graphbolt: - part_graph_file = os.path.join(part_dir, "graph.dgl") - part_metadata["part-{}".format(part_id)][ - "part_graph" - ] = os.path.relpath(part_graph_file, out_path) - _save_graphs( - part_graph_file, - [part], - formats=graph_formats, - sort_etypes=sort_etypes, - ) - else: - part = _process_partitions([part], graph_formats, sort_etypes)[0] + _save_graphs( + part_graph_file, + [part], + formats=graph_formats, + sort_etypes=sort_etypes, + ) + print( + "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( + time.time() - start, get_peak_mem() + ) + ) part_config = os.path.join(out_path, graph_name + ".json") - if use_graphbolt: - kwargs["graph_formats"] = graph_formats - _dgl_partition_to_graphbolt( - part_config, - parts=parts, - part_meta=part_metadata, - **kwargs, - ) - else: - _dump_part_config(part_config, part_metadata) + _dump_part_config(part_config, part_metadata) num_cuts = sim_g.num_edges() - tot_num_inner_edges if num_parts == 1: @@ -1478,11 +1357,12 @@ def get_homogeneous(g, balance_ntypes): ) ) - print( - "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( - time.time() - start, get_peak_mem() + if use_graphbolt: + kwargs["graph_formats"] = graph_formats + dgl_partition_to_graphbolt( + part_config, + **kwargs, ) - ) if return_mapping: return orig_nids, orig_eids @@ -1530,21 +1410,8 @@ def init_type_per_edge(graph, gpb): return etype_ids -def _load_parts(part_config, part_id, parts): - """load parts from variable or dist.""" - if parts is None: - graph, _, _, _, _, _, _ = load_partition( - part_config, part_id, load_feats=False - ) - else: - graph = parts[part_id] - return graph - - def gb_convert_single_dgl_partition( part_id, - parts, - part_meta, graph_formats, part_config, store_eids, @@ -1577,18 +1444,14 @@ def gb_convert_single_dgl_partition( "Running in debug mode which means all attributes of DGL partitions" " will be saved to the new format." ) - if part_meta is None: - part_meta = _load_part_config(part_config) - num_parts = part_meta["num_parts"] - graph = _load_parts(part_config, part_id, parts) + part_meta = _load_part_config(part_config) + num_parts = part_meta["num_parts"] - gpb, _, ntypes, etypes = ( - load_partition_book(part_config, part_id) - if part_meta is None - else _load_partition_book_from_metadata(part_meta, part_id) + graph, _, _, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=False ) - + _, _, ntypes, etypes = load_partition_book(part_config, part_id) is_homo = is_homogeneous(ntypes, etypes) node_type_to_id = ( None if is_homo else {ntype: ntid for ntid, ntype in enumerate(ntypes)} @@ -1694,12 +1557,12 @@ def gb_convert_single_dgl_partition( node_type_to_id=node_type_to_id, edge_type_to_id=edge_type_to_id, ) - orig_feats_path = os.path.join( + orig_graph_path = os.path.join( os.path.dirname(part_config), - part_meta[f"part-{part_id}"]["node_feats"], + part_meta[f"part-{part_id}"]["part_graph"], ) csc_graph_path = os.path.join( - os.path.dirname(orig_feats_path), "fused_csc_sampling_graph.pt" + os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.pt" ) torch.save(csc_graph, csc_graph_path) @@ -1707,106 +1570,6 @@ def gb_convert_single_dgl_partition( # Update graph path. -def _convert_partition_to_graphbolt( - part_meta, - graph_formats, - part_config, - store_eids, - store_inner_node, - store_inner_edge, - n_jobs, - num_parts, - parts=None, -): - # [Rui] DGL partitions are always saved as homogeneous graphs even though - # the original graph is heterogeneous. But heterogeneous information like - # node/edge types are saved as node/edge data alongside with partitions. - # What needs more attention is that due to the existence of HALO nodes in - # each partition, the local node IDs are not sorted according to the node - # types. So we fail to assign ``node_type_offset`` as required by GraphBolt. - # But this is not a problem since such information is not used in sampling. - # We can simply pass None to it. - - # Iterate over partitions. - convert_with_format = partial( - gb_convert_single_dgl_partition, - parts=parts, - part_meta=part_meta, - graph_formats=graph_formats, - part_config=part_config, - store_eids=store_eids, - store_inner_node=store_inner_node, - store_inner_edge=store_inner_edge, - ) - # Need to create entirely new interpreters, because we call C++ downstream - # See https://docs.python.org/3.12/library/multiprocessing.html#contexts-and-start-methods - # and https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil - rel_path_results = [] - if n_jobs > 1 and num_parts > 1: - mp_ctx = mp.get_context("spawn") - with concurrent.futures.ProcessPoolExecutor( # pylint: disable=unexpected-keyword-arg - max_workers=min(num_parts, n_jobs), - mp_context=mp_ctx, - ) as executor: - futures = [] - for part_id in range(num_parts): - futures.append(executor.submit(convert_with_format, part_id)) - - for part_id in range(num_parts): - rel_path_results.append(futures[part_id].result()) - else: - # If running single-threaded, avoid spawning new interpreter, which is slow - for part_id in range(num_parts): - rel_path_results.append(convert_with_format(part_id)) - - for part_id in range(num_parts): - # Update graph path. - part_meta[f"part-{part_id}"]["part_graph_graphbolt"] = rel_path_results[ - part_id - ] - - # Save dtype info into partition config. - # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more - # details in #7175. - part_meta["node_map_dtype"] = "int64" - part_meta["edge_map_dtype"] = "int64" - - _dump_part_config(part_config, part_meta) - print(f"Converted partitions to GraphBolt format into {part_config}") - - -def _dgl_partition_to_graphbolt( - part_config, - part_meta, - parts, - *, - store_eids=True, - store_inner_node=False, - store_inner_edge=False, - graph_formats=None, - n_jobs=1, -): - debug_mode = "DGL_DIST_DEBUG" in os.environ - if debug_mode: - dgl_warning( - "Running in debug mode which means all attributes of DGL partitions" - " will be saved to the new format." - ) - new_part_meta = copy.deepcopy(part_meta) - num_parts = part_meta["num_parts"] - _convert_partition_to_graphbolt( - new_part_meta, - graph_formats, - part_config, - store_eids, - store_inner_node, - store_inner_edge, - n_jobs, - num_parts, - parts=parts, - ) - - def dgl_partition_to_graphbolt( part_config, *, @@ -1855,13 +1618,57 @@ def dgl_partition_to_graphbolt( part_meta = _load_part_config(part_config) new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] - _convert_partition_to_graphbolt( - new_part_meta, - graph_formats, - part_config, - store_eids, - store_inner_node, - store_inner_edge, - n_jobs, - num_parts, + + # [Rui] DGL partitions are always saved as homogeneous graphs even though + # the original graph is heterogeneous. But heterogeneous information like + # node/edge types are saved as node/edge data alongside with partitions. + # What needs more attention is that due to the existence of HALO nodes in + # each partition, the local node IDs are not sorted according to the node + # types. So we fail to assign ``node_type_offset`` as required by GraphBolt. + # But this is not a problem since such information is not used in sampling. + # We can simply pass None to it. + + # Iterate over partitions. + convert_with_format = partial( + gb_convert_single_dgl_partition, + graph_formats=graph_formats, + part_config=part_config, + store_eids=store_eids, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, ) + # Need to create entirely new interpreters, because we call C++ downstream + # See https://docs.python.org/3.12/library/multiprocessing.html#contexts-and-start-methods + # and https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil + rel_path_results = [] + if n_jobs > 1 and num_parts > 1: + mp_ctx = mp.get_context("spawn") + with concurrent.futures.ProcessPoolExecutor( # pylint: disable=unexpected-keyword-arg + max_workers=min(num_parts, n_jobs), + mp_context=mp_ctx, + ) as executor: + futures = [] + for part_id in range(num_parts): + futures.append(executor.submit(convert_with_format, part_id)) + + for part_id in range(num_parts): + rel_path_results.append(futures[part_id].result()) + else: + # If running single-threaded, avoid spawning new interpreter, which is slow + for part_id in range(num_parts): + rel_path_results.append(convert_with_format(part_id)) + + for part_id in range(num_parts): + # Update graph path. + new_part_meta[f"part-{part_id}"][ + "part_graph_graphbolt" + ] = rel_path_results[part_id] + + # Save dtype info into partition config. + # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more + # details in #7175. + new_part_meta["node_map_dtype"] = "int64" + new_part_meta["edge_map_dtype"] = "int64" + + _dump_part_config(part_config, new_part_meta) + print(f"Converted partitions to GraphBolt format into {part_config}") diff --git a/tests/distributed/test_partition.py b/tests/distributed/test_partition.py index 0f2425cb054d..32e2bdc4fea9 100644 --- a/tests/distributed/test_partition.py +++ b/tests/distributed/test_partition.py @@ -5,12 +5,11 @@ import dgl import dgl.backend as F -import dgl.sparse as dglsp +import dgl.graphbolt as gb import numpy as np import pytest import torch as th from dgl import function as fn -from dgl.base import NTYPE from dgl.distributed import ( dgl_partition_to_graphbolt, load_partition, @@ -37,19 +36,27 @@ from utils import reset_envs -def _verify_partition_data_types(part_g, use_graphbolt=False): - if not use_graphbolt: - for k, dtype in RESERVED_FIELD_DTYPE.items(): - if k in part_g.ndata: - assert part_g.ndata[k].dtype == dtype - if k in part_g.edata: - assert part_g.edata[k].dtype == dtype - else: - for k, dtype in RESERVED_FIELD_DTYPE.items(): - if k in part_g.node_attributes: - assert part_g.node_attributes[k].dtype == dtype - if k in part_g.edge_attributes: - assert part_g.edge_attributes[k].dtype == dtype +def _verify_partition_data_types(part_g): + """ + check list: + make sure nodes and edges have correct type. + """ + ndata = ( + part_g.node_attributes + if isinstance(part_g, gb.FusedCSCSamplingGraph) + else part_g.ndata + ) + edata = ( + part_g.edge_attributes + if isinstance(part_g, gb.FusedCSCSamplingGraph) + else part_g.edata + ) + + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in ndata: + assert ndata[k].dtype == dtype + if k in edata: + assert edata[k].dtype == dtype def _verify_partition_formats(part_g, formats): @@ -90,20 +97,65 @@ def create_random_hetero(): return dgl.heterograph(edges, num_nodes) -def verify_hetero_graph(g, parts, use_graphbolt=False): - if use_graphbolt: - num_nodes = {ntype: 0 for ntype in g.ntypes} - num_edges = {etype: 0 for etype in g.canonical_etypes} - for part in parts: +def _verify_graphbolt_attributes( + parts, store_inner_node, store_inner_edge, store_eids +): + """ + check list: + make sure arguments work. + """ + for part in parts: + assert store_inner_edge == ("inner_edge" in part.edge_attributes) + assert store_inner_node == ("inner_node" in part.node_attributes) + assert store_eids == (dgl.EID in part.edge_attributes) + + +def _verify_hetero_graph_node_edge_num( + g, + parts, + store_inner_edge, + debug_mode, +): + """ + check list: + make sure edge type are correct. + make sure the number of nodes in each node type are correct. + make sure the number of nodes in each node type are correct. + """ + num_nodes = {ntype: 0 for ntype in g.ntypes} + num_edges = {etype: 0 for etype in g.canonical_etypes} + for part in parts: + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + if dgl.ETYPE in edata: + assert len(g.canonical_etypes) == len(F.unique(edata[dgl.ETYPE])) + if debug_mode or isinstance(part, dgl.DGLGraph): + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask(part, ntype_id) + num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0) + num_nodes[ntype] += num_inner_nodes + if store_inner_edge or isinstance(part, dgl.DGLGraph): for etype in g.canonical_etypes: etype_id = g.get_etype_id(etype) - inner_edge_mask = _get_inner_edge_mask( - part, etype_id, use_graphbolt - ) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0) num_edges[etype] += num_inner_edges - # Verify the number of edges are correct. + # Verify the number of nodes are correct. + if debug_mode or isinstance(part, dgl.DGLGraph): + for ntype in g.ntypes: + print( + "node {}: {}, {}".format( + ntype, g.num_nodes(ntype), num_nodes[ntype] + ) + ) + assert g.num_nodes(ntype) == num_nodes[ntype] + # Verify the number of edges are correct. + if store_inner_edge or isinstance(part, dgl.DGLGraph): for etype in g.canonical_etypes: print( "edge {}: {}, {}".format( @@ -112,111 +164,196 @@ def verify_hetero_graph(g, parts, use_graphbolt=False): ) assert g.num_edges(etype) == num_edges[etype] - nids = {ntype: [] for ntype in g.ntypes} - eids = {etype: [] for etype in g.canonical_etypes} - for part in parts: - eid = th.arange(len(part.edge_attributes[dgl.EID])) - etype_arr = F.gather_row(part.type_per_edge, eid) - eid_type = F.gather_row(part.edge_attributes[dgl.EID], eid) - for etype in g.canonical_etypes: - etype_id = g.get_etype_id(etype) - eids[etype].append( - F.boolean_mask(eid_type, etype_arr == etype_id) - ) - # Make sure edge Ids fall into a range. - inner_edge_mask = _get_inner_edge_mask( - part, etype_id, use_graphbolt - ) - inner_eids = np.sort( - F.asnumpy( - F.boolean_mask( - part.edge_attributes[dgl.EID], inner_edge_mask - ) - ) - ) - assert np.all( - inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1) - ) - return - num_nodes = {ntype: 0 for ntype in g.ntypes} - num_edges = {etype: 0 for etype in g.canonical_etypes} - for part in parts: - assert len(g.canonical_etypes) == len(F.unique(part.edata[dgl.ETYPE])) - for ntype in g.ntypes: - ntype_id = g.get_ntype_id(ntype) - inner_node_mask = _get_inner_node_mask(part, ntype_id) - num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0) - num_nodes[ntype] += num_inner_nodes - for etype in g.canonical_etypes: - etype_id = g.get_etype_id(etype) - inner_edge_mask = _get_inner_edge_mask(part, etype_id) - num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0) - num_edges[etype] += num_inner_edges - # Verify the number of nodes are correct. +def _verify_edge_id_range_hetero( + g, + part, + eids, +): + """ + check list: + make sure inner_eids fall into a range. + make sure all edges are included. + """ + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + etype = ( + part.type_per_edge + if isinstance(part, gb.FusedCSCSamplingGraph) + else edata[dgl.ETYPE] + ) + eid = th.arange(len(edata[dgl.EID])) + etype_arr = F.gather_row(etype, eid) + eid_arr = F.gather_row(edata[dgl.EID], eid) + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + eids[etype].append(F.boolean_mask(eid_arr, etype_arr == etype_id)) + # Make sure edge Ids fall into a range. + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = np.sort( + F.asnumpy(F.boolean_mask(edata[dgl.EID], inner_edge_mask)) + ) + assert np.all( + inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1) + ) + return eids + + +def _verify_node_id_range_hetero(g, part, nids): + """ + check list: + make sure inner nodes have Ids fall into a range. + """ for ntype in g.ntypes: - print( - "node {}: {}, {}".format( - ntype, g.num_nodes(ntype), num_nodes[ntype] - ) + ntype_id = g.get_ntype_id(ntype) + # Make sure inner nodes have Ids fall into a range. + inner_node_mask = _get_inner_node_mask(part, ntype_id) + inner_nids = F.boolean_mask( + part.node_attributes[dgl.NID], inner_node_mask ) - assert g.num_nodes(ntype) == num_nodes[ntype] - # Verify the number of edges are correct. - for etype in g.canonical_etypes: - print( - "edge {}: {}, {}".format( - etype, g.num_edges(etype), num_edges[etype] + assert np.all( + F.asnumpy( + inner_nids + == F.arange( + F.as_scalar(inner_nids[0]), + F.as_scalar(inner_nids[-1]) + 1, + ) ) ) - assert g.num_edges(etype) == num_edges[etype] + nids[ntype].append(inner_nids) + return nids + +def _verify_graph_attributes_hetero( + g, + parts, + store_inner_edge, + store_inner_node, +): + """ + check list: + make sure edge ids fall into a range. + make sure inner nodes have Ids fall into a range. + make sure all nodes is included. + make sure all edges is included. + """ nids = {ntype: [] for ntype in g.ntypes} eids = {etype: [] for etype in g.canonical_etypes} - for part in parts: - _, _, eid = part.edges(form="all") - etype_arr = F.gather_row(part.edata[dgl.ETYPE], eid) - eid_type = F.gather_row(part.edata[dgl.EID], eid) - for etype in g.canonical_etypes: - etype_id = g.get_etype_id(etype) - eids[etype].append(F.boolean_mask(eid_type, etype_arr == etype_id)) - # Make sure edge Ids fall into a range. - inner_edge_mask = _get_inner_edge_mask(part, etype_id) - inner_eids = np.sort( - F.asnumpy(F.boolean_mask(part.edata[dgl.EID], inner_edge_mask)) - ) - assert np.all( - inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1) - ) + # check edge id. + if store_inner_edge or isinstance(parts[0], dgl.DGLGraph): + for part in parts: + # collect eids + eids = _verify_edge_id_range_hetero(g, part, eids) + for etype in eids: + eids_type = F.cat(eids[etype], 0) + uniq_ids = F.unique(eids_type) + # We should get all nodes. + assert len(uniq_ids) == g.num_edges(etype) + + # check node id. + if store_inner_node or isinstance(parts[0], dgl.DGLGraph): + for part in parts: + nids = _verify_node_id_range_hetero(g, part, nids) + for ntype in nids: + nids_type = F.cat(nids[ntype], 0) + uniq_ids = F.unique(nids_type) + # We should get all nodes. + assert len(uniq_ids) == g.num_nodes(ntype) - for ntype in g.ntypes: - ntype_id = g.get_ntype_id(ntype) - # Make sure inner nodes have Ids fall into a range. - inner_node_mask = _get_inner_node_mask(part, ntype_id) - inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask) - assert np.all( - F.asnumpy( - inner_nids - == F.arange( - F.as_scalar(inner_nids[0]), - F.as_scalar(inner_nids[-1]) + 1, - ) - ) - ) - nids[ntype].append(inner_nids) - for ntype in nids: - nids_type = F.cat(nids[ntype], 0) - uniq_ids = F.unique(nids_type) - # We should get all nodes. - assert len(uniq_ids) == g.num_nodes(ntype) - for etype in eids: - eids_type = F.cat(eids[etype], 0) - uniq_ids = F.unique(eids_type) - assert len(uniq_ids) == g.num_edges(etype) - # TODO(zhengda) this doesn't check 'part_id' +def _verify_hetero_graph( + g, + parts, + store_eids=False, + store_inner_edge=False, + store_inner_node=False, + debug_mode=False, +): + _verify_hetero_graph_node_edge_num( + g, + parts, + store_inner_edge=store_inner_edge, + debug_mode=debug_mode, + ) + if store_eids: + _verify_graph_attributes_hetero( + g, + parts, + store_inner_edge=store_inner_edge, + store_inner_node=store_inner_node, + ) + + +def _verify_node_feats(g, part, gpb, orig_nids, node_feats, is_homo=False): + for ntype in g.ntypes: + ndata = ( + part.node_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.ndata + ) + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask( + part, + ntype_id, + (gpb if isinstance(part, gb.FusedCSCSamplingGraph) else None), + ) + inner_nids = F.boolean_mask(ndata[dgl.NID], inner_node_mask) + ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) + partid = gpb.nid2partid(inner_type_nids, ntype) + if is_homo: + assert np.all(F.asnumpy(ntype_ids) == ntype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + if is_homo: + orig_id = orig_nids[inner_type_nids] + else: + orig_id = orig_nids[ntype][inner_type_nids] + local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + + for name in g.nodes[ntype].data: + if name in [dgl.NID, "inner_node"]: + continue + true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) + ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) + assert np.all(F.asnumpy(ndata == true_feats)) -def verify_graph_feats( + +def _verify_edge_feats(g, part, gpb, orig_eids, edge_feats, is_homo=False): + for etype in g.canonical_etypes: + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = F.boolean_mask(edata[dgl.EID], inner_edge_mask) + etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) + partid = gpb.eid2partid(inner_type_eids, etype) + assert np.all(F.asnumpy(etype_ids) == etype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + if is_homo: + orig_id = orig_eids[inner_type_eids] + else: + orig_id = orig_eids[etype][inner_type_eids] + local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) + + for name in g.edges[etype].data: + if name in [dgl.EID, "inner_edge"]: + continue + true_feats = F.gather_row(g.edges[etype].data[name], orig_id) + edata = F.gather_row( + edge_feats[_etype_tuple_to_str(etype) + "/" + name], + local_eids, + ) + assert np.all(F.asnumpy(edata == true_feats)) + + +def verify_graph_feats_hetero_dgl( g, gpb, part, @@ -224,100 +361,70 @@ def verify_graph_feats( edge_feats, orig_nids, orig_eids, - use_graphbolt=False, ): - if use_graphbolt: - for ntype in g.ntypes: - ntype_id = g.get_ntype_id(ntype) - inner_node_mask = _get_inner_node_mask( - part, ntype_id, use_graphbolt - ) - inner_nids = F.boolean_mask( - part.node_attributes[dgl.NID], inner_node_mask - ) - ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) - partid = gpb.nid2partid(inner_type_nids, ntype) - assert np.all(F.asnumpy(ntype_ids) == ntype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) + """ + check list: + make sure the feats of nodes and edges are correct + """ + _verify_node_feats(g, part, gpb, orig_nids, node_feats) - orig_id = orig_nids[ntype][inner_type_nids] - local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + _verify_edge_feats(g, part, gpb, orig_eids, edge_feats) - for name in g.nodes[ntype].data: - if name in [dgl.NID, "inner_node"]: - continue - true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) - ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) - assert np.all(F.asnumpy(ndata == true_feats)) - for etype in g.canonical_etypes: - etype_id = g.get_etype_id(etype) - inner_edge_mask = _get_inner_edge_mask( - part, etype_id, use_graphbolt - ) - inner_eids = F.boolean_mask( - part.edge_attributes[dgl.EID], inner_edge_mask +def verify_graph_feats_gb( + g, + gpbs, + parts, + tot_node_feats, + tot_edge_feats, + orig_nids, + orig_eids, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, + store_inner_node=False, + store_inner_edge=False, + store_eids=False, + is_homo=False, +): + """ + check list: + make sure the feats of nodes and edges are correct + """ + for part_id in range(len(parts)): + part = parts[part_id] + gpb = gpbs[part_id] + node_feats = tot_node_feats[part_id] + edge_feats = tot_edge_feats[part_id] + if store_inner_node: + _verify_node_feats( + g, + part, + gpb, + orig_nids, + node_feats, + is_homo=is_homo, ) - etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) - partid = gpb.eid2partid(inner_type_eids, etype) - assert np.all(F.asnumpy(etype_ids) == etype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) - - orig_id = orig_eids[etype][inner_type_eids] - local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) - - for name in g.edges[etype].data: - if name in [dgl.EID, "inner_edge"]: - continue - true_feats = F.gather_row(g.edges[etype].data[name], orig_id) - edata = F.gather_row( - edge_feats[_etype_tuple_to_str(etype) + "/" + name], - local_eids, - ) - assert np.all(F.asnumpy(edata == true_feats)) - else: - for ntype in g.ntypes: - ntype_id = g.get_ntype_id(ntype) - inner_node_mask = _get_inner_node_mask( - part, ntype_id, use_graphbolt + if store_inner_edge and store_eids: + _verify_edge_feats( + g, + part, + gpb, + orig_eids, + edge_feats, + is_homo=is_homo, ) - inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask) - ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) - partid = gpb.nid2partid(inner_type_nids, ntype) - assert np.all(F.asnumpy(ntype_ids) == ntype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) - - orig_id = orig_nids[ntype][inner_type_nids] - local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) - for name in g.nodes[ntype].data: - if name in [dgl.NID, "inner_node"]: - continue - true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) - ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) - assert np.all(F.asnumpy(ndata == true_feats)) - - for etype in g.canonical_etypes: - etype_id = g.get_etype_id(etype) - inner_edge_mask = _get_inner_edge_mask(part, etype_id) - inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask) - etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) - partid = gpb.eid2partid(inner_type_eids, etype) - assert np.all(F.asnumpy(etype_ids) == etype_id) - assert np.all(F.asnumpy(partid) == gpb.partid) - - orig_id = orig_eids[etype][inner_type_eids] - local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) - - for name in g.edges[etype].data: - if name in [dgl.EID, "inner_edge"]: - continue - true_feats = F.gather_row(g.edges[etype].data[name], orig_id) - edata = F.gather_row( - edge_feats[_etype_tuple_to_str(etype) + "/" + name], - local_eids, - ) - assert np.all(F.asnumpy(edata == true_feats)) + _verify_shuffled_labels_gb( + g, + shuffled_labels, + shuffled_edata, + orig_nids, + orig_eids, + test_ntype, + test_etype, + ) def check_hetero_partition( @@ -429,7 +536,7 @@ def check_hetero_partition( assert len(orig_eids1) == len(orig_eids2) assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) parts.append(part_g) - verify_graph_feats( + verify_graph_feats_hetero_dgl( hg, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids ) @@ -437,8 +544,7 @@ def check_hetero_partition( shuffled_elabels.append( edge_feats[_etype_tuple_to_str(test_etype) + "/labels"] ) - verify_hetero_graph(hg, parts) - + _verify_hetero_graph(hg, parts) shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0)) orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype) @@ -905,8 +1011,6 @@ def test_dgl_partition_to_graphbolt_homo( orig_g.ndata["inner_node"], new_g.node_attributes["inner_node"], ) - else: - assert "inner_node" not in new_g.node_attributes if store_eids or debug_mode: assert orig_g.edata[dgl.EID].dtype == th.int64 assert new_g.edge_attributes[dgl.EID].dtype == th.int64 @@ -914,8 +1018,6 @@ def test_dgl_partition_to_graphbolt_homo( orig_g.edata[dgl.EID][orig_eids], new_g.edge_attributes[dgl.EID], ) - else: - assert dgl.EID not in new_g.edge_attributes if store_inner_edge or debug_mode: assert orig_g.edata["inner_edge"].dtype == th.uint8 assert new_g.edge_attributes["inner_edge"].dtype == th.uint8 @@ -923,8 +1025,6 @@ def test_dgl_partition_to_graphbolt_homo( orig_g.edata["inner_edge"][orig_eids], new_g.edge_attributes["inner_edge"], ) - else: - assert "inner_edge" not in new_g.edge_attributes assert new_g.type_per_edge is None assert new_g.node_type_to_id is None assert new_g.edge_type_to_id is None @@ -1031,16 +1131,12 @@ def test_dgl_partition_to_graphbolt_hetero( orig_g.ndata["inner_node"], new_g.node_attributes["inner_node"], ) - else: - assert "inner_node" not in new_g.node_attributes if debug_mode: assert orig_g.ndata[dgl.NTYPE].dtype == th.int32 assert new_g.node_attributes[dgl.NTYPE].dtype == th.int8 assert th.equal( orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE] ) - else: - assert dgl.NTYPE not in new_g.node_attributes if store_eids or debug_mode: assert orig_g.edata[dgl.EID].dtype == th.int64 assert new_g.edge_attributes[dgl.EID].dtype == th.int64 @@ -1048,8 +1144,6 @@ def test_dgl_partition_to_graphbolt_hetero( orig_g.edata[dgl.EID][orig_eids], new_g.edge_attributes[dgl.EID], ) - else: - assert dgl.EID not in new_g.edge_attributes if store_inner_edge or debug_mode: assert orig_g.edata["inner_edge"].dtype == th.uint8 assert new_g.edge_attributes["inner_edge"].dtype == th.uint8 @@ -1057,8 +1151,6 @@ def test_dgl_partition_to_graphbolt_hetero( orig_g.edata["inner_edge"], new_g.edge_attributes["inner_edge"], ) - else: - assert "inner_edge" not in new_g.edge_attributes if debug_mode: assert orig_g.edata[dgl.ETYPE].dtype == th.int32 assert new_g.edge_attributes[dgl.ETYPE].dtype == th.int8 @@ -1066,8 +1158,6 @@ def test_dgl_partition_to_graphbolt_hetero( orig_g.edata[dgl.ETYPE][orig_eids], new_g.edge_attributes[dgl.ETYPE], ) - else: - assert dgl.ETYPE not in new_g.edge_attributes assert th.equal( orig_g.edata[dgl.ETYPE][orig_eids], new_g.type_per_edge ) @@ -1189,14 +1279,230 @@ def test_not_sorted_node_edge_map(): assert gpb.local_etype_offset == [0, 500, 1100, 1800, 2600] +def _get_part_IDs(part_g): + # These are partition-local IDs. + num_columns = part_g.csc_indptr.diff() + part_src_ids = part_g.indices + part_dst_ids = th.arange(part_g.total_num_nodes).repeat_interleave( + num_columns + ) + # These are reshuffled global homogeneous IDs. + part_src_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_src_ids) + part_dst_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_dst_ids) + return part_src_ids, part_dst_ids + + +def _verify_orig_edge_IDs_gb( + g, + orig_nids, + orig_eids, + part_eids, + part_src_ids, + part_dst_ids, + src_ntype=None, + dst_ntype=None, + etype=None, +): + """ + check list: + make sure orig edge id are correct after + """ + if src_ntype is not None and dst_ntype is not None: + orig_src_nid = orig_nids[src_ntype] + orig_dst_nid = orig_nids[dst_ntype] + else: + orig_src_nid = orig_nids + orig_dst_nid = orig_nids + orig_src_ids = F.gather_row(orig_src_nid, part_src_ids) + orig_dst_ids = F.gather_row(orig_dst_nid, part_dst_ids) + if etype is not None: + orig_eids = orig_eids[etype] + orig_eids1 = F.gather_row(orig_eids, part_eids) + orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids, etype=etype) + assert len(orig_eids1) == len(orig_eids2) + assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) + + +def _verify_metadata_gb(gpb, g, num_parts, part_id, part_sizes): + """ + check list: + make sure the number of nodes and edges is correct. + make sure the number of parts is correct. + make sure the number of nodes and edges in each parts os corrcet. + """ + assert gpb._num_nodes() == g.num_nodes() + assert gpb._num_edges() == g.num_edges() + + assert gpb.num_partitions() == num_parts + gpb_meta = gpb.metadata() + assert len(gpb_meta) == num_parts + assert len(gpb.partid2nids(part_id)) == gpb_meta[part_id]["num_nodes"] + assert len(gpb.partid2eids(part_id)) == gpb_meta[part_id]["num_edges"] + part_sizes.append( + (gpb_meta[part_id]["num_nodes"], gpb_meta[part_id]["num_edges"]) + ) + + +def _verify_local_id_gb(part_g, part_id, gpb): + """ + check list: + make sure the type of local id is correct. + make sure local id have a right order. + """ + nid = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + local_nid = gpb.nid2localnid(nid, part_id) + assert F.dtype(local_nid) in (F.int64, F.int32) + assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid))) + eid = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + local_eid = gpb.eid2localeid(eid, part_id) + assert F.dtype(local_eid) in (F.int64, F.int32) + assert np.all(np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid))) + return local_nid, local_eid + + +def _verify_map_gb( + part_g, + part_id, + gpb, +): + """ + check list: + make sure the map node and its data type is correct. + """ + # Check the node map. + local_nodes = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + inner_node_index = F.nonzero_1d(part_g.node_attributes["inner_node"]) + mapping_nodes = gpb.partid2nids(part_id) + assert F.dtype(mapping_nodes) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(mapping_nodes)) + ) + assert np.all( + F.asnumpy(inner_node_index) == np.arange(len(inner_node_index)) + ) + + # Check the edge map. + + local_edges = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + inner_edge_index = F.nonzero_1d(part_g.edge_attributes["inner_edge"]) + mapping_edges = gpb.partid2eids(part_id) + assert F.dtype(mapping_edges) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(mapping_edges)) + ) + assert np.all( + F.asnumpy(inner_edge_index) == np.arange(len(inner_edge_index)) + ) + return local_nodes, local_edges + + +def _verify_local_and_map_id_gb( + part_g, + part_id, + gpb, + store_inner_node, + store_inner_edge, + store_eids, +): + """ + check list: + make sure local id are correct. + make sure mapping id are correct. + """ + if store_inner_node and store_inner_edge and store_eids: + _verify_local_id_gb(part_g, part_id, gpb) + _verify_map_gb(part_g, part_id, gpb) + + +def _verify_orig_IDs_gb( + part_g, + gpb, + g, + is_homo=False, + part_src_ids=None, + part_dst_ids=None, + src_ntype_ids=None, + dst_ntype_ids=None, + orig_nids=None, + orig_eids=None, +): + """ + check list: + make sure orig edge id are correct. + make sure hetero ntype id are correct. + """ + part_eids = part_g.edge_attributes[dgl.EID] + if is_homo: + _verify_orig_edge_IDs_gb( + g, orig_nids, orig_eids, part_eids, part_src_ids, part_dst_ids + ) + local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]] + local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]] + part_g.node_attributes["feats"] = F.gather_row( + g.ndata["feats"], local_orig_nids + ) + part_g.edge_attributes["feats"] = F.gather_row( + g.edata["feats"], local_orig_eids + ) + else: + etype_ids, part_eids = gpb.map_to_per_etype(part_eids) + # `IdMap` is in int64 by default. + assert etype_ids.dtype == F.int64 + + # These are original per-type IDs. + for etype_id, etype in enumerate(g.canonical_etypes): + part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id) + src_ntype_ids1 = F.boolean_mask( + src_ntype_ids, etype_ids == etype_id + ) + part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id) + dst_ntype_ids1 = F.boolean_mask( + dst_ntype_ids, etype_ids == etype_id + ) + part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id) + assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0])) + assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0])) + src_ntype = g.ntypes[F.as_scalar(src_ntype_ids1[0])] + dst_ntype = g.ntypes[F.as_scalar(dst_ntype_ids1[0])] + + _verify_orig_edge_IDs_gb( + g, + orig_nids, + orig_eids, + part_eids1, + part_src_ids1, + part_dst_ids1, + src_ntype, + dst_ntype, + etype, + ) + + @pytest.mark.parametrize("part_method", ["metis", "random"]) @pytest.mark.parametrize("num_parts", [1, 4]) +@pytest.mark.parametrize("store_eids", [True, False]) +@pytest.mark.parametrize("store_inner_node", [True, False]) +@pytest.mark.parametrize("store_inner_edge", [True, False]) @pytest.mark.parametrize("debug_mode", [True, False]) def test_partition_graph_graphbolt_homo( part_method, num_parts, + store_eids, + store_inner_node, + store_inner_edge, debug_mode, - num_trainers_per_machine=1, ): reset_envs() if debug_mode: @@ -1211,8 +1517,6 @@ def test_partition_graph_graphbolt_homo( g.edata["feats"] = F.tensor( np.random.randn(g.num_edges(), 10), F.float32 ) - g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h")) - g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh")) orig_nids, orig_eids = partition_graph( g, @@ -1221,194 +1525,324 @@ def test_partition_graph_graphbolt_homo( test_dir, part_method=part_method, use_graphbolt=True, - store_eids=True, - store_inner_node=True, - store_inner_edge=True, + store_eids=store_eids, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, return_mapping=True, ) - part_sizes = [] - shuffled_labels = [] - shuffled_edata = [] - part_config = os.path.join(test_dir, f"{graph_name}.json") - for i in range(num_parts): - part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( - part_config, i, load_feats=True, use_graphbolt=True - ) - if num_trainers_per_machine > 1: - for ntype in g.ntypes: - name = ntype + "/trainer_id" - assert name in node_feats - part_ids = F.floor_div( - node_feats[name], num_trainers_per_machine - ) - assert np.all(F.asnumpy(part_ids) == i) - for etype in g.canonical_etypes: - name = _etype_tuple_to_str(etype) + "/trainer_id" - assert name in edge_feats - part_ids = F.floor_div( - edge_feats[name], num_trainers_per_machine - ) - assert np.all(F.asnumpy(part_ids) == i) - - # Check the metadata - assert gpb._num_nodes() == g.num_nodes() - assert gpb._num_edges() == g.num_edges() - - assert gpb.num_partitions() == num_parts - gpb_meta = gpb.metadata() - assert len(gpb_meta) == num_parts - assert len(gpb.partid2nids(i)) == gpb_meta[i]["num_nodes"] - assert len(gpb.partid2eids(i)) == gpb_meta[i]["num_edges"] - part_sizes.append( - (gpb_meta[i]["num_nodes"], gpb_meta[i]["num_edges"]) - ) + if debug_mode: + store_eids = store_inner_node = store_inner_edge = True - nid = F.boolean_mask( - part_g.node_attributes[dgl.NID], - part_g.node_attributes["inner_node"], - ) - local_nid = gpb.nid2localnid(nid, i) - assert F.dtype(local_nid) in (F.int64, F.int32) - assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid))) - eid = F.boolean_mask( - part_g.edge_attributes[dgl.EID], - part_g.edge_attributes["inner_edge"], - ) - local_eid = gpb.eid2localeid(eid, i) - assert F.dtype(local_eid) in (F.int64, F.int32) - assert np.all( - np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid)) - ) + _verify_graphbolt_part( + g, + test_dir, + orig_nids, + orig_eids, + graph_name, + num_parts, + store_inner_node, + store_inner_edge, + store_eids, + is_homo=True, + ) - # Check the node map. - local_nodes = F.boolean_mask( - part_g.node_attributes[dgl.NID], - part_g.node_attributes["inner_node"], - ) - llocal_nodes = F.nonzero_1d(part_g.node_attributes["inner_node"]) - local_nodes1 = gpb.partid2nids(i) - assert F.dtype(local_nodes1) in (F.int32, F.int64) - assert np.all( - np.sort(F.asnumpy(local_nodes)) - == np.sort(F.asnumpy(local_nodes1)) - ) - assert np.all( - F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)) - ) - # Check the edge map. - local_edges = F.boolean_mask( - part_g.edge_attributes[dgl.EID], - part_g.edge_attributes["inner_edge"], - ) - llocal_edges = F.nonzero_1d(part_g.edge_attributes["inner_edge"]) - local_edges1 = gpb.partid2eids(i) - assert F.dtype(local_edges1) in (F.int32, F.int64) - assert np.all( - np.sort(F.asnumpy(local_edges)) - == np.sort(F.asnumpy(local_edges1)) - ) - assert np.all( - F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)) - ) +def _verify_constructed_id_gb(part_sizes, gpb): + """ + verify the part id of each node by constructed nids. + check list: + make sure each node' part id and its type are corect + """ + node_map = [] + edge_map = [] + for part_i, (num_nodes, num_edges) in enumerate(part_sizes): + node_map.append(np.ones(num_nodes) * part_i) + edge_map.append(np.ones(num_edges) * part_i) + node_map = np.concatenate(node_map) + edge_map = np.concatenate(edge_map) + nid2pid = gpb.nid2partid(F.arange(0, len(node_map))) + assert F.dtype(nid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(nid2pid) == node_map) + eid2pid = gpb.eid2partid(F.arange(0, len(edge_map))) + assert F.dtype(eid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(eid2pid) == edge_map) - # Verify the mapping between the reshuffled IDs and the original IDs. - indices, indptr = part_g.indices, part_g.csc_indptr - adj_matrix = dglsp.from_csc(indptr, indices) - part_src_ids, part_dst_ids = adj_matrix.coo() - part_src_ids = F.gather_row( - part_g.node_attributes[dgl.NID], part_src_ids - ) - part_dst_ids = F.gather_row( - part_g.node_attributes[dgl.NID], part_dst_ids - ) - part_eids = part_g.edge_attributes[dgl.EID] - orig_src_ids = F.gather_row(orig_nids, part_src_ids) - orig_dst_ids = F.gather_row(orig_nids, part_dst_ids) - orig_eids1 = F.gather_row(orig_eids, part_eids) - orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids) - assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0] - assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) - local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]] - local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]] - part_g.node_attributes["feats"] = F.gather_row( - g.ndata["feats"], local_orig_nids - ) - part_g.edge_attributes["feats"] = F.gather_row( - g.edata["feats"], local_orig_eids - ) - local_nodes = orig_nids[local_nodes] - local_edges = orig_eids[local_edges] - - # part_g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h")) - # part_g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh")) - # part_g.node_attributes["h"] = adj_matrix@part_g.node_attributes["h"] - - # assert F.allclose( - # F.gather_row(g.ndata["h"], local_nodes), - # F.gather_row(part_g.node_attributes["h"], llocal_nodes), - # ) - # assert F.allclose( - # F.gather_row(g.ndata["eh"], local_nodes), - # F.gather_row(part_g.node_attributes["eh"], llocal_nodes), - # ) - - for name in ["labels", "feats"]: - assert "_N/" + name in node_feats - assert node_feats["_N/" + name].shape[0] == len(local_nodes) - true_feats = F.gather_row(g.ndata[name], local_nodes) - ndata = F.gather_row(node_feats["_N/" + name], local_nid) - assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata)) - for name in ["feats"]: - efeat_name = _etype_tuple_to_str(DEFAULT_ETYPE) + "/" + name - assert efeat_name in edge_feats - assert edge_feats[efeat_name].shape[0] == len(local_edges) - true_feats = F.gather_row(g.edata[name], local_edges) - edata = F.gather_row(edge_feats[efeat_name], local_eid) - assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata)) - - # This only works if node/edge IDs are shuffled. - shuffled_labels.append(node_feats["_N/labels"]) - shuffled_edata.append(edge_feats["_N:_E:_N/feats"]) - - # Verify that we can reconstruct node/edge data for original IDs. - shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) - shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0)) - orig_labels = np.zeros( - shuffled_labels.shape, dtype=shuffled_labels.dtype +def _verify_shuffled_labels_gb( + g, + shuffled_labels, + shuffled_edata, + orig_nids, + orig_eids, + test_ntype=None, + test_etype=None, +): + """ + check list: + make sure node data are correct. + make sure edge data are correct. + """ + shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) + shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0)) + orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype) + orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype) + + orig_nid = orig_nids if test_ntype is None else orig_nids[test_ntype] + orig_eid = orig_eids if test_etype is None else orig_eids[test_etype] + nlabel = ( + g.ndata["labels"] + if test_ntype is None + else g.nodes[test_ntype].data["labels"] + ) + edata = ( + g.edata["feats"] + if test_etype is None + else g.edges[test_etype].data["labels"] + ) + + orig_labels[F.asnumpy(orig_nid)] = shuffled_labels + orig_edata[F.asnumpy(orig_eid)] = shuffled_edata + assert np.all(orig_labels == F.asnumpy(nlabel)) + assert np.all(orig_edata == F.asnumpy(edata)) + + +def _verify_node_type_ID_gb(part_g, gpb): + """ + check list: + make sure ntype id have correct data type + """ + part_src_ids, part_dst_ids = _get_part_IDs(part_g) + # These are reshuffled per-type IDs. + src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids) + dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids) + # `IdMap` is in int64 by default. + assert src_ntype_ids.dtype == F.int64 + assert dst_ntype_ids.dtype == F.int64 + + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_ntype(F.tensor([0], F.int32)) + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_etype(F.tensor([0], F.int32)) + return ( + part_src_ids, + part_dst_ids, + src_ntype_ids, + part_src_ids, + dst_ntype_ids, + ) + + +def _verify_IDs_gb( + g, + part_g, + part_id, + gpb, + part_sizes, + orig_nids, + orig_eids, + store_inner_node, + store_inner_edge, + store_eids, + is_homo, +): + # verify local id and mapping id + _verify_local_and_map_id_gb( + part_g, + part_id, + gpb, + store_inner_node, + store_inner_edge, + store_eids, + ) + + # Verify the mapping between the reshuffled IDs and the original IDs. + ( + part_src_ids, + part_dst_ids, + src_ntype_ids, + part_src_ids, + dst_ntype_ids, + ) = _verify_node_type_ID_gb(part_g, gpb) + + if store_eids: + _verify_orig_IDs_gb( + part_g, + gpb, + g, + part_src_ids=part_src_ids, + part_dst_ids=part_dst_ids, + src_ntype_ids=src_ntype_ids, + dst_ntype_ids=dst_ntype_ids, + orig_nids=orig_nids, + orig_eids=orig_eids, + is_homo=is_homo, + ) + _verify_constructed_id_gb(part_sizes, gpb) + + +def _collect_data_gb( + parts, + part_g, + gpbs, + gpb, + tot_node_feats, + node_feats, + tot_edge_feats, + edge_feats, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, +): + if test_ntype != None: + shuffled_labels.append(node_feats[test_ntype + "/labels"]) + shuffled_edata.append( + edge_feats[_etype_tuple_to_str(test_etype) + "/labels"] + ) + else: + shuffled_labels.append(node_feats["_N/labels"]) + shuffled_edata.append(edge_feats["_N:_E:_N/feats"]) + parts.append(part_g) + gpbs.append(gpb) + tot_node_feats.append(node_feats) + tot_edge_feats.append(edge_feats) + + +def _verify_graphbolt_part( + g, + test_dir, + orig_nids, + orig_eids, + graph_name, + num_parts, + store_inner_node, + store_inner_edge, + store_eids, + test_ntype=None, + test_etype=None, + is_homo=False, +): + """ + check list: + _verify_metadata_gb: + data type, ID's order and ID's number of edges and nodes + _verify_IDs_gb: + local id, mapping id,node type id, orig edge, hetero ntype id + verify_graph_feats_gb: + nodes and edges' feats + _verify_graphbolt_attributes: + arguments + """ + parts = [] + tot_node_feats = [] + tot_edge_feats = [] + shuffled_labels = [] + shuffled_edata = [] + part_sizes = [] + gpbs = [] + part_config = os.path.join(test_dir, f"{graph_name}.json") + # test each part + for part_id in range(num_parts): + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=True, use_graphbolt=True + ) + # verify metadata + _verify_metadata_gb( + gpb, + g, + num_parts, + part_id, + part_sizes, ) - orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype) - orig_labels[F.asnumpy(orig_nids)] = shuffled_labels - orig_edata[F.asnumpy(orig_eids)] = shuffled_edata - assert np.all(orig_labels == F.asnumpy(g.ndata["labels"])) - assert np.all(orig_edata == F.asnumpy(g.edata["feats"])) - - node_map = [] - edge_map = [] - for i, (num_nodes, num_edges) in enumerate(part_sizes): - node_map.append(np.ones(num_nodes) * i) - edge_map.append(np.ones(num_edges) * i) - node_map = np.concatenate(node_map) - edge_map = np.concatenate(edge_map) - nid2pid = gpb.nid2partid(F.arange(0, len(node_map))) - assert F.dtype(nid2pid) in (F.int32, F.int64) - assert np.all(F.asnumpy(nid2pid) == node_map) - eid2pid = gpb.eid2partid(F.arange(0, len(edge_map))) - assert F.dtype(eid2pid) in (F.int32, F.int64) - assert np.all(F.asnumpy(eid2pid) == edge_map) + + # verify eid and nid + _verify_IDs_gb( + g, + part_g, + part_id, + gpb, + part_sizes, + orig_nids, + orig_eids, + store_inner_node, + store_inner_edge, + store_eids, + is_homo, + ) + + # collect shuffled data and parts + _collect_data_gb( + parts, + part_g, + gpbs, + gpb, + tot_node_feats, + node_feats, + tot_edge_feats, + edge_feats, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, + ) + + # verify graph feats + verify_graph_feats_gb( + g, + gpbs, + parts, + tot_node_feats, + tot_edge_feats, + orig_nids, + orig_eids, + shuffled_labels=shuffled_labels, + shuffled_edata=shuffled_edata, + test_ntype=test_ntype, + test_etype=test_etype, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, + store_eids=store_eids, + is_homo=is_homo, + ) + + _verify_graphbolt_attributes( + parts, store_inner_node, store_inner_edge, store_eids + ) + + return parts + + +def _verify_original_IDs_type_hetero(hg, orig_nids, orig_eids): + """ + check list: + make sure type of nodes and edges' ids are correct. + make sure nodes and edges' number in each type is correct. + """ + assert len(orig_nids) == len(hg.ntypes) + assert len(orig_eids) == len(hg.canonical_etypes) + for ntype in hg.ntypes: + assert len(orig_nids[ntype]) == hg.num_nodes(ntype) + assert F.dtype(orig_nids[ntype]) in (F.int64, F.int32) + for etype in hg.canonical_etypes: + assert len(orig_eids[etype]) == hg.num_edges(etype) + assert F.dtype(orig_eids[etype]) in (F.int64, F.int32) @pytest.mark.parametrize("part_method", ["metis", "random"]) @pytest.mark.parametrize("num_parts", [1, 4]) +@pytest.mark.parametrize("store_eids", [True, False]) +@pytest.mark.parametrize("store_inner_node", [True, False]) +@pytest.mark.parametrize("store_inner_edge", [True, False]) @pytest.mark.parametrize("debug_mode", [True, False]) def test_partition_graph_graphbolt_hetero( part_method, num_parts, + store_eids, + store_inner_node, + store_inner_edge, debug_mode, n_jobs=1, - num_trainers_per_machine=1, ): test_ntype = "n1" test_etype = ("n1", "r1", "n2") @@ -1430,7 +1864,6 @@ def test_partition_graph_graphbolt_hetero( hg.edges[test_etype].data["labels"] = F.arange( 0, hg.num_edges(test_etype) ) - num_hops = 1 orig_nids, orig_eids = partition_graph( hg, graph_name, @@ -1440,132 +1873,37 @@ def test_partition_graph_graphbolt_hetero( return_mapping=True, num_trainers_per_machine=1, use_graphbolt=True, - store_eids=True, - store_inner_node=True, - store_inner_edge=True, + store_eids=store_eids, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, n_jobs=n_jobs, ) - assert len(orig_nids) == len(hg.ntypes) - assert len(orig_eids) == len(hg.canonical_etypes) - for ntype in hg.ntypes: - assert len(orig_nids[ntype]) == hg.num_nodes(ntype) - for etype in hg.canonical_etypes: - assert len(orig_eids[etype]) == hg.num_edges(etype) - parts = [] - shuffled_labels = [] - shuffled_elabels = [] - part_config = os.path.join(test_dir, f"{graph_name}.json") - for part_id in range(num_parts): - part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( - part_config, part_id, load_feats=True, use_graphbolt=True - ) - if num_trainers_per_machine > 1: - for ntype in hg.ntypes: - name = ntype + "/trainer_id" - assert name in node_feats - part_ids = F.floor_div( - node_feats[name], num_trainers_per_machine - ) - assert np.all(F.asnumpy(part_ids) == part_id) - for etype in hg.canonical_etypes: - name = _etype_tuple_to_str(etype) + "/trainer_id" - assert name in edge_feats - part_ids = F.floor_div( - edge_feats[name], num_trainers_per_machine - ) - assert np.all(F.asnumpy(part_ids) == part_id) - - # Verify the mapping between the reshuffled IDs and the original IDs. - # These are partition-local IDs. - indices, indptr = part_g.indices, part_g.csc_indptr - csc_matrix = dglsp.from_csc(indptr, indices) - part_src_ids, part_dst_ids = csc_matrix.coo() - # These are reshuffled global homogeneous IDs. - part_src_ids = F.gather_row( - part_g.node_attributes[dgl.NID], part_src_ids - ) - part_dst_ids = F.gather_row( - part_g.node_attributes[dgl.NID], part_dst_ids - ) - part_eids = part_g.edge_attributes[dgl.EID] - # These are reshuffled per-type IDs. - src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids) - dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids) - etype_ids, part_eids = gpb.map_to_per_etype(part_eids) - # `IdMap` is in int64 by default. - assert src_ntype_ids.dtype == F.int64 - assert dst_ntype_ids.dtype == F.int64 - assert etype_ids.dtype == F.int64 - with pytest.raises(dgl.utils.internal.InconsistentDtypeException): - gpb.map_to_per_ntype(F.tensor([0], F.int32)) - with pytest.raises(dgl.utils.internal.InconsistentDtypeException): - gpb.map_to_per_etype(F.tensor([0], F.int32)) - # These are original per-type IDs. - for etype_id, etype in enumerate(hg.canonical_etypes): - part_src_ids1 = F.boolean_mask( - part_src_ids, etype_ids == etype_id - ) - src_ntype_ids1 = F.boolean_mask( - src_ntype_ids, etype_ids == etype_id - ) - part_dst_ids1 = F.boolean_mask( - part_dst_ids, etype_ids == etype_id - ) - dst_ntype_ids1 = F.boolean_mask( - dst_ntype_ids, etype_ids == etype_id - ) - part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id) - assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0])) - assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0])) - src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])] - dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])] - orig_src_ids1 = F.gather_row( - orig_nids[src_ntype], part_src_ids1 - ) - orig_dst_ids1 = F.gather_row( - orig_nids[dst_ntype], part_dst_ids1 - ) - orig_eids1 = F.gather_row(orig_eids[etype], part_eids1) - orig_eids2 = hg.edge_ids( - orig_src_ids1, orig_dst_ids1, etype=etype - ) - assert len(orig_eids1) == len(orig_eids2) - assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) - parts.append(part_g) - if NTYPE in part_g.node_attributes: - verify_graph_feats( - hg, - gpb, - part_g, - node_feats, - edge_feats, - orig_nids, - orig_eids, - use_graphbolt=True, - ) - - shuffled_labels.append(node_feats[test_ntype + "/labels"]) - shuffled_elabels.append( - edge_feats[_etype_tuple_to_str(test_etype) + "/labels"] - ) - verify_hetero_graph(hg, parts, True) + _verify_original_IDs_type_hetero(hg, orig_nids, orig_eids) + if debug_mode: + store_eids = store_inner_node = store_inner_edge = True - shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) - shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0)) - orig_labels = np.zeros( - shuffled_labels.shape, dtype=shuffled_labels.dtype - ) - orig_elabels = np.zeros( - shuffled_elabels.shape, dtype=shuffled_elabels.dtype - ) - orig_labels[F.asnumpy(orig_nids[test_ntype])] = shuffled_labels - orig_elabels[F.asnumpy(orig_eids[test_etype])] = shuffled_elabels - assert np.all( - orig_labels == F.asnumpy(hg.nodes[test_ntype].data["labels"]) + parts = _verify_graphbolt_part( + hg, + test_dir, + orig_nids, + orig_eids, + graph_name, + num_parts, + store_inner_node, + store_inner_edge, + store_eids, + test_ntype, + test_etype, + is_homo=False, ) - assert np.all( - orig_elabels == F.asnumpy(hg.edges[test_etype].data["labels"]) + + _verify_hetero_graph( + hg, + parts, + store_eids=store_eids, + store_inner_edge=store_inner_edge, + debug_mode=debug_mode, ) @@ -1793,6 +2131,9 @@ def test_partition_graph_graphbolt_hetero_multi( part_method="random", num_parts=num_parts, n_jobs=4, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, debug_mode=False, ) From d0de08961eda142d04629fad663e01703b8acc17 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 5 Sep 2024 12:28:31 +0000 Subject: [PATCH 11/37] change data_shuffle.py --- tools/distpartitioning/data_shuffle.py | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index f8837abd398b..a7abcc75f648 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -285,21 +285,21 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup): local_etype_ids.append(rcvd_edge_data[:, 3]) local_eids.append(rcvd_edge_data[:, 4]) - edge_data[constants.GLOBAL_SRC_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_src_ids) - ) - edge_data[constants.GLOBAL_DST_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_dst_ids) - ) - edge_data[constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)] = ( - np.concatenate(local_type_eids) - ) - edge_data[constants.ETYPE_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_etype_ids) - ) - edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)] = ( - np.concatenate(local_eids) - ) + edge_data[ + constants.GLOBAL_SRC_ID + "/" + str(local_part_id) + ] = np.concatenate(local_src_ids) + edge_data[ + constants.GLOBAL_DST_ID + "/" + str(local_part_id) + ] = np.concatenate(local_dst_ids) + edge_data[ + constants.GLOBAL_TYPE_EID + "/" + str(local_part_id) + ] = np.concatenate(local_type_eids) + edge_data[ + constants.ETYPE_ID + "/" + str(local_part_id) + ] = np.concatenate(local_etype_ids) + edge_data[ + constants.GLOBAL_EID + "/" + str(local_part_id) + ] = np.concatenate(local_eids) # Check if the data was exchanged correctly local_edge_count = 0 From 6729bed667b88ff03dfd40f38018d5d74d291937 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 6 Sep 2024 08:20:54 +0000 Subject: [PATCH 12/37] change convert_partition.py --- tools/distpartitioning/convert_partition.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 976d09e3a090..3151141ecd7b 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -1,26 +1,19 @@ -import argparse import gc -import json import logging import os -import time import constants import dgl import dgl.graphbolt as gb import numpy as np -import pandas as pd -import pyarrow -import scipy.sparse as spsp import torch as th from dgl.distributed.partition import ( _etype_str_to_tuple, _etype_tuple_to_str, RESERVED_FIELD_DTYPE, ) -from pyarrow import csv -from utils import get_idranges, memory_snapshot, read_json +from utils import get_idranges, memory_snapshot def _get_unique_invidx(srcids, dstids, nids, low_mem=True): @@ -172,7 +165,7 @@ def _is_homogeneous(ntypes, etypes): return len(ntypes) == 1 and len(etypes) == 1 -def _create_csc_data(part_local_src_id, part_local_dst_id): +def _coo2csc(part_local_src_id, part_local_dst_id): part_local_src_id, part_local_dst_id = th.tensor( part_local_src_id, dtype=th.int64 ), th.tensor(part_local_dst_id, dtype=th.int64) @@ -398,7 +391,7 @@ def _partition_graphbolt( } ) - indptr, indices = _create_csc_data(part_local_src_id, part_local_dst_id) + indptr, indices = _coo2csc(part_local_src_id, part_local_dst_id) part_graph = gb.fused_csc_sampling_graph( csc_indptr=indptr, indices=indices, From 0cd0d9738651bb99a4252104062acc90b8dcbe2f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 6 Sep 2024 09:20:25 +0000 Subject: [PATCH 13/37] change convert_partition.py --- tools/distpartitioning/convert_partition.py | 256 +++++++------------- 1 file changed, 81 insertions(+), 175 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 3151141ecd7b..a2a041122583 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -241,98 +241,7 @@ def _graph_orig_ids( return orig_nids, orig_eids -def _partition_DGLGraph( - part_local_src_id, - part_local_dst_id, - global_src_id, - global_dst_id, - global_homo_nid, - idx, - reshuffle_nodes, - id_map, - edgeid_offset, - etype_ids, - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - global_edge_id, - uniq_ids, - inner_nodes, -): - num_edges = len(part_local_dst_id) - part_graph = dgl.graph( - data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) - ) - # create edge data in graph. - ( - part_graph.edata[dgl.EID], - part_graph.edata[dgl.ETYPE], - part_graph.edata["inner_edge"], - ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) - - # compute per_type_ids and ntype for all the nodes in the graph. - ntype, per_type_ids = _compute_node_ntype( - global_src_id, - global_dst_id, - global_homo_nid, - idx, - reshuffle_nodes, - id_map, - ) - - # create node data in graph. - ( - part_graph.ndata[dgl.NTYPE], - part_graph.ndata[dgl.NID], - part_graph.ndata["inner_node"], - ) = _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes) - - # get the original node ids and edge ids from original graph. - orig_nids, orig_eids = _graph_orig_ids( - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - part_graph.ndata, - part_graph.edata, - per_type_ids, - part_graph.edata[dgl.ETYPE], - global_edge_id, - ) - return ( - part_graph, - ntypes_map, - etypes_map, - orig_nids, - orig_eids, - ) - - -def _partition_graphbolt( - part_local_src_id, - part_local_dst_id, - global_src_id, - global_dst_id, - global_homo_nid, - idx, - reshuffle_nodes, - id_map, - edgeid_offset, - etype_ids, - ntypes, - etypes, - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - global_edge_id, - uniq_ids, - inner_nodes, - store_eids=True, - store_inner_node=True, - store_inner_edge=True, -): +def _create_edge_attr_gb(part_local_dst_id,edgeid_offset,etype_ids,ntypes,etypes,etypes_map): edge_attr = {} # create edge data in graph. num_edges = len(part_local_dst_id) @@ -342,6 +251,28 @@ def _partition_graphbolt( edge_attr["inner_edge"], ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) + is_homo = _is_homogeneous(ntypes, etypes) + + edge_type_to_id = ( + None + if is_homo + else { + gb.etype_tuple_to_str(etype): etid + for etype, etid in etypes_map.items() + } + ) + return edge_attr,type_per_edge,edge_type_to_id + + +def _create_node_attr( + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes): # compute per_type_ids and ntype for all the nodes in the graph. ntype, per_type_ids = _compute_node_ntype( global_src_id, @@ -359,20 +290,10 @@ def _partition_graphbolt( node_attr[dgl.NID], node_attr["inner_node"], ) = _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes) + return node_attr, per_type_ids - is_homo = _is_homogeneous(ntypes, etypes) - # get the original node ids and edge ids from original graph. - orig_nids, orig_eids = _graph_orig_ids( - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - node_attr, - edge_attr, - per_type_ids, - type_per_edge, - global_edge_id, - ) + +def remove_attr_gb(edge_attr,node_attr,store_inner_node,store_inner_edge,store_eids): if not store_inner_edge: edge_attr.pop("inner_edge") @@ -381,34 +302,7 @@ def _partition_graphbolt( if not store_inner_node: node_attr.pop("inner_node") - - edge_type_to_id = ( - None - if is_homo - else { - gb.etype_tuple_to_str(etype): etid - for etype, etid in etypes_map.items() - } - ) - - indptr, indices = _coo2csc(part_local_src_id, part_local_dst_id) - part_graph = gb.fused_csc_sampling_graph( - csc_indptr=indptr, - indices=indices, - node_type_offset=None, - type_per_edge=type_per_edge, - node_attributes=node_attr, - edge_attributes=edge_attr, - node_type_to_id=ntypes_map, - edge_type_to_id=edge_type_to_id, - ) - return ( - part_graph, - ntypes_map, - etypes_map, - orig_nids, - orig_eids, - ) + return edge_attr,node_attr def create_graph_object( @@ -700,33 +594,38 @@ def create_graph_object( # create the graph here now. if use_graphbolt: - ( - part_graph, - ntypes_map, - etypes_map, - orig_nids, - orig_eids, - ) = _partition_graphbolt( - part_local_src_id, - part_local_dst_id, - global_src_id, - global_dst_id, - global_homo_nid, - idx, - reshuffle_nodes, - id_map, - edgeid_offset, - etype_ids, - ntypes, - etypes, - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - global_edge_id, - uniq_ids, - inner_nodes, - **kwargs, + edge_attr,type_per_edge,edge_type_to_id = _create_edge_attr_gb(part_local_dst_id,edgeid_offset,etype_ids,ntypes,etypes,etypes_map) + node_attr, per_type_ids = _create_node_attr( + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes) + orig_nids, orig_eids = _graph_orig_ids( + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + node_attr, + edge_attr, + per_type_ids, + type_per_edge, + global_edge_id, + ) + remove_attr_gb(edge_attr,node_attr,**kwargs) + indptr, indices = _coo2csc(part_local_src_id, part_local_dst_id) + part_graph = gb.fused_csc_sampling_graph( + csc_indptr=indptr, + indices=indices, + node_type_offset=None, + type_per_edge=type_per_edge, + node_attributes=node_attr, + edge_attributes=edge_attr, + node_type_to_id=ntypes_map, + edge_type_to_id=edge_type_to_id, ) return ( part_graph, @@ -737,34 +636,41 @@ def create_graph_object( orig_nids, orig_eids, ) - else: + num_edges = len(part_local_dst_id) + part_graph = dgl.graph( + data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) + ) + # create edge data in graph. ( - part_graph, - ntypes_map, - etypes_map, - orig_nids, - orig_eids, - ) = _partition_DGLGraph( - part_local_src_id, - part_local_dst_id, + part_graph.edata[dgl.EID], + part_graph.edata[dgl.ETYPE], + part_graph.edata["inner_edge"], + ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) + + part_graph.ndata, per_type_ids = _create_node_attr( + idx, global_src_id, global_dst_id, global_homo_nid, - idx, + uniq_ids, reshuffle_nodes, id_map, - edgeid_offset, - etype_ids, + inner_nodes + ) + + # get the original node ids and edge ids from original graph. + orig_nids, orig_eids = _graph_orig_ids( return_orig_nids, return_orig_eids, ntypes_map, etypes_map, + part_graph.ndata, + part_graph.edata, + per_type_ids, + part_graph.edata[dgl.ETYPE], global_edge_id, - uniq_ids, - inner_nodes, ) - return ( part_graph, node_map_val, From d03a323edb2caee8d495b0bcfdfaf1d68d9cfe81 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 6 Sep 2024 14:02:53 +0000 Subject: [PATCH 14/37] change convert_partition.py --- tools/distpartitioning/convert_partition.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index a2a041122583..8a675aa32230 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -169,11 +169,13 @@ def _coo2csc(part_local_src_id, part_local_dst_id): part_local_src_id, part_local_dst_id = th.tensor( part_local_src_id, dtype=th.int64 ), th.tensor(part_local_dst_id, dtype=th.int64) - indptr = th.zeros(len(part_local_dst_id) + 1, dtype=th.int64) - col_counts = th.bincount(part_local_src_id, minlength=part_local_dst_id) + num_nodes = th.max(th.cat([part_local_src_id,part_local_dst_id],dim=0)) + indptr = th.zeros(num_nodes + 2, dtype=th.int64) + col_counts = th.bincount(part_local_dst_id, minlength=num_nodes+1) indptr[1:] = th.cumsum(col_counts, 0) - indices = part_local_dst_id - return indptr, indices + edge_id = th.argsort(part_local_dst_id) + indices = part_local_src_id[edge_id] + return indptr, indices, edge_id def _create_edge_data(edgeid_offset, etype_ids, num_edges): @@ -616,16 +618,16 @@ def create_graph_object( global_edge_id, ) remove_attr_gb(edge_attr,node_attr,**kwargs) - indptr, indices = _coo2csc(part_local_src_id, part_local_dst_id) + indptr, indices, csc_edge_ids = _coo2csc(part_local_src_id, part_local_dst_id) part_graph = gb.fused_csc_sampling_graph( csc_indptr=indptr, indices=indices, node_type_offset=None, - type_per_edge=type_per_edge, + type_per_edge=type_per_edge[csc_edge_ids], node_attributes=node_attr, - edge_attributes=edge_attr, + edge_attributes=edge_attr[csc_edge_ids], node_type_to_id=ntypes_map, - edge_type_to_id=edge_type_to_id, + edge_type_to_id=edge_type_to_id[csc_edge_ids], ) return ( part_graph, From fdbea5ea711bb839ec1e282f3c11adf35491c21b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 6 Sep 2024 14:17:51 +0000 Subject: [PATCH 15/37] change code format --- tools/distpartitioning/convert_partition.py | 87 ++++++++++++--------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 8a675aa32230..390fb0619d17 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -169,9 +169,9 @@ def _coo2csc(part_local_src_id, part_local_dst_id): part_local_src_id, part_local_dst_id = th.tensor( part_local_src_id, dtype=th.int64 ), th.tensor(part_local_dst_id, dtype=th.int64) - num_nodes = th.max(th.cat([part_local_src_id,part_local_dst_id],dim=0)) + num_nodes = th.max(th.cat([part_local_src_id, part_local_dst_id], dim=0)) indptr = th.zeros(num_nodes + 2, dtype=th.int64) - col_counts = th.bincount(part_local_dst_id, minlength=num_nodes+1) + col_counts = th.bincount(part_local_dst_id, minlength=num_nodes + 1) indptr[1:] = th.cumsum(col_counts, 0) edge_id = th.argsort(part_local_dst_id) indices = part_local_src_id[edge_id] @@ -243,7 +243,9 @@ def _graph_orig_ids( return orig_nids, orig_eids -def _create_edge_attr_gb(part_local_dst_id,edgeid_offset,etype_ids,ntypes,etypes,etypes_map): +def _create_edge_attr_gb( + part_local_dst_id, edgeid_offset, etype_ids, ntypes, etypes, etypes_map +): edge_attr = {} # create edge data in graph. num_edges = len(part_local_dst_id) @@ -263,18 +265,19 @@ def _create_edge_attr_gb(part_local_dst_id,edgeid_offset,etype_ids,ntypes,etypes for etype, etid in etypes_map.items() } ) - return edge_attr,type_per_edge,edge_type_to_id + return edge_attr, type_per_edge, edge_type_to_id def _create_node_attr( - idx, - global_src_id, - global_dst_id, - global_homo_nid, - uniq_ids, - reshuffle_nodes, - id_map, - inner_nodes): + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes, +): # compute per_type_ids and ntype for all the nodes in the graph. ntype, per_type_ids = _compute_node_ntype( global_src_id, @@ -295,7 +298,9 @@ def _create_node_attr( return node_attr, per_type_ids -def remove_attr_gb(edge_attr,node_attr,store_inner_node,store_inner_edge,store_eids): +def remove_attr_gb( + edge_attr, node_attr, store_inner_node, store_inner_edge, store_eids +): if not store_inner_edge: edge_attr.pop("inner_edge") @@ -304,7 +309,7 @@ def remove_attr_gb(edge_attr,node_attr,store_inner_node,store_inner_edge,store_e if not store_inner_node: node_attr.pop("inner_node") - return edge_attr,node_attr + return edge_attr, node_attr def create_graph_object( @@ -596,29 +601,39 @@ def create_graph_object( # create the graph here now. if use_graphbolt: - edge_attr,type_per_edge,edge_type_to_id = _create_edge_attr_gb(part_local_dst_id,edgeid_offset,etype_ids,ntypes,etypes,etypes_map) + edge_attr, type_per_edge, edge_type_to_id = _create_edge_attr_gb( + part_local_dst_id, + edgeid_offset, + etype_ids, + ntypes, + etypes, + etypes_map, + ) node_attr, per_type_ids = _create_node_attr( - idx, - global_src_id, - global_dst_id, - global_homo_nid, - uniq_ids, - reshuffle_nodes, - id_map, - inner_nodes) + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes, + ) orig_nids, orig_eids = _graph_orig_ids( - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - node_attr, - edge_attr, - per_type_ids, - type_per_edge, - global_edge_id, - ) - remove_attr_gb(edge_attr,node_attr,**kwargs) - indptr, indices, csc_edge_ids = _coo2csc(part_local_src_id, part_local_dst_id) + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + node_attr, + edge_attr, + per_type_ids, + type_per_edge, + global_edge_id, + ) + remove_attr_gb(edge_attr, node_attr, **kwargs) + indptr, indices, csc_edge_ids = _coo2csc( + part_local_src_id, part_local_dst_id + ) part_graph = gb.fused_csc_sampling_graph( csc_indptr=indptr, indices=indices, @@ -658,7 +673,7 @@ def create_graph_object( uniq_ids, reshuffle_nodes, id_map, - inner_nodes + inner_nodes, ) # get the original node ids and edge ids from original graph. From a693a396b729cb614cafde30c4ba6de13e46a1eb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 8 Sep 2024 04:00:32 +0000 Subject: [PATCH 16/37] test dist partition --- tools/distpartitioning/data_proc_pipeline.py | 19 +++++++++++++++++++ tools/distpartitioning/utils.py | 2 ++ 2 files changed, 21 insertions(+) diff --git a/tools/distpartitioning/data_proc_pipeline.py b/tools/distpartitioning/data_proc_pipeline.py index b26760eec5fb..8f34c56f851c 100644 --- a/tools/distpartitioning/data_proc_pipeline.py +++ b/tools/distpartitioning/data_proc_pipeline.py @@ -99,6 +99,25 @@ def log_params(params): action="store_true", help="Use GraphBolt for distributed partition.", ) + parser.add_argument( + "--store-inner-node", + action="store_true", + default=False, + help="Store inner nodes.", + ) + + parser.add_argument( + "--store-inner-edge", + action="store_true", + default=False, + help="Store inner edges.", + ) + parser.add_argument( + "--store-eids", + action="store_true", + default=False, + help="Store edge IDs.", + ) parser.add_argument( "--graph-formats", default=None, diff --git a/tools/distpartitioning/utils.py b/tools/distpartitioning/utils.py index fbf4ae8c0fed..72908fb71da2 100644 --- a/tools/distpartitioning/utils.py +++ b/tools/distpartitioning/utils.py @@ -3,6 +3,8 @@ import os from itertools import cycle +import sys +sys.path.append('/home/ubuntu/workspace/dgl/tools/distpartitioning') import constants import dgl From 8def0950ceb28c745c2c005cc43a351d00845773 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 8 Sep 2024 04:06:19 +0000 Subject: [PATCH 17/37] convert_partition --- tools/distpartitioning/convert_partition.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 390fb0619d17..55d119725409 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -254,6 +254,7 @@ def _create_edge_attr_gb( type_per_edge, edge_attr["inner_edge"], ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) + assert 'inner_edge' in edge_attr is_homo = _is_homogeneous(ntypes, etypes) @@ -302,12 +303,16 @@ def remove_attr_gb( edge_attr, node_attr, store_inner_node, store_inner_edge, store_eids ): if not store_inner_edge: + assert 'inner_edge' in edge_attr edge_attr.pop("inner_edge") + assert 'inner_edge' in edge_attr if not store_eids: + assert dgl.EID in edge_attr edge_attr.pop(dgl.EID) if not store_inner_node: + assert 'inner_node' in node_attr node_attr.pop("inner_node") return edge_attr, node_attr @@ -634,15 +639,18 @@ def create_graph_object( indptr, indices, csc_edge_ids = _coo2csc( part_local_src_id, part_local_dst_id ) + edge_attr = { + attr: edge_attr[attr][csc_edge_ids] for attr in edge_attr.keys() + } part_graph = gb.fused_csc_sampling_graph( csc_indptr=indptr, indices=indices, node_type_offset=None, type_per_edge=type_per_edge[csc_edge_ids], node_attributes=node_attr, - edge_attributes=edge_attr[csc_edge_ids], + edge_attributes=edge_attr, node_type_to_id=ntypes_map, - edge_type_to_id=edge_type_to_id[csc_edge_ids], + edge_type_to_id=edge_type_to_id, ) return ( part_graph, @@ -665,7 +673,7 @@ def create_graph_object( part_graph.edata["inner_edge"], ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) - part_graph.ndata, per_type_ids = _create_node_attr( + ndata, per_type_ids = _create_node_attr( idx, global_src_id, global_dst_id, @@ -675,6 +683,8 @@ def create_graph_object( id_map, inner_nodes, ) + for (attr_name,node_attributes) in ndata.items(): + part_graph.ndata[attr_name]=node_attributes # get the original node ids and edge ids from original graph. orig_nids, orig_eids = _graph_orig_ids( From 296882f30d70a56f88eeb1cb0410b1ed9f1fbb33 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 8 Sep 2024 04:07:41 +0000 Subject: [PATCH 18/37] change format --- tools/distpartitioning/convert_partition.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 55d119725409..15d98f1d0ee6 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -254,7 +254,7 @@ def _create_edge_attr_gb( type_per_edge, edge_attr["inner_edge"], ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) - assert 'inner_edge' in edge_attr + assert "inner_edge" in edge_attr is_homo = _is_homogeneous(ntypes, etypes) @@ -303,16 +303,16 @@ def remove_attr_gb( edge_attr, node_attr, store_inner_node, store_inner_edge, store_eids ): if not store_inner_edge: - assert 'inner_edge' in edge_attr + assert "inner_edge" in edge_attr edge_attr.pop("inner_edge") - assert 'inner_edge' in edge_attr + assert "inner_edge" in edge_attr if not store_eids: assert dgl.EID in edge_attr edge_attr.pop(dgl.EID) if not store_inner_node: - assert 'inner_node' in node_attr + assert "inner_node" in node_attr node_attr.pop("inner_node") return edge_attr, node_attr @@ -683,8 +683,8 @@ def create_graph_object( id_map, inner_nodes, ) - for (attr_name,node_attributes) in ndata.items(): - part_graph.ndata[attr_name]=node_attributes + for attr_name, node_attributes in ndata.items(): + part_graph.ndata[attr_name] = node_attributes # get the original node ids and edge ids from original graph. orig_nids, orig_eids = _graph_orig_ids( From 71140a711475854e699cd1e412598a11159ac957 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 8 Sep 2024 04:09:56 +0000 Subject: [PATCH 19/37] change utils --- tools/distpartitioning/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/distpartitioning/utils.py b/tools/distpartitioning/utils.py index 72908fb71da2..fbf4ae8c0fed 100644 --- a/tools/distpartitioning/utils.py +++ b/tools/distpartitioning/utils.py @@ -3,8 +3,6 @@ import os from itertools import cycle -import sys -sys.path.append('/home/ubuntu/workspace/dgl/tools/distpartitioning') import constants import dgl From 3e811abe5fe7e4597c6e9f6bbefc8659c6f024cb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 00:39:52 +0000 Subject: [PATCH 20/37] change dispatch_data.py --- tests/tools/test_dist_partition_graphbolt.py | 844 +++++++++++++++++++ tools/dispatch_data.py | 30 +- tools/distpartitioning/data_proc_pipeline.py | 2 +- tools/distpartitioning/data_shuffle.py | 9 +- 4 files changed, 880 insertions(+), 5 deletions(-) create mode 100644 tests/tools/test_dist_partition_graphbolt.py diff --git a/tests/tools/test_dist_partition_graphbolt.py b/tests/tools/test_dist_partition_graphbolt.py new file mode 100644 index 000000000000..ea596c67d6aa --- /dev/null +++ b/tests/tools/test_dist_partition_graphbolt.py @@ -0,0 +1,844 @@ +import json +import os +import tempfile + +import dgl +import dgl.backend as F +import dgl.graphbolt as gb + +import numpy as np +import pyarrow.parquet as pq +import pytest +import torch +from dgl.data.utils import load_graphs, load_tensors +from dgl.distributed.partition import ( + _etype_str_to_tuple, + _etype_tuple_to_str, + _get_inner_edge_mask, + _get_inner_node_mask, + load_partition, + RESERVED_FIELD_DTYPE, +) + +from distpartitioning import array_readwriter +from distpartitioning.utils import generate_read_list +from pytest_utils import create_chunked_dataset + +from tools.verification_utils import ( + verify_graph_feats, + verify_partition_data_types, + verify_partition_formats, +) + + +def _verify_metadata_gb(gpb, g, num_parts, part_id, part_sizes): + """ + check list: + make sure the number of nodes and edges is correct. + make sure the number of parts is correct. + make sure the number of nodes and edges in each parts os corrcet. + """ + assert gpb._num_nodes() == g.num_nodes() + assert gpb._num_edges() == g.num_edges() + + assert gpb.num_partitions() == num_parts + gpb_meta = gpb.metadata() + assert len(gpb_meta) == num_parts + assert len(gpb.partid2nids(part_id)) == gpb_meta[part_id]["num_nodes"] + assert len(gpb.partid2eids(part_id)) == gpb_meta[part_id]["num_edges"] + part_sizes.append( + (gpb_meta[part_id]["num_nodes"], gpb_meta[part_id]["num_edges"]) + ) + + +def _verify_local_id_gb(part_g, part_id, gpb): + """ + check list: + make sure the type of local id is correct. + make sure local id have a right order. + """ + nid = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + local_nid = gpb.nid2localnid(nid, part_id) + assert F.dtype(local_nid) in (F.int64, F.int32) + assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid))) + eid = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + local_eid = gpb.eid2localeid(eid, part_id) + assert F.dtype(local_eid) in (F.int64, F.int32) + assert np.all(np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid))) + return local_nid, local_eid + + +def _verify_map_gb( + part_g, + part_id, + gpb, +): + """ + check list: + make sure the map node and its data type is correct. + """ + # Check the node map. + local_nodes = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + inner_node_index = F.nonzero_1d(part_g.node_attributes["inner_node"]) + mapping_nodes = gpb.partid2nids(part_id) + assert F.dtype(mapping_nodes) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(mapping_nodes)) + ) + assert np.all( + F.asnumpy(inner_node_index) == np.arange(len(inner_node_index)) + ) + + # Check the edge map. + + local_edges = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + inner_edge_index = F.nonzero_1d(part_g.edge_attributes["inner_edge"]) + mapping_edges = gpb.partid2eids(part_id) + assert F.dtype(mapping_edges) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(mapping_edges)) + ) + assert np.all( + F.asnumpy(inner_edge_index) == np.arange(len(inner_edge_index)) + ) + return local_nodes, local_edges + + +def _verify_local_and_map_id_gb( + part_g, + part_id, + gpb, + store_inner_node, + store_inner_edge, + store_eids, +): + """ + check list: + make sure local id are correct. + make sure mapping id are correct. + """ + if store_inner_node and store_inner_edge and store_eids: + _verify_local_id_gb(part_g, part_id, gpb) + _verify_map_gb(part_g, part_id, gpb) + + +def _get_part_IDs(part_g): + # These are partition-local IDs. + num_columns = part_g.csc_indptr.diff() + part_src_ids = part_g.indices + part_dst_ids = torch.arange(part_g.total_num_nodes).repeat_interleave( + num_columns + ) + # These are reshuffled global homogeneous IDs. + part_src_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_src_ids) + part_dst_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_dst_ids) + return part_src_ids, part_dst_ids + + +def _verify_node_type_ID_gb(part_g, gpb): + """ + check list: + make sure ntype id have correct data type + """ + part_src_ids, part_dst_ids = _get_part_IDs(part_g) + # These are reshuffled per-type IDs. + src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids) + dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids) + # `IdMap` is in int64 by default. + assert src_ntype_ids.dtype == F.int64 + assert dst_ntype_ids.dtype == F.int64 + + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_ntype(F.tensor([0], F.int32)) + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_etype(F.tensor([0], F.int32)) + return ( + part_src_ids, + part_dst_ids, + src_ntype_ids, + part_src_ids, + dst_ntype_ids, + ) + + +def _verify_orig_edge_IDs_gb( + g, + orig_nids, + orig_eids, + part_eids, + part_src_ids, + part_dst_ids, + src_ntype=None, + dst_ntype=None, + etype=None, +): + """ + check list: + make sure orig edge id are correct after + """ + if src_ntype is not None and dst_ntype is not None: + orig_src_nid = orig_nids[src_ntype] + orig_dst_nid = orig_nids[dst_ntype] + else: + orig_src_nid = orig_nids + orig_dst_nid = orig_nids + orig_src_ids = F.gather_row(orig_src_nid, part_src_ids) + orig_dst_ids = F.gather_row(orig_dst_nid, part_dst_ids) + if etype is not None: + orig_eids = orig_eids[etype] + orig_eids1 = F.gather_row(orig_eids, part_eids) + orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids, etype=etype) + assert len(orig_eids1) == len(orig_eids2) + assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) + + +def _verify_orig_IDs_gb( + part_g, + gpb, + g, + is_homo=False, + part_src_ids=None, + part_dst_ids=None, + src_ntype_ids=None, + dst_ntype_ids=None, + orig_nids=None, + orig_eids=None, +): + """ + check list: + make sure orig edge id are correct. + make sure hetero ntype id are correct. + """ + part_eids = part_g.edge_attributes[dgl.EID] + if is_homo: + _verify_orig_edge_IDs_gb( + g, orig_nids, orig_eids, part_eids, part_src_ids, part_dst_ids + ) + local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]] + local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]] + part_g.node_attributes["feats"] = F.gather_row( + g.ndata["feats"], local_orig_nids + ) + part_g.edge_attributes["feats"] = F.gather_row( + g.edata["feats"], local_orig_eids + ) + else: + etype_ids, part_eids = gpb.map_to_per_etype(part_eids) + # `IdMap` is in int64 by default. + assert etype_ids.dtype == F.int64 + + # These are original per-type IDs. + for etype_id, etype in enumerate(g.canonical_etypes): + part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id) + src_ntype_ids1 = F.boolean_mask( + src_ntype_ids, etype_ids == etype_id + ) + part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id) + dst_ntype_ids1 = F.boolean_mask( + dst_ntype_ids, etype_ids == etype_id + ) + part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id) + assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0])) + assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0])) + src_ntype = g.ntypes[F.as_scalar(src_ntype_ids1[0])] + dst_ntype = g.ntypes[F.as_scalar(dst_ntype_ids1[0])] + + _verify_orig_edge_IDs_gb( + g, + orig_nids, + orig_eids, + part_eids1, + part_src_ids1, + part_dst_ids1, + src_ntype, + dst_ntype, + etype, + ) + + +def _verify_constructed_id_gb(part_sizes, gpb): + """ + verify the part id of each node by constructed nids. + check list: + make sure each node' part id and its type are corect + """ + node_map = [] + edge_map = [] + for part_i, (num_nodes, num_edges) in enumerate(part_sizes): + node_map.append(np.ones(num_nodes) * part_i) + edge_map.append(np.ones(num_edges) * part_i) + node_map = np.concatenate(node_map) + edge_map = np.concatenate(edge_map) + nid2pid = gpb.nid2partid(F.arange(0, len(node_map))) + assert F.dtype(nid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(nid2pid) == node_map) + eid2pid = gpb.eid2partid(F.arange(0, len(edge_map))) + assert F.dtype(eid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(eid2pid) == edge_map) + + +def _verify_IDs_gb( + g, + part_g, + part_id, + gpb, + part_sizes, + orig_nids, + orig_eids, + store_inner_node, + store_inner_edge, + store_eids, + is_homo, +): + # verify local id and mapping id + _verify_local_and_map_id_gb( + part_g, + part_id, + gpb, + store_inner_node, + store_inner_edge, + store_eids, + ) + + # Verify the mapping between the reshuffled IDs and the original IDs. + ( + part_src_ids, + part_dst_ids, + src_ntype_ids, + part_src_ids, + dst_ntype_ids, + ) = _verify_node_type_ID_gb(part_g, gpb) + + if store_eids: + _verify_orig_IDs_gb( + part_g, + gpb, + g, + part_src_ids=part_src_ids, + part_dst_ids=part_dst_ids, + src_ntype_ids=src_ntype_ids, + dst_ntype_ids=dst_ntype_ids, + orig_nids=orig_nids, + orig_eids=orig_eids, + is_homo=is_homo, + ) + _verify_constructed_id_gb(part_sizes, gpb) + + +def _collect_data_gb( + parts, + part_g, + gpbs, + gpb, + tot_node_feats, + node_feats, + tot_edge_feats, + edge_feats, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, +): + if test_ntype != None: + shuffled_labels.append(node_feats[test_ntype + "/label"]) + shuffled_edata.append( + edge_feats[_etype_tuple_to_str(test_etype) + "/count"] + ) + else: + shuffled_labels.append(node_feats["_N/labels"]) + shuffled_edata.append(edge_feats["_N:_E:_N/feats"]) + parts.append(part_g) + gpbs.append(gpb) + tot_node_feats.append(node_feats) + tot_edge_feats.append(edge_feats) + + +def _verify_node_feats(g, part, gpb, orig_nids, node_feats, is_homo=False): + for ntype in g.ntypes: + ndata = ( + part.node_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.ndata + ) + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask( + part, + ntype_id, + (gpb if isinstance(part, gb.FusedCSCSamplingGraph) else None), + ) + inner_nids = F.boolean_mask(ndata[dgl.NID], inner_node_mask) + ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) + partid = gpb.nid2partid(inner_type_nids, ntype) + if is_homo: + assert np.all(F.asnumpy(ntype_ids) == ntype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + if is_homo: + orig_id = orig_nids[inner_type_nids] + else: + orig_id = orig_nids[ntype][inner_type_nids] + local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + + for name in g.nodes[ntype].data: + if name in [dgl.NID, "inner_node"]: + continue + true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) + ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) + assert np.all(F.asnumpy(ndata == true_feats)) + + +def _verify_edge_feats(g, part, gpb, orig_eids, edge_feats, is_homo=False): + for etype in g.canonical_etypes: + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = F.boolean_mask(edata[dgl.EID], inner_edge_mask) + etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) + partid = gpb.eid2partid(inner_type_eids, etype) + assert np.all(F.asnumpy(etype_ids) == etype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + if is_homo: + orig_id = orig_eids[inner_type_eids] + else: + orig_id = orig_eids[etype][inner_type_eids] + local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) + + for name in g.edges[etype].data: + if name in [dgl.EID, "inner_edge"]: + continue + true_feats = F.gather_row(g.edges[etype].data[name], orig_id) + edata = F.gather_row( + edge_feats[_etype_tuple_to_str(etype) + "/" + name], + local_eids, + ) + assert np.all(F.asnumpy(edata == true_feats)) + + +def _verify_shuffled_labels_gb( + g, + shuffled_labels, + shuffled_edata, + orig_nids, + orig_eids, + test_ntype=None, + test_etype=None, +): + """ + check list: + make sure node data are correct. + make sure edge data are correct. + """ + shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) + shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0)) + orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype) + orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype) + + orig_nid = orig_nids if test_ntype is None else orig_nids[test_ntype] + orig_eid = orig_eids if test_etype is None else orig_eids[test_etype] + nlabel = ( + g.ndata["labels"] + if test_ntype is None + else g.nodes[test_ntype].data["label"] + ) + edata = ( + g.edata["feats"] + if test_etype is None + else g.edges[test_etype].data["count"] + ) + + orig_labels[F.asnumpy(orig_nid)] = shuffled_labels + orig_edata[F.asnumpy(orig_eid)] = shuffled_edata + assert np.all(orig_labels == F.asnumpy(nlabel)) + assert np.all(orig_edata == F.asnumpy(edata)) + + +def verify_graph_feats_gb( + g, + gpbs, + parts, + tot_node_feats, + tot_edge_feats, + orig_nids, + orig_eids, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, + store_inner_node=False, + store_inner_edge=False, + store_eids=False, + is_homo=False, +): + """ + check list: + make sure the feats of nodes and edges are correct + """ + for part_id in range(len(parts)): + part = parts[part_id] + gpb = gpbs[part_id] + node_feats = tot_node_feats[part_id] + edge_feats = tot_edge_feats[part_id] + if store_inner_node: + _verify_node_feats( + g, + part, + gpb, + orig_nids, + node_feats, + is_homo=is_homo, + ) + if store_inner_edge and store_eids: + _verify_edge_feats( + g, + part, + gpb, + orig_eids, + edge_feats, + is_homo=is_homo, + ) + + _verify_shuffled_labels_gb( + g, + shuffled_labels, + shuffled_edata, + orig_nids, + orig_eids, + test_ntype, + test_etype, + ) + + +def _verify_graphbolt_attributes( + parts, store_inner_node, store_inner_edge, store_eids +): + """ + check list: + make sure arguments work. + """ + for part in parts: + assert store_inner_edge == ("inner_edge" in part.edge_attributes) + assert store_inner_node == ("inner_node" in part.node_attributes) + assert store_eids == (dgl.EID in part.edge_attributes) + + +def _verify_graphbolt_part( + g, + test_dir, + orig_nids, + orig_eids, + graph_name, + num_parts, + store_inner_node, + store_inner_edge, + store_eids, + part_config=None, + test_ntype=None, + test_etype=None, + is_homo=False, +): + """ + check list: + _verify_metadata_gb: + data type, ID's order and ID's number of edges and nodes + _verify_IDs_gb: + local id, mapping id,node type id, orig edge, hetero ntype id + verify_graph_feats_gb: + nodes and edges' feats + _verify_graphbolt_attributes: + arguments + """ + parts = [] + tot_node_feats = [] + tot_edge_feats = [] + shuffled_labels = [] + shuffled_edata = [] + part_sizes = [] + gpbs = [] + if part_config is None: + part_config = os.path.join(test_dir, f"{graph_name}.json") + # test each part + for part_id in range(num_parts): + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=True, use_graphbolt=True + ) + # verify metadata + _verify_metadata_gb( + gpb, + g, + num_parts, + part_id, + part_sizes, + ) + + # verify eid and nid + _verify_IDs_gb( + g, + part_g, + part_id, + gpb, + part_sizes, + orig_nids, + orig_eids, + store_inner_node, + store_inner_edge, + store_eids, + is_homo, + ) + + # collect shuffled data and parts + _collect_data_gb( + parts, + part_g, + gpbs, + gpb, + tot_node_feats, + node_feats, + tot_edge_feats, + edge_feats, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, + ) + + # verify graph feats + verify_graph_feats_gb( + g, + gpbs, + parts, + tot_node_feats, + tot_edge_feats, + orig_nids, + orig_eids, + shuffled_labels=shuffled_labels, + shuffled_edata=shuffled_edata, + test_ntype=test_ntype, + test_etype=test_etype, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, + store_eids=store_eids, + is_homo=is_homo, + ) + + _verify_graphbolt_attributes( + parts, store_inner_node, store_inner_edge, store_eids + ) + + return parts + + +def _test_pipeline_graphbolt( + num_chunks, + num_parts, + world_size, + graph_formats=None, + data_fmt="numpy", + num_chunks_nodes=None, + num_chunks_edges=None, + num_chunks_node_data=None, + num_chunks_edge_data=None, + use_verify_partitions=False, + store_eids=True, + store_inner_edge=True, + store_inner_node=True, +): + if num_parts % world_size != 0: + # num_parts should be a multiple of world_size + return + + with tempfile.TemporaryDirectory() as root_dir: + g = create_chunked_dataset( + root_dir, + num_chunks, + data_fmt=data_fmt, + num_chunks_nodes=num_chunks_nodes, + num_chunks_edges=num_chunks_edges, + num_chunks_node_data=num_chunks_node_data, + num_chunks_edge_data=num_chunks_edge_data, + ) + graph_name = "test" + test_ntype = "paper" + test_etype = ("paper", "cites", "paper") + + # Step1: graph partition + in_dir = os.path.join(root_dir, "chunked-data") + output_dir = os.path.join(root_dir, "parted_data") + os.system( + "/opt/conda/envs/pytorch/bin/python tools/partition_algo/random_partition.py " + "--in_dir {} --out_dir {} --num_partitions {}".format( + in_dir, output_dir, num_parts + ) + ) + for ntype in ["author", "institution", "paper"]: + fname = os.path.join(output_dir, "{}.txt".format(ntype)) + with open(fname, "r") as f: + header = f.readline().rstrip() + assert isinstance(int(header), int) + + # Step2: data dispatch + partition_dir = os.path.join(root_dir, "parted_data") + out_dir = os.path.join(root_dir, "partitioned") + ip_config = os.path.join(root_dir, "ip_config.txt") + with open(ip_config, "w") as f: + for i in range(world_size): + f.write(f"127.0.0.{i + 1}\n") + + cmd = "/opt/conda/envs/pytorch/bin/python tools/dispatch_data.py" + cmd += f" --in-dir {in_dir}" + cmd += f" --partitions-dir {partition_dir}" + cmd += f" --out-dir {out_dir}" + cmd += f" --ip-config {ip_config}" + cmd += " --ssh-port 22" + cmd += " --process-group-timeout 60" + cmd += " --save-orig-nids" + cmd += " --save-orig-eids" + cmd += " --use-graphbolt" + cmd += f" --graph-formats {graph_formats}" if graph_formats else "" + + if store_eids: + cmd += " --store-eids" + if store_inner_edge: + cmd += " --store-inner-edge" + if store_inner_node: + cmd += " --store-inner-node" + os.system(cmd) + + # check if verify_partitions.py is used for validation. + if use_verify_partitions: + cmd = ( + "/opt/conda/envs/pytorch/bin/python tools/verify_partitions.py " + ) + cmd += f" --orig-dataset-dir {in_dir}" + cmd += f" --part-graph {out_dir}" + cmd += f" --partitions-dir {output_dir}" + os.system(cmd) + return + + # read original node/edge IDs + def read_orig_ids(fname): + orig_ids = {} + for i in range(num_parts): + ids_path = os.path.join(out_dir, f"part{i}", fname) + part_ids = load_tensors(ids_path) + for type, data in part_ids.items(): + if type not in orig_ids: + orig_ids[type] = data + else: + orig_ids[type] = torch.cat((orig_ids[type], data)) + return orig_ids + + orig_nids = read_orig_ids("orig_nids.dgl") + orig_eids_str = read_orig_ids("orig_eids.dgl") + + orig_eids = {} + # transmit etype from string to tuple. + for etype, eids in orig_eids_str.items(): + orig_eids[_etype_str_to_tuple(etype)] = eids + + # load partitions and verify + part_config = os.path.join(out_dir, "metadata.json") + parts = _verify_graphbolt_part( + g, + root_dir, + orig_nids, + orig_eids, + graph_name, + num_parts, + store_inner_node, + store_inner_edge, + store_eids, + test_ntype=test_ntype, + test_etype=test_etype, + part_config=part_config, + is_homo=False, + ) + + # for i in range(num_parts): + # part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + # part_config, i, use_graphbolt=True + # ) + # verify_partition_data_types(part_g, use_graphbolt=True) + # verify_graph_feats( + # g, + # gpb, + # part_g, + # node_feats, + # edge_feats, + # orig_nids, + # orig_eids, + # use_graphbolt=True, + # ) + + +@pytest.mark.parametrize( + "num_chunks, num_parts, world_size", + [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]], +) +def test_pipeline_basics(num_chunks, num_parts, world_size): + _test_pipeline_graphbolt(num_chunks, num_parts, world_size) + _test_pipeline_graphbolt( + num_chunks, num_parts, world_size, use_verify_partitions=False + ) + + +@pytest.mark.parametrize( + "num_chunks, " + "num_parts, " + "world_size, " + "num_chunks_node_data, " + "num_chunks_edge_data", + [ + # Test cases where no. of chunks more than + # no. of partitions + [8, 4, 4, 8, 8], + [8, 4, 2, 8, 8], + [9, 7, 5, 9, 9], + [8, 8, 4, 8, 8], + # Test cases where no. of chunks smaller + # than no. of partitions + [7, 8, 4, 7, 7], + [1, 8, 4, 1, 1], + [1, 4, 4, 1, 1], + [3, 4, 4, 3, 3], + [1, 4, 2, 1, 1], + [3, 4, 2, 3, 3], + [1, 5, 3, 1, 1], + ], +) +def test_pipeline_arbitrary_chunks( + num_chunks, + num_parts, + world_size, + num_chunks_node_data, + num_chunks_edge_data, +): + + _test_pipeline_graphbolt( + num_chunks, + num_parts, + world_size, + num_chunks_node_data=num_chunks_node_data, + num_chunks_edge_data=num_chunks_edge_data, + ) + + +@pytest.mark.parametrize("data_fmt", ["numpy", "parquet"]) +def test_pipeline_feature_format(data_fmt): + _test_pipeline_graphbolt(4, 4, 4, data_fmt=data_fmt) diff --git a/tools/dispatch_data.py b/tools/dispatch_data.py index 3cf1d0fbf224..27b9e9f61928 100644 --- a/tools/dispatch_data.py +++ b/tools/dispatch_data.py @@ -74,7 +74,10 @@ def submit_jobs(args) -> str: argslist += "--process-group-timeout {} ".format(args.process_group_timeout) argslist += "--log-level {} ".format(args.log_level) argslist += "--save-orig-nids " if args.save_orig_nids else "" - argslist += "--save-orig-eids " if args.save_orig_eids else "" + argslist += "--use-graphbolt " if args.use_graphbolt else "" + argslist += "--store-eids " if args.store_eids else "" + argslist += "--store-inner-node " if args.store_inner_node else "" + argslist += "--store-inner-edge " if args.store_inner_edge else "" argslist += ( f"--graph-formats {args.graph_formats} " if args.graph_formats else "" ) @@ -159,6 +162,30 @@ def main(): action="store_true", help="Save original edge IDs into files", ) + parser.add_argument( + "--use-graphbolt", + action="store_true", + help="Use GraphBolt for distributed partition.", + ) + parser.add_argument( + "--store-inner-node", + action="store_true", + default=False, + help="Store inner nodes.", + ) + + parser.add_argument( + "--store-inner-edge", + action="store_true", + default=False, + help="Store inner edges.", + ) + parser.add_argument( + "--store-eids", + action="store_true", + default=False, + help="Store edge IDs.", + ) parser.add_argument( "--graph-formats", type=str, @@ -170,6 +197,7 @@ def main(): ) args, _ = parser.parse_known_args() + assert args.use_graphbolt==True fmt = "%(asctime)s %(levelname)s %(message)s" logging.basicConfig( diff --git a/tools/distpartitioning/data_proc_pipeline.py b/tools/distpartitioning/data_proc_pipeline.py index 8f34c56f851c..62a19f7a4f74 100644 --- a/tools/distpartitioning/data_proc_pipeline.py +++ b/tools/distpartitioning/data_proc_pipeline.py @@ -125,7 +125,7 @@ def log_params(params): help="Save partitions in specified formats.", ) params = parser.parse_args() - + assert params.use_graphbolt is True # invoke the pipeline function numeric_level = getattr(logging, params.log_level.upper(), None) logging.basicConfig( diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index a7abcc75f648..20d6d533fd96 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -1334,9 +1334,12 @@ def prepare_local_data(src_data, local_part_id): schema_map[constants.STR_NUM_NODES_PER_TYPE], ), edge_typecounts, - params.save_orig_nids, - params.save_orig_eids, - params.use_graphbolt, + return_orig_nids=params.save_orig_nids, + return_orig_eids=params.save_orig_eids, + use_graphbolt=params.use_graphbolt, + store_inner_node=params.store_inner_node, + store_inner_edge=params.store_inner_edge, + store_eids=params.store_eids, ) sort_etypes = len(etypes_map) > 1 local_node_features = prepare_local_data( From 3804841abcd50f0dac778558e3aee6861ba26efb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 06:38:52 +0000 Subject: [PATCH 21/37] [distGB]change test_dist_partition --- python/dgl/distributed/partition.py | 524 +++++++++++++------ tests/tools/test_dist_partition_graphbolt.py | 255 +++++++-- tools/distpartitioning/convert_partition.py | 110 ++-- tools/distpartitioning/data_proc_pipeline.py | 1 - tools/distpartitioning/utils.py | 5 +- 5 files changed, 632 insertions(+), 263 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 07601fd5d2ca..331f4a29a335 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -88,24 +88,26 @@ def _dump_part_config(part_config, part_metadata): json.dump(part_metadata, outfile, sort_keys=False, indent=4) -def _save_graphs(filename, g_list, formats=None, sort_etypes=False): +def _process_partitions(g, formats=None, sort_etypes=False): """Preprocess partitions before saving: 1. format data types. 2. sort csc/csr by tag. """ - for g in g_list: - for k, dtype in RESERVED_FIELD_DTYPE.items(): - if k in g.ndata: - g.ndata[k] = F.astype(g.ndata[k], dtype) - if k in g.edata: - g.edata[k] = F.astype(g.edata[k], dtype) - for g in g_list: - if (not sort_etypes) or (formats is None): - continue + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in g.ndata: + g.ndata[k] = F.astype(g.ndata[k], dtype) + if k in g.edata: + g.edata[k] = F.astype(g.edata[k], dtype) + + if (sort_etypes) and (formats is not None): if "csr" in formats: g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") if "csc" in formats: g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") + return g + + +def _save_dgl_graphs(filename, g_list, formats=None): save_graphs(filename, g_list, formats=formats) @@ -332,9 +334,10 @@ def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False): "part-{}".format(part_id) in part_metadata ), "part-{} does not exist".format(part_id) part_files = part_metadata["part-{}".format(part_id)] - part_graph_field = "part_graph" if use_graphbolt: part_graph_field = "part_graph_graphbolt" + else: + part_graph_field = "part_graph" assert ( part_graph_field in part_files ), f"the partition does not contain graph structure: {part_graph_field}" @@ -461,7 +464,7 @@ def load_partition_feats( return node_feats, edge_feats -def load_partition_book(part_config, part_id): +def load_partition_book(part_config, part_id, part_metadata=None): """Load a graph partition book from the partition config file. Parameters @@ -470,6 +473,8 @@ def load_partition_book(part_config, part_id): The path of the partition config file. part_id : int The partition ID. + part_metadata : dict + The meta data of partition. Returns ------- @@ -482,7 +487,8 @@ def load_partition_book(part_config, part_id): dict The edge types """ - part_metadata = _load_part_config(part_config) + if part_metadata is None: + part_metadata = _load_part_config(part_config) assert "num_parts" in part_metadata, "num_parts does not exist." assert ( part_metadata["num_parts"] > part_id @@ -666,6 +672,38 @@ def _set_trainer_ids(g, sim_g, node_parts): g.edges[c_etype].data["trainer_id"] = trainer_id +def _partition_to_graphbolt( + parts, + part_i, + part_config, + part_metadata, + *, + store_eids=True, + store_inner_node=False, + store_inner_edge=False, + graph_formats=None, +): + gpb, _, ntypes, etypes = load_partition_book( + part_config=part_config, part_id=part_i, part_metadata=part_metadata + ) + graph = parts[part_i] + csc_graph = gb_convert_single_dgl_partition( + ntypes=ntypes, + etypes=etypes, + gpb=gpb, + part_meta=part_metadata, + graph=graph, + store_eids=store_eids, + store_inner_edge=store_inner_edge, + store_inner_node=store_inner_node, + graph_formats=graph_formats, + ) + rel_path_result = _save_graph_gb( + part_config=part_config, part_id=part_i, csc_graph=csc_graph + ) + part_metadata[f"part-{part_i}"]["part_graph_graphbolt"] = rel_path_result + + def partition_graph( g, graph_name, @@ -1200,6 +1238,7 @@ def get_homogeneous(g, balance_ntypes): "ntypes": ntypes, "etypes": etypes, } + part_config = os.path.join(out_path, graph_name + ".json") for part_id in range(num_parts): part = parts[part_id] @@ -1322,30 +1361,52 @@ def get_homogeneous(g, balance_ntypes): part_dir = os.path.join(out_path, "part" + str(part_id)) node_feat_file = os.path.join(part_dir, "node_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") - part_graph_file = os.path.join(part_dir, "graph.dgl") - part_metadata["part-{}".format(part_id)] = { - "node_feats": os.path.relpath(node_feat_file, out_path), - "edge_feats": os.path.relpath(edge_feat_file, out_path), - "part_graph": os.path.relpath(part_graph_file, out_path), - } + os.makedirs(part_dir, mode=0o775, exist_ok=True) save_tensors(node_feat_file, node_feats) save_tensors(edge_feat_file, edge_feats) + part_metadata["part-{}".format(part_id)] = { + "node_feats": os.path.relpath(node_feat_file, out_path), + "edge_feats": os.path.relpath(edge_feat_file, out_path), + } sort_etypes = len(g.etypes) > 1 - _save_graphs( - part_graph_file, - [part], - formats=graph_formats, - sort_etypes=sort_etypes, - ) - print( - "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( - time.time() - start, get_peak_mem() - ) - ) + part = _process_partitions(part, graph_formats, sort_etypes) + + # transmit to graphbolt and save graph + if use_graphbolt: + # save FusedCSCSamplingGraph + kwargs["graph_formats"] = graph_formats + n_jobs = kwargs.pop("n_jobs", 1) + mp_ctx = mp.get_context("spawn") + with concurrent.futures.ProcessPoolExecutor( + max_workers=min(num_parts, n_jobs), + mp_context=mp_ctx, + ) as executor: + for part_id in range(num_parts): + executor.submit( + _partition_to_graphbolt( + part_i=part_id, + part_config=part_config, + part_metadata=part_metadata, + parts=parts, + **kwargs, + ) + ) + else: + for part_id, part in parts.items(): + part_dir = os.path.join(out_path, "part" + str(part_id)) + part_graph_file = os.path.join(part_dir, "graph.dgl") + part_metadata["part-{}".format(part_id)][ + "part_graph" + ] = os.path.relpath(part_graph_file, out_path) + # save DGLGraph + _save_dgl_graphs( + part_graph_file, + [part], + formats=graph_formats, + ) - part_config = os.path.join(out_path, graph_name + ".json") _dump_part_config(part_config, part_metadata) num_cuts = sim_g.num_edges() - tot_num_inner_edges @@ -1357,12 +1418,11 @@ def get_homogeneous(g, balance_ntypes): ) ) - if use_graphbolt: - kwargs["graph_formats"] = graph_formats - dgl_partition_to_graphbolt( - part_config, - **kwargs, + print( + "Save partitions: {:.3f} seconds, peak memory: {:.3f} GB".format( + time.time() - start, get_peak_mem() ) + ) if return_mapping: return orig_nids, orig_eids @@ -1410,20 +1470,138 @@ def init_type_per_edge(graph, gpb): return etype_ids +def _load_part(part_config, part_id, parts=None): + """load parts from variable or dist.""" + if parts is None: + graph, _, _, _, _, _, _ = load_partition( + part_config, part_id, load_feats=False + ) + else: + graph = parts[part_id] + return graph + + +def _save_graph_gb(part_config, part_id, csc_graph): + csc_graph_save_dir = os.path.join( + os.path.dirname(part_config), + f"part{part_id}", + ) + csc_graph_path = os.path.join( + csc_graph_save_dir, "fused_csc_sampling_graph.pt" + ) + torch.save(csc_graph, csc_graph_path) + + return os.path.relpath(csc_graph_path, os.path.dirname(part_config)) + + +def cast_various_to_minimum_dtype_gb( + graph, + part_meta, + num_parts, + indptr, + indices, + type_per_edge, + etypes, + ntypes, + node_attributes, + edge_attributes, +): + """Cast various data to minimum dtype.""" + # Cast 1: indptr. + indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr) + # Cast 2: indices. + indices = _cast_to_minimum_dtype(graph.num_nodes(), indices) + # Cast 3: type_per_edge. + type_per_edge = _cast_to_minimum_dtype( + len(etypes), type_per_edge, field=ETYPE + ) + # Cast 4: node/edge_attributes. + predicates = { + NID: part_meta["num_nodes"], + "part_id": num_parts, + NTYPE: len(ntypes), + EID: part_meta["num_edges"], + ETYPE: len(etypes), + DGL2GB_EID: part_meta["num_edges"], + GB_DST_ID: part_meta["num_nodes"], + } + for attributes in [node_attributes, edge_attributes]: + for key in attributes: + if key not in predicates: + continue + attributes[key] = _cast_to_minimum_dtype( + predicates[key], attributes[key], field=key + ) + return indptr, indices, type_per_edge + + +def _create_attributes_gb( + graph, + gpb, + edge_ids, + is_homo, + store_inner_node, + store_inner_edge, + store_eids, + debug_mode, +): + # Save node attributes. Detailed attributes are shown below. + # DGL_GB\Attributes dgl.NID("_ID") dgl.NTYPE("_TYPE") "inner_node" "part_id" + # DGL_Homograph ✅ 🚫 ✅ ✅ + # GB_Homograph ✅ 🚫 optional 🚫 + # DGL_Heterograph ✅ ✅ ✅ ✅ + # GB_Heterograph ✅ 🚫 optional 🚫 + required_node_attrs = [NID] + if store_inner_node: + required_node_attrs.append("inner_node") + if debug_mode: + required_node_attrs = list(graph.ndata.keys()) + node_attributes = {attr: graph.ndata[attr] for attr in required_node_attrs} + + # Save edge attributes. Detailed attributes are shown below. + # DGL_GB\Attributes dgl.EID("_ID") dgl.ETYPE("_TYPE") "inner_edge" + # DGL_Homograph ✅ 🚫 ✅ + # GB_Homograph optional 🚫 optional + # DGL_Heterograph ✅ ✅ ✅ + # GB_Heterograph optional ✅ optional + type_per_edge = None + if not is_homo: + type_per_edge = init_type_per_edge(graph, gpb)[edge_ids] + type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE]) + required_edge_attrs = [] + if store_eids: + required_edge_attrs.append(EID) + if store_inner_edge: + required_edge_attrs.append("inner_edge") + if debug_mode: + required_edge_attrs = list(graph.edata.keys()) + edge_attributes = { + attr: graph.edata[attr][edge_ids] for attr in required_edge_attrs + } + return node_attributes, edge_attributes, type_per_edge + + def gb_convert_single_dgl_partition( - part_id, + ntypes, + etypes, + gpb, graph_formats, - part_config, store_eids, store_inner_node, store_inner_edge, + part_meta, + graph, ): """Converts a single DGL partition to GraphBolt. Parameters ---------- - part_id : int - The numerical ID of the partition to convert. + node types : dict + The node types + edge types : dict + The edge types + gpb : GraphPartitionBook + The global partition information. graph_formats : str or list[str], optional Save partitions in specified formats. It could be any combination of `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`, @@ -1437,6 +1615,10 @@ def gb_convert_single_dgl_partition( Whether to store inner node mask in the new graph. Default: False. store_inner_edge : bool, optional Whether to store inner edge mask in the new graph. Default: False. + part_meta : dict + Contain the meta data of the partition. + graph : DGLGraph + The graph to be converted to graphbolt graph. """ debug_mode = "DGL_DIST_DEBUG" in os.environ if debug_mode: @@ -1444,14 +1626,8 @@ def gb_convert_single_dgl_partition( "Running in debug mode which means all attributes of DGL partitions" " will be saved to the new format." ) - - part_meta = _load_part_config(part_config) num_parts = part_meta["num_parts"] - graph, _, _, gpb, _, _, _ = load_partition( - part_config, part_id, load_feats=False - ) - _, _, ntypes, etypes = load_partition_book(part_config, part_id) is_homo = is_homogeneous(ntypes, etypes) node_type_to_id = ( None if is_homo else {ntype: ntid for ntid, ntype in enumerate(ntypes)} @@ -1466,39 +1642,16 @@ def gb_convert_single_dgl_partition( # Obtain CSC indtpr and indices. indptr, indices, edge_ids = graph.adj_tensors("csc") - # Save node attributes. Detailed attributes are shown below. - # DGL_GB\Attributes dgl.NID("_ID") dgl.NTYPE("_TYPE") "inner_node" "part_id" - # DGL_Homograph ✅ 🚫 ✅ ✅ - # GB_Homograph ✅ 🚫 optional 🚫 - # DGL_Heterograph ✅ ✅ ✅ ✅ - # GB_Heterograph ✅ 🚫 optional 🚫 - required_node_attrs = [NID] - if store_inner_node: - required_node_attrs.append("inner_node") - if debug_mode: - required_node_attrs = list(graph.ndata.keys()) - node_attributes = {attr: graph.ndata[attr] for attr in required_node_attrs} - - # Save edge attributes. Detailed attributes are shown below. - # DGL_GB\Attributes dgl.EID("_ID") dgl.ETYPE("_TYPE") "inner_edge" - # DGL_Homograph ✅ 🚫 ✅ - # GB_Homograph optional 🚫 optional - # DGL_Heterograph ✅ ✅ ✅ - # GB_Heterograph optional ✅ optional - type_per_edge = None - if not is_homo: - type_per_edge = init_type_per_edge(graph, gpb)[edge_ids] - type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE]) - required_edge_attrs = [] - if store_eids: - required_edge_attrs.append(EID) - if store_inner_edge: - required_edge_attrs.append("inner_edge") - if debug_mode: - required_edge_attrs = list(graph.edata.keys()) - edge_attributes = { - attr: graph.edata[attr][edge_ids] for attr in required_edge_attrs - } + node_attributes, edge_attributes, type_per_edge = _create_attributes_gb( + graph, + gpb, + edge_ids, + is_homo, + store_inner_node, + store_inner_edge, + store_eids, + debug_mode, + ) # When converting DGLGraph to FusedCSCSamplingGraph, edge IDs are # re-ordered(actually FusedCSCSamplingGraph does not have edge IDs # in nature). So we need to save such re-order info for any @@ -1520,32 +1673,18 @@ def gb_convert_single_dgl_partition( indptr, dtype=indices.dtype ) - # Cast various data to minimum dtype. - # Cast 1: indptr. - indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr) - # Cast 2: indices. - indices = _cast_to_minimum_dtype(graph.num_nodes(), indices) - # Cast 3: type_per_edge. - type_per_edge = _cast_to_minimum_dtype( - len(etypes), type_per_edge, field=ETYPE + indptr, indices, type_per_edge = cast_various_to_minimum_dtype_gb( + graph, + part_meta, + num_parts, + indptr, + indices, + type_per_edge, + etypes, + ntypes, + node_attributes, + edge_attributes, ) - # Cast 4: node/edge_attributes. - predicates = { - NID: part_meta["num_nodes"], - "part_id": num_parts, - NTYPE: len(ntypes), - EID: part_meta["num_edges"], - ETYPE: len(etypes), - DGL2GB_EID: part_meta["num_edges"], - GB_DST_ID: part_meta["num_nodes"], - } - for attributes in [node_attributes, edge_attributes]: - for key in attributes: - if key not in predicates: - continue - attributes[key] = _cast_to_minimum_dtype( - predicates[key], attributes[key], field=key - ) csc_graph = gb.fused_csc_sampling_graph( indptr, @@ -1557,17 +1696,125 @@ def gb_convert_single_dgl_partition( node_type_to_id=node_type_to_id, edge_type_to_id=edge_type_to_id, ) - orig_graph_path = os.path.join( - os.path.dirname(part_config), - part_meta[f"part-{part_id}"]["part_graph"], + return csc_graph + + +def convert_partition_to_graphbolt_multi_process( + part_config, + part_id, + graph_formats, + store_eids, + store_inner_node, + store_inner_edge, +): + """ + Convert signle partition to graphbolt, which is used for multiple process. + Parameters + ---------- + part_config : str + The path of the partition config file. + part_id : int + The partition ID. + graph_formats : str or list[str], optional + Save partitions in specified formats. It could be any combination of + `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`, + it is not necessary to specify this argument. It's mainly for + specifying `coo` format to save edge ID mapping and destination node + IDs. If not specified, whether to save `coo` format is determined by + the availability of the format in DGL partitions. Default: None. + store_eids : bool, optional + Whether to store edge IDs in the new graph. Default: True. + store_inner_node : bool, optional + Whether to store inner node mask in the new graph. Default: False. + store_inner_edge : bool, optional + Whether to store inner edge mask in the new graph. Default: False. + + Returns + ------- + str + The path csc_graph to save. + """ + gpb, _, ntypes, etypes = load_partition_book( + part_config=part_config, part_id=part_id ) - csc_graph_path = os.path.join( - os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.pt" + part = _load_part(part_config, part_id) + part_meta = copy.deepcopy(_load_part_config(part_config)) + csc_graph = gb_convert_single_dgl_partition( + graph=part, + ntypes=ntypes, + etypes=etypes, + gpb=gpb, + part_meta=part_meta, + graph_formats=graph_formats, + store_eids=store_eids, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, ) - torch.save(csc_graph, csc_graph_path) + rel_path = _save_graph_gb(part_config, part_id, csc_graph) + return rel_path - return os.path.relpath(csc_graph_path, os.path.dirname(part_config)) - # Update graph path. + +def _convert_partition_to_graphbolt( + graph_formats, + part_config, + store_eids, + store_inner_node, + store_inner_edge, + n_jobs, + num_parts, +): + # [Rui] DGL partitions are always saved as homogeneous graphs even though + # the original graph is heterogeneous. But heterogeneous information like + # node/edge types are saved as node/edge data alongside with partitions. + # What needs more attention is that due to the existence of HALO nodes in + # each partition, the local node IDs are not sorted according to the node + # types. So we fail to assign ``node_type_offset`` as required by GraphBolt. + # But this is not a problem since such information is not used in sampling. + # We can simply pass None to it. + + # Iterate over partitions. + convert_with_format = partial( + convert_partition_to_graphbolt_multi_process, + part_config=part_config, + graph_formats=graph_formats, + store_eids=store_eids, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, + ) + # Need to create entirely new interpreters, because we call C++ downstream + # See https://docs.python.org/3.12/library/multiprocessing.html#contexts-and-start-methods + # and https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil + rel_path_results = [] + if n_jobs > 1 and num_parts > 1: + mp_ctx = mp.get_context("spawn") + with concurrent.futures.ProcessPoolExecutor( # pylint: disable=unexpected-keyword-arg + max_workers=min(num_parts, n_jobs), + mp_context=mp_ctx, + ) as executor: + for part_id in range(num_parts): + rel_path_results.append( + executor.submit(part_id=part_id).result() + ) + + else: + # If running single-threaded, avoid spawning new interpreter, which is slow + for part_id in range(num_parts): + rel_path = convert_with_format(part_id=part_id) + rel_path_results.append(rel_path) + part_meta = _load_part_config(part_config) + for part_id in range(num_parts): + # Update graph path. + part_meta[f"part-{part_id}"]["part_graph_graphbolt"] = rel_path_results[ + part_id + ] + + # Save dtype info into partition config. + # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more + # details in #7175. + part_meta["node_map_dtype"] = "int64" + part_meta["edge_map_dtype"] = "int64" + + return part_meta def dgl_partition_to_graphbolt( @@ -1616,59 +1863,14 @@ def dgl_partition_to_graphbolt( " will be saved to the new format." ) part_meta = _load_part_config(part_config) - new_part_meta = copy.deepcopy(part_meta) num_parts = part_meta["num_parts"] - - # [Rui] DGL partitions are always saved as homogeneous graphs even though - # the original graph is heterogeneous. But heterogeneous information like - # node/edge types are saved as node/edge data alongside with partitions. - # What needs more attention is that due to the existence of HALO nodes in - # each partition, the local node IDs are not sorted according to the node - # types. So we fail to assign ``node_type_offset`` as required by GraphBolt. - # But this is not a problem since such information is not used in sampling. - # We can simply pass None to it. - - # Iterate over partitions. - convert_with_format = partial( - gb_convert_single_dgl_partition, + part_meta = _convert_partition_to_graphbolt( graph_formats=graph_formats, part_config=part_config, store_eids=store_eids, store_inner_node=store_inner_node, store_inner_edge=store_inner_edge, + n_jobs=n_jobs, + num_parts=num_parts, ) - # Need to create entirely new interpreters, because we call C++ downstream - # See https://docs.python.org/3.12/library/multiprocessing.html#contexts-and-start-methods - # and https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil - rel_path_results = [] - if n_jobs > 1 and num_parts > 1: - mp_ctx = mp.get_context("spawn") - with concurrent.futures.ProcessPoolExecutor( # pylint: disable=unexpected-keyword-arg - max_workers=min(num_parts, n_jobs), - mp_context=mp_ctx, - ) as executor: - futures = [] - for part_id in range(num_parts): - futures.append(executor.submit(convert_with_format, part_id)) - - for part_id in range(num_parts): - rel_path_results.append(futures[part_id].result()) - else: - # If running single-threaded, avoid spawning new interpreter, which is slow - for part_id in range(num_parts): - rel_path_results.append(convert_with_format(part_id)) - - for part_id in range(num_parts): - # Update graph path. - new_part_meta[f"part-{part_id}"][ - "part_graph_graphbolt" - ] = rel_path_results[part_id] - - # Save dtype info into partition config. - # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more - # details in #7175. - new_part_meta["node_map_dtype"] = "int64" - new_part_meta["edge_map_dtype"] = "int64" - - _dump_part_config(part_config, new_part_meta) - print(f"Converted partitions to GraphBolt format into {part_config}") + _dump_part_config(part_config, part_meta) diff --git a/tests/tools/test_dist_partition_graphbolt.py b/tests/tools/test_dist_partition_graphbolt.py index ea596c67d6aa..b0d49343042e 100644 --- a/tests/tools/test_dist_partition_graphbolt.py +++ b/tests/tools/test_dist_partition_graphbolt.py @@ -644,6 +644,183 @@ def _verify_graphbolt_part( return parts +def _verify_hetero_graph_node_edge_num( + g, + parts, + store_inner_edge, + debug_mode, +): + """ + check list: + make sure edge type are correct. + make sure the number of nodes in each node type are correct. + make sure the number of nodes in each node type are correct. + """ + num_nodes = {ntype: 0 for ntype in g.ntypes} + num_edges = {etype: 0 for etype in g.canonical_etypes} + for part in parts: + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + if dgl.ETYPE in edata: + assert len(g.canonical_etypes) == len(F.unique(edata[dgl.ETYPE])) + if debug_mode or isinstance(part, dgl.DGLGraph): + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask(part, ntype_id) + num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0) + num_nodes[ntype] += num_inner_nodes + if store_inner_edge or isinstance(part, dgl.DGLGraph): + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0) + num_edges[etype] += num_inner_edges + + # Verify the number of nodes are correct. + if debug_mode or isinstance(part, dgl.DGLGraph): + for ntype in g.ntypes: + print( + "node {}: {}, {}".format( + ntype, g.num_nodes(ntype), num_nodes[ntype] + ) + ) + assert g.num_nodes(ntype) == num_nodes[ntype] + # Verify the number of edges are correct. + if store_inner_edge or isinstance(part, dgl.DGLGraph): + for etype in g.canonical_etypes: + print( + "edge {}: {}, {}".format( + etype, g.num_edges(etype), num_edges[etype] + ) + ) + assert g.num_edges(etype) == num_edges[etype] + + +def _verify_edge_id_range_hetero( + g, + part, + eids, +): + """ + check list: + make sure inner_eids fall into a range. + make sure all edges are included. + """ + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + etype = ( + part.type_per_edge + if isinstance(part, gb.FusedCSCSamplingGraph) + else edata[dgl.ETYPE] + ) + eid = torch.arange(len(edata[dgl.EID])) + etype_arr = F.gather_row(etype, eid) + eid_arr = F.gather_row(edata[dgl.EID], eid) + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + eids[etype].append(F.boolean_mask(eid_arr, etype_arr == etype_id)) + # Make sure edge Ids fall into a range. + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = np.sort( + F.asnumpy(F.boolean_mask(edata[dgl.EID], inner_edge_mask)) + ) + assert np.all( + inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1) + ) + return eids + + +def _verify_node_id_range_hetero(g, part, nids): + """ + check list: + make sure inner nodes have Ids fall into a range. + """ + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + # Make sure inner nodes have Ids fall into a range. + inner_node_mask = _get_inner_node_mask(part, ntype_id) + inner_nids = F.boolean_mask( + part.node_attributes[dgl.NID], inner_node_mask + ) + assert np.all( + F.asnumpy( + inner_nids + == F.arange( + F.as_scalar(inner_nids[0]), + F.as_scalar(inner_nids[-1]) + 1, + ) + ) + ) + nids[ntype].append(inner_nids) + return nids + + +def _verify_graph_attributes_hetero( + g, + parts, + store_inner_edge, + store_inner_node, +): + """ + check list: + make sure edge ids fall into a range. + make sure inner nodes have Ids fall into a range. + make sure all nodes is included. + make sure all edges is included. + """ + nids = {ntype: [] for ntype in g.ntypes} + eids = {etype: [] for etype in g.canonical_etypes} + # check edge id. + if store_inner_edge or isinstance(parts[0], dgl.DGLGraph): + for part in parts: + # collect eids + eids = _verify_edge_id_range_hetero(g, part, eids) + for etype in eids: + eids_type = F.cat(eids[etype], 0) + uniq_ids = F.unique(eids_type) + # We should get all nodes. + assert len(uniq_ids) == g.num_edges(etype) + + # check node id. + if store_inner_node or isinstance(parts[0], dgl.DGLGraph): + for part in parts: + nids = _verify_node_id_range_hetero(g, part, nids) + for ntype in nids: + nids_type = F.cat(nids[ntype], 0) + uniq_ids = F.unique(nids_type) + # We should get all nodes. + assert len(uniq_ids) == g.num_nodes(ntype) + + +def _verify_hetero_graph( + g, + parts, + store_eids=False, + store_inner_edge=False, + store_inner_node=False, + debug_mode=False, +): + _verify_hetero_graph_node_edge_num( + g, + parts, + store_inner_edge=store_inner_edge, + debug_mode=debug_mode, + ) + if store_eids: + _verify_graph_attributes_hetero( + g, + parts, + store_inner_edge=store_inner_edge, + store_inner_node=store_inner_node, + ) + + def _test_pipeline_graphbolt( num_chunks, num_parts, @@ -681,7 +858,7 @@ def _test_pipeline_graphbolt( in_dir = os.path.join(root_dir, "chunked-data") output_dir = os.path.join(root_dir, "parted_data") os.system( - "/opt/conda/envs/pytorch/bin/python tools/partition_algo/random_partition.py " + "python3 tools/partition_algo/random_partition.py " "--in_dir {} --out_dir {} --num_partitions {}".format( in_dir, output_dir, num_parts ) @@ -700,31 +877,29 @@ def _test_pipeline_graphbolt( for i in range(world_size): f.write(f"127.0.0.{i + 1}\n") - cmd = "/opt/conda/envs/pytorch/bin/python tools/dispatch_data.py" - cmd += f" --in-dir {in_dir}" - cmd += f" --partitions-dir {partition_dir}" - cmd += f" --out-dir {out_dir}" - cmd += f" --ip-config {ip_config}" - cmd += " --ssh-port 22" - cmd += " --process-group-timeout 60" - cmd += " --save-orig-nids" - cmd += " --save-orig-eids" - cmd += " --use-graphbolt" - cmd += f" --graph-formats {graph_formats}" if graph_formats else "" + cmd = "python3 tools/dispatch_data.py " + cmd += f" --in-dir {in_dir} " + cmd += f" --partitions-dir {partition_dir} " + cmd += f" --out-dir {out_dir} " + cmd += f" --ip-config {ip_config} " + cmd += " --ssh-port 22 " + cmd += " --process-group-timeout 60 " + cmd += " --save-orig-nids " + cmd += " --save-orig-eids " + cmd += " --use-graphbolt " + cmd += f" --graph-formats {graph_formats} " if graph_formats else "" if store_eids: - cmd += " --store-eids" + cmd += " --store-eids " if store_inner_edge: - cmd += " --store-inner-edge" + cmd += " --store-inner-edge " if store_inner_node: - cmd += " --store-inner-node" + cmd += " --store-inner-node " os.system(cmd) # check if verify_partitions.py is used for validation. if use_verify_partitions: - cmd = ( - "/opt/conda/envs/pytorch/bin/python tools/verify_partitions.py " - ) + cmd = "python3 tools/verify_partitions.py " cmd += f" --orig-dataset-dir {in_dir}" cmd += f" --part-graph {out_dir}" cmd += f" --partitions-dir {output_dir}" @@ -744,7 +919,9 @@ def read_orig_ids(fname): orig_ids[type] = torch.cat((orig_ids[type], data)) return orig_ids + orig_nids, orig_eids = None, None orig_nids = read_orig_ids("orig_nids.dgl") + orig_eids_str = read_orig_ids("orig_eids.dgl") orig_eids = {} @@ -769,22 +946,12 @@ def read_orig_ids(fname): part_config=part_config, is_homo=False, ) - - # for i in range(num_parts): - # part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( - # part_config, i, use_graphbolt=True - # ) - # verify_partition_data_types(part_g, use_graphbolt=True) - # verify_graph_feats( - # g, - # gpb, - # part_g, - # node_feats, - # edge_feats, - # orig_nids, - # orig_eids, - # use_graphbolt=True, - # ) + _verify_hetero_graph( + g, + parts, + store_eids=store_eids, + store_inner_edge=store_inner_edge, + ) @pytest.mark.parametrize( @@ -792,12 +959,30 @@ def read_orig_ids(fname): [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]], ) def test_pipeline_basics(num_chunks, num_parts, world_size): - _test_pipeline_graphbolt(num_chunks, num_parts, world_size) + _test_pipeline_graphbolt( + num_chunks, + num_parts, + world_size, + ) _test_pipeline_graphbolt( num_chunks, num_parts, world_size, use_verify_partitions=False ) +@pytest.mark.parametrize("store_inner_node", [True, False]) +@pytest.mark.parametrize("store_inner_edge", [True, False]) +@pytest.mark.parametrize("store_eids", [True, False]) +def test_pipeline_attributes(store_inner_node, store_inner_edge, store_eids): + _test_pipeline_graphbolt( + 4, + 4, + 4, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, + store_eids=store_eids, + ) + + @pytest.mark.parametrize( "num_chunks, " "num_parts, " diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 15d98f1d0ee6..bec4250da0fd 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -1,3 +1,4 @@ +import copy import gc import logging import os @@ -165,24 +166,22 @@ def _is_homogeneous(ntypes, etypes): return len(ntypes) == 1 and len(etypes) == 1 -def _coo2csc(part_local_src_id, part_local_dst_id): - part_local_src_id, part_local_dst_id = th.tensor( - part_local_src_id, dtype=th.int64 - ), th.tensor(part_local_dst_id, dtype=th.int64) - num_nodes = th.max(th.cat([part_local_src_id, part_local_dst_id], dim=0)) - indptr = th.zeros(num_nodes + 2, dtype=th.int64) - col_counts = th.bincount(part_local_dst_id, minlength=num_nodes + 1) - indptr[1:] = th.cumsum(col_counts, 0) - edge_id = th.argsort(part_local_dst_id) - indices = part_local_src_id[edge_id] - return indptr, indices, edge_id +def _coo2csc(src_ids, dst_ids): + src_ids, dst_ids = th.tensor(src_ids, dtype=th.int64), th.tensor( + dst_ids, dtype=th.int64 + ) + num_nodes = th.max(th.stack([src_ids, dst_ids], dim=0)).item() + 1 + dst, idx = dst_ids.sort() + indptr = th.searchsorted(dst, th.arange(num_nodes + 1)) + indices = src_ids[idx] + return indptr, indices, idx def _create_edge_data(edgeid_offset, etype_ids, num_edges): eid = th.arange( edgeid_offset, edgeid_offset + num_edges, - dtype=th.int64, + dtype=RESERVED_FIELD_DTYPE[dgl.EID], ) etype = th.as_tensor(etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE]) inner_edge = th.ones(num_edges, dtype=RESERVED_FIELD_DTYPE["inner_edge"]) @@ -302,19 +301,19 @@ def _create_node_attr( def remove_attr_gb( edge_attr, node_attr, store_inner_node, store_inner_edge, store_eids ): + edata, ndata = copy.deepcopy(edge_attr), copy.deepcopy(node_attr) if not store_inner_edge: - assert "inner_edge" in edge_attr - edge_attr.pop("inner_edge") - assert "inner_edge" in edge_attr + assert "inner_edge" in edata + edata.pop("inner_edge") if not store_eids: - assert dgl.EID in edge_attr - edge_attr.pop(dgl.EID) + assert dgl.EID in edata + edata.pop(dgl.EID) if not store_inner_node: - assert "inner_node" in node_attr - node_attr.pop("inner_node") - return edge_attr, node_attr + assert "inner_node" in ndata + ndata.pop("inner_node") + return edata, ndata def create_graph_object( @@ -605,8 +604,18 @@ def create_graph_object( ) # create the graph here now. + ndata, per_type_ids = _create_node_attr( + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes, + ) if use_graphbolt: - edge_attr, type_per_edge, edge_type_to_id = _create_edge_attr_gb( + edata, type_per_edge, edge_type_to_id = _create_edge_attr_gb( part_local_dst_id, edgeid_offset, etype_ids, @@ -614,28 +623,9 @@ def create_graph_object( etypes, etypes_map, ) - node_attr, per_type_ids = _create_node_attr( - idx, - global_src_id, - global_dst_id, - global_homo_nid, - uniq_ids, - reshuffle_nodes, - id_map, - inner_nodes, - ) - orig_nids, orig_eids = _graph_orig_ids( - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - node_attr, - edge_attr, - per_type_ids, - type_per_edge, - global_edge_id, + edge_attr, node_attr = remove_attr_gb( + edge_attr=edata, node_attr=ndata, **kwargs ) - remove_attr_gb(edge_attr, node_attr, **kwargs) indptr, indices, csc_edge_ids = _coo2csc( part_local_src_id, part_local_dst_id ) @@ -652,15 +642,6 @@ def create_graph_object( node_type_to_id=ntypes_map, edge_type_to_id=edge_type_to_id, ) - return ( - part_graph, - node_map_val, - edge_map_val, - ntypes_map, - etypes_map, - orig_nids, - orig_eids, - ) else: num_edges = len(part_local_dst_id) part_graph = dgl.graph( @@ -685,19 +666,20 @@ def create_graph_object( ) for attr_name, node_attributes in ndata.items(): part_graph.ndata[attr_name] = node_attributes - - # get the original node ids and edge ids from original graph. - orig_nids, orig_eids = _graph_orig_ids( - return_orig_nids, - return_orig_eids, - ntypes_map, - etypes_map, - part_graph.ndata, - part_graph.edata, - per_type_ids, - part_graph.edata[dgl.ETYPE], - global_edge_id, - ) + type_per_edge = part_graph.edata[dgl.ETYPE] + ndata, edata = part_graph.ndata, part_graph.edata + # get the original node ids and edge ids from original graph. + orig_nids, orig_eids = _graph_orig_ids( + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + ndata, + edata, + per_type_ids, + type_per_edge, + global_edge_id, + ) return ( part_graph, node_map_val, diff --git a/tools/distpartitioning/data_proc_pipeline.py b/tools/distpartitioning/data_proc_pipeline.py index 62a19f7a4f74..e0159f55b9a1 100644 --- a/tools/distpartitioning/data_proc_pipeline.py +++ b/tools/distpartitioning/data_proc_pipeline.py @@ -125,7 +125,6 @@ def log_params(params): help="Save partitions in specified formats.", ) params = parser.parse_args() - assert params.use_graphbolt is True # invoke the pipeline function numeric_level = getattr(logging, params.log_level.upper(), None) logging.basicConfig( diff --git a/tools/distpartitioning/utils.py b/tools/distpartitioning/utils.py index fbf4ae8c0fed..71b9b7712984 100644 --- a/tools/distpartitioning/utils.py +++ b/tools/distpartitioning/utils.py @@ -533,9 +533,10 @@ def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes): sort_etypes : bool Whether to sort etypes in csc/csr. """ - dgl.distributed.partition._save_graphs( - graph_file, [graph_obj], formats, sort_etypes + dgl.distributed.partition.process_partitions( + graph_obj, formats, sort_etypes ) + dgl.save_graphs(graph_file, [graph_obj], formats=formats) def write_dgl_objects( From 00bb70c5bb9f735151b433536639a62c83024879 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 06:40:02 +0000 Subject: [PATCH 22/37] modify dispatch_data.py --- tools/dispatch_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/dispatch_data.py b/tools/dispatch_data.py index 27b9e9f61928..b2b54e51a6ec 100644 --- a/tools/dispatch_data.py +++ b/tools/dispatch_data.py @@ -74,6 +74,7 @@ def submit_jobs(args) -> str: argslist += "--process-group-timeout {} ".format(args.process_group_timeout) argslist += "--log-level {} ".format(args.log_level) argslist += "--save-orig-nids " if args.save_orig_nids else "" + argslist += "--save-orig-eids " if args.save_orig_eids else "" argslist += "--use-graphbolt " if args.use_graphbolt else "" argslist += "--store-eids " if args.store_eids else "" argslist += "--store-inner-node " if args.store_inner_node else "" @@ -197,7 +198,6 @@ def main(): ) args, _ = parser.parse_known_args() - assert args.use_graphbolt==True fmt = "%(asctime)s %(levelname)s %(message)s" logging.basicConfig( From 8ca0f8930feb8652b314248503f4bd9c6395dd1f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 08:00:59 +0000 Subject: [PATCH 23/37] change partition format --- python/dgl/distributed/partition.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index ae18c20b7881..7bc6ce0aeb62 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -88,7 +88,7 @@ def _dump_part_config(part_config, part_metadata): json.dump(part_metadata, outfile, sort_keys=False, indent=4) -def _process_partitions(g, formats=None, sort_etypes=False): +def process_partitions(g, formats=None, sort_etypes=False): """Preprocess partitions before saving: 1. format data types. 2. sort csc/csr by tag. @@ -702,6 +702,8 @@ def _partition_to_graphbolt( part_config=part_config, part_id=part_i, csc_graph=csc_graph ) part_metadata[f"part-{part_i}"]["part_graph_graphbolt"] = rel_path_result + + def _update_node_edge_map(node_map_val, edge_map_val, g, num_parts): """ If the original graph contains few nodes or edges for specific node/edge @@ -1472,7 +1474,7 @@ def get_homogeneous(g, balance_ntypes): "edge_feats": os.path.relpath(edge_feat_file, out_path), } sort_etypes = len(g.etypes) > 1 - part = _process_partitions(part, graph_formats, sort_etypes) + part = process_partitions(part, graph_formats, sort_etypes) # transmit to graphbolt and save graph if use_graphbolt: From 412cf7deee0b1c2cd258232af8b5879ebf8a9d1e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 08:36:11 +0000 Subject: [PATCH 24/37] partition.py --- python/dgl/distributed/partition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 7bc6ce0aeb62..982f15dfcc92 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1503,6 +1503,7 @@ def get_homogeneous(g, balance_ntypes): part_metadata["part-{}".format(part_id)][ "part_graph" ] = os.path.relpath(part_graph_file, out_path) + # save DGLGraph _save_dgl_graphs( part_graph_file, From eb29e14742e4802d2f8c26101c43c86399853876 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 08:39:56 +0000 Subject: [PATCH 25/37] change partition --- python/dgl/distributed/partition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 982f15dfcc92..71a190fc7d44 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1501,8 +1501,8 @@ def get_homogeneous(g, balance_ntypes): part_dir = os.path.join(out_path, "part" + str(part_id)) part_graph_file = os.path.join(part_dir, "graph.dgl") part_metadata["part-{}".format(part_id)][ - "part_graph" - ] = os.path.relpath(part_graph_file, out_path) + "part_graph" + ] = os.path.relpath(part_graph_file, out_path) # save DGLGraph _save_dgl_graphs( From b4e3afd03c644d66db541fd9d22f4ade0f09e6f2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 08:43:21 +0000 Subject: [PATCH 26/37] change partition format --- python/dgl/distributed/partition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 71a190fc7d44..4aadfe98e617 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1503,7 +1503,6 @@ def get_homogeneous(g, balance_ntypes): part_metadata["part-{}".format(part_id)][ "part_graph" ] = os.path.relpath(part_graph_file, out_path) - # save DGLGraph _save_dgl_graphs( part_graph_file, From 2e58bad1ac3e11ec70b2e49184a2692c3b8780e1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 09:12:29 +0000 Subject: [PATCH 27/37] change partition --- python/dgl/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 4aadfe98e617..9acbc5329802 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1896,7 +1896,7 @@ def _convert_partition_to_graphbolt( ) as executor: for part_id in range(num_parts): rel_path_results.append( - executor.submit(part_id=part_id).result() + executor.submit(convert_with_format,part_id=part_id).result() ) else: From 283eaccaf004d9b1470f675933c42c219a8b7af0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 10 Sep 2024 09:28:05 +0000 Subject: [PATCH 28/37] change partition --- python/dgl/distributed/partition.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 9acbc5329802..7fe4d61b0edb 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1482,7 +1482,7 @@ def get_homogeneous(g, balance_ntypes): kwargs["graph_formats"] = graph_formats n_jobs = kwargs.pop("n_jobs", 1) mp_ctx = mp.get_context("spawn") - with concurrent.futures.ProcessPoolExecutor( + with concurrent.futures.ProcessPoolExecutor( # pylint: disable=unexpected-keyword-arg max_workers=min(num_parts, n_jobs), mp_context=mp_ctx, ) as executor: @@ -1896,7 +1896,9 @@ def _convert_partition_to_graphbolt( ) as executor: for part_id in range(num_parts): rel_path_results.append( - executor.submit(convert_with_format,part_id=part_id).result() + executor.submit( + convert_with_format, part_id=part_id + ).result() ) else: From 86a0c990419c85181a81075df8b541c32d5ccd17 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 12 Sep 2024 04:09:03 +0000 Subject: [PATCH 29/37] change dist partition --- tests/tools/test_dist_partition_graphbolt.py | 6 - tools/distpartitioning/convert_partition.py | 128 ++++++++++++++++++- tools/distpartitioning/data_shuffle.py | 34 ++--- tools/distpartitioning/utils.py | 29 +++-- 4 files changed, 166 insertions(+), 31 deletions(-) diff --git a/tests/tools/test_dist_partition_graphbolt.py b/tests/tools/test_dist_partition_graphbolt.py index b0d49343042e..5ebf37718c8b 100644 --- a/tests/tools/test_dist_partition_graphbolt.py +++ b/tests/tools/test_dist_partition_graphbolt.py @@ -24,12 +24,6 @@ from distpartitioning.utils import generate_read_list from pytest_utils import create_chunked_dataset -from tools.verification_utils import ( - verify_graph_feats, - verify_partition_data_types, - verify_partition_formats, -) - def _verify_metadata_gb(gpb, g, num_parts, part_id, part_sizes): """ diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index bec4250da0fd..3d1f7a41a3aa 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -4,12 +4,15 @@ import os import constants - import dgl import dgl.graphbolt as gb import numpy as np import torch as th +from dgl import EID, ETYPE, NID, NTYPE + +from dgl.distributed.constants import DGL2GB_EID, GB_DST_ID from dgl.distributed.partition import ( + _cast_to_minimum_dtype, _etype_str_to_tuple, _etype_tuple_to_str, RESERVED_FIELD_DTYPE, @@ -316,7 +319,72 @@ def remove_attr_gb( return edata, ndata +def cast_various_to_minimum_dtype_gb( + node_count, + edge_count, + num_parts, + indptr, + indices, + type_per_edge, + etypes, + ntypes, + node_attributes, + edge_attributes, +): + """Cast various data to minimum dtype.""" + # Cast 1: indptr. + indptr = _cast_to_minimum_dtype(edge_count, indptr) + # Cast 2: indices. + indices = _cast_to_minimum_dtype(node_count, indices) + # Cast 3: type_per_edge. + type_per_edge = _cast_to_minimum_dtype( + len(etypes), type_per_edge, field=ETYPE + ) + # Cast 4: node/edge_attributes. + predicates = { + NID: node_count, + "part_id": num_parts, + NTYPE: len(ntypes), + EID: edge_count, + ETYPE: len(etypes), + DGL2GB_EID: edge_count, + GB_DST_ID: node_count, + } + for attributes in [node_attributes, edge_attributes]: + for key in attributes: + if key not in predicates: + continue + attributes[key] = _cast_to_minimum_dtype( + predicates[key], attributes[key], field=key + ) + return indptr, indices, type_per_edge + + +def _process_partition_gb( + part_local_src_id, + part_local_dst_id, + edge_attr, + node_attr, + type_per_edge, + formats, +): + temp_g = dgl.DGLGraph((part_local_src_id, part_local_dst_id)) + for edge_type, data in edge_attr.items(): + temp_g.edata[edge_type] = data + for node_type, data in node_attr.items(): + temp_g.ndata[node_type] = data + sort_etypes = max(type_per_edge) > 1 + temp_g = dgl.distributed.partition.process_partitions( + temp_g, formats=formats, sort_etypes=sort_etypes + ) + return temp_g.edata, temp_g.ndata + + def create_graph_object( + node_count, + edge_count, + graph_formats, + num_parts, schema, part_id, node_data, @@ -377,6 +445,14 @@ def create_graph_object( Parameters: ----------- + node_count : int + the number of all nodes + edge_count : int + the number of all edges + graph_formats : str + the format of graph + num_parts : int + the number of parts schame : json object json object created by reading the graph metadata json file part_id : int @@ -603,6 +679,33 @@ def create_graph_object( nid_map[part_local_dst_id], ) + """ + Creating attributes for graphbolt and DGLGraph is as follows. + + node attributes: + this part is implemented in _create_node_attr. + compute the ntype and per type ids for each node with global node type id. + create ntype, nid and inner node with orig ntype and inner nodes + this part is shared by graphbolt and DGLGraph. + + the attributes created for graphbolt are as follows: + + edge attributes: + this part is implemented in _create_edge_attr_gb. + create eid, type per edge and inner edge with edgeid_offset. + create edge_type_to_id with etypes_map. + + The process to remove extra attribute is implemented in remove_attr_gb. + the unused attributes like inner_node, inner_edge, eids will be removed following the arguments in kwargs. + edge_attr, node_attr are the variable that have removed extra attributes to construct csc_graph. + edata, ndata are the variable that reserve extra attributes to be used to generate orig_nid and orig_eid. + + the src_ids and dst_ids will be transformed into indptr and indices in _coo2csc. + + all variable mentioned above will be casted to minimum data type in cast_various_to_minimum_dtype_gb. + + orig_nids and orig_eids will be generated in _graph_orig_ids with ndata and edata. + """ # create the graph here now. ndata, per_type_ids = _create_node_attr( idx, @@ -623,6 +726,17 @@ def create_graph_object( etypes, etypes_map, ) + edata, ndata = _process_partition_gb( + part_local_src_id, + part_local_dst_id, + edata, + ndata, + type_per_edge, + graph_formats, + ) + assert edata is not None + assert ndata is not None + edge_attr, node_attr = remove_attr_gb( edge_attr=edata, node_attr=ndata, **kwargs ) @@ -632,6 +746,18 @@ def create_graph_object( edge_attr = { attr: edge_attr[attr][csc_edge_ids] for attr in edge_attr.keys() } + cast_various_to_minimum_dtype_gb( + node_count, + edge_count, + num_parts, + indptr, + indices, + type_per_edge, + etypes, + ntypes, + node_attr, + edge_attr, + ) part_graph = gb.fused_csc_sampling_graph( csc_indptr=indptr, indices=indices, diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index 20d6d533fd96..c2dcd01e400e 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -285,21 +285,21 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup): local_etype_ids.append(rcvd_edge_data[:, 3]) local_eids.append(rcvd_edge_data[:, 4]) - edge_data[ - constants.GLOBAL_SRC_ID + "/" + str(local_part_id) - ] = np.concatenate(local_src_ids) - edge_data[ - constants.GLOBAL_DST_ID + "/" + str(local_part_id) - ] = np.concatenate(local_dst_ids) - edge_data[ - constants.GLOBAL_TYPE_EID + "/" + str(local_part_id) - ] = np.concatenate(local_type_eids) - edge_data[ - constants.ETYPE_ID + "/" + str(local_part_id) - ] = np.concatenate(local_etype_ids) - edge_data[ - constants.GLOBAL_EID + "/" + str(local_part_id) - ] = np.concatenate(local_eids) + edge_data[constants.GLOBAL_SRC_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_src_ids) + ) + edge_data[constants.GLOBAL_DST_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_dst_ids) + ) + edge_data[constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)] = ( + np.concatenate(local_type_eids) + ) + edge_data[constants.ETYPE_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_etype_ids) + ) + edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)] = ( + np.concatenate(local_eids) + ) # Check if the data was exchanged correctly local_edge_count = 0 @@ -1324,6 +1324,10 @@ def prepare_local_data(src_data, local_part_id): orig_nids, orig_eids, ) = create_graph_object( + node_count, + edge_count, + graph_formats, + params.num_parts, schema_map, rank + local_part_id * world_size, local_node_data, diff --git a/tools/distpartitioning/utils.py b/tools/distpartitioning/utils.py index 71b9b7712984..32292a843bc5 100644 --- a/tools/distpartitioning/utils.py +++ b/tools/distpartitioning/utils.py @@ -539,6 +539,19 @@ def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes): dgl.save_graphs(graph_file, [graph_obj], formats=formats) +def _write_graph( + part_dir, graph_obj, formats=None, sort_etypes=None, use_graphbolt=False +): + if use_graphbolt: + write_graph_graghbolt( + os.path.join(part_dir, "fused_csc_sampling_graph.pt"), graph_obj + ) + else: + write_graph_dgl( + os.path.join(part_dir, "graph.dgl"), graph_obj, formats, sort_etypes + ) + + def write_dgl_objects( graph_obj, node_features, @@ -579,15 +592,13 @@ def write_dgl_objects( """ part_dir = output_dir + "/part" + str(part_id) os.makedirs(part_dir, exist_ok=True) - if use_graphbolt: - write_graph_graghbolt( - os.path.join(part_dir, "fused_csc_sampling_graph.pt"), graph_obj - ) - else: - write_graph_dgl( - os.path.join(part_dir, "graph.dgl"), graph_obj, formats, sort_etypes - ) - + _write_graph( + part_dir, + graph_obj, + formats=formats, + sort_etypes=sort_etypes, + use_graphbolt=use_graphbolt, + ) if node_features != None: write_node_features( node_features, os.path.join(part_dir, "node_feat.dgl") From da02eb42aee5da1b1a03e9dfeed4ebf9f3d47fdb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 12 Sep 2024 04:16:43 +0000 Subject: [PATCH 30/37] fix format problem --- tools/distpartitioning/data_shuffle.py | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index c2dcd01e400e..273db1d9914d 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -285,21 +285,21 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup): local_etype_ids.append(rcvd_edge_data[:, 3]) local_eids.append(rcvd_edge_data[:, 4]) - edge_data[constants.GLOBAL_SRC_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_src_ids) - ) - edge_data[constants.GLOBAL_DST_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_dst_ids) - ) - edge_data[constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)] = ( - np.concatenate(local_type_eids) - ) - edge_data[constants.ETYPE_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_etype_ids) - ) - edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)] = ( - np.concatenate(local_eids) - ) + edge_data[ + constants.GLOBAL_SRC_ID + "/" + str(local_part_id) + ] = np.concatenate(local_src_ids) + edge_data[ + constants.GLOBAL_DST_ID + "/" + str(local_part_id) + ] = np.concatenate(local_dst_ids) + edge_data[ + constants.GLOBAL_TYPE_EID + "/" + str(local_part_id) + ] = np.concatenate(local_type_eids) + edge_data[ + constants.ETYPE_ID + "/" + str(local_part_id) + ] = np.concatenate(local_etype_ids) + edge_data[ + constants.GLOBAL_EID + "/" + str(local_part_id) + ] = np.concatenate(local_eids) # Check if the data was exchanged correctly local_edge_count = 0 From 954c4d7861d8a897073d3dd352944e37eb56ad6b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 12 Sep 2024 07:13:42 +0000 Subject: [PATCH 31/37] change partition --- python/dgl/distributed/partition.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 7fe4d61b0edb..27828cc8dcc5 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -93,17 +93,27 @@ def process_partitions(g, formats=None, sort_etypes=False): 1. format data types. 2. sort csc/csr by tag. """ + ndata = ( + g.node_attributes + if isinstance(g, gb.FusedCSCSamplingGraph) + else g.ndata + ) + edata = ( + g.edge_attributes + if isinstance(g, gb.FusedCSCSamplingGraph) + else g.edata + ) for k, dtype in RESERVED_FIELD_DTYPE.items(): if k in g.ndata: - g.ndata[k] = F.astype(g.ndata[k], dtype) + ndata[k] = F.astype(ndata[k], dtype) if k in g.edata: - g.edata[k] = F.astype(g.edata[k], dtype) + edata[k] = F.astype(edata[k], dtype) if (sort_etypes) and (formats is not None): if "csr" in formats: - g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") + g = sort_csr_by_tag(g, tag=edata[ETYPE], tag_type="edge") if "csc" in formats: - g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") + g = sort_csc_by_tag(g, tag=edata[ETYPE], tag_type="edge") return g From b3c1be51d936991e6a97ed8cfd4234c23fe5c3f9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 12 Sep 2024 07:51:23 +0000 Subject: [PATCH 32/37] change docstring in test case --- tests/tools/test_dist_partition_graphbolt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tools/test_dist_partition_graphbolt.py b/tests/tools/test_dist_partition_graphbolt.py index 5ebf37718c8b..81c16f8809c3 100644 --- a/tests/tools/test_dist_partition_graphbolt.py +++ b/tests/tools/test_dist_partition_graphbolt.py @@ -30,7 +30,7 @@ def _verify_metadata_gb(gpb, g, num_parts, part_id, part_sizes): check list: make sure the number of nodes and edges is correct. make sure the number of parts is correct. - make sure the number of nodes and edges in each parts os corrcet. + make sure the number of nodes and edges in each part is corrcet. """ assert gpb._num_nodes() == g.num_nodes() assert gpb._num_edges() == g.num_edges() From 3834358678147b6d9f469b7048019a922b245b36 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 12 Sep 2024 08:13:12 +0000 Subject: [PATCH 33/37] change partition --- python/dgl/distributed/partition.py | 55 +++++++++++++---------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 27828cc8dcc5..079ed8806a96 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -93,27 +93,17 @@ def process_partitions(g, formats=None, sort_etypes=False): 1. format data types. 2. sort csc/csr by tag. """ - ndata = ( - g.node_attributes - if isinstance(g, gb.FusedCSCSamplingGraph) - else g.ndata - ) - edata = ( - g.edge_attributes - if isinstance(g, gb.FusedCSCSamplingGraph) - else g.edata - ) for k, dtype in RESERVED_FIELD_DTYPE.items(): if k in g.ndata: - ndata[k] = F.astype(ndata[k], dtype) + g.ndata[k] = F.astype(g.ndata[k], dtype) if k in g.edata: - edata[k] = F.astype(edata[k], dtype) + g.edata[k] = F.astype(g.edata[k], dtype) if (sort_etypes) and (formats is not None): if "csr" in formats: - g = sort_csr_by_tag(g, tag=edata[ETYPE], tag_type="edge") + g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") if "csc" in formats: - g = sort_csc_by_tag(g, tag=edata[ETYPE], tag_type="edge") + g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge") return g @@ -1506,6 +1496,8 @@ def get_homogeneous(g, balance_ntypes): **kwargs, ) ) + part_metadata["node_map_dtype"] = "int64" + part_metadata["edge_map_dtype"] = "int64" else: for part_id, part in parts.items(): part_dir = os.path.join(out_path, "part" + str(part_id)) @@ -1698,12 +1690,12 @@ def gb_convert_single_dgl_partition( ntypes, etypes, gpb, - graph_formats, - store_eids, - store_inner_node, - store_inner_edge, part_meta, graph, + graph_formats=None, + store_eids=False, + store_inner_node=False, + store_inner_edge=False, ): """Converts a single DGL partition to GraphBolt. @@ -1715,6 +1707,10 @@ def gb_convert_single_dgl_partition( The edge types gpb : GraphPartitionBook The global partition information. + part_meta : dict + Contain the meta data of the partition. + graph : DGLGraph + The graph to be converted to graphbolt graph. graph_formats : str or list[str], optional Save partitions in specified formats. It could be any combination of `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`, @@ -1728,10 +1724,6 @@ def gb_convert_single_dgl_partition( Whether to store inner node mask in the new graph. Default: False. store_inner_edge : bool, optional Whether to store inner edge mask in the new graph. Default: False. - part_meta : dict - Contain the meta data of the partition. - graph : DGLGraph - The graph to be converted to graphbolt graph. """ debug_mode = "DGL_DIST_DEBUG" in os.environ if debug_mode: @@ -1812,16 +1804,17 @@ def gb_convert_single_dgl_partition( return csc_graph -def convert_partition_to_graphbolt_multi_process( +def _convert_partition_to_graphbolt( part_config, part_id, - graph_formats, - store_eids, - store_inner_node, - store_inner_edge, + graph_formats=None, + store_eids=False, + store_inner_node=False, + store_inner_edge=False, ): """ - Convert signle partition to graphbolt, which is used for multiple process. + The pipeline converting signle partition to graphbolt. + Parameters ---------- part_config : str @@ -1867,7 +1860,7 @@ def convert_partition_to_graphbolt_multi_process( return rel_path -def _convert_partition_to_graphbolt( +def _convert_partition_to_graphbolt_wrapper( graph_formats, part_config, store_eids, @@ -1887,7 +1880,7 @@ def _convert_partition_to_graphbolt( # Iterate over partitions. convert_with_format = partial( - convert_partition_to_graphbolt_multi_process, + _convert_partition_to_graphbolt, part_config=part_config, graph_formats=graph_formats, store_eids=store_eids, @@ -1979,7 +1972,7 @@ def dgl_partition_to_graphbolt( ) part_meta = _load_part_config(part_config) num_parts = part_meta["num_parts"] - part_meta = _convert_partition_to_graphbolt( + part_meta = _convert_partition_to_graphbolt_wrapper( graph_formats=graph_formats, part_config=part_config, store_eids=store_eids, From b29e5a21808761d96e319a30dc3b37eafe69a4be Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Sep 2024 13:45:29 +0000 Subject: [PATCH 34/37] change convert_partition.py --- tools/distpartitioning/convert_partition.py | 60 +++++++++++---------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 3d1f7a41a3aa..e4f2368fb653 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -9,6 +9,7 @@ import numpy as np import torch as th from dgl import EID, ETYPE, NID, NTYPE +import dgl.backend as F from dgl.distributed.constants import DGL2GB_EID, GB_DST_ID from dgl.distributed.partition import ( @@ -261,7 +262,7 @@ def _create_edge_attr_gb( is_homo = _is_homogeneous(ntypes, etypes) edge_type_to_id = ( - None + {gb.etype_tuple_to_str(('_N','_E','_N')) : 0} if is_homo else { gb.etype_tuple_to_str(etype): etid @@ -361,29 +362,40 @@ def cast_various_to_minimum_dtype_gb( def _process_partition_gb( - part_local_src_id, - part_local_dst_id, - edge_attr, node_attr, + edge_attr, type_per_edge, - formats, + src_ids, + dst_ids, + sort_etypes, ): - temp_g = dgl.DGLGraph((part_local_src_id, part_local_dst_id)) - for edge_type, data in edge_attr.items(): - temp_g.edata[edge_type] = data - for node_type, data in node_attr.items(): - temp_g.ndata[node_type] = data - sort_etypes = max(type_per_edge) > 1 - temp_g = dgl.distributed.partition.process_partitions( - temp_g, formats=formats, sort_etypes=sort_etypes - ) - return temp_g.edata, temp_g.ndata + """Preprocess partitions before saving: + 1. format data types. + 2. sort csc/csr by tag. + """ + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in node_attr: + node_attr[k] = F.astype(node_attr[k], dtype) + if k in edge_attr: + edge_attr[k] = F.astype(edge_attr[k], dtype) + + indptr,indices,edge_ids=_coo2csc(src_ids,dst_ids) + if sort_etypes: + split_size = th.diff(indptr) + split_indices = th.split(type_per_edge, tuple(split_size), dim=0) + sorted_idxs=[] + for split_indice in split_indices: + sorted_idxs.append(split_indice.sort()[1]) + + sorted_idx = th.cat(sorted_idxs, dim=0) + sorted_idx=th.repeat_interleave(indptr[:-1], split_size, dim=0)+sorted_idx + + return indptr, indices, edge_ids def create_graph_object( node_count, edge_count, - graph_formats, num_parts, schema, part_id, @@ -726,23 +738,17 @@ def create_graph_object( etypes, etypes_map, ) - edata, ndata = _process_partition_gb( - part_local_src_id, - part_local_dst_id, - edata, - ndata, - type_per_edge, - graph_formats, - ) + assert edata is not None assert ndata is not None + sort_etypes = len(etypes_map) > 1 + indptr, indices, csc_edge_ids = _process_partition_gb( + ndata, edata, type_per_edge, part_local_src_id, part_local_dst_id,sort_etypes + ) edge_attr, node_attr = remove_attr_gb( edge_attr=edata, node_attr=ndata, **kwargs ) - indptr, indices, csc_edge_ids = _coo2csc( - part_local_src_id, part_local_dst_id - ) edge_attr = { attr: edge_attr[attr][csc_edge_ids] for attr in edge_attr.keys() } From 33c6ea8050474fe5036d6aceb75f25d78605a281 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 18 Sep 2024 02:15:26 +0000 Subject: [PATCH 35/37] change cast_various_to_minimum_dtype_gb --- python/dgl/distributed/partition.py | 70 +++++++++------ tools/distpartitioning/convert_partition.py | 97 ++++++++------------- tools/distpartitioning/data_shuffle.py | 36 ++++---- 3 files changed, 98 insertions(+), 105 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 079ed8806a96..7c6fc7138199 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1417,9 +1417,9 @@ def get_homogeneous(g, balance_ntypes): for name in g.edges[etype].data: if name in [EID, "inner_edge"]: continue - edge_feats[ - _etype_tuple_to_str(etype) + "/" + name - ] = F.gather_row(g.edges[etype].data[name], local_edges) + edge_feats[_etype_tuple_to_str(etype) + "/" + name] = ( + F.gather_row(g.edges[etype].data[name], local_edges) + ) else: for ntype in g.ntypes: if len(g.ntypes) > 1: @@ -1454,9 +1454,9 @@ def get_homogeneous(g, balance_ntypes): for name in g.edges[etype].data: if name in [EID, "inner_edge"]: continue - edge_feats[ - _etype_tuple_to_str(etype) + "/" + name - ] = F.gather_row(g.edges[etype].data[name], local_edges) + edge_feats[_etype_tuple_to_str(etype) + "/" + name] = ( + F.gather_row(g.edges[etype].data[name], local_edges) + ) # delete `orig_id` from ndata/edata del part.ndata["orig_id"] del part.edata["orig_id"] @@ -1502,9 +1502,9 @@ def get_homogeneous(g, balance_ntypes): for part_id, part in parts.items(): part_dir = os.path.join(out_path, "part" + str(part_id)) part_graph_file = os.path.join(part_dir, "graph.dgl") - part_metadata["part-{}".format(part_id)][ - "part_graph" - ] = os.path.relpath(part_graph_file, out_path) + part_metadata["part-{}".format(part_id)]["part_graph"] = ( + os.path.relpath(part_graph_file, out_path) + ) # save DGLGraph _save_dgl_graphs( part_graph_file, @@ -1600,8 +1600,6 @@ def _save_graph_gb(part_config, part_id, csc_graph): def cast_various_to_minimum_dtype_gb( - graph, - part_meta, num_parts, indptr, indices, @@ -1610,25 +1608,43 @@ def cast_various_to_minimum_dtype_gb( ntypes, node_attributes, edge_attributes, + part_meta=None, + graph=None, + edge_count=None, + node_count=None, + tot_edge_count=None, + tot_node_count=None, ): """Cast various data to minimum dtype.""" + if graph is not None: + assert part_meta is not None + tot_edge_count = graph.num_edges() + tot_node_count = graph.num_nodes() + node_count = part_meta["num_nodes"] + edge_count = part_meta["num_edges"] + else: + assert tot_edge_count is not None + assert tot_node_count is not None + assert edge_count is not None + assert node_count is not None + # Cast 1: indptr. - indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr) + indptr = _cast_to_minimum_dtype(tot_edge_count, indptr) # Cast 2: indices. - indices = _cast_to_minimum_dtype(graph.num_nodes(), indices) + indices = _cast_to_minimum_dtype(tot_node_count, indices) # Cast 3: type_per_edge. type_per_edge = _cast_to_minimum_dtype( len(etypes), type_per_edge, field=ETYPE ) # Cast 4: node/edge_attributes. predicates = { - NID: part_meta["num_nodes"], + NID: node_count, "part_id": num_parts, NTYPE: len(ntypes), - EID: part_meta["num_edges"], + EID: edge_count, ETYPE: len(etypes), - DGL2GB_EID: part_meta["num_edges"], - GB_DST_ID: part_meta["num_nodes"], + DGL2GB_EID: edge_count, + GB_DST_ID: node_count, } for attributes in [node_attributes, edge_attributes]: for key in attributes: @@ -1779,16 +1795,16 @@ def gb_convert_single_dgl_partition( ) indptr, indices, type_per_edge = cast_various_to_minimum_dtype_gb( - graph, - part_meta, - num_parts, - indptr, - indices, - type_per_edge, - etypes, - ntypes, - node_attributes, - edge_attributes, + graph=graph, + part_meta=part_meta, + num_parts=num_parts, + indptr=indptr, + indices=indices, + type_per_edge=type_per_edge, + etypes=etypes, + ntypes=ntypes, + node_attributes=node_attributes, + edge_attributes=edge_attributes, ) csc_graph = gb.fused_csc_sampling_graph( diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index e4f2368fb653..d351efcfa11a 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -5,17 +5,18 @@ import constants import dgl +import dgl.backend as F import dgl.graphbolt as gb import numpy as np import torch as th from dgl import EID, ETYPE, NID, NTYPE -import dgl.backend as F from dgl.distributed.constants import DGL2GB_EID, GB_DST_ID from dgl.distributed.partition import ( _cast_to_minimum_dtype, _etype_str_to_tuple, _etype_tuple_to_str, + cast_various_to_minimum_dtype_gb, RESERVED_FIELD_DTYPE, ) from utils import get_idranges, memory_snapshot @@ -262,7 +263,7 @@ def _create_edge_attr_gb( is_homo = _is_homogeneous(ntypes, etypes) edge_type_to_id = ( - {gb.etype_tuple_to_str(('_N','_E','_N')) : 0} + {gb.etype_tuple_to_str(("_N", "_E", "_N")): 0} if is_homo else { gb.etype_tuple_to_str(etype): etid @@ -320,47 +321,6 @@ def remove_attr_gb( return edata, ndata -def cast_various_to_minimum_dtype_gb( - node_count, - edge_count, - num_parts, - indptr, - indices, - type_per_edge, - etypes, - ntypes, - node_attributes, - edge_attributes, -): - """Cast various data to minimum dtype.""" - # Cast 1: indptr. - indptr = _cast_to_minimum_dtype(edge_count, indptr) - # Cast 2: indices. - indices = _cast_to_minimum_dtype(node_count, indices) - # Cast 3: type_per_edge. - type_per_edge = _cast_to_minimum_dtype( - len(etypes), type_per_edge, field=ETYPE - ) - # Cast 4: node/edge_attributes. - predicates = { - NID: node_count, - "part_id": num_parts, - NTYPE: len(ntypes), - EID: edge_count, - ETYPE: len(etypes), - DGL2GB_EID: edge_count, - GB_DST_ID: node_count, - } - for attributes in [node_attributes, edge_attributes]: - for key in attributes: - if key not in predicates: - continue - attributes[key] = _cast_to_minimum_dtype( - predicates[key], attributes[key], field=key - ) - return indptr, indices, type_per_edge - - def _process_partition_gb( node_attr, edge_attr, @@ -378,22 +338,26 @@ def _process_partition_gb( node_attr[k] = F.astype(node_attr[k], dtype) if k in edge_attr: edge_attr[k] = F.astype(edge_attr[k], dtype) - - indptr,indices,edge_ids=_coo2csc(src_ids,dst_ids) + + indptr, indices, edge_ids = _coo2csc(src_ids, dst_ids) if sort_etypes: split_size = th.diff(indptr) split_indices = th.split(type_per_edge, tuple(split_size), dim=0) - sorted_idxs=[] + sorted_idxs = [] for split_indice in split_indices: sorted_idxs.append(split_indice.sort()[1]) sorted_idx = th.cat(sorted_idxs, dim=0) - sorted_idx=th.repeat_interleave(indptr[:-1], split_size, dim=0)+sorted_idx - + sorted_idx = ( + th.repeat_interleave(indptr[:-1], split_size, dim=0) + sorted_idx + ) + return indptr, indices, edge_ids def create_graph_object( + tot_node_count, + tot_edge_count, node_count, edge_count, num_parts, @@ -457,10 +421,14 @@ def create_graph_object( Parameters: ----------- - node_count : int + tot_node_count : int the number of all nodes - edge_count : int + tot_edge_count : int the number of all edges + node_count : int + the number of nodes in partition + edge_count : int + the number of edges in partition graph_formats : str the format of graph num_parts : int @@ -744,7 +712,12 @@ def create_graph_object( sort_etypes = len(etypes_map) > 1 indptr, indices, csc_edge_ids = _process_partition_gb( - ndata, edata, type_per_edge, part_local_src_id, part_local_dst_id,sort_etypes + ndata, + edata, + type_per_edge, + part_local_src_id, + part_local_dst_id, + sort_etypes, ) edge_attr, node_attr = remove_attr_gb( edge_attr=edata, node_attr=ndata, **kwargs @@ -753,16 +726,18 @@ def create_graph_object( attr: edge_attr[attr][csc_edge_ids] for attr in edge_attr.keys() } cast_various_to_minimum_dtype_gb( - node_count, - edge_count, - num_parts, - indptr, - indices, - type_per_edge, - etypes, - ntypes, - node_attr, - edge_attr, + node_count=node_count, + edge_count=edge_count, + tot_node_count=tot_node_count, + tot_edge_count=tot_edge_count, + num_parts=num_parts, + indptr=indptr, + indices=indices, + type_per_edge=type_per_edge, + etypes=etypes, + ntypes=ntypes, + node_attributes=node_attr, + edge_attributes=edge_attr, ) part_graph = gb.fused_csc_sampling_graph( csc_indptr=indptr, diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index 273db1d9914d..587e9ab718d2 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -285,21 +285,21 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup): local_etype_ids.append(rcvd_edge_data[:, 3]) local_eids.append(rcvd_edge_data[:, 4]) - edge_data[ - constants.GLOBAL_SRC_ID + "/" + str(local_part_id) - ] = np.concatenate(local_src_ids) - edge_data[ - constants.GLOBAL_DST_ID + "/" + str(local_part_id) - ] = np.concatenate(local_dst_ids) - edge_data[ - constants.GLOBAL_TYPE_EID + "/" + str(local_part_id) - ] = np.concatenate(local_type_eids) - edge_data[ - constants.ETYPE_ID + "/" + str(local_part_id) - ] = np.concatenate(local_etype_ids) - edge_data[ - constants.GLOBAL_EID + "/" + str(local_part_id) - ] = np.concatenate(local_eids) + edge_data[constants.GLOBAL_SRC_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_src_ids) + ) + edge_data[constants.GLOBAL_DST_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_dst_ids) + ) + edge_data[constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)] = ( + np.concatenate(local_type_eids) + ) + edge_data[constants.ETYPE_ID + "/" + str(local_part_id)] = ( + np.concatenate(local_etype_ids) + ) + edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)] = ( + np.concatenate(local_eids) + ) # Check if the data was exchanged correctly local_edge_count = 0 @@ -1121,7 +1121,6 @@ def gen_dist_partitions(rank, world_size, params): ) id_map = dgl.distributed.id_map.IdMap(global_nid_ranges) id_lookup.set_idMap(id_map) - # read input graph files and augment these datastructures with # appropriate information (global_nid and owner process) for node and edge data ( @@ -1315,6 +1314,8 @@ def prepare_local_data(src_data, local_part_id): ) local_node_data = prepare_local_data(node_data, local_part_id) local_edge_data = prepare_local_data(edge_data, local_part_id) + tot_node_count = sum(schema_map["num_nodes_per_type"]) + tot_edge_count = sum(schema_map["num_edges_per_type"]) ( graph_obj, ntypes_map_val, @@ -1324,9 +1325,10 @@ def prepare_local_data(src_data, local_part_id): orig_nids, orig_eids, ) = create_graph_object( + tot_node_count, + tot_edge_count, node_count, edge_count, - graph_formats, params.num_parts, schema_map, rank + local_part_id * world_size, From 3d44eb951f32a275adc45272076198dd76641bde Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 18 Sep 2024 02:22:29 +0000 Subject: [PATCH 36/37] change format --- python/dgl/distributed/partition.py | 18 ++++++++-------- tools/distpartitioning/data_shuffle.py | 30 +++++++++++++------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 7c6fc7138199..48005ffb4d27 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1417,9 +1417,9 @@ def get_homogeneous(g, balance_ntypes): for name in g.edges[etype].data: if name in [EID, "inner_edge"]: continue - edge_feats[_etype_tuple_to_str(etype) + "/" + name] = ( - F.gather_row(g.edges[etype].data[name], local_edges) - ) + edge_feats[ + _etype_tuple_to_str(etype) + "/" + name + ] = F.gather_row(g.edges[etype].data[name], local_edges) else: for ntype in g.ntypes: if len(g.ntypes) > 1: @@ -1454,9 +1454,9 @@ def get_homogeneous(g, balance_ntypes): for name in g.edges[etype].data: if name in [EID, "inner_edge"]: continue - edge_feats[_etype_tuple_to_str(etype) + "/" + name] = ( - F.gather_row(g.edges[etype].data[name], local_edges) - ) + edge_feats[ + _etype_tuple_to_str(etype) + "/" + name + ] = F.gather_row(g.edges[etype].data[name], local_edges) # delete `orig_id` from ndata/edata del part.ndata["orig_id"] del part.edata["orig_id"] @@ -1502,9 +1502,9 @@ def get_homogeneous(g, balance_ntypes): for part_id, part in parts.items(): part_dir = os.path.join(out_path, "part" + str(part_id)) part_graph_file = os.path.join(part_dir, "graph.dgl") - part_metadata["part-{}".format(part_id)]["part_graph"] = ( - os.path.relpath(part_graph_file, out_path) - ) + part_metadata["part-{}".format(part_id)][ + "part_graph" + ] = os.path.relpath(part_graph_file, out_path) # save DGLGraph _save_dgl_graphs( part_graph_file, diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index 587e9ab718d2..6800064a2b0b 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -285,21 +285,21 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup): local_etype_ids.append(rcvd_edge_data[:, 3]) local_eids.append(rcvd_edge_data[:, 4]) - edge_data[constants.GLOBAL_SRC_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_src_ids) - ) - edge_data[constants.GLOBAL_DST_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_dst_ids) - ) - edge_data[constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)] = ( - np.concatenate(local_type_eids) - ) - edge_data[constants.ETYPE_ID + "/" + str(local_part_id)] = ( - np.concatenate(local_etype_ids) - ) - edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)] = ( - np.concatenate(local_eids) - ) + edge_data[ + constants.GLOBAL_SRC_ID + "/" + str(local_part_id) + ] = np.concatenate(local_src_ids) + edge_data[ + constants.GLOBAL_DST_ID + "/" + str(local_part_id) + ] = np.concatenate(local_dst_ids) + edge_data[ + constants.GLOBAL_TYPE_EID + "/" + str(local_part_id) + ] = np.concatenate(local_type_eids) + edge_data[ + constants.ETYPE_ID + "/" + str(local_part_id) + ] = np.concatenate(local_etype_ids) + edge_data[ + constants.GLOBAL_EID + "/" + str(local_part_id) + ] = np.concatenate(local_eids) # Check if the data was exchanged correctly local_edge_count = 0 From 305d7f67ba667ee1636f5a8f5d6b7719ed108b25 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 18 Sep 2024 03:23:41 +0000 Subject: [PATCH 37/37] change convert_partition.py --- tools/distpartitioning/convert_partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index d351efcfa11a..5013b6d40f20 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -352,7 +352,7 @@ def _process_partition_gb( th.repeat_interleave(indptr[:-1], split_size, dim=0) + sorted_idx ) - return indptr, indices, edge_ids + return indptr, indices[sorted_idx], edge_ids[sorted_idx] def create_graph_object(