Skip to content

Commit

Permalink
Major refactor, small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sthalles committed Jan 17, 2021
1 parent 2c9536f commit d0112ed
Show file tree
Hide file tree
Showing 4 changed files with 473 additions and 449 deletions.
9 changes: 7 additions & 2 deletions data_aug/contrastive_learning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from data_aug.gaussian_blur import GaussianBlur
from torchvision import transforms, datasets
from data_aug.view_generator import ContrastiveLearningViewGenerator
from exceptions.exceptions import InvalidDatasetSelection


class ContrastiveLearningDataset:
Expand Down Expand Up @@ -33,5 +34,9 @@ def get_dataset(self, name, n_views):
n_views),
download=True)}

dataset = valid_datasets.get(name, 'Invalid dataset option.')()
return dataset
try:
dataset_fn = valid_datasets[name]
except KeyError:
raise InvalidDatasetSelection()
else:
return dataset_fn()
6 changes: 5 additions & 1 deletion exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ class BaseSimCLRException(Exception):


class InvalidBackboneError(BaseSimCLRException):
"""Raised when the choice of backbone Convnet is invalid."""
"""Raised when the choice of backbone Convnet is invalid."""


class InvalidDatasetSelection(BaseSimCLRException):
"""Raised when the choice of dataset is invalid."""
Loading

0 comments on commit d0112ed

Please sign in to comment.