-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_embed.py
39 lines (29 loc) · 1.19 KB
/
get_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from facenet_pytorch import MTCNN, InceptionResnetV1
from tqdm import tqdm
from PIL import Image
import torch
import os
import json
with open('data\id.json', 'r') as f:
ID = json.load(f)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))
print("Aligning for faces from data")
mtcnn = MTCNN(image_size=160, device=device)
aligned = []
for idx, person in tqdm(enumerate(ID)):
fpath = os.path.join('./data/images', person['image'])
with Image.open(fpath) as x:
x_aligned, prob = mtcnn(x, return_prob=True)
if x_aligned is not None:
aligned.append(x_aligned)
else:
print(f"Error: No image of {person['first']} {person['last']} found in {person['image']}", end='')
print("You might wanna crop it and reduce the dimensions to around 700px on the sorter side.")
print("Loading face recognition model")
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
print("Computing face embeddings")
aligned = torch.stack(aligned).to(device)
embeddings = resnet(aligned).detach().cpu()
torch.save({'embedding': embeddings}, 'data/embeddings.pt')
print("\n[All done]")