forked from working-yuhao/DEAL
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
159 lines (120 loc) · 4.86 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import random
import argparse
from typing import Union, Optional
import multiprocessing as mp
import networkx as nx
from torch_geometric.data import Data
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
import torch
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def to_numpy_cpu(*x: torch.Tensor) -> Union[list, torch.Tensor]:
outputs = []
for x_ in x:
if isinstance(x_, torch.Tensor):
x_ = x_.detach().cpu().numpy()
outputs.append(x_)
if len(outputs) == 1:
return outputs[0]
else:
return outputs
def single_source_shortest_path_length_range(graph, node_range, cutoff):
dists_dict = {}
for node in node_range:
dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff)
return dists_dict
def merge_dicts(dicts):
result = {}
for dictionary in dicts:
result.update(dictionary)
return result
def all_pairs_shortest_path_length_parallel(graph, cutoff=None):
nodes = list(graph.nodes)
random.shuffle(nodes)
num_workers = int(mp.cpu_count() * 0.8)
pool = mp.Pool(processes=num_workers)
results = [pool.apply_async(single_source_shortest_path_length_range,
args=(graph, nodes[int(len(nodes)/num_workers*i):int(len(nodes)/num_workers*(i+1))], cutoff)) for i in range(num_workers)]
output = [p.get() for p in results]
dists_dict = merge_dicts(output)
pool.close()
pool.join()
return dists_dict
def precompute_dist_data(edge_index, num_nodes, approximate=0) -> np.ndarray:
'''
Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
:return:
'''
graph = nx.Graph()
edge_list = edge_index.transpose(1,0).tolist()
graph.add_edges_from(edge_list)
n = num_nodes
dists_array = np.zeros((n, n))
np.fill_diagonal(dists_array, 1)
# dists_dict = nx.all_pairs_shortest_path_length(graph,cutoff=approximate if approximate>0 else None)
# dists_dict = {c[0]: c[1] for c in dists_dict}
dists_dict = all_pairs_shortest_path_length_parallel(graph,cutoff=approximate if approximate>0 else None)
for i, node_i in enumerate(graph.nodes()):
shortest_dist = dists_dict[node_i]
for j, node_j in enumerate(graph.nodes()):
dist = shortest_dist.get(node_j, -1)
if dist!=-1:
# dists_array[i, j] = 1 / (dist + 1)
dists_array[node_i, node_j] = 1 / (dist + 1)
return dists_array
def score_link_prediction(labels, scores):
labels, scores = to_numpy_cpu(labels, scores)
return roc_auc_score(labels, scores), average_precision_score(labels, scores)
def inductive_eval(cmodel, nodes, gt_labels, X, lambdas = (0, 1, 1)):
# anode_emb = torch.sparse.mm(data.x, cmodel.attr_emb(torch.arange(data.x.shape[1]).to(cmodel.device)))
test_data = Data(X, None)
anode_emb = cmodel.attr_emb(test_data)
first_embs = anode_emb[nodes[:, 0]]
sec_embs = anode_emb[nodes[:, 1]]
res = cmodel.attr_layer(first_embs, sec_embs) * lambdas[1]
node_emb = anode_emb.clone()
res = res + cmodel.inter_layer(first_embs, node_emb[nodes[:, 1]]) * lambdas[2]
if len(res.shape) > 1:
res = res.softmax(dim=1)[:, 1]
res = res.detach().cpu().numpy()
return score_link_prediction(gt_labels, res)
def transductive_eval(cmodel, edge_index, gt_labels, data, lambdas=(1, 1, 1)):
res = cmodel.evaluate(edge_index, data, lambdas)
if len(res.shape) > 1:
res = res.softmax(dim=1)[:, 1]
return score_link_prediction(gt_labels, res)
def detailed_eval(model,test_data,gt_labels,sp_M, evaluate,nodes_keep=None, verbose=False, lambdas=(1,1,1)):
setting = {}
setting['Full '] = lambdas
setting['Inter'] = (0,0,1)
if lambdas[1]:
setting['Attr '] = (0,1,0)
if lambdas[0]:
setting['Node '] = (1,0,0)
res = {}
for s in setting:
if not nodes_keep is None:
if s != 'Node ':
res[s] = evaluate(model, test_data, gt_labels,sp_M,nodes_keep,setting[s])
if verbose:
print(s+' ROC-AUC:%.4f AP:%.4f'%res[s])
else:
res[s] = evaluate(model, test_data, gt_labels,sp_M,setting[s])
if verbose:
print(s+' ROC-AUC:%.4f AP:%.4f'%res[s])
return res
def seed_everything(seed: Optional[int] = None):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')