forked from jiadonglee/aspgap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_xp.py
124 lines (93 loc) · 3.77 KB
/
train_xp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# import numpy as np
# %load_ext autoreload
# %autoreload 2
import time
import torch
from torch.utils.data import DataLoader
from model import Spec2HRd
from data import GaiaXPlabel_cont_norm
if __name__ == "__main__":
#=========================Data loading ================================
data_dir = "/data/jdli/gaia/"
tr_file = "ap17_wise_xpcont_cut.npy"
device = torch.device('cuda:0')
TOTAL_NUM = 6000
BATCH_SIZE = 1024
gdata = GaiaXPlabel_cont_norm(
data_dir+tr_file,
total_num=TOTAL_NUM, part_train=False,
device=device
)
val_size = int(0.1*len(gdata))
A_size = int(0.5*(len(gdata)-val_size))
B_size = len(gdata) - A_size - val_size
A_dataset, B_dataset, val_dataset = torch.utils.data.random_split(gdata, [A_size, B_size, val_size], generator=torch.Generator().manual_seed(42))
print(len(A_dataset), len(B_dataset), len(val_dataset))
A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=1024)
##==================Model parameters============================
##==============================================================
#===============================================================
INPUT_LEN = 30
model = Spec2HRd(
n_encoder_inputs=INPUT_LEN, n_decoder_inputs=INPUT_LEN+2,
n_outputs=2, channels=512, n_heads=8, n_layers=8,
).to(device)
# cost = torch.nn.GaussianNLLLoss(full=True, reduction='sum')
cost = torch.nn.MSELoss(reduction='mean')
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(
model.parameters(),
lr=1e-3, weight_decay=1e-5
)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
itr = 1
num_iters=100
tr_select = "A"
model_dir = "/data/jdli/gaia/model/1119/" + tr_select
if tr_select=="A":
tr_loader = A_loader
elif tr_select=="B":
tr_loader = B_loader
# check_point = "/data/jdli/gaia/model/1119/B/sp2_4labels_mse_B_ep23.pt"
# print("Loading checkpoint %s"%(check_point))
print("===================================")
# model.load_state_dict(torch.load(check_point))
print("Traing %s begin"%tr_select)
def train_epoch(tr_loader, epoch):
# model.train()
model.train()
total_loss = 0.
start_time = time.time()
for batch, data in enumerate(tr_loader):
output = model(data['x'])
loss = cost(output.view(-1, 4), data['y'],)
loss_value = loss.item()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss+=loss_value
del data, output
print("epoch %d train loss:%.4f | %.4f s"%(epoch, total_loss/(batch+1e-5), time.time()-start_time))
def eval(val_loader):
model.eval()
total_val_loss=0
with torch.no_grad():
for bs, data in enumerate(val_loader):
output = model(data['x'])
loss = cost(output.view(-1, 4), data['y'],)
total_val_loss+=loss.item()
del data, output
print("val loss:%.4f"%(total_val_loss/(bs+1e-5)))
num_epochs = 200
for epoch in range(num_epochs+1):
train_epoch(tr_loader, epoch)
if epoch%5==0:
eval(val_loader)
if epoch%50==0:
save_point = "/sp2_4labels_mse_%s_ep%d.pt"%(tr_select, epoch)
torch.save(model.state_dict(), model_dir+save_point)
torch.cuda.empty_cache()