Skip to content

Commit 0734e33

Browse files
CfromBUUbuntuUbuntu
authored
[DistGB] save as graphbolt graph directly after partition test case (#7724)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-8-126.us-west-2.compute.internal> Co-authored-by: Ubuntu <ubuntu@ip-172-31-52-174.us-west-2.compute.internal>
1 parent 4c86533 commit 0734e33

File tree

2 files changed

+905
-218
lines changed

2 files changed

+905
-218
lines changed

python/dgl/distributed/partition.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -109,28 +109,45 @@ def _save_graphs(filename, g_list, formats=None, sort_etypes=False):
109109
save_graphs(filename, g_list, formats=formats)
110110

111111

112-
def _get_inner_node_mask(graph, ntype_id):
113-
if NTYPE in graph.ndata:
114-
dtype = F.dtype(graph.ndata["inner_node"])
115-
return (
116-
graph.ndata["inner_node"]
117-
* F.astype(graph.ndata[NTYPE] == ntype_id, dtype)
118-
== 1
112+
def _get_inner_node_mask(graph, ntype_id, gpb=None):
113+
ndata = (
114+
graph.node_attributes
115+
if isinstance(graph, gb.FusedCSCSamplingGraph)
116+
else graph.ndata
117+
)
118+
assert "inner_node" in ndata, "'inner_node' is not in nodes' data"
119+
if NTYPE in ndata or gpb is not None:
120+
ntype = (
121+
gpb.map_to_per_ntype(ndata[NID])[0]
122+
if gpb is not None
123+
else ndata[NTYPE]
119124
)
125+
dtype = F.dtype(ndata["inner_node"])
126+
return ndata["inner_node"] * F.astype(ntype == ntype_id, dtype) == 1
120127
else:
121-
return graph.ndata["inner_node"] == 1
128+
return ndata["inner_node"] == 1
122129

123130

124-
def _get_inner_edge_mask(graph, etype_id):
125-
if ETYPE in graph.edata:
126-
dtype = F.dtype(graph.edata["inner_edge"])
127-
return (
128-
graph.edata["inner_edge"]
129-
* F.astype(graph.edata[ETYPE] == etype_id, dtype)
130-
== 1
131-
)
131+
def _get_inner_edge_mask(
132+
graph,
133+
etype_id,
134+
):
135+
edata = (
136+
graph.edge_attributes
137+
if isinstance(graph, gb.FusedCSCSamplingGraph)
138+
else graph.edata
139+
)
140+
assert "inner_edge" in edata, "'inner_edge' is not in edges' data"
141+
etype = (
142+
graph.type_per_edge
143+
if isinstance(graph, gb.FusedCSCSamplingGraph)
144+
else (graph.edata[ETYPE] if ETYPE in graph.edata else None)
145+
)
146+
if etype is not None:
147+
dtype = F.dtype(edata["inner_edge"])
148+
return edata["inner_edge"] * F.astype(etype == etype_id, dtype) == 1
132149
else:
133-
return graph.edata["inner_edge"] == 1
150+
return edata["inner_edge"] == 1
134151

135152

136153
def _get_part_ranges(id_ranges):

0 commit comments

Comments
 (0)