-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·124 lines (104 loc) · 4.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from tqdm.auto import tqdm
from model import MusicTransformerDecoder
from custom.layers import *
from custom import callback
import params as par
from tensorflow.python.keras.optimizer_v2.adam import Adam
from data import Data
import utils
import argparse
import math
import datetime
import sys
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.executing_eagerly()
def train(input_path, save_path, l_r=None, batch_size=2,
max_seq=1024, epochs=100,
load_path=None, num_layer=6, log_dir='/pfs/out/logs'):
# load data
dataset = Data(input_path)
print('dataset', dataset)
# load model
learning_rate = callback.CustomSchedule(par.embedding_dim) if l_r is None else l_r
opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
# define model
mt = MusicTransformerDecoder(
embedding_dim=256,
vocab_size=par.vocab_size,
num_layer=num_layer,
max_seq=max_seq,
dropout=0.2,
debug=False, loader_path=load_path)
mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)
# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = '{log_dir}/{time}/train'.format(log_dir=log_dir, time=current_time)
eval_log_dir = '{log_dir}/{time}/eval'.format(log_dir=log_dir, time=current_time)
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
eval_summary_writer = tf.summary.create_file_writer(eval_log_dir)
# Train Start
idx = 0
batchings = len(dataset.files) // batch_size
how_often_to_print = 50
for e in tqdm(range(epochs), desc='epochs'):
mt.reset_metrics()
for b in tqdm(range(batchings), desc='batches'):
try:
batch_x, batch_y = dataset.slide_seq2seq_batch(batch_size, max_seq)
except:
continue
result_metrics = mt.train_on_batch(batch_x, batch_y)
if b % how_often_to_print == 0:
eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval')
eval_result_metrics, weights = mt.evaluate(eval_x, eval_y)
mt.save(save_path)
with train_summary_writer.as_default():
if b == 0:
tf.summary.histogram("target_analysis", batch_y, step=e)
tf.summary.histogram("source_analysis", batch_x, step=e)
tf.summary.scalar('loss', result_metrics[0], step=idx)
tf.summary.scalar('accuracy', result_metrics[1], step=idx)
with eval_summary_writer.as_default():
if b == 0:
mt.sanity_check(eval_x, eval_y, step=e)
tf.summary.scalar('loss', eval_result_metrics[0], step=idx)
tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx)
for i, weight in enumerate(weights):
with tf.name_scope("layer_%d" % i):
with tf.name_scope("w"):
utils.attention_image_summary(weight, step=idx)
# for i, weight in enumerate(weights):
# with tf.name_scope("layer_%d" % i):
# with tf.name_scope("_w0"):
# utils.attention_image_summary(weight[0])
# with tf.name_scope("_w1"):
# utils.attention_image_summary(weight[1])
idx += 1
print('\n====================================================')
print('Epoch/Batch: {}/{}'.format(e, b))
print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1]))
print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1]))
if __name__ == "__main__":
print('train.py')
parser = argparse.ArgumentParser()
parser.add_argument('--l_r', default=None, type=float)
parser.add_argument('--batch_size', default=2, help='batch size', type=int)
parser.add_argument('--max_seq', default=1024, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--load_path', default=None, type=str)
parser.add_argument('--input_path')
parser.add_argument('--save_path', default="/pfs/out")
parser.add_argument('--num_layers', default=6, type=int)
args = parser.parse_args()
print('args', args)
# set arguments
l_r = args.l_r
batch_size = args.batch_size
max_seq = args.max_seq
epochs = args.epochs
load_path = args.load_path
save_path = args.save_path
num_layer = args.num_layers
input_path = args.input_path
train(input_path, save_path, l_r, batch_size, max_seq, epochs, load_path, num_layer)