-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_all.py
48 lines (38 loc) · 1.15 KB
/
run_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import random
import numpy as np
import torch
import json
seed = 2021
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
os.environ['PYTHONHASHSEED'] = str(seed)
data_list = ['Structured/Amazon-Google',
'Structured/BeerAdvo-RateBeer',
'Structured/DBLP-ACM',
'Structured/Fodors-Zagats',
'Structured/iTunes-Amazon',
'Dirty/DBLP-ACM',
'Dirty/iTunes-Amazon',
'Textual/Abt-Buy']
for data in data_list:
configs = json.load(open('configs.json'))
configs = {conf['name']: conf for conf in configs}
config = configs[data]
cmd = """python train.py \
--data_name %s --n_epoch %d --seed %d""" % (data, config['epoch'], seed)
if config['literal_channel']:
cmd += ' --literal'
if config['digital_channel']:
cmd += ' --digital'
if config['structure_channel']:
cmd += ' --structure'
if config['name_channel']:
cmd += ' --name'
print(cmd)
os.system(cmd)