-
Notifications
You must be signed in to change notification settings - Fork 1
/
train
executable file
·91 lines (71 loc) · 2.62 KB
/
train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
import argparse
import subprocess
import os
import json
import re
import psutil
print("Training...")
#512MB of reserved system memory
reserved_memory = (1024**3)//2
available_memory = psutil.virtual_memory()[0]
usable_memory = (available_memory - reserved_memory)//1024**2
my_env = os.environ.copy()
my_env["TERRIER_HEAP_MEM"] = f"{usable_memory}M"
parser = argparse.ArgumentParser()
parser.add_argument("--json", type=json.loads, required=True, help="json arguments")
args, _ = parser.parse_known_args()
if not os.path.isdir("/output"):
os.mkdir("/output", 777)
subprocess.run("rm -f /output/ensemble.txt /output/jforests-discrete-va.letor /output/tr.letor /output/va.letor /output/jforests-discrete-tr.letor /output/jforests-feature-stats.txt /output/tr.bin /output/va.bin".split())
alltopics = args.json["topic"]
print("Features are:")
with open("/output/features.list","w+") as ff:
for feature in args.json["opts"]["features"].split(";"):
ff.write(feature)
print("\t"+ feature)
ff.write("\n")
run_args={
"opts" : {
"config" : "bm25_ltr_features",
"-Dlearning.labels.file="+args.json["qrels"]["path"] : ""
},
"collection" : args.json["collection"],
"topic" : alltopics,
"top_k" : 1000
}
print("calling /search --json {}".format(json.dumps(json.dumps(run_args))))
subprocess.run(["/search", "--json" ,"{}".format((json.dumps(run_args)))])
with open("/data/splits/test_split.txt", "r") as te_qids_file:
teqids = set(te_qids_file)
with open("/data/splits/validation_split.txt", "r") as va_qids_file:
vaqids = set(va_qids_file)
alltopics="/output/run.{0}.{1}.txt".format(args.json['collection']['name'], "bm25_ltr_features")
print("Separating {0} into training ({1}) & validation ({2})".format(alltopics,len(teqids), len(vaqids)))
allrun = open(alltopics, "r")
tr = open("/output/tr.letor", "w+")
va = open("/output/va.letor", "w+")
for line in allrun:
m = re.match(".*qid:(\w+).*", line)
if m is None:
tr.write(line)
va.write(line)
else:
qid=m.group(1)
if not qid in vaqids and not qid in teqids:
tr.write(line)
elif qid in vaqids:
va.write(line)
tr.close()
va.close()
allrun.close()
subprocess.run("""
/work/terrier-core/bin/terrier jforests
--config-file /work/terrier-core/etc/jforests.properties
--cmd=generate-bin --ranking --folder /output/ --file tr.letor --file va.letor
""".split(), env=my_env)
subprocess.run("""
/work/terrier-core/bin/terrier jforests
--config-file /work/terrier-core/etc/jforests.properties
--cmd=train --ranking --folder /output/
--train-file /output/tr.bin --validation-file /output/va.bin --output-model /output/ensemble.txt""".split(), env=my_env)