-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
129 lines (114 loc) · 5.23 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*-coding:utf-8 -*-
import tensorflow as tf # 0.12
from tensorflow.models.rnn.translate import seq2seq_model
import os
import numpy as np
import math
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
train_encode_vec = 'train_encode.vec'
train_decode_vec = 'train_decode.vec'
test_encode_vec = 'test_encode.vec'
test_decode_vec = 'test_decode.vec'
# 词汇表大小50000
vocabulary_encode_size = 50000
vocabulary_decode_size = 50000
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
layer_size = 256 # 每层大小
num_layers = 3 # 层数
batch_size = 64
# 读取*dencode.vec和*decode.vec数据(数据还不算太多, 一次读人到内存)
def read_data(source_path, target_path, max_size=None):
data_set = [[] for _ in buckets]
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
counter = 0
while source and target and (not max_size or counter < max_size):
counter += 1
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
target_ids.append(EOS_ID)
for bucket_id, (source_size, target_size) in enumerate(buckets):
if len(source_ids) < source_size and len(target_ids) < target_size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
print ("run 1")
model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_size, target_vocab_size=vocabulary_decode_size,
buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm=5.0,
batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.97,
forward_only=False)
print ("run 2")
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC' # 防止 out of memory
with tf.Session(config=config) as sess:
print ("run 3")
# 恢复前一次训练
ckpt = tf.train.get_checkpoint_state('.')
if ckpt != None:
print ("ckpt!=None")
print(ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.initialize_all_variables())
train_set = read_data(train_encode_vec, train_decode_vec)
test_set = read_data(test_encode_vec, test_decode_vec)
print (type(train_set),len(train_set),len (train_set[1]),len (train_set[1][1]) )
print ("(source_size<5, target_size)<10")
print (train_set[0][:10])
print ("(source_size<10, target_size)<15")
print (train_set[1][:10])
train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]
print ("每块大小:")
print (train_bucket_sizes)
train_total_size = float(sum(train_bucket_sizes))
print ("总块大小:")
print ( train_total_size)
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
print ("前i块总和占比:")
print (train_buckets_scale)
loss = 0.0
total_step = 0
previous_losses = []
# 一直训练,每过一段时间保存一次模型
while True:
random_number_01 = np.random.random_sample()
# print ("random_number_01:")
# print (random_number_01)
bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
# print ("bucker_id:")
# print ( bucket_id)
#上面这两行代码其实就是随机找一块进行训练!!!!!!
encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
# print ("encoder_inputs:")
# print (encoder_inputs)
# print ("decoder_inputs:")
# print (decoder_inputs)
# print ("target_weights:")
# print (target_weights)
loss += step_loss / 500
total_step += 1
# print(total_step)
if total_step % 500 == 0:
print(model.global_step.eval(), model.learning_rate.eval(), loss)
# 如果模型没有得到提升,减小learning rate
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
# 保存模型
checkpoint_path = "chatbot_seq2seq.ckpt"
model.saver.save(sess, checkpoint_path, global_step=model.global_step)
loss = 0.0
# 使用测试数据评估模型
for bucket_id in range(len(buckets)):
if len(test_set[bucket_id]) == 0:
continue
encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
print(bucket_id, eval_ppx)