-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathdatasets.py
59 lines (43 loc) · 1.88 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from config import ROOT, DATASETS_HOME
FB15K = "FB15k"
FB15K_237 = "FB15k-237"
WN18 = "WN18"
WN18RR = "WN18RR"
YAGO3_10 = "YAGO3-10"
ALL_DATASET_NAMES = [FB15K, FB15K_237, WN18, WN18RR, YAGO3_10]
class Dataset:
def __init__(self, name, separator="\t"):
self.name = name
self.home = os.path.join(ROOT, "datasets", self.name)
if not os.path.isdir(self.home):
raise Exception("Folder %s does not exist" % self.home)
self.train_path = os.path.join(self.home, "train.txt")
self.valid_path = os.path.join(self.home, "valid.txt")
self.test_path = os.path.join(self.home, "test.txt")
self.entities = set()
self.relationships = set()
print("Reading train triples for %s..." % self.name)
self.train_triples = self._read_triples(self.train_path, separator)
print("Reading validation triples for %s..." % self.name)
self.valid_triples = self._read_triples(self.valid_path, separator)
print("Reading test triples for %s..." % self.name)
self.test_triples = self._read_triples(self.test_path, separator)
def _read_triples(self, triples_path, separator="\t"):
triples = []
with open(triples_path, "r") as triples_file:
lines = triples_file.readlines()
for line in lines:
#line = html.unescape(line)
head, relationship, tail = line.strip().split(separator)
triples.append((head, relationship, tail))
self.entities.add(head)
self.entities.add(tail)
self.relationships.add(relationship)
return triples
def home_folder_for(dataset_name):
dataset_home = os.path.join(DATASETS_HOME, dataset_name)
if os.path.isdir(dataset_home):
return dataset_home
else:
raise Exception("Folder %s does not exist" % dataset_home)