stable-ssl
streamlines the training and evaluation of deep learning models by offering all the essential boilerplate code with minimal hardcoded utilities. stable-ssl
adopts a flexible and modular design for seamless integration of components from external libraries, including architectures, loss functions, evaluation metrics, and augmentations. The provided utilities are primarily focused on Self-Supervised Learning, yet stable-ssl
will save you time and headache regardless of your use-case.
At its core, stable-ssl
provides a BaseTrainer
class that manages job submission, data loading, training, evaluation, logging, monitoring, checkpointing, and requeuing, all customizable via a configuration file. This class is intended to be subclassed for specific training needs (see these trainers
as examples).
stable-ssl
uses Hydra
to manage input parameters through configuration files, enabling efficient hyperparameter tuning with multirun
and integration with job launchers like submitit
for Slurm.
The first step is to specify a trainer class which is a subclass of BaseTrainer
.
Optionally, the trainer may require a loss function which is then used in the compute_loss
method of the trainer.
The trainer parameters are then structured according to the following categories:
Category | Description |
---|---|
data | Defines the dataset, loading, and augmentation pipelines. The train dataset is used for training, and if absent, the model runs in evaluation mode. Its structure is fully flexible. |
module | Specifies the neural network modules and their architecture. Its structure is fully flexible. |
optim | Defines the optimization components, including the optimizer, scheduler, and the number of epochs. See defaults parameters in the OptimConfig . |
hardware | Specifies the hardware configuration, including the number of GPUs, CPUs, and precision settings. See defaults parameters in the HardwareConfig . |
logger | Configures model performance monitoring. APIs like WandB are supported. See defaults parameters in the LoggerConfig . |
Config Example : SimCLR CIFAR10
trainer:
# ===== Base Trainer =====
_target_: stable_ssl.JointEmbeddingTrainer
# ===== loss Parameters =====
loss:
_target_: stable_ssl.NTXEntLoss
temperature: 0.5
# ===== Module Parameters =====
module:
backbone:
_target_: stable_ssl.modules.load_backbone
name: resnet50
low_resolution: True
num_classes: null
projector:
_target_: stable_ssl.modules.MLP
sizes: [2048, 2048, 128]
projector_classifier:
_target_: torch.nn.Linear
in_features: 128
out_features: ${trainer.data._num_classes}
backbone_classifier:
_target_: torch.nn.Linear
in_features: 2048
out_features: ${trainer.data._num_classes}
# ===== Optim Parameters =====
optim:
epochs: 1000
optimizer:
_target_: stable_ssl.optimizers.LARS
_partial_: True
lr: 5
weight_decay: 1e-6
scheduler:
_target_: stable_ssl.schedulers.LinearWarmupCosineAnnealing
_partial_: True
total_steps: ${eval:'${trainer.optim.epochs} * ${trainer.data._num_train_samples} // ${trainer.data.train.batch_size}'}
# ===== Data Parameters =====
data:
_num_classes: 10
_num_train_samples: 50000
train: # training dataset as indicated by name 'train'
_target_: torch.utils.data.DataLoader
batch_size: 256
drop_last: True
shuffle: True
num_workers: 6
dataset:
_target_: torchvision.datasets.CIFAR10
root: ~/data
train: True
download: True
transform:
_target_: stable_ssl.data.MultiViewSampler
transforms:
# === First View ===
- _target_: torchvision.transforms.v2.Compose
transforms:
- _target_: torchvision.transforms.v2.RandomResizedCrop
size: 32
scale:
- 0.2
- 1.0
- _target_: torchvision.transforms.v2.RandomHorizontalFlip
p: 0.5
- _target_: torchvision.transforms.v2.RandomApply
p: 0.8
transforms:
- {
_target_: torchvision.transforms.v2.ColorJitter,
brightness: 0.4,
contrast: 0.4,
saturation: 0.2,
hue: 0.1,
}
- _target_: torchvision.transforms.v2.RandomGrayscale
p: 0.2
- _target_: torchvision.transforms.v2.ToImage
- _target_: torchvision.transforms.v2.ToDtype
dtype:
_target_: stable_ssl.utils.str_to_dtype
_args_: [float32]
scale: True
# === Second View ===
- _target_: torchvision.transforms.v2.Compose
transforms:
- _target_: torchvision.transforms.v2.RandomResizedCrop
size: 32
scale:
- 0.2
- 1.0
- _target_: torchvision.transforms.v2.RandomHorizontalFlip
p: 0.5
- _target_: torchvision.transforms.v2.RandomApply
p: 0.8
transforms:
- {
_target_: torchvision.transforms.v2.ColorJitter,
brightness: 0.4,
contrast: 0.4,
saturation: 0.2,
hue: 0.1,
}
- _target_: torchvision.transforms.v2.RandomGrayscale
p: 0.2
- _target_: torchvision.transforms.v2.RandomSolarize
threshold: 128
p: 0.2
- _target_: torchvision.transforms.v2.ToImage
- _target_: torchvision.transforms.v2.ToDtype
dtype:
_target_: stable_ssl.utils.str_to_dtype
_args_: [float32]
scale: True
test: # can be any name
_target_: torch.utils.data.DataLoader
batch_size: 256
num_workers: ${trainer.data.train.num_workers}
dataset:
_target_: torchvision.datasets.CIFAR10
train: False
root: ~/data
transform:
_target_: torchvision.transforms.v2.Compose
transforms:
- _target_: torchvision.transforms.v2.ToImage
- _target_: torchvision.transforms.v2.ToDtype
dtype:
_target_: stable_ssl.utils.str_to_dtype
_args_: [float32]
scale: True
# ===== Logger Parameters =====
logger:
eval_every_epoch: 10
log_every_step: 100
wandb: True
metric:
test:
acc1:
_target_: torchmetrics.classification.MulticlassAccuracy
num_classes: ${trainer.data._num_classes}
top_k: 1
acc5:
_target_: torchmetrics.classification.MulticlassAccuracy
num_classes: ${trainer.data._num_classes}
top_k: 5
# ===== Hardware Parameters =====
hardware:
seed: 0
float16: true
device: "cuda:0"
world_size: 1
To launch a run using a configuration file located in a specified folder, simply use the following command:
stable-ssl --config-path <config_path> --config-name <config_name>
Replace <config_path>
with the path to your configuration folder and <config_name>
with the name of your configuration file.
Useful options include:
Launching in multirun (example with batch size validation)
stable-ssl --multirun --config-path <config_path> --config-name <config_name> ++trainer.data.train.batch_size=128,256,512
Launching on slurm
stable-ssl --multirun --config-path <config_path> --config-name <config_name> hydra/launcher=submitit_slurm
Note: One must include the
--multirun
flag when using a launcher likesubmitit_slurm
.
The library is not yet available on PyPI. You can install it from the source code, as follows.
pip install -e .
Or you can also run:
pip install -U git+https://github.com/rbalestr-lab/stable-ssl
-
If you'd like to contribute new features, bug fixes, or improvements to the documentation, please refer to our contributing guide for detailed instructions on how to get started.
-
You can also contribute by adding new methods, datasets, or configurations that improve the current performance of a method in the benchmark section.