-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict.py
72 lines (54 loc) · 1.92 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# coding=utf-8
# /************************************************************************************
# ***
# *** File Author: Dell, 2018年 08月 19日 星期日 20:38:18 CST
# ***
# ************************************************************************************/
import argparse
import os
import random
import skimage.io
import numpy as np
import config
import model as modellib
import data as datalib
import utils
import torch
# Root directory of the project
ROOT_DIR = os.getcwd()
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "models/mask_rcnn_coco.pth")
parser = argparse.ArgumentParser(description='Mask RCNN Predictor')
parser.add_argument(
'-model',
type=str,
default=COCO_MODEL_PATH,
help='trained model [' + COCO_MODEL_PATH + ']')
parser.add_argument('image', type=str, help='image file')
if __name__ == '__main__':
random.seed()
args = parser.parse_args()
config = config.CocoInferenceConfig()
# Create model object.
model = modellib.MaskRCNN(model_dir=MODEL_DIR, config=config)
if config.GPU_COUNT:
model = model.cuda()
# Load weights trained on MS-COCO
model.load_state_dict(torch.load(args.model))
print(model)
img = skimage.io.imread(args.image)
if img .ndim != 3:
img = skimage.color.grey2rgb(img)
# Run detection
class_ids, scores, boxes, masks = model.detect(img)
if class_ids is not None:
class_names = []
for i in range(len(class_ids)):
j = class_ids[i]
class_names.append(datalib.CocoLabel.name(j))
print(j, datalib.CocoLabel.zh_name(j), boxes[i], scores[i])
utils.display_instances(img, np.array(boxes), np.array(masks),
np.array(class_ids), class_names, np.array(scores))
else:
utils.display_instances(img, None, None, None, class_names, None)