Skip to content

Commit

Permalink
Merge pull request #24 from sthalles/simclr-refactor
Browse files Browse the repository at this point in the history
Simclr refactor
  • Loading branch information
sthalles authored Jan 18, 2021
2 parents e8a690a + e3065f3 commit 0ba2d0d
Show file tree
Hide file tree
Showing 14 changed files with 919 additions and 1,286 deletions.
83 changes: 17 additions & 66 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,78 +16,29 @@ $ python run.py

## Config file

Before running SimCLR, make sure you choose the correct running configurations on the ```config.yaml``` file.

```yaml

# A batch size of N, produces 2 * (N-1) negative samples. Original implementation uses a batch size of 8192
batch_size: 512

# Number of epochs to train
epochs: 40

# Frequency to eval the similarity score using the validation set
eval_every_n_epochs: 1

# Specify a folder containing a pre-trained model to fine-tune. If training from scratch, pass None.
fine_tune_from: 'resnet-18_80-epochs'

# Frequency to which tensorboard is updated
log_every_n_steps: 50

# l2 Weight decay magnitude, original implementation uses 10e-6
weight_decay: 10e-6

# if True, training is done using mixed precision. Apex needs to be installed in this case.
fp16_precision: False

# Model related parameters
model:
# Output dimensionality of the embedding vector z. Original implementation uses 2048
out_dim: 256

# The ConvNet base model. Choose one of: "resnet18" or "resnet50". Original implementation uses resnet50
base_model: "resnet18"

# Dataset related parameters
dataset:
s: 1

# dataset input shape. For datasets containing images of different size, this defines the final
input_shape: (96,96,3)

# Number of workers for the data loader
num_workers: 0

# Size of the validation set in percentage
valid_size: 0.05

# NTXent loss related parameters
loss:
# Temperature parameter for the contrastive objective
temperature: 0.5

# Distance metric for contrastive loss. If False, uses dot product. Original implementation uses cosine similarity.
use_cosine_similarity: True
Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the ```run.py``` file.

```python

$ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100

