-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpretrain_config.py
37 lines (30 loc) · 1.03 KB
/
pretrain_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import time
import torch
cuda_condition = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda_condition else 'cpu')
# ## 模型文件路径 ## #
SourcePath = '../../data/src_data/src_set.csv'
CorpusPath = '../../data/train_data/train_set.csv'
# CorpusPath = '../../data/train_data/train_demo.csv'
EvalPath = '../../data/test_data/eval_set.csv'
TestPath = '../../data/test_data/test_a.csv'
# 保存最大句长,字符数,类别数
Assistant = '../../data/train_data/assistant.txt'
# ## 训练调试参数开始 ## #
Epochs = 16
BatchSize = 1
LearningRate = 1e-5
MemoryLength = 128
AttentionMask = False
HiddenLayerNum = 6
SentenceLength = 128
PretrainPath = '../../checkpoint/finetune/transformerXL_classify_%s.model' % SentenceLength
# ## 训练调试参数结束 ## #
# ## 通用参数 ## #
DropOut = 0.1
VocabSize = int(open(Assistant, 'r', encoding='utf-8').readline().split(',')[0])
HiddenSize = 768
IntermediateSize = 3072
AttentionHeadNum = 12
def get_time():
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())