-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation_UNet.py
73 lines (54 loc) · 2.8 KB
/
evaluation_UNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#coding:utf8
import models
from config import *
import torch as t
from tqdm import tqdm
import numpy
import time
import os
check_ing_path = '/userhome/GUOXUTAO/2020_00/NET21/00/check/stage1_from_pth/'
check_list = os.listdir(check_ing_path)
read_list = os.listdir('/userhome/GUOXUTAO/2020_00/NET21/00/check/dice/')
for index,checkname in enumerate(check_list):
print(index,checkname)
if checkname not in read_list:
#if 1 > 0:
model = getattr(models, 'unet_3d')()
model.eval()
model.load_state_dict(t.load(check_ing_path+checkname))
model.eval()
if opt.use_gpu: model.cuda()
if 1 > 0:
testpath = '/userhome/GUOXUTAO/data/datafristpaper/data/test/data/'
folderlist = os.listdir(testpath)
WT_dice = []
TC_dice = []
ET_dice = []
for index,fodername in enumerate(folderlist):
print(index,fodername)
data = np.load(testpath+fodername)
vector = data[0:4,:,:,:]
tru = data[4,:,:,:]
prob = np.zeros((5,data.shape[1],data.shape[2],data.shape[3]))
g = 10
s0 = 32
s1 = 32
ss = 128
for i in range(50):
for ii in range(50):
for iii in range(50):
if g+s0*i+ss < data.shape[1]-g:
if g+s0*ii+ss < data.shape[2]-g:
if g+s1*iii+ss < data.shape[3]-g:
img_out = vector[:,g+s0*i:g+s0*i+ss,g+s0*ii:g+s0*ii+ss,g+s1*iii:g+s1*iii+ss]
img = torch.from_numpy(img_out).unsqueeze(0).float()
with torch.no_grad():
input = t.autograd.Variable(img)
if True: input = input.cuda()
#down_1 = model_feature(input)
#print(down_1.shape)
score = model(input)
score = torch.nn.Softmax(dim=1)(score).squeeze().detach().cpu().numpy()
prob[:,g+s0*i:g+s0*i+ss,g+s0*ii:g+s0*ii+ss,g+s1*iii:g+s1*iii+ss] = prob[:,g+s0*i:g+s0*i+ss,g+s0*ii:g+s0*ii+ss,g+s1*iii:g+s1*iii+ss] + score
label = np.argmax((prob).astype(float),axis=0)
pre = label