forked from CTA-detection/DLCA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
122 lines (102 loc) · 4.17 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import os
import time
import numpy as np
import data
from importlib import import_module
import shutil
from utils.log_utils import *
import sys
from utils.inference_utils import SplitComb, postprocess, plot_box
import torch
from torch.nn import DataParallel
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
parser = argparse.ArgumentParser(description='ca detection')
parser.add_argument('--model', '-m', metavar='MODEL', default='model.network',
help='model')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 32)')
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N', help='mini-batch size (default: 16)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--input', default='', type=str, metavar='data',
help='directory to save images (default: none)')
parser.add_argument('--output', default='', type=str, metavar='SAVE',
help='directory to save prediction results(default: none)')
parser.add_argument('--test', default=1, type=int, metavar='TEST',
help='1 do test evaluation, 0 not')
parser.add_argument('--n_test', default=1, type=int, metavar='N',
help='number of gpu for test')
def main():
global args
args = parser.parse_args()
torch.manual_seed(0)
model = import_module(args.model)
config, net, loss, get_pbb = model.get_model()
test_name = (args.input).split("/")[-1]
data_dir = (args.input).split("/")[-2]
save_dir = (args.output).split("/")[-2]
if args.resume:
checkpoint = torch.load(args.resume)
net.load_state_dict(checkpoint['state_dict'])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
logfile = os.path.join(save_dir, 'log')
net = net.cuda()
loss = loss.cuda()
cudnn.benchmark = True
net = DataParallel(net)
margin = config["margin"]
sidelen = config["split_size"]
split_comber = SplitComb(sidelen,config['max_stride'],config['stride'],margin,config['pad_value'])
dataset = data.TestDetector(
data_dir,
test_name,
config,
split_comber=split_comber)
test_loader = DataLoader(
dataset,
batch_size = 1,
shuffle = False,
num_workers = args.workers,
collate_fn = data.collate,
pin_memory=False)
test(test_loader, net, get_pbb, save_dir, config)
return
def test(data_loader, net, get_pbb, save_dir, config):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(save_dir)
net.eval()
split_comber = data_loader.dataset.split_comber
for i_name, (data, coord, nzhw) in enumerate(data_loader):
nzhw = nzhw[0]
name = data_loader.dataset.filenames[i_name].split('-')[0].split('/')[-1]
data = data[0][0]
coord = coord[0][0]
n_per_run = args.n_test
print(data.size())
splitlist = range(0,len(data)+1,n_per_run)
if splitlist[-1]!=len(data):
splitlist.append(len(data))
outputlist = []
for i in range(len(splitlist)-1):
input = Variable(data[splitlist[i]:splitlist[i+1]], volatile = True).cuda()
inputcoord = Variable(coord[splitlist[i]:splitlist[i+1]], volatile = True).cuda()
output = net(input,inputcoord)
outputlist.append(output.data.cpu().numpy())
output = np.concatenate(outputlist,0)
output = split_comber.combine(output,nzhw=nzhw)
thresh = -3
pbb,mask = get_pbb(output,thresh,ismask=True)
print([i_name,name])
pbb_nms = postprocess(pbb)
np.save(os.path.join(save_dir, name+'_pbb.npy'), pbb_nms)
print("start plot prediction boxes")
plot_box(name, pbb_nms)
if __name__ == '__main__':
main()