Skip to content

Commit

Permalink
Major refactor, small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sthalles committed Jan 18, 2021
1 parent f60e9b8 commit 651d40a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
36 changes: 3 additions & 33 deletions simclr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import os
import shutil
import sys

import torch
import torch.nn.functional as F
import yaml
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils import save_config_file, accuracy, save_checkpoint

torch.manual_seed(0)

apex_support = False
Expand All @@ -22,36 +22,6 @@
apex_support = False


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')


def _save_config_file(model_checkpoints_folder, args):
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
yaml.dump(args, outfile, default_flow_style=False)


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res


class SimCLR(object):

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -86,7 +56,7 @@ def train(self, train_loader):
opt_level='O2',
keep_batchnorm_fp32=True)
# save config file
_save_config_file(self.writer.log_dir, self.args)
save_config_file(self.writer.log_dir, self.args)

n_iter = 0
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
Expand Down
35 changes: 35 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import shutil

import torch
import yaml


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')


def save_config_file(model_checkpoints_folder, args):
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
yaml.dump(args, outfile, default_flow_style=False)


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

0 comments on commit 651d40a

Please sign in to comment.