-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathprepare_dataset.py
74 lines (53 loc) · 1.92 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Code referred from https://github.com/tkipf/relational-gcn/blob/master/rgcn/prepare_dataset.py
from data_utils import load_data
from utils import *
import pickle as pkl
import os
import sys
import time
import argparse
from sklearn.preprocessing import normalize
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", type=str, default="aifb",
help="Dataset string ('aifb', 'mutag', 'bgs', 'am')")
args = vars(ap.parse_args())
print(args)
# Define parameters
DATASET = args['dataset']
NUM_GC_LAYERS = 2 # Number of graph convolutional layers
# Get data
A, X, y, labeled_nodes_idx, train_idx, test_idx, rel_dict, train_names, test_names = load_data(
DATASET)
rel_list = list(range(len(A)))
for key, value in rel_dict.items():
if value * 2 >= len(A):
continue
rel_list[value * 2] = key
rel_list[value * 2 + 1] = key + '_INV'
num_nodes = A[0].shape[0]
A.append(sp.identity(A[0].shape[0]).tocsr()) # add identity matrix
support = len(A)
print("Relations used and their frequencies" + str([a.sum() for a in A]))
print("Calculating level sets...")
t = time.time()
# Get level sets (used for memory optimization)
bfs_generator = bfs_relational(A, labeled_nodes_idx)
lvls = list()
lvls.append(set(labeled_nodes_idx))
lvls.append(set.union(*bfs_generator.__next__()))
print("Done! Elapsed time " + str(time.time() - t))
# Delete unnecessary rows in adjacencies for memory efficiency
todel = list(set(range(num_nodes)) - set.union(set(lvls[0]), set(lvls[1])))
for i in range(len(A)):
csr_zero_rows(A[i], todel)
normalize(A[i], norm='l1', axis=1, copy=False)
#features = np.eye(A[0].shape[0])
data = {'A': A,
# 'feat': features,
'y': y,
'train_idx': train_idx,
'test_idx': test_idx
}
dirname = os.path.dirname(os.path.realpath(sys.argv[0]))
with open(dirname + '/' + DATASET + '.pickle', 'wb') as f:
pkl.dump(data, f, pkl.HIGHEST_PROTOCOL)