-
Notifications
You must be signed in to change notification settings - Fork 25
/
core.py
139 lines (115 loc) · 5.45 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from torch.utils.data import TensorDataset
import pickle
def use_model(model_name, config_file_path, model_file_path, vocab_file_path, num_labels):
# 選擇模型並加載設定
if(model_name == 'bert'):
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
model_config, model_class, model_tokenizer = (BertConfig, BertForSequenceClassification, BertTokenizer)
config = model_config.from_pretrained(config_file_path,num_labels = num_labels)
model = model_class.from_pretrained(model_file_path, from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config)
tokenizer = model_tokenizer(vocab_file=vocab_file_path)
return model, tokenizer
elif(model_name == 'albert'):
from albert.albert_zh import AlbertConfig, AlbertTokenizer, AlbertForSequenceClassification
model_config, model_class, model_tokenizer = (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer)
config = model_config.from_pretrained(config_file_path,num_labels = num_labels)
model = model_class.from_pretrained(model_file_path, config=config)
tokenizer = model_tokenizer.from_pretrained(vocab_file_path)
return model, tokenizer
def compute_accuracy(y_pred, y_target):
# 計算正確率
_, y_pred_indices = y_pred.max(dim=1)
n_correct = torch.eq(y_pred_indices, y_target).sum().item()
return n_correct / len(y_pred_indices) * 100
def to_bert_ids(tokenizer,q_input):
# 將文字輸入轉換成對應的id編號
return tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q_input)))
def make_dataset(input_ids, input_masks, input_segment_ids, answer_lables):
all_input_ids = torch.tensor([input_id for input_id in input_ids], dtype=torch.long)
all_input_masks = torch.tensor([input_mask for input_mask in input_masks], dtype=torch.long)
all_input_segment_ids = torch.tensor([input_segment_id for input_segment_id in input_segment_ids], dtype=torch.long)
all_answer_lables = torch.tensor([answer_lable for answer_lable in answer_lables], dtype=torch.long)
return TensorDataset(all_input_ids, all_input_masks, all_input_segment_ids, all_answer_lables)
def split_dataset(full_dataset, split_rate=0.8):
train_size = int(split_rate * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
return train_dataset,test_dataset
class DataDic(object):
def __init__(self, answers):
self.answers = answers #全部答案(含重複)
self.answers_norepeat = sorted(list(set(answers))) # 不重複
self.answers_types = len(self.answers_norepeat) # 總共多少類
self.ans_list = [] # 用於查找id或是text的list
self._make_dic() # 製作字典
def _make_dic(self):
for index_a,a in enumerate(self.answers_norepeat):
if a != None:
self.ans_list.append((index_a,a))
def to_id(self,text):
for ans_id,ans_text in self.ans_list:
if text == ans_text:
return ans_id
def to_text(self,id):
for ans_id,ans_text in self.ans_list:
if id == ans_id:
return ans_text
@property
def types(self):
return self.answers_types
@property
def data(self):
return self.answers
def __len__(self):
return len(self.answers)
def convert_data_to_feature(tokenizer, train_data_path):
with open(train_data_path,'r',encoding='utf-8') as f:
data = f.read()
qa_pairs = data.split("\n")
questions = []
answers = []
for qa_pair in qa_pairs:
qa_pair = qa_pair.split()
try:
a,q = qa_pair
questions.append(q)
answers.append(a)
except:
continue
assert len(answers) == len(questions)
ans_dic = DataDic(answers)
question_dic = DataDic(questions)
q_tokens = []
max_seq_len = 0
for q in question_dic.data:
bert_ids = to_bert_ids(tokenizer,q)
if(len(bert_ids)>max_seq_len):
max_seq_len = len(bert_ids)
q_tokens.append(bert_ids)
# print(tokenizer.convert_ids_to_tokens(tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q)))))
print("最長問句長度:",max_seq_len)
assert max_seq_len <= 512 # 小於BERT-base長度限制
# 補齊長度
for q in q_tokens:
while len(q)<max_seq_len:
q.append(0)
a_labels = []
for a in ans_dic.data:
a_labels.append(ans_dic.to_id(a))
# print (ans_dic.to_id(a))
# BERT input embedding
answer_lables = a_labels
input_ids = q_tokens
input_masks = [[1]*max_seq_len for i in range(len(question_dic))]
input_segment_ids = [[0]*max_seq_len for i in range(len(question_dic))]
assert len(input_ids) == len(question_dic) and len(input_ids) == len(input_masks) and len(input_ids) == len(input_segment_ids)
data_features = {'input_ids':input_ids,
'input_masks':input_masks,
'input_segment_ids':input_segment_ids,
'answer_lables':answer_lables,
'question_dic':question_dic,
'answer_dic':ans_dic}
output = open('trained_model/data_features.pkl', 'wb')
pickle.dump(data_features,output)
return data_features