-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
107 lines (93 loc) · 3.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import csv
import os
import logging
import torch
import random
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
def worker_init(worker_init):
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
def read_entity(table_paths, skip=True, add_token=True):
"""
Read entities from tables.
"""
entity_list = []
if type(table_paths) is list:
for table_path in table_paths:
lines = list(csv.reader(open(table_path, 'r')))
att = []
for i, line in enumerate(lines):
sentence = ''
for j in range(1, len(line)):
if i == 0:
att.append(line[j])
elif skip and (line[j] == ''):
continue
elif add_token:
sentence += 'COL ' + att[j - 1] + ' VAL ' + line[j] + ' '
else:
sentence += att[j - 1] + ' ' + line[j] + ' '
if i != 0:
entity_list.append(sentence.strip())
else:
lines = list(csv.reader(open(table_paths, 'r')))
att = []
for i, line in enumerate(lines):
sentence = ''
for j in range(1, len(line)):
if i == 0:
att.append(line[j])
elif skip and (line[j] == ''):
continue
elif add_token:
sentence += 'COL ' + att[j - 1] + ' VAL ' + line[j] + ' '
else:
sentence += att[j - 1] + ' ' + line[j] + ' '
if i != 0:
entity_list.append(sentence.strip())
return entity_list
def evaluate(y_truth, y_pred):
"""
Evaluate model.
"""
precision = precision_score(y_truth, y_pred)
recall = recall_score(y_truth, y_pred)
f1 = f1_score(y_truth, y_pred)
return precision, recall, f1
def set_logger(name):
"""
Write logs to checkpoint and console.
"""
log_file = os.path.join('./log', name)
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
filename=log_file,
filemode='w'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def get_tokenizer(lm='bert'):
"""Return the tokenizer. Initialize it if not initialized.
Args:
lm (string): the name of the language model (bert, albert, or distilbert)
Returns:
BertTokenizer or DistilBertTokenizer or AlbertTokenizer
"""
tokenizer = None
if lm == 'bert':
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('./lm_model/bert-base-uncased')
elif lm == 'roberta':
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained('./lm_model/roberta-base')
elif lm == 'xlnet':
from transformers import XLNetTokenizer
tokenizer = XLNetTokenizer.from_pretrained('./lm_model/xlnet-base-cased')
return tokenizer