From 4c056cb919f873103b0b058aeaa5a86b8c866f13 Mon Sep 17 00:00:00 2001 From: Thalles Silva Date: Thu, 21 Jan 2021 06:15:50 -0300 Subject: [PATCH] fix loss function labels --- simclr.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/simclr.py b/simclr.py index 2fa78b4..5bd0e41 100644 --- a/simclr.py +++ b/simclr.py @@ -34,19 +34,38 @@ def __init__(self, *args, **kwargs): self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device) def info_nce_loss(self, features): - batch_targets = torch.arange(self.args.batch_size, dtype=torch.long).to(self.args.device) - batch_targets = torch.cat(self.args.n_views * [batch_targets]) + + labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0) + labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() + labels = labels.to(self.args.device) features = F.normalize(features, dim=1) similarity_matrix = torch.matmul(features, features.T) # assert similarity_matrix.shape == ( # self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size) + # assert similarity_matrix.shape == labels.shape + + # discard the main diagonal from both: labels and similarities matrix + mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device) + labels = labels[~mask].view(labels.shape[0], -1) + similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) + # assert similarity_matrix.shape == labels.shape + + # select and combine multiple positives + positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) + + # if there is more than one potive (n_views >= 2) combine the multiple positives + positives = positives.mean(dim=1).unsqueeze(1) + + # select only the negatives the negatives + negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) + + logits = torch.cat([positives, negatives], dim=1) + labels = torch.zeros(logits.shape[0]).to(self.args.device) - mask = torch.eye(len(batch_targets)).to(self.args.device) - similarities = similarity_matrix[~mask.bool()].view(similarity_matrix.shape[0], -1) - similarities = similarities / self.args.temperature - return similarities, batch_targets + logits = logits / self.args.temperature + return logits, labels def train(self, train_loader):