-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
120 lines (84 loc) · 3.35 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os, yaml, argparse, torch
from tokenizers import Tokenizer
from tokenizers.processors import TemplateProcessing
from module import (
load_dataloader,
load_model,
Trainer,
Tester,
Generator
)
def set_seed(SEED=42):
import random
import numpy as np
import torch.backends.cudnn as cudnn
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
cudnn.benchmark = False
cudnn.deterministic = True
class Config(object):
def __init__(self, args):
with open('config.yaml', 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
for group in params.keys():
for key, val in params[group].items():
setattr(self, key, val)
self.mode = args.mode
self.strategy = args.strategy
self.search_method = args.search
if self.mode == 'finetune':
self.lr = self.fine_lr
self.base_ckpt = 'ckpt/baseline_model.pt'
self.ckpt = f"ckpt/{self.strategy}_ft_model.pt"
self.tokenizer_path = f'data/tokenizer.json'
use_cuda = torch.cuda.is_available()
self.device_type = 'cuda' \
if use_cuda and self.mode != 'inference' \
else 'cpu'
self.device = torch.device(self.device_type)
def print_attr(self):
for attribute, value in self.__dict__.items():
print(f"* {attribute}: {value}")
def load_tokenizer(config):
assert os.path.exists(config.tokenizer_path)
tokenizer = Tokenizer.from_file(config.tokenizer_path)
tokenizer.post_processor = TemplateProcessing(
single=f"{config.bos_token} $A {config.eos_token}",
special_tokens=[(config.bos_token, config.bos_id),
(config.eos_token, config.eos_id)]
)
return tokenizer
def main(args):
set_seed()
config = Config(args)
model = load_model(config)
tokenizer = load_tokenizer(config)
if config.mode == ['train', 'finetune']:
train_dataloader = load_dataloader(config, tokenizer, 'train')
valid_dataloader = load_dataloader(config, tokenizer, 'valid')
trainer = Trainer(config, model, train_dataloader, valid_dataloader)
trainer.train()
elif config.mode == 'test':
test_dataloader = load_dataloader(config, tokenizer, 'test')
tester = Tester(config, model, tokenizer, test_dataloader)
tester.test()
elif config.mode == 'inference':
generator = Generator(config, model, tokenizer)
generator.inference()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-mode', required=True)
parser.add_argument('-strategy', default='standard', required=True)
parser.add_argument('-search', default='greedy', required=False)
args = parser.parse_args()
assert args.mode.lower() in ['train', 'finetune', 'test', 'inference']
assert args.strategy.lower() in ['standard','auxiliary', 'recurrent', 'generative']
assert args.search.lower() in ['greedy', 'beam']
if args.mode == 'finetune':
assert os.path.exists('ckpt/baseline_model.pt')
elif args.mode in ['test', 'inference']:
assert os.path.exists(f'ckpt/{args.strategy}_ft_model.pt')
main(args)