-
Notifications
You must be signed in to change notification settings - Fork 9
/
predict_single.py
104 lines (80 loc) · 3.04 KB
/
predict_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import wasr.models as models
NORM_MEAN = np.array([0.485, 0.456, 0.406])
NORM_STD = np.array([0.229, 0.224, 0.225])
# Colors corresponding to each segmentation class
SEGMENTATION_COLORS = np.array([
[247, 195, 37],
[41, 167, 224],
[90, 75, 164]
], np.uint8)
BATCH_SIZE = 12
ARCHITECTURE = 'wasr_resnet101_imu'
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="WaSR Network MaSTr1325 Inference")
parser.add_argument("image", type=str,
help="Path to the image to run inference on.")
parser.add_argument("output", type=str,
help="Path to the file, where the output prediction will be saved.")
parser.add_argument("--imu_mask", type=str, default=None,
help="Path to the corresponding IMU mask (if needed by the model).")
parser.add_argument("--architecture", type=str, choices=models.model_list, default=ARCHITECTURE,
help="Model architecture.")
parser.add_argument("--weights", type=str, required=True,
help="Path to the model weights or a model checkpoint.")
return parser.parse_args()
def predict_image(model, image, imu_mask=None):
feat = {'image': image.cuda()}
if imu_mask is not None:
feat['imu_mask'] = imu_mask.cuda()
res = model(feat)
prediction = res['out'].detach().softmax(1).cpu()
return prediction
def predict(args):
# Load and prepare model
model = models.get_model(args.architecture, pretrained=False)
state_dict = torch.load(args.weights, map_location='cpu')
if 'model' in state_dict:
# Loading weights from checkpoint
state_dict = state_dict['model']
model.load_state_dict(state_dict)
# Enable eval mode and move to CUDA
model = model.eval().cuda()
# Load and normalize image
img = np.array(Image.open(args.image))
H,W,_ = img.shape
img = torch.from_numpy(img) / 255.0
img = (img - NORM_MEAN) / NORM_STD
img = img.permute(2,0,1).unsqueeze(0) # [1xCxHxW]
img = img.float()
# Load IMU mask if provided
imu_mask = None
if args.imu_mask is not None:
imu_mask = np.array(Image.open(args.imu_mask))
imu_mask = imu_mask.astype(np.bool)
imu_mask = torch.from_numpy(imu_mask).unsqueeze(0) # [1xHxW]
# Run inference
probs = predict_image(model, img, imu_mask)
probs = torch.nn.functional.interpolate(probs, (H,W), mode='bilinear')
preds = probs.argmax(1)[0]
# Convert predictions to RGB class colors
preds_rgb = SEGMENTATION_COLORS[preds]
preds_img = Image.fromarray(preds_rgb)
output_dir = Path(args.output).parent
if not output_dir.exists():
output_dir.mkdir(parents=True)
preds_img.save(args.output)
def main():
args = get_arguments()
print(args)
predict(args)
if __name__ == '__main__':
main()