-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
118 lines (83 loc) · 3.33 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os, yaml, argparse, torch
from tokenizers import Tokenizer
from tokenizers.processors import TemplateProcessing
from module import (
load_dataloader,
load_model,
Trainer,
Tester,
SeqGenerator
)
def set_seed(SEED=42):
import random
import numpy as np
import torch.backends.cudnn as cudnn
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
cudnn.benchmark = False
cudnn.deterministic = True
class Config(object):
def __init__(self, args):
with open('config.yaml', 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
for group in params.keys():
for key, val in params[group].items():
setattr(self, key, val)
self.task = args.task
self.mode = args.mode
self.model_type = args.model
self.search_method = args.search
self.ckpt = f"ckpt/{self.task}/{self.model_type}_model.pt"
self.tokenizer_path = f'data/{self.task}/tokenizer.json'
use_cuda = torch.cuda.is_available()
device_condition = use_cuda and self.mode != 'inference'
self.device_type = 'cuda' if device_condition else 'cpu'
self.device = torch.device(self.device_type)
def print_attr(self):
for attribute, value in self.__dict__.items():
print(f"* {attribute}: {value}")
def load_tokenizer(config):
assert os.path.exists(config.tokenizer_path)
tokenizer = Tokenizer.from_file(config.tokenizer_path)
tokenizer.post_processor = TemplateProcessing(
single=f"{config.bos_token} $A {config.eos_token}",
special_tokens=[(config.bos_token, config.bos_id),
(config.eos_token, config.eos_id)]
)
return tokenizer
def main(args):
set_seed()
config = Config(args)
model = load_model(config)
tokenizer = load_tokenizer(config)
if config.mode == 'train':
train_dataloader = load_dataloader(config, tokenizer, 'train')
valid_dataloader = load_dataloader(config, tokenizer, 'valid')
trainer = Trainer(config, model, train_dataloader, valid_dataloader)
trainer.train()
elif config.mode == 'test':
test_dataloader = load_dataloader(config, tokenizer, 'test')
tester = Tester(config, model, tokenizer, test_dataloader)
tester.test()
elif config.mode == 'inference':
generator = SeqGenerator(config, model, tokenizer)
generator.inference()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-task', required=True)
parser.add_argument('-mode', required=True)
parser.add_argument('-model', required=True)
parser.add_argument('-search', default='greedy', required=False)
args = parser.parse_args()
assert args.task.lower() in ['translation', 'dialogue', 'summarization']
assert args.mode.lower() in ['train', 'test', 'inference']
assert args.model.lower() in ['standard', 'recurrent', 'evolved', 'recurrent_hybrid', 'evolved_hybrid']
assert args.search.lower() in ['greedy', 'beam']
if args.mode == 'train':
os.makedirs(f"ckpt/{args.task}", exist_ok=True)
else:
assert os.path.exists(f'ckpt/{args.task}/{args.model}_model.pt')
main(args)