-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss_func_utils.py
104 lines (81 loc) · 3.19 KB
/
loss_func_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
from experiment_logger import get_logger
class BaseLossFunc(nn.Module):
def __init__(self, **kwargs):
super().__init__()
# defaults
self.enabled = True
self.unregistered = {}
self.logger = get_logger()
self.init_defaults()
# override
for k, v in kwargs.items():
if k in self.unregistered:
self.unregistered[k] = v
continue
setattr(self, k, v)
self.init_parameters()
self.reset_parameters()
@classmethod
def from_kwargs_dict(cls, context, kwargs_dict):
kwargs_dict['_cuda'] = context['cuda']
return cls(**kwargs_dict)
def init_defaults(self):
pass
def init_parameters(self):
pass
def reset_parameters(self):
params = [p for p in self.parameters() if p.requires_grad]
for i, param in enumerate(params):
param.data.normal_()
def forward(self, sentences, neg_samples, diora, info, embed=None):
# loss = 0
# ret = {}
# ret[self.name] = loss
# return loss, ret
raise NotImplementedError
def scores_for_cross_entropy(sentences, neg_samples, cell, embeddings, mat):
batch_size, length = sentences.shape
k = neg_samples.shape[0]
emb_pos = embeddings(sentences)
emb_neg = embeddings(neg_samples.unsqueeze(0))
cell = cell.view(batch_size, length, 1, -1)
proj_pos = torch.matmul(emb_pos, torch.t(mat))
proj_neg = torch.matmul(emb_neg, torch.t(mat))
xp = torch.einsum('abc,abxc->abx', proj_pos, cell)
xn = torch.einsum('zec,abxc->abe', proj_neg, cell)
score = torch.cat([xp, xn], 2)
return score
def scores_for_tokens(tokens, cell, embeddings, mat):
assert len(tokens.shape) == 1
batch_size, length, size = cell.shape
k = tokens.shape[0]
emb = embeddings(tokens.unsqueeze(0))
cell = cell.view(batch_size, length, 1, size)
proj = torch.matmul(emb, torch.t(mat))
score = torch.einsum('zec,abxc->abe', proj, cell)
return score
def cross_entropy(sentences, neg_samples, score, mask_duplicates=True):
batch_size, length = sentences.shape
k = neg_samples.shape[0]
# Ignore duplicates of the ground truth from the negative samples.
mask = sentences.view(batch_size, length, 1) == neg_samples.view(1, 1, k)
if mask.shape[0] == 0 or not mask_duplicates:
mask = None
def logsumexp(x, dim=-1, eps=1e-8, mask=None):
if mask is None:
argval, argmax = torch.max(x, dim=dim, keepdim=True)
exp = torch.exp(x - argval)
else:
# Optionally mask elements.
x = x * 1 # copy x
x[:, :, 1:][mask] = x.min().data # do this so that max is not effected
argval, argmax = torch.max(x, dim=dim, keepdim=True)
exp = torch.exp(x - argval)
diff = torch.zeros(x.shape, dtype=torch.float, device=x.device)
diff[:, :, 1:][mask] = exp[:, :, 1:][mask]
exp = exp - diff # need to mask exp also
return argval.squeeze(dim) + torch.log(torch.sum(exp, dim=dim) + eps)
xent = -score[:, :, 0] + logsumexp(score, dim=-1, mask=mask)
return xent