-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
49 lines (34 loc) · 1.49 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
import os
import joblib
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from data import IMAGE_TRANSFORMER_TORCH, cv2_face_segmentation
from model_utils import torch_active_device
def make_prediction(model_name, img):
"""
:param model_name:
:param img:
:return:
"""
# model_dir = os.path.dirname(model_name)
# Load model and LabelEncoder
model = torch.load(os.path.join(model_name, model_name), map_location=torch_active_device)
target_encoder = joblib.load(os.path.join(model_name, f"{model_name}.joblib"))
model.to(torch_active_device)
model.eval()
img_tensor = Image.open(img).convert('L') # read image
# segment face if found
segmented_face, segmented_bounds = cv2_face_segmentation(np.array(img_tensor, dtype='uint8'))
img_tensor = IMAGE_TRANSFORMER_TORCH(segmented_face)
y_pred = model(**{"x": img_tensor.unsqueeze(0)})
y_pred = target_encoder.inverse_transform(F.softmax(y_pred, dim=1).argmax(-1).cpu().numpy())[0]
return y_pred, segmented_bounds
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Facial Emotion Recognition - Prediction')
parser.add_argument('--image-file', type=str, action='store', help='Image file path.', required=True)
parser.add_argument('--model-path', type=str, action='store', help='Path to saved model.', required=True)
args = parser.parse_args()
print(make_prediction(args.model_path, args.image_file))