-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlosses.py
120 lines (98 loc) · 4.46 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
from pytorch_metric_learning import miners, losses
def binarize(T, nb_classes):
T = T.cpu().numpy()
import sklearn.preprocessing
T = sklearn.preprocessing.label_binarize(
T, classes = range(0, nb_classes)
)
T = torch.FloatTensor(T).cuda()
return T
def l2_norm(input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-12)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
class Proxy_Anchor(torch.nn.Module):
def __init__(self, nb_classes, sz_embed, mrg = 0.1, alpha = 32):
torch.nn.Module.__init__(self)
# Proxy Anchor Initialization
self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda())
nn.init.kaiming_normal_(self.proxies, mode='fan_out')
self.nb_classes = nb_classes
self.sz_embed = sz_embed
self.mrg = mrg
self.alpha = alpha
def forward(self, X, T):
P = self.proxies
cos = F.linear(l2_norm(X), l2_norm(P)) # Calcluate cosine similarity
P_one_hot = binarize(T = T, nb_classes = self.nb_classes)
N_one_hot = 1 - P_one_hot
pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
neg_exp = torch.exp(self.alpha * (cos - self.mrg))
with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1) # The set of positive proxies of data in the batch
num_valid_proxies = len(with_pos_proxies) # The number of positive proxies
P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0)
N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)
pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
loss = pos_term + neg_term
return loss
# We use PyTorch Metric Learning library for the following codes.
# Please refer to "https://github.com/KevinMusgrave/pytorch-metric-learning" for details.
class Proxy_NCA(torch.nn.Module):
def __init__(self, nb_classes, sz_embed, scale=32):
super(Proxy_NCA, self).__init__()
self.nb_classes = nb_classes
self.sz_embed = sz_embed
self.scale = scale
self.loss_func = losses.ProxyNCALoss(num_classes = self.nb_classes, embedding_size = self.sz_embed, softmax_scale = self.scale).cuda()
def forward(self, embeddings, labels):
loss = self.loss_func(embeddings, labels)
return loss
class MultiSimilarityLoss(torch.nn.Module):
def __init__(self, ):
super(MultiSimilarityLoss, self).__init__()
self.thresh = 0.5
self.epsilon = 0.1
self.scale_pos = 2
self.scale_neg = 50
self.miner = miners.MultiSimilarityMiner(epsilon=self.epsilon)
self.loss_func = losses.MultiSimilarityLoss(self.scale_pos, self.scale_neg, self.thresh)
def forward(self, embeddings, labels):
hard_pairs = self.miner(embeddings, labels)
loss = self.loss_func(embeddings, labels, hard_pairs)
return loss
class ContrastiveLoss(nn.Module):
def __init__(self, margin=0.5, **kwargs):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.loss_func = losses.ContrastiveLoss(neg_margin=self.margin)
def forward(self, embeddings, labels):
loss = self.loss_func(embeddings, labels)
return loss
class TripletLoss(nn.Module):
def __init__(self, margin=0.1, **kwargs):
super(TripletLoss, self).__init__()
self.margin = margin
self.miner = miners.TripletMarginMiner(margin, type_of_triplets = 'semihard')
self.loss_func = losses.TripletMarginLoss(margin = self.margin)
def forward(self, embeddings, labels):
hard_pairs = self.miner(embeddings, labels)
loss = self.loss_func(embeddings, labels, hard_pairs)
return loss
class NPairLoss(nn.Module):
def __init__(self, l2_reg=0):
super(NPairLoss, self).__init__()
self.l2_reg = l2_reg
self.loss_func = losses.NPairsLoss(l2_reg_weight=self.l2_reg, normalize_embeddings = False)
def forward(self, embeddings, labels):
loss = self.loss_func(embeddings, labels)
return loss