-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
65 lines (55 loc) · 1.93 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from glob import glob
from os import system
from flair.datasets import ColumnCorpus
from flair.embeddings import *
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from tqdm import tqdm
import util
from tag import process_file
device = torch.device('cuda:0')
if __name__ == '__main__':
util.print_flag('Loading')
util.print_flag('Dataset', big=False)
corpus: ColumnCorpus = ColumnCorpus(
data_folder='resources/data/',
train_file='concat_PharmaCoNER.conll',
dev_file=None,
test_file='test_PharmaCoNER.conll',
column_format={0: 'text', 1: 'begin', 2: 'end', 3: 'ner'}
)
util.print_flag('Embeddings', big=False)
pooling_op = 'min'
embeddings: StackedEmbeddings = util.get_embeddings(pooling_op)
util.print_flag('Training')
tag_type = 'ner'
model = f'PharmaCoNER-PCE_{pooling_op}-BPEmb-FT-w2v'
tagger: SequenceTagger = SequenceTagger(
embeddings=embeddings,
tag_dictionary=corpus.make_tag_dictionary(tag_type=tag_type),
tag_type=tag_type,
hidden_size=256,
rnn_layers=1,
dropout=0.0
)
print(tagger)
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.train(
f'resources/models/{model}',
learning_rate=0.1,
mini_batch_size=32,
patience=3
)
util.print_flag('Tagging')
out_path = f'system/{model}/'
os.makedirs(out_path, exist_ok=True)
found_tags = 0
tq = tqdm(glob('resources/data/background_processed/*.conll'))
for file in tq:
file_name = os.path.split(file)[1]
found_tags += process_file(tagger, file,
os.path.join(out_path, file_name.replace('.conll', '.ann')))
tq.set_postfix(Tags=found_tags)
util.print_flag('Evaluating')
system(
f'python3 evaluate.py ner gold/test/subtrack1 system/{model}/ | tee resources/models/{model}/eval_results.txt')