-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
131 lines (101 loc) Β· 4.74 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from net import DeepRank
from utils import euclidean_distance, data_transforms, DatasetImageNet
# -- path info
TRIPLET_PATH = 'triplet.csv'
MODEL_PATH = 'deeprank.pt'
EMBEDDING_PATH = 'embedding.txt'
class Prediction:
def __init__(self):
self.model = DeepRank()
self.model.load_state_dict(torch.load(MODEL_PATH)) # load model parameters
self.train_df = pd.read_csv(TRIPLET_PATH).drop_duplicates('query', keep='first').reset_index(drop=True)
# check embedding
if not os.path.exists(EMBEDDING_PATH):
print('pre-generated [embedding.txt] not exist!')
self.embedding()
self.train_embedded = np.fromfile(EMBEDDING_PATH, dtype=np.float32).reshape(-1, 4096)
def embedding(self):
""" create embedding textfile with train data """
print(' ==> Generate embedding...', end='')
self.model.eval() # set to eval mode
if torch.cuda.is_available():
self.model.to('cuda')
train_dataset = DatasetImageNet(TRIPLET_PATH, embedding=True, transform=data_transforms['val'])
embedded_images = []
for idx in range(len(train_dataset)):
input_tensor = train_dataset[idx][0]
input_batch = input_tensor.unsqueeze(0)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
embedding = self.model(input_batch)
embedding_np = embedding.cpu().detach().numpy()
embedded_images.append(embedding_np) # collect train data's predicted results
embedded_images_train = np.concatenate(embedded_images, axis=0)
embedded_images_train.astype('float32').tofile(EMBEDDING_PATH) # save embedding result
print('done! [embedding.txt] generated')
def query_embedding(self, query_image_path):
""" return embedded query image """
print(f'Query image [{query_image_path}] embedding...', end='')
# read query image and pre-processing
query_image = Image.open(query_image_path).convert('RGB')
query_image = data_transforms['val'](query_image)
query_image = query_image[None] # add new axis. same as 'query_image[None, :, :, :]'
self.model.eval() # set to eval mode
embedding = self.model(query_image)
print('done!')
return embedding.cpu().detach().numpy()
def save_result(self, result, result_num, result_name):
""" save similarity result """
print('Save predicted result ...', end='')
fig = plt.figure(figsize=(64, 64))
columns = result_num + 1
ax = []
for i in range(1, columns + 1):
dist, img_path = result[i - 1]
img = mpimg.imread(img_path) # read image
ax.append(fig.add_subplot(1, columns, i))
if i == 1: # query image
ax[-1].set_title("query image", fontsize=50)
else: # others
ax[-1].set_title("img_:" + str(i - 1), fontsize=50)
ax[-1].set(xlabel='l2-dist=' + str(dist))
ax[-1].xaxis.label.set_fontsize(25)
plt.imshow(img, cmap='Greys_r')
plt.savefig(result_name) # save as file
print('done!')
def predict(self, query_image_path, result_num, save_as='result.png'):
""" predict top-n similar images """
# check query path is valid
if not os.path.exists(query_image_path):
print(f'[ERROR] invalid query image path: {query_image_path}')
return
# embedding query image
query_embedded = self.query_embedding(query_image_path)
# by euclidean distance, find top ranked similar images
image_dist = euclidean_distance(self.train_embedded, query_embedded)
image_dist_indexed = zip(image_dist, range(image_dist.shape[0]))
image_dist_sorted = sorted(image_dist_indexed, key=lambda x: x[0])
# top 5 images
predicted_images = [(img[0], self.train_df.loc[img[1], "query"]) for img in image_dist_sorted[:result_num]]
print(predicted_images)
# make png file
self.save_result([(0.0, query_image_path)] + predicted_images, result_num, result_name=save_as)
def main():
predictor = Prediction()
image_path1 = 'patent_image/train/λ/4020020037823.jpg'
image_path2 = 'patent_image/train/μμ/3020020001785M010.jpg'
image_path3 = 'patent_image/train/νμ/4019880001850.jpg'
# get images for 3 Validation set
test_images = [image_path2]
for idx, p in enumerate(test_images):
predictor.predict(p, 5, f'result_{idx}.png')
if __name__ == '__main__':
main()