Skip to content

Commit

Permalink
fix loss function labels
Browse files Browse the repository at this point in the history
  • Loading branch information
sthalles committed Jan 21, 2021
1 parent 4c056cb commit 13a7e64
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 4 deletions.
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

def main():
args = parser.parse_args()

assert args.n_views == 2, "Only two view training is supported."
# check if gpu training is available
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
Expand Down
3 changes: 0 additions & 3 deletions simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def info_nce_loss(self, features):
# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

# if there is more than one potive (n_views >= 2) combine the multiple positives
positives = positives.mean(dim=1).unsqueeze(1)

# select only the negatives the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

Expand Down

0 comments on commit 13a7e64

Please sign in to comment.