-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
72 lines (41 loc) · 1.41 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import cv2
from tqdm import tqdm
from train_config import config as cfg
from lib.dataset.augmentor.augmentation import CenterCrop
from lib.core.api.classifier import Shufflenet
def eval(models):
classifier = Shufflenet(models)
center_crop = CenterCrop(target_size=224, resize_size=256)
with open('val.txt','r') as f:
val_lines=f.readlines()
top1_right=0
top5_right=0
total=0
for one_line in tqdm(val_lines):
line = one_line.rstrip()
_img_path = line.rsplit('|', 1)[0]
_label = int(line.rsplit('|', 1)[-1])
image = cv2.imread(_img_path, cv2.IMREAD_COLOR)
if cfg.DATA.rgb:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = center_crop(image)
input=np.expand_dims(image,axis=0)
label = np.array(_label)
res=classifier.run(input)
res=np.array(res[0])
sorted_id=np.argsort(-res)
if sorted_id[0]==label:
top1_right+=1
if np.sum(sorted_id[0:5]==label)==1:
top5_right+=1
total+=1
print('top1 err:%f'%(1-top1_right/total))
print('top5 err:%f'%(1-top5_right/total))
if __name__=='__main__':
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True, default='the tf model', help="the tensorflow model")
args = ap.parse_args()
pb = args.model
eval(pb)