-
Notifications
You must be signed in to change notification settings - Fork 1
/
losses.py
97 lines (73 loc) · 3.32 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from torch import nn
from torch.nn import functional
import math
import numpy as np
class Arcface(nn.Module):
"""
An implementation of ArcFace: Additive Angular Margin Loss for Deep Face Recognition: https://arxiv.org/pdf/1801.07698.pdf
Args:
embedding_size (int): Feature dimension.
calssnum (int): Number of total classes.
m (float): Margin value, see the paper for details. Default: 0.5.
s (float): The scale value, see the paper for details. Default: 64.
"""
def __init__(self, embedding_size, classnum, m=0.5, s=64.):
super(Arcface, self).__init__()
# initial kernel
self.kernel = nn.Parameter(torch.empty(classnum, embedding_size))
nn.init.xavier_uniform_(self.kernel)
self.classnum = classnum
self.s = s
self.m = m
def forward(self, embbedings, label):
if not self.training:
self.m = 0.
cos_m = math.cos(self.m)
sin_m = math.sin(self.m)
cos_theta = functional.linear(functional.normalize(embbedings), functional.normalize(self.kernel))
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
cos_theta_m = cos_theta * cos_m - sin_theta * sin_m
cos_theta_m = torch.where(cos_theta > self.m, cos_theta_m, cos_theta)
# one hot encoding label
one_hot = torch.zeros_like(cos_theta).scatter_(1, label.view(-1, 1), 1)
output = (one_hot * cos_theta_m) + (1.0 - one_hot) * cos_theta
return output * self.s
class SVSoftmax(nn.Module):
"""
An implementation of Support Vector Guided Softmax Loss for Face Recognition: https://arxiv.org/pdf/1812.11317.pdf
Args:
embedding_size (int): Feature dimension, e.g. 512.
classnum (int): Number of total classes.
s (float): The scale value. Default: 30.
t (float): Indicator parameter, see the paper for detailed introduction. Default: 1.2.
m: Margin value used in Arcface Loss. Default: 0.5.
Notes:
This implementation is based on arcface.
"""
def __init__(self, embedding_size, classnum, s=30., t=1.2, m=0.5):
super(SVSoftmax, self).__init__()
# initial kernel
self.kernel = nn.Parameter(torch.empty(classnum, embedding_size))
nn.init.xavier_uniform_(self.kernel)
self.classnum = classnum
self.m = m
self.s = s
self.t = t
def forward(self, embbedings, label):
if not self.training:
self.m = 0.
self.t = 1
cos_m = math.cos(self.m)
sin_m = math.sin(self.m)
cos_theta = functional.linear(functional.normalize(embbedings), functional.normalize(self.kernel))
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
cos_theta_m = cos_theta * cos_m - sin_theta * sin_m
cos_theta_t = self.t * cos_theta + self.t - 1
# one hot encoding label
one_hot = torch.zeros_like(cos_theta).scatter_(1, label.view(-1, 1), 1)
# get predicted label from cos_theta
p_label = cos_theta[np.arange(cos_theta.size(0)), label].view(-1, 1)
cos_theta_m = torch.where(cos_theta > self.m, cos_theta_m, cos_theta)
output = (one_hot * cos_theta_m) + (1.0 - one_hot) * torch.where(((cos_theta > 0) & (cos_theta <= p_label)), cos_theta, cos_theta_t)
return output * self.s