Skip to content

Commit

Permalink
Fixed loss and added pytorch built in mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
sthalles committed Feb 11, 2021
1 parent 7c06b6e commit 727cbae
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 27 deletions.
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
help='seed for initializing training. ')
parser.add_argument('--disable-cuda', action='store_true',
help='Disable CUDA')
parser.add_argument('--fp16_precision', default=False, type=bool,
parser.add_argument('--fp16-precision', action='store_true',
help='Whether or not to use 16-bit precision GPU training.')

parser.add_argument('--out_dim', default=128, type=int,
Expand Down
38 changes: 12 additions & 26 deletions simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@

import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils import save_config_file, accuracy, save_checkpoint

torch.manual_seed(0)

apex_support = False
try:
sys.path.append('./apex')
from apex import amp

apex_support = True
except:
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
apex_support = False


class SimCLR(object):

Expand Down Expand Up @@ -59,18 +49,15 @@ def info_nce_loss(self, features):
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0]).to(self.args.device)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

logits = logits / self.args.temperature
return logits, labels

def train(self, train_loader):

if apex_support and self.args.fp16_precision:
logging.debug("Using apex for fp16 precision training.")
self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
opt_level='O2',
keep_batchnorm_fp32=True)
scaler = GradScaler(enabled=self.args.fp16_precision)

# save config file
save_config_file(self.writer.log_dir, self.args)

Expand All @@ -84,18 +71,17 @@ def train(self, train_loader):

images = images.to(self.args.device)

features = self.model(images)
logits, labels = self.info_nce_loss(features)
loss = self.criterion(logits, labels)
with autocast(enabled=self.args.fp16_precision):
features = self.model(images)
logits, labels = self.info_nce_loss(features)
loss = self.criterion(logits, labels)

self.optimizer.zero_grad()
if apex_support and self.args.fp16_precision:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()

self.optimizer.step()
scaler.scale(loss).backward()

scaler.step(self.optimizer)
scaler.update()

if n_iter % self.args.log_every_n_steps == 0:
top1, top5 = accuracy(logits, labels, topk=(1, 5))
Expand Down

0 comments on commit 727cbae

Please sign in to comment.