-
Notifications
You must be signed in to change notification settings - Fork 4
/
inference.py
executable file
·96 lines (69 loc) · 2.69 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import sys
import glob
import math
from random import shuffle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import nibabel as nib
import torch
import torch.nn as nn
import util
from model import UNet3D
def inference(T1_path, b0_d_path, model, device):
with torch.no_grad():
# Eval mode
model.eval()
# Get image
img_T1 = np.expand_dims(util.get_nii_img(T1_path), axis=3)
img_b0_d = np.expand_dims(util.get_nii_img(b0_d_path), axis=3)
# Pad array since I stupidly used template with dimensions not factorable by 8
# Assumes input is (77, 91, 77) and pad to (80, 96, 80) with zeros
img_T1 = np.pad(img_T1, ((2, 1), (3, 2), (2, 1), (0, 0)), 'constant')
img_b0_d = np.pad(img_b0_d, ((2, 1), (3, 2), (2, 1), (0, 0)), 'constant')
# Convert to torch img format
img_T1 = util.nii2torch(img_T1)
img_b0_d = util.nii2torch(img_b0_d)
# Normalize data
img_T1 = util.normalize_img(img_T1, 150, 0, 1, -1)
max_img_b0_d = np.percentile(img_b0_d, 99)
min_img_b0_d = 0
img_b0_d = util.normalize_img(img_b0_d, max_img_b0_d, min_img_b0_d, 1, -1)
# Set "data"
img_data = np.concatenate((img_b0_d, img_T1), axis=1)
# Send data to device
img_data = torch.from_numpy(img_data).float().to(device)
# Pass through model
img_model = model(img_data)
# Unnormalize model
img_model = util.unnormalize_img(img_model, max_img_b0_d, min_img_b0_d, 1, -1)
# Remove padding
img_model = img_model[:, :, 2:-1, 2:-1, 3:-2]
# Return model
return img_model
if __name__ == '__main__':
# Get input arguments ----------------------------------#
T1_input_path = sys.argv[1]
b0_input_path = sys.argv[2]
b0_output_path = sys.argv[3]
model_path = sys.argv[4]
print('T1 input path: ' + T1_input_path)
print('b0 input path: ' + b0_input_path)
print('b0 output path: ' + b0_output_path)
print('Model path: ' + model_path)
# Run code ---------------------------------------------#
# Get device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Get model
model = UNet3D(2, 1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
# Inference
img_model = inference(T1_input_path, b0_input_path, model, device)
# Save
nii_template = nib.load(b0_input_path)
nii = nib.Nifti1Image(util.torch2nii(img_model.detach().cpu()), nii_template.affine, nii_template.header)
nib.save(nii, b0_output_path)