Skip to content

Commit

Permalink
Major refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sthalles committed Jan 17, 2021
1 parent e8a690a commit 2c9536f
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 340 deletions.
21 changes: 0 additions & 21 deletions config.yaml

This file was deleted.

37 changes: 37 additions & 0 deletions data_aug/contrastive_learning_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from torchvision.transforms import transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import transforms, datasets
from data_aug.view_generator import ContrastiveLearningViewGenerator


class ContrastiveLearningDataset:
def __init__(self, root_folder):
self.root_folder = root_folder

@staticmethod
def get_simclr_pipeline_transform(size, s=1):
"""Return a set of data augmentation transformations as described in the SimCLR paper."""
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * size)),
transforms.ToTensor()])
return data_transforms

def get_dataset(self, name, n_views):
valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(32),
n_views),
download=True),

'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(96),
n_views),
download=True)}

dataset = valid_datasets.get(name, 'Invalid dataset option.')()
return dataset
68 changes: 0 additions & 68 deletions data_aug/dataset_wrapper.py

This file was deleted.

53 changes: 38 additions & 15 deletions data_aug/gaussian_blur.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,48 @@
import cv2
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms

np.random.seed(0)


class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
"""blur a single image on CPU"""
def __init__(self, kernel_size):
radias = kernel_size // 2
kernel_size = radias * 2 + 1
self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
stride=1, padding=0, bias=False, groups=3)
self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
stride=1, padding=0, bias=False, groups=3)
self.k = kernel_size
self.r = radias

def __call__(self, sample):
sample = np.array(sample)
self.blur = nn.Sequential(
nn.ReflectionPad2d(radias),
self.blur_h,
self.blur_v
)

# blur the image with a 50% chance
prob = np.random.random_sample()
self.pil_to_tensor = transforms.ToTensor()
self.tensor_to_pil = transforms.ToPILImage()

if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
def __call__(self, img):
img = self.pil_to_tensor(img).unsqueeze(0)

return sample
sigma = np.random.uniform(0.1, 2.0)
x = np.arange(-self.r, self.r + 1)
x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
x = x / x.sum()
x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

with torch.no_grad():
img = self.blur(img)
img = img.squeeze()

img = self.tensor_to_pil(img)

return img
14 changes: 14 additions & 0 deletions data_aug/view_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np

np.random.seed(0)


class ContrastiveLearningViewGenerator(object):
"""Take two random crops of one image as the query and key."""

def __init__(self, base_transform, n_views=2):
self.base_transform = base_transform
self.n_views = n_views

def __call__(self, x):
return [self.base_transform(x) for i in range(self.n_views)]
6 changes: 6 additions & 0 deletions exceptions/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class BaseSimCLRException(Exception):
"""Base exception"""


class InvalidBackboneError(BaseSimCLRException):
"""Raised when the choice of backbone Convnet is invalid."""
65 changes: 0 additions & 65 deletions loss/nt_xent.py

This file was deleted.

43 changes: 0 additions & 43 deletions models/baseline_encoder.py

This file was deleted.

33 changes: 13 additions & 20 deletions models/resnet_simclr.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,30 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

from exceptions.exceptions import InvalidBackboneError


class ResNetSimCLR(nn.Module):

def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
"resnet50": models.resnet50(pretrained=False)}

resnet = self._get_basemodel(base_model)
num_ftrs = resnet.fc.in_features
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

self.features = nn.Sequential(*list(resnet.children())[:-1])
self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features

# projection MLP
self.l1 = nn.Linear(num_ftrs, num_ftrs)
self.l2 = nn.Linear(num_ftrs, out_dim)
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

def _get_basemodel(self, model_name):
try:
model = self.resnet_dict[model_name]
print("Feature extractor:", model_name)
except KeyError:
raise InvalidBackboneError(
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
else:
return model
except:
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

def forward(self, x):
h = self.features(x)
h = h.squeeze()

x = self.l1(h)
x = F.relu(x)
x = self.l2(x)
return h, x
return self.backbone(x)
Loading

0 comments on commit 2c9536f

Please sign in to comment.