-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding_loss.py
28 lines (22 loc) · 1008 Bytes
/
embedding_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
import torch.nn.functional as F
class EmbeddingLoss(nn.Module):
def __init__(self, feature_extractor):
super(EmbeddingLoss, self).__init__()
self.feature_extractor = feature_extractor
for param in self.feature_extractor.parameters():
param.requires_grad = False
def forward(self, input_embeddings, output_sequence):
"""
Computes the loss between the feature extracted from y and the target x.
Parameters:
x (torch.Tensor): Target tensor of shape (2000, 1, 512)
y (torch.Tensor): Input tensor to be passed through the feature extractor
Returns:
torch.Tensor: Computed loss
"""
y, style_embedding, positional_embedding = self.feature_extractor(output_sequence)
output_embeddings = torch.concat((style_embedding, positional_embedding), 0)
loss = F.mse_loss(output_embeddings, input_embeddings)
return loss