@@ -109,28 +109,45 @@ def _save_graphs(filename, g_list, formats=None, sort_etypes=False):
109
109
save_graphs (filename , g_list , formats = formats )
110
110
111
111
112
- def _get_inner_node_mask (graph , ntype_id ):
113
- if NTYPE in graph .ndata :
114
- dtype = F .dtype (graph .ndata ["inner_node" ])
115
- return (
116
- graph .ndata ["inner_node" ]
117
- * F .astype (graph .ndata [NTYPE ] == ntype_id , dtype )
118
- == 1
112
+ def _get_inner_node_mask (graph , ntype_id , gpb = None ):
113
+ ndata = (
114
+ graph .node_attributes
115
+ if isinstance (graph , gb .FusedCSCSamplingGraph )
116
+ else graph .ndata
117
+ )
118
+ assert "inner_node" in ndata , "'inner_node' is not in nodes' data"
119
+ if NTYPE in ndata or gpb is not None :
120
+ ntype = (
121
+ gpb .map_to_per_ntype (ndata [NID ])[0 ]
122
+ if gpb is not None
123
+ else ndata [NTYPE ]
119
124
)
125
+ dtype = F .dtype (ndata ["inner_node" ])
126
+ return ndata ["inner_node" ] * F .astype (ntype == ntype_id , dtype ) == 1
120
127
else :
121
- return graph . ndata ["inner_node" ] == 1
128
+ return ndata ["inner_node" ] == 1
122
129
123
130
124
- def _get_inner_edge_mask (graph , etype_id ):
125
- if ETYPE in graph .edata :
126
- dtype = F .dtype (graph .edata ["inner_edge" ])
127
- return (
128
- graph .edata ["inner_edge" ]
129
- * F .astype (graph .edata [ETYPE ] == etype_id , dtype )
130
- == 1
131
- )
131
+ def _get_inner_edge_mask (
132
+ graph ,
133
+ etype_id ,
134
+ ):
135
+ edata = (
136
+ graph .edge_attributes
137
+ if isinstance (graph , gb .FusedCSCSamplingGraph )
138
+ else graph .edata
139
+ )
140
+ assert "inner_edge" in edata , "'inner_edge' is not in edges' data"
141
+ etype = (
142
+ graph .type_per_edge
143
+ if isinstance (graph , gb .FusedCSCSamplingGraph )
144
+ else (graph .edata [ETYPE ] if ETYPE in graph .edata else None )
145
+ )
146
+ if etype is not None :
147
+ dtype = F .dtype (edata ["inner_edge" ])
148
+ return edata ["inner_edge" ] * F .astype (etype == etype_id , dtype ) == 1
132
149
else :
133
- return graph . edata ["inner_edge" ] == 1
150
+ return edata ["inner_edge" ] == 1
134
151
135
152
136
153
def _get_part_ranges (id_ranges ):
0 commit comments