-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_wide_UNet.py
121 lines (86 loc) · 3.43 KB
/
train_wide_UNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#coding:utf8
import models
from config import *
from data.dataset_2d import Brats17
from torch.utils.data import DataLoader
import torch as t
from tqdm import tqdm
import numpy
import time
############################################################################
def val(model,model_feature,dataloader):
'''
The accuracy of the model on the verification set is calculated
'''
model.eval()
val_losses, dcs = [], []
for ii, data in enumerate(dataloader):
input, label = data
val_input = Variable(input.cuda())
val_label = Variable(label.cuda())
if opt.use_gpu:
val_input = val_input.cuda()
val_label = val_label.cuda()
model = model.cuda()
concat = model_feature(val_input)
outputs=model(concat)
pred = outputs.data.max(1)[1].cpu().numpy().squeeze()
gt = val_label.data.cpu().numpy().squeeze()
if 1 > 0:
dc,val_loss=calc_dice(gt[:,:,:],pred[:,:,:])
dcs.append(dc)
val_losses.append(val_loss)
model.train()
return np.mean(dcs),np.mean(val_losses)
############################################################################
############################################################################
print('train:')
lr = opt.lr
batch_size = batch_size.batch_size
print('batch_size:',batch_size,'lr:',lr)
plt_list = []
model = getattr(models, 'unet_3dwide')()
#model.load_state_dict(t.load('/userhome/xxxxxx.pth'))
if opt.use_gpu:
model.cuda()
train_data=dataread(opt.train_data_root,train = True, test = False, val = False)
val_data=dataread(opt.val_data_root,train = False, test = False, val = True)
val_dataloader = DataLoader(val_data,1,shuffle=False,num_workers=opt.num_workers)
train_dataloader = DataLoader(train_data,batch_size = batch_size,shuffle=True,num_workers=opt.num_workers)
criterion = t.nn.CrossEntropyLoss()
#criterion = DiceLoss3D()
if opt.use_gpu:
criterion = criterion.cuda()
loss_meter=AverageMeter()
previous_loss = 1e+20
optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay)
# train
for epoch in range(opt.max_epoch):
loss_meter.reset()
for ii,(data,label) in tqdm(enumerate(train_dataloader),total=len(train_data)):
# train model
input = Variable(data)
target = Variable(label)
if opt.use_gpu:
input = input.cuda()
target = target.cuda()
optimizer.zero_grad()
score = model(input)
loss = criterion(score,target)
loss.backward()
optimizer.step()
loss_meter.update(loss.item())
if ii==1:
plt_list.append(loss_meter.val)
if ii==1:
print('train-loss-avg:', loss_meter.avg,'train-loss-each:', loss_meter.val)
if ii==1:
acc,val_loss = val(model,model_feature,val_dataloader)
prefix = opt.pth_save_path + str(acc)+'_'+str(val_loss) + '_'+str(lr)+'_'+str(batch_size)+'_'
name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
t.save(model.state_dict(), name)
name1 = time.strftime(opt.plt_save_path + '%m%d_%H:%M:%S.npy')
numpy.save(name1, plt_list)
print('old:','batch_size:',batch_size,'lr:',lr)
print('new:','batch_size:',batch_size,'lr:',lr)
############################################################################