-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform_dataset.py
68 lines (50 loc) · 2.5 KB
/
transform_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pickle
import argparse
import numpy as np
from util.dataset_io import save_dataset
from util.embeddings_io import load_embeddings
parser = argparse.ArgumentParser(description='Transform the original dataset to use the precalculated embeddings.')
parser.add_argument('--input_file', required=True,
help='file containing the dataset to be transformed')
parser.add_argument('--output_name', required=True,
help='name of the transformed dataset')
parser.add_argument('--graph_embeddings_file', default=True,
help='name of the file containing the graph embeddings')
parser.add_argument('--text_embeddings_file',
help='name of the file containing the text embeddings')
parser.add_argument('--concatenate', action='store_true',
help='if set, will concatenate the two embeddings instead of combining them')
def _get_vector(key, embeddings, vec_shape):
if key in embeddings:
return embeddings[key]
return np.random.normal(size=vec_shape)
if __name__ == "__main__":
args = parser.parse_args()
# Load graph embeddings
graph_emb = load_embeddings(args.graph_embeddings_file, key_transform=int)
# Load text embeddigns
if args.text_embeddings_file is not None:
text_emb = load_embeddings(args.text_embeddings_file, key_transform=int)
# Transform dataset
shape = next(iter(graph_emb.values())).shape
X, y = [], []
with open(args.input_file, "r") as f:
for line in f:
line = line.split()
src, tgt = int(line[0]), int(line[1])
src_embedding = _get_vector(src, graph_emb, shape)
tgt_embedding = _get_vector(tgt, graph_emb, shape)
if args.text_embeddings_file is not None:
src_embedding = np.concatenate([src_embedding, _get_vector(src, text_emb, shape)], axis=None)
tgt_embedding = np.concatenate([tgt_embedding, _get_vector(tgt, text_emb, shape)], axis=None)
if args.concatenate is None:
# Use the Hadamard operator to combine the two embeddings
X.append(np.multiply(src_embedding, tgt_embedding))
else:
# Concatenate the two embeddings
X.append(np.concatenate([src_embedding, tgt_embedding], axis=None))
if len(line) >= 3:
y.append(int(line[2]))
X, y = np.array(X), np.array(y).ravel() if len(y) > 0 else None
# Save datasets
save_dataset(args.output_name, X, y)