```

If you want to run it on CPU (for debugging purposes) use the ```--disable-cuda``` option.

For 16-bit precision GPU training, make sure to install [NVIDIA apex](https://github.com/NVIDIA/apex) and use the ```--fp16_precision``` flag.

## Feature Evaluation

Feature evaluation is done using a linear model protocol.

Features are learned using the ```STL10 train+unsupervised``` set and evaluated in the ```test``` set;
First, we learned features using SimCLR on the ```STL10 unsupervised``` set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linera model is trained on features extracted from the ```STL10 train``` set and evaluated on the ```STL10 test``` set.

Check the [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/sthalles/SimCLR/blob/master/feature_eval/linear_feature_eval.ipynb) notebook for reproducibility.
Check the [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb) notebook for reproducibility.


| Linear Classifier | Feature Extractor | Architecture | Feature dimensionality | Projection Head dimensionality | Epochs | STL10 Top 1 |
|:---------------------------:|:-----------------:|:------------:|:----------------------:|:-------------------------------:|:------:|:-----------:|
| Logistic Regression | PCA Features | - | 256 | - | | 36.0% |
| KNN | PCA Features | - | 256 | - | | 31.8% |
| Logistic Regression (LBFGS) | SimCLR | [ResNet-18](https://drive.google.com/open?id=1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk) | 512 | 256 | 40 | 70.3% |
| KNN | SimCLR | ResNet-18 | 512 | 256 | 40 | 66.2% |
| Logistic Regression (LBFGS) | SimCLR | [ResNet-18](https://drive.google.com/open?id=1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT) | 512 | 256 | 80 | 72.9% |
| KNN | SimCLR | ResNet-18 | 512 | 256 | 80 | 69.8% |
| Logistic Regression (Adam) | SimCLR | [ResNet-18](https://drive.google.com/open?id=1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ) | 512 | 256 | 100 | 75.4% |
| Logistic Regression (Adam) | SimCLR | [ResNet-50](https://drive.google.com/open?id=1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q) | 2048 | 128 | 40 | 74.6% |
| Logistic Regression (Adam) | SimCLR | [ResNet-50](https://drive.google.com/open?id=1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35) | 2048 | 128 | 80 | 77.3% |
| Linear Classification | Dataset | Feature Extractor | Architecture | Feature dimensionality | Projection Head dimensionality | Epochs | Top1 % |
|----------------------------|---------|-------------------|---------------------------------------------------------------------------------|------------------------|--------------------------------|--------|--------|
| Logistic Regression (Adam) | STL10 | SimCLR | [ResNet-18](https://drive.google.com/open?id=14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF) | 512 | 128 | 100 | 70.45 |
| Logistic Regression (Adam) | CIFAR10 | SimCLR | [ResNet-18](https://drive.google.com/open?id=1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C) | 512 | 128 | 100 | 64.82 |
| Logistic Regression (Adam) | STL10 | SimCLR | [ResNet-50](https://drive.google.com/open?id=1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu) | 2048 | 128 | 50 | 67.075 |
21 changes: 0 additions & 21 deletions config.yaml

This file was deleted.

42 changes: 42 additions & 0 deletions data_aug/contrastive_learning_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torchvision.transforms import transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import transforms, datasets
from data_aug.view_generator import ContrastiveLearningViewGenerator
from exceptions.exceptions import InvalidDatasetSelection


class ContrastiveLearningDataset:
def __init__(self, root_folder):
self.root_folder = root_folder

@staticmethod
def get_simclr_pipeline_transform(size, s=1):
"""Return a set of data augmentation transformations as described in the SimCLR paper."""
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * size)),
transforms.ToTensor()])
return data_transforms

def get_dataset(self, name, n_views):
valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(32),
n_views),
download=True),

'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(96),
n_views),
download=True)}

try:
dataset_fn = valid_datasets[name]
except KeyError:
raise InvalidDatasetSelection()
else:
return dataset_fn()
68 changes: 0 additions & 68 deletions data_aug/dataset_wrapper.py

This file was deleted.

53 changes: 38 additions & 15 deletions data_aug/gaussian_blur.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,48 @@
import cv2
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms

np.random.seed(0)


class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
"""blur a single image on CPU"""
def __init__(self, kernel_size):
radias = kernel_size // 2
kernel_size = radias * 2 + 1
self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
stride=1, padding=0, bias=False, groups=3)
self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
stride=1, padding=0, bias=False, groups=3)
self.k = kernel_size
self.r = radias

def __call__(self, sample):
sample = np.array(sample)
self.blur = nn.Sequential(
nn.ReflectionPad2d(radias),
self.blur_h,
self.blur_v
)

# blur the image with a 50% chance
prob = np.random.random_sample()
self.pil_to_tensor = transforms.ToTensor()
self.tensor_to_pil = transforms.ToPILImage()

if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
def __call__(self, img):
img = self.pil_to_tensor(img).unsqueeze(0)

return sample
sigma = np.random.uniform(0.1, 2.0)
x = np.arange(-self.r, self.r + 1)
x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
x = x / x.sum()
x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

with torch.no_grad():
img = self.blur(img)
img = img.squeeze()

img = self.tensor_to_pil(img)

return img
14 changes: 14 additions & 0 deletions data_aug/view_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np

np.random.seed(0)


class ContrastiveLearningViewGenerator(object):
"""Take two random crops of one image as the query and key."""

def __init__(self, base_transform, n_views=2):
self.base_transform = base_transform
self.n_views = n_views

def __call__(self, x):
return [self.base_transform(x) for i in range(self.n_views)]
10 changes: 10 additions & 0 deletions exceptions/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class BaseSimCLRException(Exception):
"""Base exception"""


class InvalidBackboneError(BaseSimCLRException):
"""Raised when the choice of backbone Convnet is invalid."""


class InvalidDatasetSelection(BaseSimCLRException):
"""Raised when the choice of dataset is invalid."""
Loading

0 comments on commit 0ba2d0d

Please sign in to comment.