diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index d8db9c81..f28dcb77 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -30,7 +30,7 @@ jobs: run: | echo "PYTHON_VERSION=$(python -c "import platform; print(platform.python_version())")" - - name: Cache folder for Torch Uncertainty + - name: Cache folder for TorchUncertainty uses: actions/cache@v3 id: cache-folder with: diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index b3872575..c404a106 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -52,7 +52,7 @@ jobs: LICENSE .gitignore - - name: Cache folder for Torch Uncertainty + - name: Cache folder for TorchUncertainty if: steps.changed-files-specific.outputs.only_changed != 'true' uses: actions/cache@v4 id: cache-folder @@ -70,8 +70,8 @@ jobs: - name: Check style & format if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m ruff check torch_uncertainty tests --no-fix - python3 -m ruff format torch_uncertainty tests --check + python3 -m ruff check torch_uncertainty --no-fix + python3 -m ruff format torch_uncertainty --check - name: Test with pytest and compute coverage if: steps.changed-files-specific.outputs.only_changed != 'true' diff --git a/.gitignore b/.gitignore index 6ed77954..4aa0c186 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ docs/*/auto_tutorials/ *.pth *.ckpt *.out +docs/source/sg_execution_times.rst +test**/*.csv # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 45472d62..b90edf6f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@
-![Torch Uncertainty Logo](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/blob/main/docs/source/_static/images/torch_uncertainty.png) +![TorchUncertaintyLogo](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/blob/main/docs/source/_static/images/torch_uncertainty.png) [![pypi](https://img.shields.io/pypi/v/torch_uncertainty.svg)](https://pypi.python.org/pypi/torch_uncertainty) [![tests](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/actions/workflows/run-tests.yml/badge.svg?branch=main&event=push)](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/actions/workflows/run-tests.yml) @@ -11,40 +11,42 @@ [![Discord Badge](https://dcbadge.vercel.app/api/server/HMCawt5MJu?compact=true&style=flat)](https://discord.gg/HMCawt5MJu)
-_TorchUncertainty_ is a package designed to help you leverage uncertainty quantification techniques and make your deep neural networks more reliable. It aims at being collaborative and including as many methods as possible, so reach out to add yours! +_TorchUncertainty_ is a package designed to help you leverage [uncertainty quantification techniques](https://github.com/ENSTA-U2IS-AI/awesome-uncertainty-deeplearning) and make your deep neural networks more reliable. It aims at being collaborative and including as many methods as possible, so reach out to add yours! :construction: _TorchUncertainty_ is in early development :construction: - expect changes, but reach out and contribute if you are interested in the project! **Please raise an issue if you have any bugs or difficulties and join the [discord server](https://discord.gg/HMCawt5MJu).** +Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). + --- This package provides a multi-level API, including: +- easy-to-use ⚡️ lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. - ready-to-train baselines on research datasets, such as ImageNet and CIFAR -- deep learning baselines available for training on your datasets - [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (work in progress 🚧). -- layers available for use in your networks -- scikit-learn style post-processing methods such as Temperature Scaling +- **layers**, **models**, **metrics**, & **losses** available for use in your networks +- scikit-learn style post-processing methods such as Temperature Scaling. -See the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. +Have a look at the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. -## Installation +## ⚙️ Installation -Install the desired PyTorch version in your environment. +TorchUncertainty requires Python 3.10 or greater. Install the desired PyTorch version in your environment. Then, install the package from PyPI: ```sh pip install torch-uncertainty ``` -If you aim to contribute, have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). +The installation procedure for contributors is different: have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). -## Getting Started and Documentation +## :racehorse: Quickstart -Please find the documentation at [torch-uncertainty.github.io](https://torch-uncertainty.github.io). +We make a quickstart available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). -A quickstart is available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). +## :books: Implemented methods -## Implemented methods +TorchUncertainty currently supports **Classification**, **probabilistic** and pointwise **Regression** and **Segmentation**. ### Baselines @@ -55,7 +57,7 @@ To date, the following deep learning baselines have been implemented: - BatchEnsemble - Masksembles - MIMO -- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) +- Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) - Regression with Beta Gaussian NLL Loss - Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) @@ -75,7 +77,7 @@ To date, the following post-processing methods have been implemented: ## Tutorials -We provide the following tutorials in our documentation: +Our documentation contains the following tutorials: - [From a Standard Classifier to a Packed-Ensemble](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - [Training a Bayesian Neural Network in 3 minutes](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) @@ -84,10 +86,6 @@ We provide the following tutorials in our documentation: - [Training a LeNet with Monte-Carlo Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) - [Training a LeNet with Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) -## Awesome Uncertainty repositories - -You may find a lot of papers about modern uncertainty estimation techniques on the [Awesome Uncertainty in Deep Learning](https://github.com/ENSTA-U2IS-AI/awesome-uncertainty-deeplearning). - ## Other References This package also contains the official implementation of Packed-Ensembles. diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 64c8fce8..04f1202e 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -2,58 +2,56 @@ Train a Bayesian Neural Network in Three Minutes ================================================ -In this tutorial, we will train a Bayesian Neural Network (BNN) LeNet classifier on the MNIST dataset. +In this tutorial, we will train a variational inference Bayesian Neural Network (BNN) LeNet classifier on the MNIST dataset. Foreword on Bayesian Neural Networks ------------------------------------ -Bayesian Neural Networks (BNNs) are a class of neural networks that can estimate the uncertainty of their predictions via uncertainty on their weights. This is achieved by considering the weights of the neural network as random variables, and by learning their posterior distribution. This is in contrast to standard neural networks, which only learn a single set of weights, which can be seen as Dirac distributions on the weights. +Bayesian Neural Networks (BNNs) are a class of neural networks that estimate the uncertainty on their predictions via uncertainty +on their weights. This is achieved by considering the weights of the neural network as random variables, and by learning their +posterior distribution. This is in contrast to standard neural networks, which only learn a single set of weights, which can be +seen as Dirac distributions on the weights. For more information on Bayesian Neural Networks, we refer the reader to the following resources: - Weight Uncertainty in Neural Networks `ICML2015 `_ - Hands-on Bayesian Neural Networks - a Tutorial for Deep Learning Users `IEEE Computational Intelligence Magazine `_ -Training a Bayesian LeNet using TorchUncertainty models and PyTorch Lightning ------------------------------------------------------------------------------ +Training a Bayesian LeNet using TorchUncertainty models and Lightning +--------------------------------------------------------------------- In this part, we train a bayesian LeNet, based on the model and routines already implemented in TU. 1. Loading the utilities ~~~~~~~~~~~~~~~~~~~~~~~~ -To train a BNN using TorchUncertainty, we have to load the following utilities from TorchUncertainty: +To train a BNN using TorchUncertainty, we have to load the following modules: -- the cli handler: cli_main and argument parser: init_args -- the model: bayesian_lenet, which lies in the torch_uncertainty.model module -- the classification training routine in the torch_uncertainty.training.classification module +- the Trainer from Lightning +- the model: bayesian_lenet, which lies in the torch_uncertainty.model +- the classification training routine from torch_uncertainty.routines - the bayesian objective: the ELBOLoss, which lies in the torch_uncertainty.losses file -- the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule -""" +- the datamodule that handles dataloaders: MNISTDataModule from torch_uncertainty.datamodules -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.losses import ELBOLoss -from torch_uncertainty.models.lenet import bayesian_lenet -from torch_uncertainty.routines.classification import ClassificationSingle +We will also need to define an optimizer using torch.optim, the +neural network utils from torch.nn, as well as the partial util to provide +the modified default arguments for the ELBO loss. +""" # %% -# We will also need to define an optimizer using torch.optim as well as the -# neural network utils withing torch.nn, as well as the partial util to provide -# the modified default arguments for the ELBO loss. -# -# We also import sys to override the command line arguments. - -import os -from functools import partial from pathlib import Path -import sys +from lightning.pytorch import Trainer from torch import nn, optim +from torch_uncertainty.datamodules import MNISTDataModule +from torch_uncertainty.losses import ELBOLoss +from torch_uncertainty.models.lenet import bayesian_lenet +from torch_uncertainty.routines import ClassificationRoutine + # %% -# 2. Creating the Optimizer Wrapper -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 2. The Optimization Recipe +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # We will use the Adam optimizer with the default learning rate of 0.001. @@ -69,26 +67,19 @@ def optim_lenet(model: nn.Module) -> dict: # 3. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In the following, we will need to define the root of the datasets and the -# logs, and to fake-parse the arguments needed for using the PyTorch Lightning -# Trainer. We also create the datamodule that handles the MNIST dataset, -# dataloaders and transforms. Finally, we create the model using the -# blueprint from torch_uncertainty.models. - -root = Path(os.path.abspath("")) +# In the following, we define the Lightning trainer, the root of the datasets and the logs. +# We also create the datamodule that handles the MNIST dataset, dataloaders and transforms. +# Please note that the datamodules can also handle OOD detection by setting the eval_ood +# parameter to True. Finally, we create the model using the blueprint from torch_uncertainty.models. -# We mock the arguments for the trainer -sys.argv = ["file.py", "--max_epochs", "1", "--enable_progress_bar", "False"] -args = init_args(datamodule=MNISTDataModule) - -net_name = "logs/bayesian-lenet-mnist" +trainer = Trainer(accelerator="cpu", enable_progress_bar=False, max_epochs=1) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) +root = Path("") / "data" +datamodule = MNISTDataModule(root=root, batch_size=128, eval_ood=False) # model -model = bayesian_lenet(dm.num_channels, dm.num_classes) +model = bayesian_lenet(datamodule.num_channels, datamodule.num_classes) # %% # 4. The Loss and the Training Routine @@ -99,24 +90,21 @@ def optim_lenet(model: nn.Module) -> dict: # library. As we are train a classification model, we use the CrossEntropyLoss # as the likelihood. # We then define the training routine using the classification training routine -# from torch_uncertainty.training.classification. We provide the model, the ELBO -# loss and the optimizer, as well as all the default arguments. +# from torch_uncertainty.classification. We provide the model, the ELBO +# loss and the optimizer to the routine. -loss = partial( - ELBOLoss, +loss = ELBOLoss( model=model, - criterion=nn.CrossEntropyLoss(), + inner_loss=nn.CrossEntropyLoss(), kl_weight=1 / 50000, num_samples=3, ) -baseline = ClassificationSingle( +routine = ClassificationRoutine( model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, + num_classes=datamodule.num_classes, loss=loss, - optimization_procedure=optim_lenet, - **vars(args), + optim_recipe=optim_lenet(model), ) # %% @@ -124,14 +112,14 @@ def optim_lenet(model: nn.Module) -> dict: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Now that we have prepared all of this, we just have to gather everything in -# the main function and to train the model using the PyTorch Lightning Trainer. -# Specifically, it needs the baseline, that includes the model as well as the -# training routine, the datamodule, the root for the datasets and the logs, the -# name of the model for the logs and all the training arguments. +# the main function and to train the model using the Lightning Trainer. +# Specifically, it needs the routine, that includes the model as well as the +# training/eval logic and the datamodule # The dataset will be downloaded automatically in the root/data folder, and the # logs will be saved in the root/logs folder. -results = cli_main(baseline, dm, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% # 6. Testing the Model @@ -140,19 +128,20 @@ def optim_lenet(model: nn.Module) -> dict: # Now that the model is trained, let's test it on MNIST import matplotlib.pyplot as plt +import numpy as np import torch import torchvision -import numpy as np - def imshow(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.axis("off") + plt.tight_layout() plt.show() -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) # print images diff --git a/auto_tutorials_source/tutorial_corruptions.py b/auto_tutorials_source/tutorial_corruptions.py index 709fa0a7..6713834f 100644 --- a/auto_tutorials_source/tutorial_corruptions.py +++ b/auto_tutorials_source/tutorial_corruptions.py @@ -105,7 +105,7 @@ def show_images(transform): #%% # 10. Frost -# ~~~~~~~~ +# ~~~~~~~~~ from torch_uncertainty.transforms.corruptions import Frost show_images(Frost) diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 3470689b..b77a0a4d 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -2,9 +2,14 @@ Deep Evidential Regression on a Toy Example =========================================== -This tutorial aims to provide an introductory overview of Deep Evidential Regression (DER) using a practical example. We demonstrate an application of DER by tackling the toy-problem of fitting :math:`y=x^3` using a Multi-Layer Perceptron (MLP) neural network model. The output layer of the MLP has four outputs, and is trained by minimizing the Normal Inverse-Gamma (NIG) loss function. +This tutorial provides an introduction to probabilistic regression in TorchUncertainty. -DER represents an evidential approach to quantifying uncertainty in neural network regression models. This method involves introducing prior distributions over the parameters of the Gaussian likelihood function. Then, the MLP model estimate the parameters of the evidential distribution. +More specifically, we present Deep Evidential Regression (DER) using a practical example. We demonstrate an application of DER by tackling the toy-problem of fitting :math:`y=x^3` using a Multi-Layer Perceptron (MLP) neural network model. +The output layer of the MLP provides a NormalInverseGamma distribution which is used to optimize the model, through its negative log-likelihood. + +DER represents an evidential approach to quantifying epistemic and aleatoric uncertainty in neural network regression models. +This method involves introducing prior distributions over the parameters of the Gaussian likelihood function. +Then, the MLP model estimates the parameters of this evidential distribution. Training a MLP with DER using TorchUncertainty models and PyTorch Lightning --------------------------------------------------------------------------- @@ -14,42 +19,32 @@ 1. Loading the utilities ~~~~~~~~~~~~~~~~~~~~~~~~ -To train a MLP with the NIG loss function using TorchUncertainty, we have to load the following utilities from TorchUncertainty: +To train a MLP with the DER loss function using TorchUncertainty, we have to load the following modules: -- the cli handler: cli_main and argument parser: init_args -- the model: mlp, which lies in the torch_uncertainty.baselines.regression.mlp module. -- the regression training routine in the torch_uncertainty.routines.regression module. -- the evidential objective: the NIGLoss, which lies in the torch_uncertainty.losses file -- a dataset that generates samples from a noisy cubic function: Cubic, which lies in the torch_uncertainty.datasets.regression -""" - -from pytorch_lightning import LightningDataModule -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines.regression.mlp import mlp -from torch_uncertainty.datasets.regression.toy import Cubic -from torch_uncertainty.losses import NIGLoss -from torch_uncertainty.routines.regression import RegressionSingle +- the Trainer from Lightning +- the model: mlp from torch_uncertainty.models.mlp +- the regression training routine from torch_uncertainty.routines +- the evidential objective: the DERLoss from torch_uncertainty.losses. This loss contains the classic NLL loss and a regularization term. +- a dataset that generates samples from a noisy cubic function: Cubic from torch_uncertainty.datasets.regression +We also need to define an optimizer using torch.optim and the neural network utils within torch.nn. +""" # %% -# We also need to define an optimizer using torch.optim as well as the -# neural network utils withing torch.nn, as well as the partial util to provide -# the modified default arguments for the NIG loss. -# -# We also import sys to override the command line arguments. - -import os -import sys -from functools import partial -from pathlib import Path - import torch +from lightning.pytorch import Trainer +from lightning import LightningDataModule from torch import nn, optim -# %% -# 2. Creating the Optimizer Wrapper -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# We use the Adam optimizer with the default learning rate of 0.001. +from torch_uncertainty.models.mlp import mlp +from torch_uncertainty.datasets.regression.toy import Cubic +from torch_uncertainty.losses import DERLoss +from torch_uncertainty.routines import RegressionRoutine +from torch_uncertainty.layers.distributions import NormalInverseGammaLayer +# %% +# 2. The Optimization Recipe +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# We use the Adam optimizer with a rate of 5e-4. def optim_regression( model: nn.Module, @@ -69,85 +64,74 @@ def optim_regression( # 3. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In the following, we need to define the root of the logs, and to -# fake-parse the arguments needed for using the PyTorch Lightning Trainer. We -# also use the same synthetic regression task example as that used in the -# original DER paper. - -root = Path(os.path.abspath("")) - -# We mock the arguments for the trainer -sys.argv = ["file.py", "--max_epochs", "50", "--enable_progress_bar", "False"] -args = init_args() +# In the following, we create a trainer to train the model, the same synthetic regression +# datasets as in the original DER paper and the model, a simple MLP with 2 hidden layers of 64 neurons each. +# Please note that this MLP finishes with a NormalInverseGammaLayer that interpret the outputs of the model +# as the parameters of a Normal Inverse Gamma distribution. -net_name = "logs/der-mlp-cubic" +trainer = Trainer(accelerator="cpu", max_epochs=50)#, enable_progress_bar=False) # dataset train_ds = Cubic(num_samples=1000) val_ds = Cubic(num_samples=300) -test_ds = train_ds # datamodule datamodule = LightningDataModule.from_datasets( - train_ds, val_dataset=val_ds, test_dataset=test_ds, batch_size=32 + train_ds, val_dataset=val_ds, test_dataset=val_ds, batch_size=32 ) datamodule.training_task = "regression" # model -model = mlp(in_features=1, num_outputs=4, hidden_dims=[64, 64]) +model = mlp( + in_features=1, + num_outputs=4, + hidden_dims=[64, 64], + final_layer=NormalInverseGammaLayer, + final_layer_args={"dim": 1}, +) # %% # 4. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Next, we need to define the loss to be used during training. To do this, we -# redefine the default parameters for the NIG loss using the partial -# function from functools. After that, we define the training routine using -# the regression training routine from torch_uncertainty.routines.regression. In -# this routine, we provide the model, the NIG loss, and the optimizer, -# along with the dist_estimation parameter, which refers to the number of -# distribution parameters, and all the default arguments. - -loss = partial( - NIGLoss, - reg_weight=1e-2, -) +# set the weight of the regularizer of the DER Loss. After that, we define the +# training routine using the probabilistic regression training routine from +# torch_uncertainty.routines. In this routine, we provide the model, the DER +# loss, and the optimization recipe. -baseline = RegressionSingle( +loss = DERLoss(reg_weight=1e-2) + +routine = RegressionRoutine( + probabilistic=True, + output_dim=1, model=model, loss=loss, - optimization_procedure=optim_regression, - dist_estimation=4, - **vars(args), + optim_recipe=optim_regression(model), ) # %% # 5. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Finally, we train the model using the trainer and the regression routine. We also +# test the model using the same trainer -results = cli_main(baseline, datamodule, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% # 6. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ +# We can now test the model by plotting the predictions and the uncertainty estimates. +# In this specific case, we can reproduce the results of the paper. import matplotlib.pyplot as plt -from torch.nn import functional as F with torch.no_grad(): - x = torch.linspace(-7, 7, 1000).unsqueeze(-1) - - logits = model(x) - means, v, alpha, beta = logits.split(1, dim=-1) - - v = F.softplus(v) - alpha = 1 + F.softplus(alpha) - beta = F.softplus(beta) - - vars = torch.sqrt(beta / (v * (alpha - 1))) + x = torch.linspace(-7, 7, 1000) - means.squeeze_(1) - vars.squeeze_(1) - x.squeeze_(1) + dists = model(x.unsqueeze(-1)) + means = dists.loc.squeeze(1) + variances = torch.sqrt(dists.variance_loc).squeeze(1) fig, ax = plt.subplots(1, 1) ax.plot(x, x**3, "--r", label="ground truth", zorder=3) @@ -155,8 +139,8 @@ def optim_regression( for k in torch.linspace(0, 4, 4): ax.fill_between( x, - means - k * vars, - means + k * vars, + means - k * variances, + means + k * variances, linewidth=0, alpha=0.3, edgecolor=None, diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index a3a44e17..1b780361 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -16,36 +16,29 @@ To train a LeNet with the DEC loss function using TorchUncertainty, we have to load the following utilities from TorchUncertainty: -- the cli handler: cli_main and argument parser: init_args +- the Trainer from Lightning - the model: LeNet, which lies in torch_uncertainty.models -- the classification training routine in the torch_uncertainty.training.classification module -- the evidential objective: the DECLoss, which lies in the torch_uncertainty.losses file -- the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule -""" - -# %% -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.routines.classification import ClassificationSingle -from torch_uncertainty.losses import DECLoss -from torch_uncertainty.datamodules import MNISTDataModule +- the classification training routine in the torch_uncertainty.routines +- the evidential objective: the DECLoss from torch_uncertainty.losses +- the datamodule that handles dataloaders & transforms: MNISTDataModule from torch_uncertainty.datamodules +We also need to define an optimizer using torch.optim, the neural network utils within torch.nn, as well as the partial util to provide +the modified default arguments for the DEC loss. +""" # %% -# We also need to define an optimizer using torch.optim as well as the -# neural network utils withing torch.nn, as well as the partial util to provide -# the modified default arguments for the DEC loss. -# -# We also import sys to override the command line arguments. - -import os from functools import partial from pathlib import Path import torch -from cli_test_helpers import ArgvContext +from lightning.pytorch import Trainer from torch import nn, optim +from torch_uncertainty.datamodules import MNISTDataModule +from torch_uncertainty.losses import DECLoss +from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.routines import ClassificationRoutine + # %% # 2. Creating the Optimizer Wrapper @@ -54,9 +47,7 @@ # with the default learning rate of 0.001 and a step scheduler. def optim_lenet(model: nn.Module) -> dict: optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005) - exp_lr_scheduler = optim.lr_scheduler.StepLR( - optimizer, step_size=7, gamma=0.1 - ) + exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) return {"optimizer": optimizer, "lr_scheduler": exp_lr_scheduler} @@ -67,29 +58,16 @@ def optim_lenet(model: nn.Module) -> dict: # In the following, we need to define the root of the logs, and to # fake-parse the arguments needed for using the PyTorch Lightning Trainer. We # also use the same MNIST classification example as that used in the -# original DEC paper. We only train for 5 epochs for the sake of time. -root = Path(os.path.abspath("")) - -# We mock the arguments for the trainer. Replace with 25 epochs on your machine. -with ArgvContext( - "file.py", - "--max_epochs", - "5", - "--enable_progress_bar", - "True", -): - args = init_args(datamodule=MNISTDataModule) - -net_name = "logs/dec-lenet-mnist" +# original DEC paper. We only train for 3 epochs for the sake of time. +trainer = Trainer(accelerator="cpu", max_epochs=3, enable_progress_bar=False) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) - +root = Path() / "data" +datamodule = MNISTDataModule(root=root, batch_size=128) model = lenet( - in_channels=dm.num_channels, - num_classes=dm.num_classes, + in_channels=datamodule.num_channels, + num_classes=datamodule.num_classes, ) # %% @@ -103,25 +81,21 @@ def optim_lenet(model: nn.Module) -> dict: # In this routine, we provide the model, the DEC loss, the optimizer, # and all the default arguments. -loss = partial( - DECLoss, - reg_weight=1e-2, -) +loss = DECLoss(reg_weight=1e-2) -baseline = ClassificationSingle( +routine = ClassificationRoutine( model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, + num_classes=datamodule.num_classes, loss=loss, - optimization_procedure=optim_lenet, - **vars(args), + optim_recipe=optim_lenet(model), ) # %% # 5. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -results = cli_main(baseline, dm, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% # 6. Testing the Model @@ -129,12 +103,10 @@ def optim_lenet(model: nn.Module) -> dict: # Now that the model is trained, let's test it on MNIST. import matplotlib.pyplot as plt -import torch +import numpy as np import torchvision import torchvision.transforms.functional as F -import numpy as np - def imshow(img) -> None: npimg = img.numpy() @@ -150,11 +122,11 @@ def rotated_mnist(angle: int) -> None: """ rotated_images = F.rotate(images, angle) # print rotated images - plt.axis('off') + plt.axis("off") imshow(torchvision.utils.make_grid(rotated_images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) - evidence = baseline(rotated_images) + evidence = routine(rotated_images) alpha = torch.relu(evidence) + 1 strength = torch.sum(alpha, dim=1, keepdim=True) probs = alpha / strength @@ -167,16 +139,15 @@ def rotated_mnist(angle: int) -> None: ) -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) with torch.no_grad(): - baseline.eval() + routine.eval() rotated_mnist(0) rotated_mnist(45) rotated_mnist(90) - # %% # References # ---------- diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index b2ed7d4e..217e69ed 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -13,103 +13,71 @@ First, we have to load the following utilities from TorchUncertainty: -- the cli handler: cli_main and argument parser: init_args +- the Trainer from Lightning - the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule - the model: LeNet, which lies in torch_uncertainty.models - the mc-batch-norm wrapper: mc_dropout, which lies in torch_uncertainty.models -- a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module -- the optimizer wrapper in the torch_uncertainty.optimization_procedures module. -""" -# %% -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.post_processing.mc_batch_norm import MCBatchNorm -from torch_uncertainty.baselines.classification import ResNet -from torch_uncertainty.routines.classification import ClassificationSingle -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +- an optimization recipe in the torch_uncertainty.optim_recipes module. +We also need import the neural network utils within `torch.nn`. +""" # %% -# We will also need import the neural network utils withing `torch.nn`. -# -# We also import sys to override the command line arguments. - -import os from pathlib import Path +from lightning import Trainer from torch import nn -from cli_test_helpers import ArgvContext + +from torch_uncertainty.datamodules import MNISTDataModule +from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.post_processing.mc_batch_norm import MCBatchNorm +from torch_uncertainty.routines import ClassificationRoutine # %% # 2. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# -# In the following, we will need to define the root of the datasets and the -# logs, and to fake-parse the arguments needed for using the PyTorch Lightning -# Trainer. We also create the datamodule that handles the MNIST dataset, -# dataloaders and transforms. We create the model using the -# blueprint from torch_uncertainty.models and we wrap it into mc-dropout. -# -# It is important to specify the arguments ``version`` as ``mc-dropout``, -# ``num_estimators`` and the ``dropout_rate`` to use Monte Carlo dropout. - -root = Path(os.path.abspath("")) - -# We mock the arguments for the trainer -with ArgvContext( - "file.py", - "--max_epochs", - "1", - "--enable_progress_bar", - "False", - "--num_estimators", - "8", - "--max_epochs", - "2" -): - args = init_args(network=ResNet, datamodule=MNISTDataModule) - -net_name = "logs/lenet-mnist" +# In the following, we define the root of the datasets and the +# logs. We also create the datamodule that handles the MNIST dataset +# dataloaders and transforms. + +trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) +root = Path("") / "data" +datamodule = MNISTDataModule(root, batch_size=128) model = lenet( - in_channels=dm.num_channels, - num_classes=dm.num_classes, - norm = nn.BatchNorm2d, + in_channels=datamodule.num_channels, + num_classes=datamodule.num_classes, + norm=nn.BatchNorm2d, ) # %% # 3. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# This is a classification problem, and we use CrossEntropyLoss as the likelihood. +# This is a classification problem, and we use CrossEntropyLoss as likelihood. # We define the training routine using the classification training routine from -# torch_uncertainty.training.classification. We provide the number of classes -# and channels, the optimizer wrapper, the dropout rate, and the number of -# forward passes to perform through the network, as well as all the default -# arguments. +# torch_uncertainty.training.classification. We provide the number of classes, +# and the optimization recipe. -baseline = ClassificationSingle( - num_classes=dm.num_classes, +routine = ClassificationRoutine( + num_classes=datamodule.num_classes, model=model, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18(model), ) # %% -# 5. Gathering Everything and Training the Model +# 4. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -results = cli_main(baseline, dm, root, net_name, args) - +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% -# 6. Wrapping the Model in a MCBatchNorm +# 5. Wrapping the Model in a MCBatchNorm # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We can now wrap the model in a MCBatchNorm to add stochasticity to the # predictions. We specify that the BatchNorm layers are to be converted to @@ -119,12 +87,14 @@ # The authors suggest 32 as a good value for ``mc_batch_size`` but we use 4 here # to highlight the effect of stochasticity on the predictions. -baseline.model = MCBatchNorm(baseline.model, num_estimators=8, convert=True, mc_batch_size=32) -baseline.model.fit(dm.train) -baseline.eval() +routine.model = MCBatchNorm( + routine.model, num_estimators=8, convert=True, mc_batch_size=4 +) +routine.model.fit(datamodule.train) +routine.eval() # %% -# 7. Testing the Model +# 6. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # Now that the model is trained, let's test it on MNIST. Don't forget to call # .eval() to enable Monte Carlo batch normalization at inference. @@ -132,27 +102,28 @@ # the variance of the predictions is the highest. import matplotlib.pyplot as plt +import numpy as np import torch import torchvision -import numpy as np - def imshow(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.axis("off") + plt.tight_layout() plt.show() -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) # print images imshow(torchvision.utils.make_grid(images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) -baseline.eval() -logits = baseline(images).reshape(8, 128, 10) +routine.eval() +logits = routine(images).reshape(8, 128, 10) probs = torch.nn.functional.softmax(logits, dim=-1) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 02b291e2..d8902bfe 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -19,33 +19,26 @@ First, we have to load the following utilities from TorchUncertainty: -- the cli handler: cli_main and argument parser: init_args +- the Trainer from Lightning - the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule - the model: LeNet, which lies in torch_uncertainty.models - the mc-dropout wrapper: mc_dropout, which lies in torch_uncertainty.models -- a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module -- the optimizer wrapper in the torch_uncertainty.optimization_procedures module. -""" -# %% -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.models.mc_dropout import mc_dropout -from torch_uncertainty.baselines.classification import ResNet -from torch_uncertainty.routines.classification import ClassificationEnsemble -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +- an optimization recipe in the torch_uncertainty.optim_recipes module. +We also need import the neural network utils within `torch.nn`. +""" # %% -# We will also need import the neural network utils withing `torch.nn`. -# -# We also import sys to override the command line arguments. - -import os from pathlib import Path +from lightning.pytorch import Trainer from torch import nn -from cli_test_helpers import ArgvContext + +from torch_uncertainty.datamodules import MNISTDataModule +from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.models import mc_dropout +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.routines import ClassificationRoutine # %% # 2. Creating the necessary variables @@ -60,91 +53,76 @@ # It is important to specify the arguments ``version`` as ``mc-dropout``, # ``num_estimators`` and the ``dropout_rate`` to use Monte Carlo dropout. -root = Path(os.path.abspath("")) - -# We mock the arguments for the trainer -with ArgvContext( - "file.py", - "--max_epochs", - "1", - "--enable_progress_bar", - "False", - "--dropout_rate", - "0.6", - "--num_estimators", - "16", - "--max_epochs", - "2" -): - args = init_args(network=ResNet, datamodule=MNISTDataModule) - -net_name = "logs/mc-dropout-lenet-mnist" +trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) +root = Path("") / "data" +datamodule = MNISTDataModule(root=root, batch_size=128) model = lenet( - in_channels=dm.num_channels, - num_classes=dm.num_classes, - dropout_rate=args.dropout_rate, + in_channels=datamodule.num_channels, + num_classes=datamodule.num_classes, + dropout_rate=0.6, ) -mc_model = mc_dropout(model, num_estimators=args.num_estimators, last_layer=0.0) +mc_model = mc_dropout(model, num_estimators=16, last_layer=False) # %% # 3. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # This is a classification problem, and we use CrossEntropyLoss as the likelihood. # We define the training routine using the classification training routine from -# torch_uncertainty.training.classification. We provide the number of classes +# torch_uncertainty.routines.classification. We provide the number of classes # and channels, the optimizer wrapper, the dropout rate, and the number of # forward passes to perform through the network, as well as all the default # arguments. -baseline = ClassificationEnsemble( - num_classes=dm.num_classes, +routine = ClassificationRoutine( + num_classes=datamodule.num_classes, model=mc_model, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18(mc_model), + num_estimators=16, + ) # %% -# 5. Gathering Everything and Training the Model +# 4. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -results = cli_main(baseline, dm, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% -# 6. Testing the Model +# 5. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # Now that the model is trained, let's test it on MNIST. Don't forget to call # .eval() to enable dropout at inference. import matplotlib.pyplot as plt +import numpy as np import torch import torchvision -import numpy as np - def imshow(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.axis("off") + plt.tight_layout() plt.show() -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) # print images imshow(torchvision.utils.make_grid(images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) -baseline.eval() -logits = baseline(images).reshape(16, 128, 10) +routine.eval() +logits = routine(images).reshape(16, 128, 10) probs = torch.nn.functional.softmax(logits, dim=-1) @@ -152,7 +130,7 @@ def imshow(img): for j in range(4): values, predicted = torch.max(probs[:, j], 1) print( - f"Predicted digits for the image {j}: ", + f"Predicted digits for the image {j+1}: ", " ".join([str(image_id.item()) for image_id in predicted]), ) diff --git a/auto_tutorials_source/tutorial_pe_cifar10.py b/auto_tutorials_source/tutorial_pe_cifar10.py index 8a4af0e7..52820064 100644 --- a/auto_tutorials_source/tutorial_pe_cifar10.py +++ b/auto_tutorials_source/tutorial_pe_cifar10.py @@ -119,7 +119,7 @@ def imshow(img): # %% # 2. Define a Packed-Ensemble from a standard classifier -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # First we define a standard classifier for CIFAR10 for reference. We will use a # convolutional neural network. diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index b67f21ca..2d927b10 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -3,10 +3,13 @@ ====================================================== In this tutorial, we use *TorchUncertainty* to improve the calibration -of the top-label predictions -and the reliability of the underlying neural network. +of the top-label predictions and the reliability of the underlying neural network. -We also see how to use the datamodules outside any Lightning trainers, +This tutorial provides extensive details on how to use the TemperatureScaler +class, however, this is done automatically in the classification routine when setting +the `calibration_set` to val or test. + +Through this tutorial, we also see how to use the datamodules outside any Lightning trainers, and how to use TorchUncertainty's models. 1. Loading the Utilities @@ -17,11 +20,12 @@ - torch for its objects - the "calibration error" metric to compute the ECE and evaluate the top-label calibration - the CIFAR-100 datamodule to handle the data -- a ResNet 18 as starting model +- a ResNet 18 as starting model - the temperature scaler to improve the top-label calibration - a utility to download hf models easily -- the calibration plot to visualize the calibration. If you use the classification routine, - the plots will be automatically available in the tensorboard logs. +- the calibration plot to visualize the calibration. + +If you use the classification routine, the plots will be automatically available in the tensorboard logs. """ from torch_uncertainty.datamodules import CIFAR100DataModule @@ -52,7 +56,8 @@ # # To get the dataloader from the datamodule, just call prepare_data, setup, and # extract the first element of the test dataloader list. There are more than one -# element if `:attr:eval_ood` is True. +# element if eval_ood is True: the dataloader of in-distribution data and the dataloader +# of out-of-distribution data. Otherwise, it is a list of 1 element. dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32) dm.prepare_data() @@ -61,7 +66,6 @@ # Get the full test dataloader (unused in this tutorial) dataloader = dm.test_dataloader()[0] - # %% # 4. Iterating on the Dataloader and Computing the ECE # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -93,8 +97,7 @@ ece.update(probs, target) # Compute & print the calibration error -cal = ece.compute() -print(f"ECE before scaling - {cal*100:.3}%.") +print(f"ECE before scaling - {ece.compute():.3%}.") # %% # We also compute and plot the top-label calibration figure. We see that the @@ -133,8 +136,7 @@ probs = logits.softmax(-1) ece.update(probs, target) -cal = ece.compute() -print(f"ECE after scaling - {cal*100:.3}%.") +print(f"ECE after scaling - {ece.compute():.3%}.") # %% # We finally compute and plot the scaled top-label calibration figure. We see diff --git a/codecov.yml b/codecov.yml index 3c1d530a..c5a92954 100644 --- a/codecov.yml +++ b/codecov.yml @@ -6,3 +6,6 @@ coverage: patch: default: target: 95% + +codecov: + disable_default_path_fixes: true diff --git a/docs/source/api.rst b/docs/source/api.rst index 6ae95cf7..24abb1ef 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1,14 +1,16 @@ -API reference +API Reference ============= .. currentmodule:: torch_uncertainty -Baselines ---------- +Routines +-------- -This API provides lightning-based models that can be easily trained and evaluated. +The routine are the main building blocks of the library. They define the framework +in which the models are trained and evaluated. They allow for easy computation of different +metrics crucial for uncertainty estimation in different contexts, namely classification, regression and segmentation. -.. currentmodule:: torch_uncertainty.baselines.classification +.. currentmodule:: torch_uncertainty.routines Classification ^^^^^^^^^^^^^^ @@ -18,11 +20,7 @@ Classification :nosignatures: :template: class.rst - ResNet - VGG - WideResNet - -.. currentmodule:: torch_uncertainty.baselines.regression + ClassificationRoutine Regression ^^^^^^^^^^ @@ -32,118 +30,62 @@ Regression :nosignatures: :template: class.rst - MLP - -.. Models -.. ------ - -.. This section encapsulates the model backbones currently supported by the library. - -.. ResNet -.. ^^^^^^ - -.. .. currentmodule:: torch_uncertainty.models.resnet - -.. Concerning ResNet backbones, we provide building functions for ResNet18, ResNet34, -.. ResNet50, ResNet101 and, ResNet152 (from `Deep Residual Learning for Image Recognition -.. `_, CVPR 2016). - -.. Standard -.. ~~~~~~~ - -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: + RegressionRoutine -.. resnet18 -.. resnet34 -.. resnet50 -.. resnet101 -.. resnet152 +Segmentation +^^^^^^^^^^^^ -.. Packed-Ensembles -.. ~~~~~~~~~~~~~~~~ - -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: - -.. packed_resnet18 -.. packed_resnet34 -.. packed_resnet50 -.. packed_resnet101 -.. packed_resnet152 - -.. Masksembles -.. ~~~~~~~~~~~ - -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: - -.. masked_resnet18 -.. masked_resnet34 -.. masked_resnet50 -.. masked_resnet101 -.. masked_resnet152 - -.. BatchEnsemble -.. ~~~~~~~~~~~~~ - -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: - -.. batched_resnet18 -.. batched_resnet34 -.. batched_resnet50 -.. batched_resnet101 -.. batched_resnet152 +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst -.. Wide-ResNet -.. ^^^^^^^^^^^ + SegmentationRoutine -.. .. currentmodule:: torch_uncertainty.models.wideresnet +Baselines +--------- -.. Concerning Wide-ResNet backbones, we provide building functions for Wide-ResNet28x10 -.. (from `Wide Residual Networks `_, British -.. Machine Vision Conference 2016). +TorchUncertainty provide lightning-based models that can be easily trained and evaluated. +These models inherit from the routines and are specifically designed to benchmark +different methods in similar settings, here with constant architectures. -.. Standard -.. ~~~~~~~ +.. currentmodule:: torch_uncertainty.baselines.classification -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: +Classification +^^^^^^^^^^^^^^ -.. wideresnet28x10 +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst -.. Packed-Ensembles -.. ~~~~~~~~~~~~~~~~ + ResNetBaseline + VGGBaseline + WideResNetBaseline -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: +.. currentmodule:: torch_uncertainty.baselines.regression -.. packed_wideresnet28x10 +Regression +^^^^^^^^^^ -.. Masksembles -.. ~~~~~~~~~~~ +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: + MLPBaseline -.. masked_wideresnet28x10 +.. currentmodule:: torch_uncertainty.baselines.segmentation -.. BatchEnsemble -.. ~~~~~~~~~~~~~ +Segmentation +^^^^^^^^^^^^ -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst -.. batched_wideresnet28x10 + SegFormerBaseline Layers ------ @@ -181,6 +123,30 @@ Bayesian layers BayesConv2d BayesConv3d +Models +------ + +.. currentmodule:: torch_uncertainty.models + +Deep Ensembles +^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + deep_ensembles + +Monte Carlo Dropout + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + mc_dropout + Metrics ------- @@ -193,13 +159,18 @@ Metrics AUSE BrierScore + CategoricalNLL CE Disagreement + DistributionNLL Entropy - MutualInformation - NegativeLogLikelihood - GaussianNegativeLogLikelihood FPR95 + Log10 + MeanGTRelativeAbsoluteError + MeanGTRelativeSquaredError + MutualInformation + SILog + ThresholdAccuracy Losses ------ @@ -211,10 +182,10 @@ Losses :nosignatures: :template: class.rst + DistributionNLLLoss KLDiv ELBOLoss BetaNLL - NIGLoss DECLoss Post-Processing Methods @@ -230,12 +201,24 @@ Post-Processing Methods TemperatureScaler VectorScaler MatrixScaler + MCBatchNorm Datamodules ----------- +.. currentmodule:: torch_uncertainty.datamodules.abstract + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + .. currentmodule:: torch_uncertainty.datamodules +Classification +^^^^^^^^^^^^^^ + .. autosummary:: :toctree: generated/ :nosignatures: @@ -246,4 +229,26 @@ Datamodules MNISTDataModule TinyImageNetDataModule ImageNetDataModule + +Regression +^^^^^^^^^^ +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + UCIDataModule + +.. currentmodule:: torch_uncertainty.datamodules.segmentation + +Segmentation +^^^^^^^^^^^^ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + CamVidDataModule + CityscapesDataModule + MUADDataModule diff --git a/docs/source/cli_guide.rst b/docs/source/cli_guide.rst new file mode 100644 index 00000000..ec111cf1 --- /dev/null +++ b/docs/source/cli_guide.rst @@ -0,0 +1,251 @@ +CLI Guide +========= + +Introduction to the Lightning CLI +--------------------------------- + +The Lightning CLI tool eases the implementation of a CLI to instanciate models to train and evaluate them on +some data. The CLI tool is a wrapper around the ``Trainer`` class and provides a set of subcommands to train +and test a ``LightningModule`` on a ``LightningDataModule``. To better match our needs, we created an inherited +class from the ``LightningCLI`` class, namely ``TULightningCLI``. + +.. note:: + ``TULightningCLI`` adds a new argument to the ``LightningCLI`` class: :attr:`eval_after_fit` to know whether + an evaluation on the test set should be performed after the training phase. + +Let's see how to implement the CLI, by checking out the ``experiments/classification/cifar10/resnet.py``. + +.. code:: python + + import torch + from lightning.pytorch.cli import LightningArgumentParser + + from torch_uncertainty.baselines.classification import ResNetBaseline + from torch_uncertainty.datamodules import CIFAR10DataModule + from torch_uncertainty.utils import TULightningCLI + + + class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) + + + def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, CIFAR10DataModule) + + + if __name__ == "__main__": + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") + +This file enables both training and testing ResNet architectures on the CIFAR-10 dataset. +The ``ResNetCLI`` class inherits from the ``TULightningCLI`` class and implements the +``add_arguments_to_parser`` method to add the optimizer and learning rate scheduler arguments +into the parser. In this case, we use the ``torch.optim.SGD`` optimizer and the +``torch.optim.lr_scheduler.MultiStepLR`` learning rate scheduler. + +.. code:: python + + class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) + +The ``LightningCLI`` takes a ``LightningModule`` and a ``LightningDataModule`` as arguments. +Here the ``cli_main`` function creates an instance of the ``ResNetCLI`` class by taking the ``ResNetBaseline`` +model and the ``CIFAR10DataModule`` as arguments. + +.. code:: python + + def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, CIFAR10DataModule) + +.. note:: + + The ``ResNetBaseline`` is a subclass of the ``ClassificationRoutine`` seemlessly instanciating a + ResNet model based on a :attr:`version` and an :attr:`arch` to be passed to the routine. + +Depending on the CLI subcommand calling ``cli_main()`` will either train or test the model on the using +the CIFAR-10 dataset. But what are these subcommands? + +.. code:: bash + + python resnet.py --help + +This command will display the available subcommands of the CLI tool. + +.. code:: bash + + subcommands: + For more details of each subcommand, add it as an argument followed by --help. + + Available subcommands: + fit Runs the full optimization routine. + validate Perform one evaluation epoch over the validation set. + test Perform one evaluation epoch over the test set. + predict Run inference on your data. + +You can execute whichever subcommand you like and set up all your hyperparameters directly using the command line + +.. code:: bash + + python resnet.py fit --trainer.max_epochs 75 --trainer.accelerators gpu --trainer.devices 1 --model.version std --model.arch 18 --model.in_channels 3 --model.num_classes 10 --model.loss CrossEntropyLoss --model.style cifar --data.root ./data --data.batch_size 128 --optimizer.lr 0.05 --lr_scheduler.milestones [25,50] + +All arguments in the ``__init__()`` methods of the ``Trainer``, ``LightningModule`` (here ``ResNetBaseline``), +``LightningDataModule`` (here ``CIFAR10DataModule``), ``torch.optim.SGD``, and ``torch.optim.lr_scheduler.MultiStepLR`` +classes are configurable using the CLI tool using the ``--trainer``, ``--model``, ``--data``, ``--optimizer``, and +``--lr_scheduler`` prefixes, respectively. + +However for a large number of hyperparameters, it is not practical to pass them all in the command line. +It is more convenient to use configuration files to store these hyperparameters and ease the burden of +repeating them each time you want to train or test a model. Let's see how to do that. + +.. note:: + + Note that ``Pytorch`` classes are supported by the CLI tool, so you can use them directly: ``--model.loss CrossEntropyLoss`` + and they would be automatically instanciated by the CLI tool with their default arguments (i.e., ``CrossEntropyLoss()``). + +.. tip:: + + Add the following after calling ``cli=cli_main()`` to eventually evaluate the model on the test set + after training, if the ``eval_after_fit`` argument is set to ``True`` and ``trainer.fast_dev_run`` + is set to ``False``. + + .. code:: python + + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") + +Configuration files +------------------- + +By default the ``LightningCLI`` support configuration files in the YAML format (learn more about this format +`here `_). +Taking the previous example, we can create a configuration file named ``config.yaml`` with the following content: + +.. code:: yaml + + # config.yaml + eval_after_fit: true + trainer: + max_epochs: 75 + accelerators: gpu + devices: 1 + model: + version: std + arch: 18 + in_channels: 3 + num_classes: 10 + loss: CrossEntropyLoss + style: cifar + data: + root: ./data + batch_size: 128 + optimizer: + lr: 0.05 + lr_scheduler: + milestones: + - 25 + - 50 + +Then, we can run the following command to train the model: + +.. code:: bash + + python resnet.py fit --config config.yaml + +By default, executing the command above will store the experiment results in a directory named ``lightning_logs``, +and the last state of the model will be saved in a directory named ``lightning_logs/version_{int}/checkpoints``. +In addition, all arguments passed to instanciate the ``Trainer``, ``ResNetBaseline``, ``CIFAR10DataModule``, +``torch.optim.SGD``, and ``torch.optim.lr_scheduler.MultiStepLR`` classes will be saved in a file named +``lightning_logs/version_{int}/config.yaml``. When testing the model, we advise to use this configuration file +to ensure that the same hyperparameters are used for training and testing. + +.. code:: bash + + python resnet.py test --config lightning_logs/version_{int}/config.yaml --ckpt_path lightning_logs/version_{int}/checkpoints/{filename}.ckpt + +Experiment folder usage +----------------------- + +Now that we have seen how to implement the CLI tool and how to use configuration files, let explore the +configurations available in the ``experiments`` directory. The ``experiments`` directory is +mainly organized as follows: + +.. code:: bash + + experiments + ├── classification + │ ├── cifar10 + │ │ ├── configs + │ │ ├── resnet.py + │ │ ├── vgg.py + │ │ └── wideresnet.py + │ └── cifar100 + │ ├── configs + │ ├── resnet.py + │ ├── vgg.py + │ └── wideresnet.py + ├── regression + │ └── uci_datasets + │ ├── configs + │ └── mlp.py + └── segmentation + ├── cityscapes + │ ├── configs + │ └── segformer.py + └── muad + ├── configs + └── segformer.py + +For each task (**classification**, **regression**, and **segmentation**), we have a directory containing the datasets +(e.g., CIFAR10, CIFAR100, UCI datasets, Cityscapes, and Muad) and for each dataset, we have a directory containing +the configuration files and the CLI files for different backbones. + +You can directly use the CLI files with the command line or use the predefined configuration files to train and test +the models. The configuration files are stored in the ``configs``. For example, the configuration file for the classic +ResNet-18 model on the CIFAR-10 dataset is stored in the ``experiments/classification/cifar10/configs/resnet18/standard.yaml`` +file. For the Packed ResNet-18 model on the CIFAR-10 dataset, the configuration file is stored in the +``experiments/classification/cifar10/configs/resnet18/packed.yaml`` file. + +If you are interested in using a ResNet model but want to choose some of the hyperparameters using the command line, +you can use the configuration file and override the hyperparameters using the command line. For example, to train +a ResNet-18 model on the CIFAR-10 dataset with a batch size of :math:`256`, you can use the following command: + +.. code:: bash + + python resnet.py fit --config configs/resnet18/standard.yaml --data.batch_size 256 + +To use the weights argument of the ``torch.nn.CrossEntropyLoss`` class, you can use the following command: + +.. code:: bash + + python resnet.py fit --config configs/resnet18/standard.yaml --model.loss CrossEntropyLoss --model.loss.weight Tensor --model.loss.weight.dict_kwargs.data [1,2,3,4,5,6,7,8,9,10] + + +In addition, we provide a default configuration file for some backbones in the ``configs`` directory. For example, +``experiments/classification/cifar10/configs/resnet.yaml`` contains the default hyperparameters to train a ResNet model +on the CIFAR-10 dataset. Yet, some hyperparameters are purposely missing to be set by the user using the command line. + +For instance, to train a Packed ResNet-34 model on the CIFAR-10 dataset with :math:`4` estimators and a :math:`\alpha` value of :math:`2`, +you can use the following command: + +.. code:: bash + + python resnet.py fit --config configs/resnet.yaml --trainer.max_epochs 75 --model.version packed --model.arch 34 --model.num_estimators 4 --model.alpha 2 --optimizer.lr 0.05 --lr_scheduler.milestones [25,50] + + +.. tip:: + + Explore the `Lightning CLI docs `_ to learn more about the CLI tool, + the available arguments, and how to use them with configuration files. diff --git a/docs/source/conf.py b/docs/source/conf.py index c442fdbc..3a03317b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,9 +10,12 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "TorchUncertainty" -copyright = f"{datetime.utcnow().year!s}, Adrien Lafage and Olivier Laurent" # noqa: A001 + +copyright = ( # noqa: A001 + f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" +) author = "Adrien Lafage and Olivier Laurent" -release = "0.1.6" +release = "0.2.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -83,7 +86,3 @@ html_static_path = ["_static"] html_style = "css/custom.css" -# html_css_files = [ -# 'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css', -# 'css/custom.css' -# ] diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 4a822d4c..466479b6 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -4,7 +4,6 @@ Contributing .. role:: bash(code) :language: bash - TorchUncertainty is in early development stage. We are looking for contributors to help us build a comprehensive library for uncertainty quantification in PyTorch. @@ -54,7 +53,6 @@ Then navigate to :bash:`./docs` and build the documentation with: make html - Optionally, specify :bash:`html-noplot` instead of :bash:`html` to avoid running the tutorials. Guidelines diff --git a/docs/source/index.rst b/docs/source/index.rst index aebd2502..09a9d53e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,9 +1,9 @@ -.. Torch Uncertainty documentation master file, created by +.. TorchUncertainty documentation master file, created by sphinx-quickstart on Wed Feb 1 18:07:01 2023. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Torch Uncertainty +TorchUncertainty ================= .. role:: bash(code) @@ -45,6 +45,7 @@ the following paper: quickstart introduction_uncertainty auto_tutorials/index + cli_guide api contributing references diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d2153e31..a05ae32e 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -5,7 +5,7 @@ Installation :language: bash -You can install the package from PyPI or from source. Choose the latter if you +You can install the package either from PyPI or from source. Choose the latter if you want to access the files included the `experiments `_ folder or if you want to contribute to the project. @@ -59,11 +59,11 @@ Options You can install the package with the following options: * dev: includes all the dependencies for the development of the package -including ruff and the pre-commits hooks. + including ruff and the pre-commits hooks. * docs: includes all the dependencies for the documentation of the package -based on sphinx + based on sphinx * image: includes all the dependencies for the image processing module -including opencv and scikit-image + including opencv and scikit-image * tabular: includes pandas * all: includes all the aforementioned dependencies diff --git a/docs/source/introduction_uncertainty.rst b/docs/source/introduction_uncertainty.rst index 6e2d226a..4e92c4fe 100644 --- a/docs/source/introduction_uncertainty.rst +++ b/docs/source/introduction_uncertainty.rst @@ -22,7 +22,6 @@ it may not be a good idea to trust these predictions. Let's see why in more deta The Overconfidence of Neural Networks ------------------------------------- - References ---------- diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index e27461f2..18960abb 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -4,89 +4,136 @@ Quickstart .. role:: bash(code) :language: bash -Torch Uncertainty comes with different usage levels ranging from specific -PyTorch layers to ready to train Lightning-based models. The following -presents a short introduction to each one of them. Let's start with the -highest-level usage. +TorchUncertainty is centered around **uncertainty-aware** training and evaluation routines. +These routines make it very easy to: -Using the Lightning-based CLI tool ----------------------------------- +- train ensembles-like methods (Deep Ensembles, Packed-Ensembles, MIMO, Masksembles, etc) +- compute and monitor uncertainty metrics: calibration, out-of-distribution detection, proper scores, grouping loss, etc. +- leverage calibration methods automatically during evaluation -Procedure -^^^^^^^^^ +Yet, we take account that their will be as many different uses of TorchUncertainty as there are of users. +This page provides ideas on how to benefit from TorchUncertainty at all levels: from ready-to-train lightning-based models to using only specific +PyTorch layers. -The library provides a full-fledged trainer which can be used directly, via -CLI. To do so, create a file in the experiments folder and use the `cli_main` -routine, which takes as arguments: - -* a Lightning Module corresponding to the model, its own arguments, and - forward/validation/test logic. For instance, you might use already available - modules, such as the Packed-Ensembles-style ResNet available at - `torch_uncertainty/baselines/packed/resnet.py `_ -* a Lightning DataModule corresponding to the training, validation, and test - sets with again its arguments and logic. CIFAR-10/100, ImageNet, and - ImageNet-200 are available, for instance. -* a PyTorch loss such as the torch.nn.CrossEntropyLoss -* a dictionary containing the optimization procedure, namely a scheduler and - an optimizer. Many procedures are available at - `torch_uncertainty/optimization_procedures.py `_ - -* the path to the data and logs folder, in the example below, the root of the library -* and finally, the name of your model (used for logs) - -Move to the directory containing your file and execute the code with :bash:`python3 experiment.py`. -Add lightning arguments such as :bash:`--accelerator gpu --devices "0, 1" --benchmark True` -for multi-gpu training and cuDNN benchmark, etc. +Training with TorchUncertainty's Uncertainty-aware Routines +----------------------------------------------------------- -Example -^^^^^^^ +Let's have a look at the `Classification routine `_. + +.. code:: python -The following code - `available in the experiments folder `_ - -trains any ResNet architecture on CIFAR10: + from lightning.pytorch import LightningModule + + class ClassificationRoutine(LightningModule): + def __init__( + self, + model: nn.Module, + num_classes: int, + loss: nn.Module, + num_estimators: int = 1, + format_batch_fn: nn.Module | None = None, + optim_recipe: dict | Optimizer | None = None, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, + eval_ood: bool = False, + eval_grouping_loss: bool = False, + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + ) -> None: + ... + + +Building your First Routine +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +This routine is a wrapper of any custom or TorchUncertainty classification model. To use it, +just build your model and pass it to the routine as argument along with an optimization recipe +and the loss as well as the number of classes that we use for torch metrics. .. code:: python - from pathlib import Path + from torch import nn, optim - from torch import nn + model = MyModel(num_classes=10) + routine = ClassificationRoutine( + model, + num_classes=10, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim.Adam(model.parameters(), lr=1e-3), + ) - from torch_uncertainty import cli_main, init_args - from torch_uncertainty.baselines import ResNet - from torch_uncertainty.datamodules import CIFAR10DataModule - from torch_uncertainty.optimization_procedures import get_procedure - root = Path(__file__).parent.absolute().parents[1] +Training with the Routine +^^^^^^^^^^^^^^^^^^^^^^^^^ - args = init_args(ResNet, CIFAR10DataModule) +To train with this routine, you will first need to create a lightning Trainer and have either a lightning datamodule +or PyTorch dataloaders. When benchmarking models, we advise to use lightning datamodules that will automatically handle +train/val/test splits, out-of-distribution detection and dataset shift. For this example, let us use TorchUncertainty's +CIFAR10 datamodule. Please keep in mind that you could use your own datamodule or dataloaders. - net_name = f"{args.version}-resnet{args.arch}-cifar10" +.. code:: python - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) + from torch_uncertainty.datamodules import CIFAR10DataModule + from pytorch_lightning import Trainer - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.in_channels, - loss=nn.CrossEntropyLoss(), - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - imagenet_structure=False, - **vars(args), - ) + dm = CIFAR10DataModule(root="data", batch_size=32) + trainer = Trainer(gpus=1, max_epochs=100) + trainer.fit(routine, dm) + trainer.test(routine, dm) + +Here it is, you have trained your first model with TorchUncertainty! As a result, you will get access to various metrics +measuring the ability of your model to handle uncertainty. + +More metrics +^^^^^^^^^^^^ + +With TorchUncertainty datamodules, you can easily test models on out-of-distribution datasets, by +setting the ``eval_ood`` parameter to ``True``. You can also evaluate the grouping loss by setting ``eval_grouping_loss`` to ``True``. +Finally, you can calibrate your model using the ``calibration_set`` parameter. In this case, you will get +metrics for but the uncalibrated and calibrated models: the metrics corresponding to the temperature scaled +model will begin with ``ts_``. + +---- + +Using the Lightning CLI tool +---------------------------------- + +Procedure +^^^^^^^^^ + +The library leverages the `Lightning CLI tool `_ +to provide a simple way to train models and evaluate them, while insuring reproducibility via configuration files. +Under the ``experiment`` folder, you will find scripts for the three application tasks covered by the library: +classification, regression and segmentation. Take the most out of the CLI by checking our `CLI Guide `_. + +.. note:: + + In particular, the ``experiments/classification`` folder contains scripts to reproduce the experiments covered + in the paper: *Packed-Ensembles for Efficient Uncertainty Estimation*, O. Laurent & A. Lafage, et al., in ICLR 2023. - cli_main(model, dm, root, net_name, args) -Run this model with, for instance: + +Example +^^^^^^^ + +Training a model with the Lightning CLI tool is as simple as running the following command: .. code:: bash - python3 resnet.py --version std --arch 18 --accelerator gpu --device 1 --benchmark True --max_epochs 75 --precision 16 + # in pyjam/experiments/classification/cifar10 + python resnet.py fit --config configs/resnet18/standard.yaml -You may replace the architecture (which should be a Lightning Module), the -Datamodule (a Lightning Datamodule), the loss or the optimization procedure to your likings. +Which trains a classic ResNet18 model on CIFAR10 with the settings used in the Packed-Ensembles paper. + +---- Using the PyTorch-based models ------------------------------ @@ -118,6 +165,8 @@ backbone with the following code: num_classes = 10, ) +---- + Using the PyTorch-based layers ------------------------------ @@ -135,7 +184,7 @@ issue on the GitHub repository! .. tip:: - Do not hesitate to go to the API reference to get better explanations on the + Do not hesitate to go to the `API Reference `_ to get better explanations on the layer usage. Example @@ -178,6 +227,8 @@ code: packed_net = PackedNet() +---- + Other usage ----------- diff --git a/experiments/README.md b/experiments/README.md deleted file mode 100644 index 19d6dd0a..00000000 --- a/experiments/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Experiments - -Torch-Uncertainty proposes various benchmarks to evaluate the uncertainty estimation methods. - -## Classification - -*Work in progress* - -## Regression - -*Work in progress* diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml new file mode 100644 index 00000000..aa053391 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -0,0 +1,34 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + style: cifar +data: + root: ./data + batch_size: 128 diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml new file mode 100644 index 00000000..e71130f9 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: batched + arch: 18 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml new file mode 100644 index 00000000..202ba0c4 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: masked + arch: 18 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml new file mode 100644 index 00000000..e45988db --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: mimo + arch: 18 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml new file mode 100644 index 00000000..79bd47f3 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: packed + arch: 18 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml new file mode 100644 index 00000000..b5406a28 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: std + arch: 18 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml new file mode 100644 index 00000000..7133cc5f --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: batched + arch: 50 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.08 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml new file mode 100644 index 00000000..00eaf9c3 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: masked + arch: 50 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml new file mode 100644 index 00000000..d7d23ccd --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: mimo + arch: 50 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml new file mode 100644 index 00000000..2ecc4e6a --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -0,0 +1,52 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: packed + arch: 50 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml new file mode 100644 index 00000000..1797df73 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: std + arch: 50 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml new file mode 100644 index 00000000..fb1bea00 --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -0,0 +1,35 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + style: cifar +data: + root: ./data + batch_size: 128 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml new file mode 100644 index 00000000..f4010902 --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: batched + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml new file mode 100644 index 00000000..ae31197b --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: masked + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml new file mode 100644 index 00000000..31a09775 --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: mimo + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml new file mode 100644 index 00000000..a46c6fac --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: packed + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml new file mode 100644 index 00000000..c5cd566f --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + version: std + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/deep_ensembles.py b/experiments/classification/cifar10/deep_ensembles.py index f497092c..d7316811 100644 --- a/experiments/classification/cifar10/deep_ensembles.py +++ b/experiments/classification/cifar10/deep_ensembles.py @@ -1,11 +1,11 @@ from pathlib import Path from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines import DeepEnsemblesBaseline from torch_uncertainty.datamodules import CIFAR10DataModule if __name__ == "__main__": - args = init_args(DeepEnsembles, CIFAR10DataModule) + args = init_args(DeepEnsemblesBaseline, CIFAR10DataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -19,7 +19,7 @@ # model args.task = "classification" - model = DeepEnsembles( + model = DeepEnsemblesBaseline( **vars(args), num_classes=dm.num_classes, in_channels=dm.num_channels, diff --git a/experiments/classification/cifar10/readme.md b/experiments/classification/cifar10/readme.md new file mode 100644 index 00000000..6fdfb043 --- /dev/null +++ b/experiments/classification/cifar10/readme.md @@ -0,0 +1,64 @@ +# CIFAR10 - Benchmark + +This folder contains the code to train models on the CIFAR10 dataset. The task is to classify images into $10$ classes. + +## ResNet-backbone models + +`torch-uncertainty` leverages [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html#lightning.pytorch.cli.LightningCLI) the configurable command line tool for pytorch-lightning. To ease the train of models, we provide a set of predefined configurations for the CIFAR10 dataset (corresponding to the experiments reported in [Packed-Ensembles for Efficient Uncertainty Estimation](https://arxiv.org/abs/2210.09184)). The configurations are located in the `configs` folder. + +*Examples:* + +* Training a standard ResNet18 model as in [Packed-Ensembles for Efficient Uncertainty Estimation](https://arxiv.org/abs/2210.09184): + +```bash +python resnet.py fit --config configs/resnet18/standard.yaml +``` + +* Training Packed-Ensembles ResNet50 model as in [Packed-Ensembles for Efficient Uncertainty Estimation](https://arxiv.org/abs/2210.09184): + +```bash +python resnet.py fit --config configs/resnet50/packed.yaml +``` + + +**Note:** In addition we provide a default resnet config file (`configs/resnet.yaml`) to enable the training of any ResNet model. Here a basic example to train a MIMO ResNet101 model with $4$ estimators and $\rho=1.0$: + +```bash +python resnet.py fit --config configs/resnet.yaml --model.arch 101 --model.version mimo --model.num_estimators 4 --model.rho 1.0 +``` + +## Available configurations: + +### ResNet + +||ResNet18|ResNet34|ResNet50|ResNet101|ResNet152| +|---|---|---|---|---|---| +|Standard|✅|✅|✅|✅|✅| +|Packed-Ensembles|✅|✅|✅|✅|✅| +|BatchEnsemble|✅|✅|✅|✅|✅| +|Masked-Ensembles|✅|✅|✅|✅|✅| +|MIMO|✅|✅|✅|✅|✅| +|MC Dropout|✅|✅|✅|✅|✅| + + +### WideResNet + +||WideResNet28-10| +|---|---| +|Standard|✅| +|Packed-Ensembles|✅| +|BatchEnsemble|✅| +|Masked-Ensembles|✅| +|MIMO|✅| +|MC Dropout|✅| + +### VGG + +||VGG11|VGG13|VGG16|VGG19| +|---|---|---|---|---| +|Standard|✅|✅|✅|✅| +|Packed-Ensembles|✅|✅|✅|✅| +|BatchEnsemble||||| +|Masked-Ensembles||||| +|MIMO||||| +|MC Dropout|✅|✅|✅|✅| diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index c09f4af9..6deddd4c 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -1,74 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import ResNet +from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import get_procedure -from torch_uncertainty.utils import csv_writer - -if __name__ == "__main__": - args = init_args(ResNet, CIFAR10DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) +from torch_uncertainty.utils import TULightningCLI - if args.exp_name == "": - args.exp_name = f"{args.version}-resnet{args.arch}-cifar10" - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - if args.opt_temp_scaling: - calibration_set = dm.get_test_set - elif args.val_temp_scaling: - calibration_set = dm.get_val_set - else: - calibration_set = None - results = None - if args.use_cv: - list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) - list_model = [ - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - style="cifar", - calibration_set=calibration_set, - **vars(args), - ) - for i in range(len(list_dm)) - ] +def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, CIFAR10DataModule) - results = cli_main( - list_model, list_dm, args.exp_dir, args.exp_name, args - ) - else: - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - style="cifar", - calibration_set=calibration_set, - **vars(args), - ) - results = cli_main(model, dm, args.exp_dir, args.exp_name, args) - - if results is not None: - for dict_result in results: - csv_writer( - Path(args.exp_dir) / Path(args.exp_name) / "results.csv", - dict_result, - ) +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar10/vgg.py b/experiments/classification/cifar10/vgg.py index 004a8526..a4614e3a 100644 --- a/experiments/classification/cifar10/vgg.py +++ b/experiments/classification/cifar10/vgg.py @@ -1,36 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import VGG +from torch_uncertainty.baselines.classification import VGGBaseline from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(VGG, CIFAR10DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-vgg{args.arch}-cifar10" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.Adam) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - # model - model = VGG( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"vgg{args.arch}", "cifar10", args.version - ), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(VGGBaseline, CIFAR10DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar10/wideresnet.py b/experiments/classification/cifar10/wideresnet.py index fb004a87..03870002 100644 --- a/experiments/classification/cifar10/wideresnet.py +++ b/experiments/classification/cifar10/wideresnet.py @@ -1,36 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import WideResNet +from torch_uncertainty.baselines.classification import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(WideResNet, CIFAR10DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-wideresnet28x10-cifar10" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - # model - model = WideResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - "wideresnet28x10", "cifar10", args.version - ), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(WideResNetBaseline, CIFAR10DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml new file mode 100644 index 00000000..d72a2c2b --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -0,0 +1,34 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/ + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: CrossEntropyLoss + style: cifar +data: + root: ./data + batch_size: 128 diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml new file mode 100644 index 00000000..61393563 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: batched + arch: 18 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 1e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml new file mode 100644 index 00000000..31f6e2a8 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: masked + arch: 18 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 1e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml new file mode 100644 index 00000000..7a3aec17 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: mimo + arch: 18 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 1e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml new file mode 100644 index 00000000..4e14cce9 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: packed + arch: 18 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 1e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml new file mode 100644 index 00000000..f8e9b821 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: standard + arch: 18 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 1e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml new file mode 100644 index 00000000..69259b96 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: batched + arch: 50 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.08 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml new file mode 100644 index 00000000..a1707666 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: masked + arch: 50 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml new file mode 100644 index 00000000..987a632d --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: mimo + arch: 50 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml new file mode 100644 index 00000000..954caf11 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -0,0 +1,52 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: packed + arch: 50 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml new file mode 100644 index 00000000..575b6e6f --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: standard + arch: 50 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/deep_ensembles.py b/experiments/classification/cifar100/deep_ensembles.py index 69a419a8..3a1ed65f 100644 --- a/experiments/classification/cifar100/deep_ensembles.py +++ b/experiments/classification/cifar100/deep_ensembles.py @@ -1,11 +1,11 @@ from pathlib import Path from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines import DeepEnsemblesBaseline from torch_uncertainty.datamodules import CIFAR100DataModule if __name__ == "__main__": - args = init_args(DeepEnsembles, CIFAR100DataModule) + args = init_args(DeepEnsemblesBaseline, CIFAR100DataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -19,7 +19,7 @@ # model args.task = "classification" - model = DeepEnsembles( + model = DeepEnsemblesBaseline( **vars(args), num_classes=dm.num_classes, in_channels=dm.num_channels, diff --git a/experiments/classification/cifar100/readme.md b/experiments/classification/cifar100/readme.md new file mode 100644 index 00000000..5bbe475a --- /dev/null +++ b/experiments/classification/cifar100/readme.md @@ -0,0 +1,3 @@ +# CIFAR100 - Benchmark + +TODO diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py index 19d9ea36..0c3a0068 100644 --- a/experiments/classification/cifar100/resnet.py +++ b/experiments/classification/cifar100/resnet.py @@ -1,36 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import ResNet +from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(ResNet, CIFAR100DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-resnet{args.arch}-cifar100" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR100DataModule(**vars(args)) - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar100", args.version - ), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, CIFAR100DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar100/vgg.py b/experiments/classification/cifar100/vgg.py index f218f83c..1936f809 100644 --- a/experiments/classification/cifar100/vgg.py +++ b/experiments/classification/cifar100/vgg.py @@ -1,36 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import VGG +from torch_uncertainty.baselines.classification import VGGBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(VGG, CIFAR100DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-vgg{args.arch}-cifar100" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.Adam) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR100DataModule(**vars(args)) - # model - model = VGG( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"vgg{args.arch}", "cifar100", args.version - ), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(VGGBaseline, CIFAR100DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar100/wideresnet.py b/experiments/classification/cifar100/wideresnet.py index 729f07ba..49b9a227 100644 --- a/experiments/classification/cifar100/wideresnet.py +++ b/experiments/classification/cifar100/wideresnet.py @@ -1,36 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import WideResNet +from torch_uncertainty.baselines.classification import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(WideResNet, CIFAR100DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-wideresnet28x10-cifar100" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR100DataModule(**vars(args)) - # model - model = WideResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - "wideresnet28x10", "cifar100", args.version - ), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(WideResNetBaseline, CIFAR100DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/mnist/bayesian_lenet.py b/experiments/classification/mnist/bayesian_lenet.py index 4789407c..05a7c17e 100644 --- a/experiments/classification/mnist/bayesian_lenet.py +++ b/experiments/classification/mnist/bayesian_lenet.py @@ -11,7 +11,7 @@ def optim_lenet(model: nn.Module) -> dict: - """Optimization procedure for LeNet. + """Optimization recipe for LeNet. Uses Adam default hyperparameters. @@ -45,7 +45,7 @@ def optim_lenet(model: nn.Module) -> dict: # hyperparameters are from blitz. loss = partial( ELBOLoss, - criterion=nn.CrossEntropyLoss(), + inner_loss=nn.CrossEntropyLoss(), kl_weight=1 / 50000, num_samples=3, ) @@ -55,7 +55,7 @@ def optim_lenet(model: nn.Module) -> dict: num_classes=dm.num_classes, in_channels=dm.num_channels, loss=loss, - optimization_procedure=optim_lenet, + optim_recipe=optim_lenet, **vars(args), ) diff --git a/experiments/classification/mnist/lenet.py b/experiments/classification/mnist/lenet.py index 0514c892..450f72c2 100644 --- a/experiments/classification/mnist/lenet.py +++ b/experiments/classification/mnist/lenet.py @@ -9,7 +9,7 @@ def optim_lenet(model: nn.Module) -> dict: - """Optimization procedure for LeNet. + """Optimization recipe for LeNet. Uses Adam default hyperparameters. @@ -44,8 +44,8 @@ def optim_lenet(model: nn.Module) -> dict: model=model, num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_lenet, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_lenet, **vars(args), ) diff --git a/experiments/classification/README.md b/experiments/classification/readme.md similarity index 100% rename from experiments/classification/README.md rename to experiments/classification/readme.md diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index 390223eb..e003ae84 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -3,9 +3,9 @@ from torch import nn, optim from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import ResNet +from torch_uncertainty.baselines import ResNetBaseline from torch_uncertainty.datamodules import TinyImageNetDataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.optim_recipes import get_procedure from torch_uncertainty.utils import csv_writer @@ -22,13 +22,12 @@ def optim_tiny(model: nn.Module) -> dict: if __name__ == "__main__": - args = init_args(ResNet, TinyImageNetDataModule) + args = init_args(ResNetBaseline, TinyImageNetDataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: root = Path(args.root) - # net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet" if args.exp_name == "": args.exp_name = f"{args.version}-resnet{args.arch}-tinyimagenet" @@ -46,11 +45,11 @@ def optim_tiny(model: nn.Module) -> dict: if args.use_cv: list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) list_model = [ - ResNet( + ResNetBaseline( num_classes=list_dm[i].dm.num_classes, in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( + loss=nn.CrossEntropyLoss(), + optim_recipe=get_procedure( f"resnet{args.arch}", "tiny-imagenet", args.version ), style="cifar", @@ -65,11 +64,11 @@ def optim_tiny(model: nn.Module) -> dict: ) else: # model - model = ResNet( + model = ResNetBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( + loss=nn.CrossEntropyLoss(), + optim_recipe=get_procedure( f"resnet{args.arch}", "tiny-imagenet", args.version ), calibration_set=calibration_set, diff --git a/experiments/readme.md b/experiments/readme.md new file mode 100644 index 00000000..0035c5a7 --- /dev/null +++ b/experiments/readme.md @@ -0,0 +1,19 @@ +# Experiments + +Torch-Uncertainty proposes various benchmarks to evaluate uncertainty quantification methods. + +## Classification + +*Work in progress* + +## Segmentation + +*Work in progress* + +## Regression + +*Work in progress* + +## Monocular Depth Estimation + +*Work in progress* diff --git a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml new file mode 100644 index 00000000..2e9b056d --- /dev/null +++ b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/gaussian_mlp_kin8nm + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: reg_val/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: reg_val/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 100 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: kin8nm +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml new file mode 100644 index 00000000..d95e09a1 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/gaussian_mlp_kin8nm + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: reg_val/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: reg_val/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 100 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: kin8nm +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml new file mode 100644 index 00000000..b6ce9fad --- /dev/null +++ b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/pw_mlp_kin8nm + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: reg_val/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: reg_val/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 100 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: kin8nm +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/deep_ensemble.py b/experiments/regression/uci_datasets/deep_ensemble.py index ef7217f3..6628e6e6 100644 --- a/experiments/regression/uci_datasets/deep_ensemble.py +++ b/experiments/regression/uci_datasets/deep_ensemble.py @@ -1,11 +1,11 @@ from pathlib import Path from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines import DeepEnsemblesBaseline from torch_uncertainty.datamodules import UCIDataModule if __name__ == "__main__": - args = init_args(DeepEnsembles, UCIDataModule) + args = init_args(DeepEnsemblesBaseline, UCIDataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -19,7 +19,7 @@ # model args.task = "regression" - model = DeepEnsembles( + model = DeepEnsemblesBaseline( **vars(args), ) diff --git a/experiments/regression/uci_datasets/mlp-kin8nm.py b/experiments/regression/uci_datasets/mlp-kin8nm.py deleted file mode 100644 index d96979a6..00000000 --- a/experiments/regression/uci_datasets/mlp-kin8nm.py +++ /dev/null @@ -1,48 +0,0 @@ -from pathlib import Path - -from torch import nn, optim - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines.regression.mlp import MLP -from torch_uncertainty.datamodules import UCIDataModule - - -def optim_regression( - model: nn.Module, - learning_rate: float = 5e-3, -) -> dict: - optimizer = optim.Adam( - model.parameters(), - lr=learning_rate, - weight_decay=0, - ) - return { - "optimizer": optimizer, - } - - -if __name__ == "__main__": - args = init_args(MLP, UCIDataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - - net_name = "mlp-kin8nm" - - # datamodule - args.root = str(root / "data") - dm = UCIDataModule(dataset_name="kin8nm", **vars(args)) - - # model - model = MLP( - num_outputs=2, - in_features=8, - hidden_dims=[100], - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_regression, - dist_estimation=2, - **vars(args), - ) - - cli_main(model, dm, root, net_name, args) diff --git a/experiments/regression/uci_datasets/mlp.py b/experiments/regression/uci_datasets/mlp.py new file mode 100644 index 00000000..a0605472 --- /dev/null +++ b/experiments/regression/uci_datasets/mlp.py @@ -0,0 +1,26 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.regression import MLPBaseline +from torch_uncertainty.datamodules import UCIDataModule +from torch_uncertainty.utils import TULightningCLI + + +class MLPCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.Adam) + + +def cli_main() -> MLPCLI: + return MLPCLI(MLPBaseline, UCIDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/regression/uci_datasets/readme.md b/experiments/regression/uci_datasets/readme.md new file mode 100644 index 00000000..3e0ec7b0 --- /dev/null +++ b/experiments/regression/uci_datasets/readme.md @@ -0,0 +1,17 @@ +# UCI Regression - Benchmark + +This folder contains the code to train models on the UCI regression datasets. The task is to predict (a) continuous target variable(s). + +Three experiments are provided: + +```bash +python mlp.py fit --config configs/pw_mlp_kin8nm.yaml +``` + +```bash +python mlp.py fit --config configs/gaussian_mlp_kin8nm.yaml +``` + +```bash +python mlp.py fit --config configs/laplace_mlp_kin8nm.yaml +``` diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml new file mode 100644 index 00000000..7cbb001b --- /dev/null +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -0,0 +1,22 @@ +# lightning.pytorch==2.1.3 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 +model: + num_classes: 12 + loss: CrossEntropyLoss + version: std + arch: 0 + num_estimators: 1 +data: + root: ./data + batch_size: 16 + num_workers: 20 +optimizer: + lr: 0.01 +lr_scheduler: + milestones: + - 30 + gamma: 0.1 diff --git a/experiments/segmentation/camvid/segformer.py b/experiments/segmentation/camvid/segformer.py new file mode 100644 index 00000000..8eecfb50 --- /dev/null +++ b/experiments/segmentation/camvid/segformer.py @@ -0,0 +1,27 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.segmentation import SegFormerBaseline +from torch_uncertainty.datamodules.segmentation import CamVidDataModule +from torch_uncertainty.utils import TULightningCLI + + +class SegFormerCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) + + +def cli_main() -> SegFormerCLI: + return SegFormerCLI(SegFormerBaseline, CamVidDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml new file mode 100644 index 00000000..b2abf11e --- /dev/null +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -0,0 +1,26 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + max_steps: 160000 +model: + num_classes: 19 + loss: CrossEntropyLoss + version: std + arch: 0 + num_estimators: 1 +data: + root: ./data + batch_size: 8 + crop_size: 1024 + inference_size: + - 1024 + - 2048 + num_workers: 30 +optimizer: + lr: 6e-5 +lr_scheduler: + step_size: 10000 + gamma: 0.1 diff --git a/experiments/segmentation/cityscapes/segformer.py b/experiments/segmentation/cityscapes/segformer.py new file mode 100644 index 00000000..2b7fe992 --- /dev/null +++ b/experiments/segmentation/cityscapes/segformer.py @@ -0,0 +1,27 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.segmentation import SegFormerBaseline +from torch_uncertainty.datamodules.segmentation import CityscapesDataModule +from torch_uncertainty.utils import TULightningCLI + + +class SegFormerCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.AdamW) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.StepLR) + + +def cli_main() -> SegFormerCLI: + return SegFormerCLI(SegFormerBaseline, CityscapesDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml new file mode 100644 index 00000000..b2abf11e --- /dev/null +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -0,0 +1,26 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + max_steps: 160000 +model: + num_classes: 19 + loss: CrossEntropyLoss + version: std + arch: 0 + num_estimators: 1 +data: + root: ./data + batch_size: 8 + crop_size: 1024 + inference_size: + - 1024 + - 2048 + num_workers: 30 +optimizer: + lr: 6e-5 +lr_scheduler: + step_size: 10000 + gamma: 0.1 diff --git a/experiments/segmentation/muad/segformer.py b/experiments/segmentation/muad/segformer.py new file mode 100644 index 00000000..67ad9564 --- /dev/null +++ b/experiments/segmentation/muad/segformer.py @@ -0,0 +1,27 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.segmentation import SegFormerBaseline +from torch_uncertainty.datamodules.segmentation import MUADDataModule +from torch_uncertainty.utils import TULightningCLI + + +class SegFormerCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.AdamW) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.StepLR) + + +def cli_main() -> SegFormerCLI: + return SegFormerCLI(SegFormerBaseline, MUADDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/segmentation/readme.md b/experiments/segmentation/readme.md new file mode 100644 index 00000000..e8ef0698 --- /dev/null +++ b/experiments/segmentation/readme.md @@ -0,0 +1 @@ +# Segmentation Benchmarks diff --git a/pyproject.toml b/pyproject.toml index dbba42a5..4d003ce5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.1.6" +version = "0.2.0" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, @@ -32,7 +32,8 @@ classifiers = [ ] dependencies = [ "timm", - "pytorch-lightning<2", + "lightning[pytorch-extra]", + "torchvision>=0.16", "tensorboard", "einops", "torchinfo", @@ -45,11 +46,10 @@ dependencies = [ [project.optional-dependencies] dev = [ - "ruff", + "ruff==0.3.4", "pytest-cov", "pre-commit", "pre-commit-hooks", - "cli-test-helpers", ] docs = [ "sphinx<6", @@ -74,11 +74,14 @@ name = "torch_uncertainty" [tool.ruff] line-length = 80 target-version = "py310" -extend-select = [ +lint.extend-select = [ "A", + "ARG", "B", "C4", "D", + "ERA", + "F", "G", "I", "ISC", @@ -88,19 +91,20 @@ extend-select = [ "PIE", "PTH", "PYI", + "Q", "RET", "RUF", "RSE", "S", "SIM", - "UP", "TCH", "TID", "TRY", + "UP", "YTT", ] -ignore = [ - "B017", +lint.ignore = [ + "ARG002", "D100", "D101", "D102", @@ -142,7 +146,7 @@ exclude = [ "venv", ] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" [tool.coverage.run] diff --git a/tests/_dummies/__init__.py b/tests/_dummies/__init__.py index db6a25de..ac5d4d0d 100644 --- a/tests/_dummies/__init__.py +++ b/tests/_dummies/__init__.py @@ -1,6 +1,18 @@ # ruff: noqa: F401 -from .baseline import DummyClassificationBaseline, DummyRegressionBaseline -from .datamodule import DummyClassificationDataModule, DummyRegressionDataModule -from .dataset import DummyClassificationDataset, DummyRegressionDataset +from .baseline import ( + DummyClassificationBaseline, + DummyRegressionBaseline, + DummySegmentationBaseline, +) +from .datamodule import ( + DummyClassificationDataModule, + DummyRegressionDataModule, + DummySegmentationDataModule, +) +from .dataset import ( + DummyClassificationDataset, + DummyRegressionDataset, + DummySegmentationDataset, +) from .model import dummy_model from .transform import DummyTransform diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 47a375c4..b650f180 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -1,20 +1,22 @@ -from argparse import ArgumentParser -from typing import Any +import copy from pytorch_lightning import LightningModule from torch import nn -from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, +from torch_uncertainty.layers.distributions import ( + LaplaceLayer, + NormalInverseGammaLayer, + NormalLayer, ) -from torch_uncertainty.routines.regression import ( - RegressionEnsemble, - RegressionSingle, +from torch_uncertainty.models.deep_ensembles import deep_ensembles +from torch_uncertainty.routines import ( + ClassificationRoutine, + RegressionRoutine, + SegmentationRoutine, ) from torch_uncertainty.transforms import RepeatTarget -from .model import dummy_model +from .model import dummy_model, dummy_segmentation_model class DummyClassificationBaseline: @@ -23,91 +25,165 @@ def __new__( num_classes: int, in_channels: int, loss: type[nn.Module], - optimization_procedure: Any, baseline_type: str = "single", + optim_recipe=None, with_feats: bool = True, with_linear: bool = True, - **kwargs, + ood_criterion: str = "msp", + eval_ood: bool = False, + eval_grouping_loss: bool = False, + calibrate: bool = False, + save_in_csv: bool = False, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, ) -> LightningModule: model = dummy_model( in_channels=in_channels, num_classes=num_classes, - num_estimators=1 + int(baseline_type == "ensemble"), with_feats=with_feats, with_linear=with_linear, ) if baseline_type == "single": - return ClassificationSingle( + return ClassificationRoutine( num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, format_batch_fn=nn.Identity(), log_plots=True, - **kwargs, + optim_recipe=optim_recipe(model), + num_estimators=1, + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + ood_criterion=ood_criterion, + eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, + calibration_set="val" if calibrate else None, + save_in_csv=save_in_csv, ) # baseline_type == "ensemble": - kwargs["num_estimators"] = 2 - return ClassificationEnsemble( + model = deep_ensembles( + [model, copy.deepcopy(model)], + task="classification", + ) + return ClassificationRoutine( num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, + optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), log_plots=True, - **kwargs, + num_estimators=2, + ood_criterion=ood_criterion, + eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, + calibration_set="val" if calibrate else None, + save_in_csv=save_in_csv, ) - @classmethod - def add_model_specific_args( - cls, - parser: ArgumentParser, - ) -> ArgumentParser: - return ClassificationEnsemble.add_model_specific_args(parser) - class DummyRegressionBaseline: def __new__( cls, + probabilistic: bool, in_features: int, - out_features: int, + output_dim: int, loss: type[nn.Module], - optimization_procedure: Any, baseline_type: str = "single", - dist_estimation: int = 1, - **kwargs, + optim_recipe=None, + dist_type: str = "normal", ) -> LightningModule: + if probabilistic: + if dist_type == "normal": + last_layer = NormalLayer(output_dim) + num_classes = output_dim * 2 + elif dist_type == "laplace": + last_layer = LaplaceLayer(output_dim) + num_classes = output_dim * 2 + else: # dist_type == "nig" + last_layer = NormalInverseGammaLayer(output_dim) + num_classes = output_dim * 4 + else: + last_layer = nn.Identity() + num_classes = output_dim + model = dummy_model( in_channels=in_features, - num_classes=out_features, - num_estimators=1 + int(baseline_type == "ensemble"), + num_classes=num_classes, + last_layer=last_layer, ) - if baseline_type == "single": - return RegressionSingle( - out_features=out_features, + return RegressionRoutine( + probabilistic=probabilistic, + output_dim=output_dim, model=model, loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, - **kwargs, + num_estimators=1, + optim_recipe=optim_recipe(model), ) # baseline_type == "ensemble": - kwargs["num_estimators"] = 2 - return RegressionEnsemble( + model = deep_ensembles( + [model, copy.deepcopy(model)], + task="regression", + probabilistic=probabilistic, + ) + return RegressionRoutine( + probabilistic=probabilistic, + output_dim=output_dim, model=model, loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, - mode="mean", - out_features=out_features, - **kwargs, + num_estimators=2, + optim_recipe=optim_recipe(model), + format_batch_fn=RepeatTarget(2), ) - @classmethod - def add_model_specific_args( + +class DummySegmentationBaseline: + def __new__( cls, - parser: ArgumentParser, - ) -> ArgumentParser: - return ClassificationEnsemble.add_model_specific_args(parser) + in_channels: int, + num_classes: int, + image_size: int, + loss: type[nn.Module], + baseline_type: str = "single", + optim_recipe=None, + ) -> LightningModule: + model = dummy_segmentation_model( + in_channels=in_channels, + num_classes=num_classes, + image_size=image_size, + ) + + if baseline_type == "single": + return SegmentationRoutine( + num_classes=num_classes, + model=model, + loss=loss, + format_batch_fn=None, + num_estimators=1, + optim_recipe=optim_recipe(model), + ) + + # baseline_type == "ensemble": + model = deep_ensembles( + [model, copy.deepcopy(model)], + task="segmentation", + ) + return SegmentationRoutine( + num_classes=num_classes, + model=model, + loss=loss, + format_batch_fn=RepeatTarget(2), + num_estimators=2, + optim_recipe=optim_recipe(model), + ) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index a6b15d2a..9cf0ab77 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -1,15 +1,19 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any import numpy as np -import torchvision.transforms as T +import torch +import torchvision.transforms.v2 as T from numpy.typing import ArrayLike from torch.utils.data import DataLoader +from torchvision import tv_tensors from torch_uncertainty.datamodules.abstract import AbstractDataModule -from .dataset import DummyClassificationDataset, DummyRegressionDataset +from .dataset import ( + DummyClassificationDataset, + DummyRegressionDataset, + DummySegmentationDataset, +) class DummyClassificationDataModule(AbstractDataModule): @@ -20,17 +24,17 @@ class DummyClassificationDataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, num_classes: int = 2, num_workers: int = 1, + eval_ood: bool = False, pin_memory: bool = True, persistent_workers: bool = True, num_images: int = 2, - **kwargs, ) -> None: super().__init__( root=root, + val_split=None, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -98,16 +102,6 @@ def _get_train_data(self) -> ArrayLike: def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - p.add_argument("--eval-ood", action="store_true") - return parent_parser - class DummyRegressionDataModule(AbstractDataModule): in_features = 4 @@ -116,13 +110,11 @@ class DummyRegressionDataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, out_features: int = 2, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: super().__init__( root=root, @@ -130,9 +122,9 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, + val_split=0, ) - self.eval_ood = eval_ood self.out_features = out_features self.dataset = DummyRegressionDataset @@ -162,25 +154,95 @@ def setup(self, stage: str | None = None) -> None: out_features=self.out_features, transform=self.test_transform, ) - if self.eval_ood: - self.ood = self.ood_dataset( + + def test_dataloader(self) -> DataLoader | list[DataLoader]: + return [self._data_loader(self.test)] + + +class DummySegmentationDataModule(AbstractDataModule): + num_channels = 3 + training_task = "segmentation" + + def __init__( + self, + root: str | Path, + batch_size: int, + num_classes: int = 2, + num_workers: int = 1, + image_size: int = 4, + pin_memory: bool = True, + persistent_workers: bool = True, + num_images: int = 2, + ) -> None: + super().__init__( + root=root, + val_split=None, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.num_classes = num_classes + self.num_channels = 3 + self.num_images = num_images + self.image_size = image_size + + self.dataset = DummySegmentationDataset + + self.train_transform = T.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ) + self.test_transform = T.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ) + + def prepare_data(self) -> None: + pass + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + self.train = self.dataset( self.root, - out_features=self.out_features, - transform=self.test_transform, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transforms=self.train_transform, + num_images=self.num_images, + ) + self.val = self.dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transforms=self.test_transform, + num_images=self.num_images, + ) + elif stage == "test": + self.test = self.dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transforms=self.test_transform, + num_images=self.num_images, ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - dataloader = [self._data_loader(self.test)] - if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) - return dataloader + return [self._data_loader(self.test)] + + def _get_train_data(self) -> ArrayLike: + return self.train.data - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - p.add_argument("--eval-ood", action="store_true") - return parent_parser + def _get_train_targets(self) -> ArrayLike: + return np.array(self.train.targets) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 9c35c7d2..3e5e4024 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -6,6 +6,7 @@ import torch from PIL import Image from torch.utils.data import Dataset +from torchvision import tv_tensors class DummyClassificationDataset(Dataset): @@ -37,7 +38,7 @@ def __init__( image_size: int = 4, num_classes: int = 10, num_images: int = 2, - **kwargs: Any, + **args, ) -> None: self.root = root self.train = train # training set or test set @@ -111,7 +112,7 @@ def __init__( in_features: int = 3, out_features: int = 10, num_samples: int = 2, - **kwargs: Any, + **args, ) -> None: self.root = root self.train = train # training set or test set @@ -122,7 +123,10 @@ def __init__( self.targets = [] input_shape = (num_samples, in_features) - output_shape = (num_samples, out_features) + if out_features != 1: + output_shape = (num_samples, out_features) + else: + output_shape = (num_samples,) self.data = torch.rand( size=input_shape, @@ -153,3 +157,114 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: def __len__(self) -> int: return len(self.data) + + +class DummySegmentationDataset(Dataset): + def __init__( + self, + root: Path, + split: str = "train", + transforms: Callable[..., Any] | None = None, + num_channels: int = 3, + image_size: int = 4, + num_classes: int = 10, + num_images: int = 2, + **args, + ) -> None: + super().__init__() + + self.root = root + self.split = split + self.transforms = transforms + + self.data: Any = [] + self.targets = [] + + if num_channels == 1: + img_shape = (num_images, image_size, image_size) + else: + img_shape = (num_images, num_channels, image_size, image_size) + + smnt_shape = (num_images, 1, image_size, image_size) + + self.data = np.random.randint( + low=0, + high=255, + size=img_shape, + dtype=np.uint8, + ) + + self.targets = np.random.randint( + low=0, + high=num_classes, + size=smnt_shape, + dtype=np.uint8, + ) + + def __getitem__(self, index: int) -> tuple[Any, Any]: + img = tv_tensors.Image(self.data[index]) + target = tv_tensors.Mask(self.targets[index]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + +class DummyDepthDataset(Dataset): + def __init__( + self, + root: Path, + split: str = "train", + transforms: Callable[..., Any] | None = None, + num_channels: int = 3, + image_size: int = 4, + num_images: int = 2, + **args, + ) -> None: + super().__init__() + + self.root = root + self.split = split + self.transforms = transforms + + self.data: Any = [] + self.targets = [] + + if num_channels == 1: + img_shape = (num_images, image_size, image_size) + else: + img_shape = (num_images, num_channels, image_size, image_size) + + smnt_shape = (num_images, 1, image_size, image_size) + + self.data = np.random.randint( + low=0, + high=255, + size=img_shape, + dtype=np.uint8, + ) + + self.targets = ( + np.random.uniform( + low=0, + high=1, + size=smnt_shape, + ) + * 100 + ) + + def __getitem__(self, index: int) -> tuple[Any, Any]: + img = tv_tensors.Image(self.data[index]) + target = tv_tensors.Mask(self.targets[index]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self) -> int: + return len(self.data) diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 51b65a77..2e29e2b5 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -11,11 +11,12 @@ def __init__( self, in_channels: int, num_classes: int, - num_estimators: int, dropout_rate: float, with_linear: bool, + last_layer: nn.Module, ) -> None: super().__init__() + self.in_channels = in_channels self.dropout_rate = dropout_rate if with_linear: @@ -28,15 +29,17 @@ def __init__( 1, num_classes, ) + self.last_layer = last_layer self.dropout = nn.Dropout(p=dropout_rate) - self.num_estimators = num_estimators - def forward(self, x: Tensor) -> Tensor: - return self.dropout( - self.linear( - torch.ones( - (x.shape[0] * self.num_estimators, 1), dtype=torch.float32 + return self.last_layer( + self.dropout( + self.linear( + torch.ones( + (x.shape[0], 1), + dtype=torch.float32, + ) ) ) ) @@ -44,16 +47,53 @@ def forward(self, x: Tensor) -> Tensor: class _DummyWithFeats(_Dummy): def feats_forward(self, x: Tensor) -> Tensor: - return self.forward(x) + return torch.ones( + (x.shape[0], 1), + dtype=torch.float32, + ) + + +class _DummySegmentation(nn.Module): + def __init__( + self, + in_channels: int, + num_classes: int, + dropout_rate: float, + image_size: int, + ) -> None: + super().__init__() + self.dropout_rate = dropout_rate + self.in_channels = in_channels + self.num_classes = num_classes + self.image_size = image_size + self.conv = nn.Conv2d( + in_channels, num_classes, kernel_size=3, padding=1 + ) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x: Tensor) -> Tensor: + return self.dropout( + self.conv( + torch.ones( + ( + x.shape[0], + self.in_channels, + self.image_size, + self.image_size, + ), + dtype=torch.float32, + ) + ) + ) def dummy_model( in_channels: int, num_classes: int, - num_estimators: int, dropout_rate: float = 0.0, with_feats: bool = True, with_linear: bool = True, + last_layer=None, ) -> _Dummy: """Dummy model for testing purposes. @@ -65,22 +105,50 @@ def dummy_model( with_feats (bool, optional): Whether to include features. Defaults to True. with_linear (bool, optional): Whether to include a linear layer. Defaults to True. + last_layer ([type], optional): Last layer of the model. Defaults to None. Returns: _Dummy: Dummy model. """ + if last_layer is None: + last_layer = nn.Identity() if with_feats: return _DummyWithFeats( in_channels=in_channels, num_classes=num_classes, - num_estimators=num_estimators, dropout_rate=dropout_rate, with_linear=with_linear, + last_layer=last_layer, ) return _Dummy( in_channels=in_channels, num_classes=num_classes, - num_estimators=num_estimators, dropout_rate=dropout_rate, with_linear=with_linear, + last_layer=last_layer, + ) + + +def dummy_segmentation_model( + in_channels: int, + num_classes: int, + image_size: int, + dropout_rate: float = 0.0, +) -> nn.Module: + """Dummy segmentation model for testing purposes. + + Args: + in_channels (int): Number of input channels. + num_classes (int): Number of output classes. + image_size (int): Size of the input image. + dropout_rate (float, optional): Dropout rate. Defaults to 0.0. + + Returns: + nn.Module: Dummy segmentation model. + """ + return _DummySegmentation( + in_channels=in_channels, + num_classes=num_classes, + dropout_rate=dropout_rate, + image_size=image_size, ) diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index b6cd1fae..ef208523 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -2,11 +2,9 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_wideresnet, - optim_cifar100_resnet18, - optim_cifar100_resnet50, +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + WideResNetBaseline, ) @@ -14,11 +12,10 @@ class TestBatchedBaseline: """Testing the BatchedResNet baseline class.""" def test_batched_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, + loss=nn.CrossEntropyLoss(), version="batched", arch=18, style="cifar", @@ -27,17 +24,13 @@ def test_batched_18(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_batched_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet50, + loss=nn.CrossEntropyLoss(), version="batched", arch=50, style="imagenet", @@ -46,9 +39,6 @@ def test_batched_50(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) @@ -56,11 +46,10 @@ class TestBatchedWideBaseline: """Testing the BatchedWideResNet baseline class.""" def test_batched(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, + loss=nn.CrossEntropyLoss(), version="batched", style="cifar", num_estimators=4, @@ -68,7 +57,4 @@ def test_batched(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index 5f25910c..fbb7a512 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -1,23 +1,18 @@ -from argparse import ArgumentParser +import pytest -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines.classification.deep_ensembles import ( + DeepEnsemblesBaseline, +) class TestDeepEnsembles: """Testing the Deep Ensembles baseline class.""" - def test_standard(self): - DeepEnsembles( - task="classification", - log_path=".", - checkpoint_ids=[], - backbone="resnet", - in_channels=3, - num_classes=10, - version="std", - arch=18, - style="cifar", - groups=1, - ) - parser = ArgumentParser() - DeepEnsembles.add_model_specific_args(parser) + def test_failure(self): + with pytest.raises(ValueError): + DeepEnsemblesBaseline( + log_path=".", + checkpoint_ids=[], + backbone="resnet", + num_classes=10, + ) diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index e992bdea..3fd48ebf 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -3,11 +3,9 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_wideresnet, - optim_cifar100_resnet18, - optim_cifar100_resnet50, +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + WideResNetBaseline, ) @@ -15,11 +13,10 @@ class TestMaskedBaseline: """Testing the MaskedResNet baseline class.""" def test_masked_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, + loss=nn.CrossEntropyLoss(), version="masked", arch=18, style="cifar", @@ -29,17 +26,13 @@ def test_masked_18(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_masked_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet50, + loss=nn.CrossEntropyLoss(), version="masked", arch=50, style="imagenet", @@ -49,18 +42,14 @@ def test_masked_50(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) - def test_masked_scale_lt_1(self): - with pytest.raises(Exception): - _ = ResNet( + def test_masked_errors(self): + with pytest.raises(ValueError): + _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, + loss=nn.CrossEntropyLoss(), version="masked", arch=18, style="cifar", @@ -69,13 +58,11 @@ def test_masked_scale_lt_1(self): groups=1, ) - def test_masked_groups_lt_1(self): - with pytest.raises(Exception): - _ = ResNet( + with pytest.raises(ValueError): + _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, + loss=nn.CrossEntropyLoss(), version="masked", arch=18, style="cifar", @@ -89,11 +76,10 @@ class TestMaskedWideBaseline: """Testing the MaskedWideResNet baseline class.""" def test_masked(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, + loss=nn.CrossEntropyLoss(), version="masked", style="cifar", num_estimators=4, @@ -102,7 +88,4 @@ def test_masked(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_mc_dropout.py b/tests/baselines/test_mc_dropout.py index c12daf15..dca61c3c 100644 --- a/tests/baselines/test_mc_dropout.py +++ b/tests/baselines/test_mc_dropout.py @@ -2,22 +2,21 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import VGG, ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_wideresnet, +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + VGGBaseline, + WideResNetBaseline, ) class TestStandardBaseline: - """Testing the ResNet baseline class.""" + """Testing the ResNetBaseline baseline class.""" def test_standard(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="mc-dropout", dropout_rate=0.1, num_estimators=4, @@ -26,21 +25,17 @@ def test_standard(self): groups=1, ) summary(net) - - _ = net.criterion - net.configure_optimizers() net(torch.rand(1, 3, 32, 32)) class TestStandardWideBaseline: - """Testing the WideResNet baseline class.""" + """Testing the WideResNetBaseline baseline class.""" def test_standard(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, + loss=nn.CrossEntropyLoss(), version="mc-dropout", dropout_rate=0.1, num_estimators=4, @@ -48,21 +43,17 @@ def test_standard(self): groups=1, ) summary(net) - - _ = net.criterion - net.configure_optimizers() net(torch.rand(1, 3, 32, 32)) class TestStandardVGGBaseline: - """Testing the VGG baseline class.""" + """Testing the VGGBaseline baseline class.""" def test_standard(self): - net = VGG( + net = VGGBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="mc-dropout", dropout_rate=0.1, num_estimators=4, @@ -71,7 +62,18 @@ def test_standard(self): last_layer_dropout=True, ) summary(net) + net(torch.rand(1, 3, 32, 32)) - _ = net.criterion - net.configure_optimizers() + net = VGGBaseline( + num_classes=10, + in_channels=3, + loss=nn.CrossEntropyLoss(), + version="mc-dropout", + num_estimators=4, + arch=11, + groups=1, + dropout_rate=0.3, + last_layer_dropout=True, + ) + net.eval() net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_mimo.py b/tests/baselines/test_mimo.py index cf4a29cc..18c83a08 100644 --- a/tests/baselines/test_mimo.py +++ b/tests/baselines/test_mimo.py @@ -2,11 +2,9 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_resnet50, - optim_cifar10_wideresnet, +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + WideResNetBaseline, ) @@ -14,11 +12,10 @@ class TestMIMOBaseline: """Testing the MIMOResNet baseline class.""" def test_mimo_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, + loss=nn.CrossEntropyLoss(), version="mimo", arch=50, style="cifar", @@ -29,17 +26,13 @@ def test_mimo_50(self): ).eval() summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_mimo_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="mimo", arch=18, style="imagenet", @@ -50,9 +43,6 @@ def test_mimo_18(self): ).eval() summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) @@ -60,11 +50,10 @@ class TestMIMOWideBaseline: """Testing the PackedWideResNet baseline class.""" def test_mimo(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, + loss=nn.CrossEntropyLoss(), version="mimo", style="cifar", num_estimators=4, @@ -74,7 +63,4 @@ def test_mimo(self): ).eval() summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index 834718ad..c8331119 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -3,24 +3,22 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import VGG, ResNet, WideResNet -from torch_uncertainty.baselines.regression import MLP -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_resnet50, - optim_cifar10_wideresnet, +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + VGGBaseline, + WideResNetBaseline, ) +from torch_uncertainty.baselines.regression import MLPBaseline class TestPackedBaseline: """Testing the PackedResNet baseline class.""" def test_packed_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, + loss=nn.CrossEntropyLoss(), version="packed", arch=50, style="cifar", @@ -32,16 +30,13 @@ def test_packed_50(self): summary(net) - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_packed_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="packed", arch=18, style="imagenet", @@ -52,18 +47,14 @@ def test_packed_18(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) def test_packed_exception(self): - with pytest.raises(Exception): - _ = ResNet( + with pytest.raises(ValueError): + _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, + loss=nn.CrossEntropyLoss(), version="packed", arch=50, style="cifar", @@ -73,12 +64,11 @@ def test_packed_exception(self): groups=1, ) - with pytest.raises(Exception): - _ = ResNet( + with pytest.raises(ValueError): + _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, + loss=nn.CrossEntropyLoss(), version="packed", arch=50, style="cifar", @@ -93,11 +83,10 @@ class TestPackedWideBaseline: """Testing the PackedWideResNet baseline class.""" def test_packed(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, + loss=nn.CrossEntropyLoss(), version="packed", style="cifar", num_estimators=4, @@ -107,9 +96,6 @@ def test_packed(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) @@ -117,12 +103,11 @@ class TestPackedVGGBaseline: """Testing the PackedWideResNet baseline class.""" def test_packed(self): - net = VGG( + net = VGGBaseline( num_classes=10, in_channels=3, arch=13, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, + loss=nn.CrossEntropyLoss(), version="packed", num_estimators=4, alpha=2, @@ -131,9 +116,6 @@ def test_packed(self): ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(2, 3, 32, 32)) @@ -141,20 +123,15 @@ class TestPackedMLPBaseline: """Testing the Packed MLP baseline class.""" def test_packed(self): - net = MLP( + net = MLPBaseline( in_features=3, - num_outputs=10, - loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, + output_dim=10, + loss=nn.MSELoss(), version="packed", hidden_dims=[1], num_estimators=2, alpha=2, gamma=1, - dist_estimation=1, ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3)) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index b530ecca..77cb948f 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -1,48 +1,39 @@ -from argparse import ArgumentParser - import pytest import torch from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import VGG, ResNet, WideResNet -from torch_uncertainty.baselines.regression import MLP -from torch_uncertainty.baselines.utils.parser_addons import ( - add_mlp_specific_args, -) -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_wideresnet, +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + VGGBaseline, + WideResNetBaseline, ) +from torch_uncertainty.baselines.regression import MLPBaseline +from torch_uncertainty.baselines.segmentation import SegFormerBaseline class TestStandardBaseline: - """Testing the ResNet baseline class.""" + """Testing the ResNetBaseline baseline class.""" def test_standard(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="std", arch=18, style="cifar", groups=1, ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): with pytest.raises(ValueError): - ResNet( + ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="test", arch=18, style="cifar", @@ -51,31 +42,26 @@ def test_errors(self): class TestStandardWideBaseline: - """Testing the WideResNet baseline class.""" + """Testing the WideResNetBaseline baseline class.""" def test_standard(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, + loss=nn.CrossEntropyLoss(), version="std", style="cifar", groups=1, ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): with pytest.raises(ValueError): - WideResNet( + WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, + loss=nn.CrossEntropyLoss(), version="test", style="cifar", groups=1, @@ -83,31 +69,26 @@ def test_errors(self): class TestStandardVGGBaseline: - """Testing the VGG baseline class.""" + """Testing the VGGBaseline baseline class.""" def test_standard(self): - net = VGG( + net = VGGBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="std", arch=11, groups=1, ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): with pytest.raises(ValueError): - VGG( + VGGBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + loss=nn.CrossEntropyLoss(), version="test", arch=11, groups=1, @@ -118,32 +99,55 @@ class TestStandardMLPBaseline: """Testing the MLP baseline class.""" def test_standard(self): - net = MLP( + net = MLPBaseline( in_features=3, - num_outputs=10, - loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, + output_dim=10, + loss=nn.MSELoss(), version="std", hidden_dims=[1], - dist_estimation=1, ) summary(net) - - _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3)) - parser = ArgumentParser() - add_mlp_specific_args(parser) + for distribution in ["normal", "laplace", "nig"]: + MLPBaseline( + in_features=3, + output_dim=10, + loss=nn.MSELoss(), + version="std", + hidden_dims=[1], + distribution=distribution, + ) def test_errors(self): with pytest.raises(ValueError): - MLP( + MLPBaseline( in_features=3, - num_outputs=10, - loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, + output_dim=10, + loss=nn.MSELoss(), version="test", hidden_dims=[1], - dist_estimation=1, + ) + + +class TestStandardSegFormerBaseline: + """Testing the SegFormer baseline class.""" + + def test_standard(self): + net = SegFormerBaseline( + num_classes=10, + loss=nn.CrossEntropyLoss(), + version="std", + arch=0, + ) + summary(net) + _ = net(torch.rand(1, 3, 32, 32)) + + def test_errors(self): + with pytest.raises(ValueError): + SegFormerBaseline( + num_classes=10, + loss=nn.CrossEntropyLoss(), + version="test", + arch=0, ) diff --git a/torch_uncertainty/baselines/utils/__init__.py b/tests/datamodules/classification/__init__.py similarity index 100% rename from torch_uncertainty/baselines/utils/__init__.py rename to tests/datamodules/classification/__init__.py diff --git a/tests/datamodules/test_cifar100_datamodule.py b/tests/datamodules/classification/test_cifar100_datamodule.py similarity index 61% rename from tests/datamodules/test_cifar100_datamodule.py rename to tests/datamodules/classification/test_cifar100_datamodule.py index 5ca47529..e24af243 100644 --- a/tests/datamodules/test_cifar100_datamodule.py +++ b/tests/datamodules/classification/test_cifar100_datamodule.py @@ -1,6 +1,3 @@ -from argparse import ArgumentParser -from pathlib import Path - import pytest from torchvision.datasets import CIFAR100 @@ -13,14 +10,7 @@ class TestCIFAR100DataModule: """Testing the CIFAR100DataModule datamodule class.""" def test_cifar100(self): - parser = ArgumentParser() - parser = CIFAR100DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - args.cutout = 8 - - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR100 assert isinstance(dm.train_transform.transforms[2], Cutout) @@ -30,7 +20,6 @@ def test_cifar100(self): dm.prepare_data() dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() @@ -41,18 +30,22 @@ def test_cifar100(self): dm.setup("test") dm.test_dataloader() - args.test_alt = "c" - args.cutout = 0 - args.root = Path(args.root) - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", batch_size=128, cutout=0, test_alt="c" + ) dm.dataset = DummyClassificationDataset + dm.setup("test") with pytest.raises(ValueError): dm.setup() - args.test_alt = None - args.num_dataloaders = 2 - args.val_split = 0.1 - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", + batch_size=128, + cutout=0, + test_alt=None, + val_split=0.1, + num_dataloaders=2, + ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset @@ -62,27 +55,25 @@ def test_cifar100(self): with pytest.raises(ValueError): dm.setup("other") - args.num_dataloaders = 1 - args.cutout = 8 - args.randaugment = True with pytest.raises(ValueError): - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", + batch_size=128, + num_dataloaders=1, + cutout=8, + randaugment=True, + ) - args.cutout = None - dm = CIFAR100DataModule(**vars(args)) - args.randaugment = False + dm = CIFAR100DataModule( + root="./data/", batch_size=128, randaugment=True + ) - args.auto_augment = "rand-m9-n2-mstd0.5" - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", batch_size=128, auto_augment="rand-m9-n2-mstd0.5" + ) def test_cifar100_cv(self): - parser = ArgumentParser() - parser = CIFAR100DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule(root="./data/", batch_size=128) dm.dataset = ( lambda root, train, download, transform: DummyClassificationDataset( root, @@ -94,8 +85,7 @@ def test_cifar100_cv(self): ) dm.make_cross_val_splits(2, 1) - args.val_split = 0.1 - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule(root="./data/", batch_size=128, val_split=0.1) dm.dataset = ( lambda root, train, download, transform: DummyClassificationDataset( root, diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/classification/test_cifar10_datamodule.py similarity index 63% rename from tests/datamodules/test_cifar10_datamodule.py rename to tests/datamodules/classification/test_cifar10_datamodule.py index 0bcd99f9..df12f214 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/classification/test_cifar10_datamodule.py @@ -1,5 +1,3 @@ -from argparse import ArgumentParser - import pytest from torchvision.datasets import CIFAR10 @@ -12,14 +10,7 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" def test_cifar10_main(self): - parser = ArgumentParser() - parser = CIFAR10DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - args.cutout = 16 - - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR10 assert isinstance(dm.train_transform.transforms[2], Cutout) @@ -29,7 +20,6 @@ def test_cifar10_main(self): dm.prepare_data() dm.setup() - dm.setup("test") with pytest.raises(ValueError): dm.setup("xxx") @@ -48,44 +38,52 @@ def test_cifar10_main(self): dm.setup("test") dm.test_dataloader() - args.test_alt = "c" - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", batch_size=128, cutout=16, test_alt="c" + ) dm.dataset = DummyClassificationDataset with pytest.raises(ValueError): dm.setup() - args.test_alt = "h" - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", batch_size=128, cutout=16, test_alt="h" + ) dm.dataset = DummyClassificationDataset dm.setup("test") - args.test_alt = None - args.num_dataloaders = 2 - args.val_split = 0.1 - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + cutout=16, + num_dataloaders=2, + val_split=0.1, + ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.setup() - dm.setup("test") dm.train_dataloader() - args.cutout = 8 - args.auto_augment = "rand-m9-n2-mstd0.5" with pytest.raises(ValueError): - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + cutout=8, + num_dataloaders=2, + val_split=0.1, + auto_augment="rand-m9-n2-mstd0.5", + ) - args.cutout = None - args.auto_augment = "rand-m9-n2-mstd0.5" - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + cutout=None, + num_dataloaders=2, + val_split=0.1, + auto_augment="rand-m9-n2-mstd0.5", + ) def test_cifar10_cv(self): - parser = ArgumentParser() - parser = CIFAR10DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule(root="./data/", batch_size=128) dm.dataset = ( lambda root, train, download, transform: DummyClassificationDataset( root, @@ -97,8 +95,7 @@ def test_cifar10_cv(self): ) dm.make_cross_val_splits(2, 1) - args.val_split = 0.1 - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule(root="./data/", batch_size=128, val_split=0.1) dm.dataset = ( lambda root, train, download, transform: DummyClassificationDataset( root, diff --git a/tests/datamodules/test_imagenet_datamodule.py b/tests/datamodules/classification/test_imagenet_datamodule.py similarity index 57% rename from tests/datamodules/test_imagenet_datamodule.py rename to tests/datamodules/classification/test_imagenet_datamodule.py index 23af2e63..9088a701 100644 --- a/tests/datamodules/test_imagenet_datamodule.py +++ b/tests/datamodules/classification/test_imagenet_datamodule.py @@ -1,5 +1,4 @@ -import pathlib -from argparse import ArgumentParser +from pathlib import Path import pytest from torchvision.datasets import ImageNet @@ -12,26 +11,17 @@ class TestImageNetDataModule: """Testing the ImageNetDataModule datamodule class.""" def test_imagenet(self): - parser = ArgumentParser() - parser = ImageNetDataModule.add_argparse_args(parser) - - args = parser.parse_args("") - args.val_split = 0.1 - dm = ImageNetDataModule(**vars(args)) - + dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=0.1) assert dm.dataset == ImageNet - dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.prepare_data() dm.setup() - dm.setup("test") - args.val_split = ( - pathlib.Path(__file__).parent.resolve() - / "../assets/dummy_indices.yaml" + path = ( + Path(__file__).parent.resolve() / "../../assets/dummy_indices.yaml" ) - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=path) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.setup("fit") @@ -51,25 +41,33 @@ def test_imagenet(self): dm.setup("test") dm.test_dataloader() + ImageNetDataModule( + root="./data/", + batch_size=128, + val_split=path, + rand_augment_opt="rand-m9-n1", + ) + with pytest.raises(ValueError): dm.setup("other") for test_alt in ["r", "o", "a"]: - args.test_alt = test_alt - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", batch_size=128, test_alt=test_alt + ) with pytest.raises(ValueError): dm.setup() - args.test_alt = "x" with pytest.raises(ValueError): - dm = ImageNetDataModule(**vars(args)) - - args.test_alt = None + dm = ImageNetDataModule( + root="./data/", batch_size=128, test_alt="x" + ) for ood_ds in ["inaturalist", "imagenet-o", "textures", "openimage-o"]: - args.ood_ds = ood_ds - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", batch_size=128, ood_ds=ood_ds + ) if ood_ds == "inaturalist": dm.eval_ood = True dm.dataset = DummyClassificationDataset @@ -78,21 +76,27 @@ def test_imagenet(self): dm.setup("test") dm.test_dataloader() - args.ood_ds = "other" with pytest.raises(ValueError): - dm = ImageNetDataModule(**vars(args)) - - args.ood_ds = "svhn" + dm = ImageNetDataModule( + root="./data/", batch_size=128, ood_ds="other" + ) for procedure in ["ViT", "A3"]: - args.procedure = procedure - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", + batch_size=128, + ood_ds="svhn", + procedure=procedure, + ) - args.procedure = "A2" with pytest.raises(ValueError): - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", batch_size=128, procedure="A2" + ) + + with pytest.raises(FileNotFoundError): + dm._verify_splits(split="test") - args.procedure = None - args.rand_augment_opt = "rand-m9-n2-mstd0.5" with pytest.raises(FileNotFoundError): + dm.root = Path("./tests/testlog") dm._verify_splits(split="test") diff --git a/tests/datamodules/test_mnist_datamodule.py b/tests/datamodules/classification/test_mnist_datamodule.py similarity index 61% rename from tests/datamodules/test_mnist_datamodule.py rename to tests/datamodules/classification/test_mnist_datamodule.py index 36255381..1707409a 100644 --- a/tests/datamodules/test_mnist_datamodule.py +++ b/tests/datamodules/classification/test_mnist_datamodule.py @@ -1,6 +1,3 @@ -from argparse import ArgumentParser -from pathlib import Path - import pytest from torch import nn from torchvision.datasets import MNIST @@ -14,35 +11,28 @@ class TestMNISTDataModule: """Testing the MNISTDataModule datamodule class.""" def test_mnist_cutout(self): - parser = ArgumentParser() - parser = MNISTDataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 16 - args = parser.parse_args("") - args.cutout = 16 - args.val_split = 0.1 - dm = MNISTDataModule(**vars(args)) + dm = MNISTDataModule( + root="./data/", batch_size=128, cutout=16, val_split=0.1 + ) assert dm.dataset == MNIST assert isinstance(dm.train_transform.transforms[0], Cutout) - args.root = Path(args.root) - args.ood_ds = "not" - args.cutout = 0 - args.val_split = 0 - dm = MNISTDataModule(**vars(args)) + dm = MNISTDataModule( + root="./data/", batch_size=128, ood_ds="not", cutout=0, val_split=0 + ) assert isinstance(dm.train_transform.transforms[0], nn.Identity) - args.ood_ds = "other" with pytest.raises(ValueError): - MNISTDataModule(**vars(args)) + MNISTDataModule(root="./data/", batch_size=128, ood_ds="other") + + MNISTDataModule(root="./data/", batch_size=128, test_alt="c") dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.prepare_data() dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() @@ -54,5 +44,5 @@ def test_mnist_cutout(self): dm.eval_ood = True dm.val_split = 0.1 dm.prepare_data() - dm.setup("test") + dm.setup() dm.test_dataloader() diff --git a/tests/datamodules/test_tiny_imagenet_datamodule.py b/tests/datamodules/classification/test_tiny_imagenet_datamodule.py similarity index 58% rename from tests/datamodules/test_tiny_imagenet_datamodule.py rename to tests/datamodules/classification/test_tiny_imagenet_datamodule.py index 954e07ef..007b5f4d 100644 --- a/tests/datamodules/test_tiny_imagenet_datamodule.py +++ b/tests/datamodules/classification/test_tiny_imagenet_datamodule.py @@ -1,5 +1,3 @@ -from argparse import ArgumentParser - import pytest from tests._dummies.dataset import DummyClassificationDataset @@ -11,31 +9,31 @@ class TestTinyImageNetDataModule: """Testing the TinyImageNetDataModule datamodule class.""" def test_tiny_imagenet(self): - parser = ArgumentParser() - parser = TinyImageNetDataModule.add_argparse_args(parser) - - args = parser.parse_args("") - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule(root="./data/", batch_size=128) assert dm.dataset == TinyImageNet - args.rand_augment_opt = "rand-m9-n3-mstd0.5" - args.ood_ds = "imagenet-o" - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule( + root="./data/", + batch_size=128, + rand_augment_opt="rand-m9-n3-mstd0.5", + ood_ds="imagenet-o", + ) - args.ood_ds = "textures" - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule( + root="./data/", batch_size=128, ood_ds="textures" + ) - args.ood_ds = "other" with pytest.raises(ValueError): - TinyImageNetDataModule(**vars(args)) + TinyImageNetDataModule( + root="./data/", batch_size=128, ood_ds="other" + ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.prepare_data() dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() @@ -49,21 +47,25 @@ def test_tiny_imagenet(self): dm.setup("test") dm.test_dataloader() - def test_tiny_imagenet_cv(self): - parser = ArgumentParser() - parser = TinyImageNetDataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") + dm = TinyImageNetDataModule( + root="./data/", batch_size=128, ood_ds="svhn" + ) + dm.dataset = DummyClassificationDataset + dm.ood_dataset = DummyClassificationDataset + dm.eval_ood = True + dm.prepare_data() + dm.setup("test") - dm = TinyImageNetDataModule(**vars(args)) + def test_tiny_imagenet_cv(self): + dm = TinyImageNetDataModule(root="./data/", batch_size=128) dm.dataset = lambda root, split, transform: DummyClassificationDataset( root, split=split, transform=transform, num_images=20 ) dm.make_cross_val_splits(2, 1) - args.val_split = 0.1 - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule( + root="./data/", batch_size=128, val_split=0.1 + ) dm.dataset = lambda root, split, transform: DummyClassificationDataset( root, split=split, transform=transform, num_images=20 ) diff --git a/tests/datamodules/test_uci_regression_datamodule.py b/tests/datamodules/classification/test_uci_regression_datamodule.py similarity index 67% rename from tests/datamodules/test_uci_regression_datamodule.py rename to tests/datamodules/classification/test_uci_regression_datamodule.py index 7094ce31..1297666c 100644 --- a/tests/datamodules/test_uci_regression_datamodule.py +++ b/tests/datamodules/classification/test_uci_regression_datamodule.py @@ -1,4 +1,3 @@ -from argparse import ArgumentParser from functools import partial from tests._dummies.dataset import DummyRegressionDataset @@ -9,18 +8,14 @@ class TestUCIDataModule: """Testing the UCIDataModule datamodule class.""" def test_uci_regression(self): - parser = ArgumentParser() - parser = UCIDataModule.add_argparse_args(parser) - - args = parser.parse_args("") - - dm = UCIDataModule(dataset_name="kin8nm", **vars(args)) + dm = UCIDataModule( + dataset_name="kin8nm", root="./data/", batch_size=128 + ) dm.dataset = partial(DummyRegressionDataset, num_samples=64) dm.prepare_data() dm.val_split = 0.5 dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() diff --git a/tests/datamodules/depth_estimation/__init__.py b/tests/datamodules/depth_estimation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/depth_estimation/test_muad.py b/tests/datamodules/depth_estimation/test_muad.py new file mode 100644 index 00000000..cce2088a --- /dev/null +++ b/tests/datamodules/depth_estimation/test_muad.py @@ -0,0 +1,37 @@ +import pytest + +from tests._dummies.dataset import DummyDepthDataset +from torch_uncertainty.datamodules.depth_estimation import MUADDataModule +from torch_uncertainty.datasets import MUAD + + +class TestMUADDataModule: + """Testing the MUADDataModule datamodule.""" + + def test_camvid_main(self): + dm = MUADDataModule(root="./data/", batch_size=128) + + assert dm.dataset == MUAD + + dm.dataset = DummyDepthDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() diff --git a/tests/datamodules/segmentation/__init__.py b/tests/datamodules/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/segmentation/test_camvid.py b/tests/datamodules/segmentation/test_camvid.py new file mode 100644 index 00000000..23b1d4d3 --- /dev/null +++ b/tests/datamodules/segmentation/test_camvid.py @@ -0,0 +1,37 @@ +import pytest + +from tests._dummies.dataset import DummySegmentationDataset +from torch_uncertainty.datamodules.segmentation import CamVidDataModule +from torch_uncertainty.datasets.segmentation import CamVid + + +class TestCamVidDataModule: + """Testing the CamVidDataModule datamodule.""" + + def test_camvid_main(self): + dm = CamVidDataModule(root="./data/", batch_size=128) + + assert dm.dataset == CamVid + + dm.dataset = DummySegmentationDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() diff --git a/tests/datamodules/segmentation/test_cityscapes.py b/tests/datamodules/segmentation/test_cityscapes.py new file mode 100644 index 00000000..25b0bbd1 --- /dev/null +++ b/tests/datamodules/segmentation/test_cityscapes.py @@ -0,0 +1,37 @@ +import pytest + +from tests._dummies.dataset import DummySegmentationDataset +from torch_uncertainty.datamodules.segmentation import CityscapesDataModule +from torch_uncertainty.datasets.segmentation import Cityscapes + + +class TestCityscapesDataModule: + """Testing the CityscapesDataModule datamodule.""" + + def test_camvid_main(self): + dm = CityscapesDataModule(root="./data/", batch_size=128) + + assert dm.dataset == Cityscapes + + dm.dataset = DummySegmentationDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() diff --git a/tests/datamodules/segmentation/test_muad.py b/tests/datamodules/segmentation/test_muad.py new file mode 100644 index 00000000..862206f0 --- /dev/null +++ b/tests/datamodules/segmentation/test_muad.py @@ -0,0 +1,37 @@ +import pytest + +from tests._dummies.dataset import DummySegmentationDataset +from torch_uncertainty.datamodules.segmentation import MUADDataModule +from torch_uncertainty.datasets import MUAD + + +class TestMUADDataModule: + """Testing the MUADDataModule datamodule.""" + + def test_camvid_main(self): + dm = MUADDataModule(root="./data/", batch_size=128) + + assert dm.dataset == MUAD + + dm.dataset = DummySegmentationDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 02c5b8e8..7b0f5e66 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -13,7 +13,7 @@ class TestAbstractDataModule: """Testing the AbstractDataModule class.""" def test_errors(self): - dm = AbstractDataModule("root", 128, 4, True, True) + dm = AbstractDataModule("root", 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): dm.setup() dm._get_train_data() @@ -24,12 +24,14 @@ class TestCrossValDataModule: """Testing the CrossValDataModule class.""" def test_cv_main(self): - dm = AbstractDataModule("root", 128, 4, True, True) + dm = AbstractDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + cv_dm = CrossValDataModule( + "root", [0], [1], dm, 128, 0.0, 4, True, True + ) cv_dm.setup() cv_dm.setup("test") @@ -44,13 +46,18 @@ def test_cv_main(self): cv_dm.test_dataloader() def test_errors(self): - dm = AbstractDataModule("root", 128, 4, True, True) + dm = AbstractDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + cv_dm = CrossValDataModule( + "root", [0], [1], dm, 128, 0.0, 4, True, True + ) with pytest.raises(NotImplementedError): cv_dm.setup() cv_dm._get_train_data() cv_dm._get_train_targets() + + with pytest.raises(ValueError): + cv_dm.setup("other") diff --git a/tests/datasets/segmentation/__init__.py b/tests/datasets/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datasets/segmentation/test_camvid.py b/tests/datasets/segmentation/test_camvid.py new file mode 100644 index 00000000..2777ad4d --- /dev/null +++ b/tests/datasets/segmentation/test_camvid.py @@ -0,0 +1,11 @@ +import pytest + +from torch_uncertainty.datasets.segmentation import CamVid + + +class TestCamVid: + """Testing the CamVid dataset class.""" + + def test_nodataset(self): + with pytest.raises(RuntimeError): + _ = CamVid("./.data") diff --git a/tests/datasets/segmentation/test_cityscapes.py b/tests/datasets/segmentation/test_cityscapes.py new file mode 100644 index 00000000..c9b9e6f5 --- /dev/null +++ b/tests/datasets/segmentation/test_cityscapes.py @@ -0,0 +1,11 @@ +import pytest + +from torch_uncertainty.datasets.segmentation import Cityscapes + + +class TestCityscapes: + """Testing the Cityscapes dataset class.""" + + def test_nodataset(self): + with pytest.raises(RuntimeError): + _ = Cityscapes("./.data") diff --git a/tests/datasets/test_muad.py b/tests/datasets/test_muad.py new file mode 100644 index 00000000..3a431f3f --- /dev/null +++ b/tests/datasets/test_muad.py @@ -0,0 +1,11 @@ +import pytest + +from torch_uncertainty.datasets import MUAD + + +class TestMUAD: + """Testing the MUAD dataset class.""" + + def test_nodataset(self): + with pytest.raises(FileNotFoundError): + _ = MUAD("./.data", split="train") diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py new file mode 100644 index 00000000..b1fbf4bd --- /dev/null +++ b/tests/layers/test_distributions.py @@ -0,0 +1,16 @@ +import pytest + +from torch_uncertainty.layers.distributions import ( + LaplaceLayer, + NormalLayer, +) + + +class TestDistributions: + def test_errors(self): + with pytest.raises(ValueError): + NormalLayer(-1, 1) + with pytest.raises(ValueError): + NormalLayer(1, -1) + with pytest.raises(ValueError): + LaplaceLayer(1, -1) diff --git a/tests/layers/test_filter_response_norm.py b/tests/layers/test_filter_response_norm.py index bde2d534..e1f58eb1 100644 --- a/tests/layers/test_filter_response_norm.py +++ b/tests/layers/test_filter_response_norm.py @@ -5,7 +5,7 @@ FilterResponseNorm1d, FilterResponseNorm2d, FilterResponseNorm3d, - FilterResponseNormNd, + _FilterResponseNormNd, ) from torch_uncertainty.layers.mc_batch_norm import ( MCBatchNorm1d, @@ -27,7 +27,7 @@ def test_main(self): def test_errors(self): """Test errors.""" with pytest.raises(ValueError): - FilterResponseNormNd(-1, 1) + _FilterResponseNormNd(-1, 1) with pytest.raises(ValueError): FilterResponseNorm2d(0) with pytest.raises(ValueError): diff --git a/tests/layers/test_mask.py b/tests/layers/test_mask.py index 972d3f7f..bf8e2c2d 100644 --- a/tests/layers/test_mask.py +++ b/tests/layers/test_mask.py @@ -34,7 +34,7 @@ def test_linear_one_estimator(self, feat_input_odd: torch.Tensor): def test_linear_two_estimators_odd(self, feat_input_odd: torch.Tensor): layer = MaskedLinear(10, 2, num_estimators=2, scale=2) - with pytest.raises(Exception): + with pytest.raises(RuntimeError): _ = layer(feat_input_odd) def test_linear_two_estimators_even(self, feat_input_even: torch.Tensor): diff --git a/tests/metrics/classification/__init__.py b/tests/metrics/classification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/metrics/test_brier_score.py b/tests/metrics/classification/test_brier_score.py similarity index 97% rename from tests/metrics/test_brier_score.py rename to tests/metrics/classification/test_brier_score.py index 559109e3..b0d0b03f 100644 --- a/tests/metrics/test_brier_score.py +++ b/tests/metrics/classification/test_brier_score.py @@ -193,10 +193,12 @@ def test_compute_3d_to_2d( assert metric.compute() == 0.5 def test_bad_input(self) -> None: - with pytest.raises(Exception): + with pytest.raises(ValueError): metric = BrierScore(num_classes=2, reduction="none") metric.update(torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2)) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises( + ValueError, match="Expected argument `reduction` to be one of" + ): _ = BrierScore(num_classes=2, reduction="geometric_mean") diff --git a/tests/metrics/test_calibration.py b/tests/metrics/classification/test_calibration.py similarity index 51% rename from tests/metrics/test_calibration.py rename to tests/metrics/classification/test_calibration.py index 8f807f28..fb3c2035 100644 --- a/tests/metrics/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -5,41 +5,15 @@ from torch_uncertainty.metrics import CE -@pytest.fixture -def preds_binary() -> torch.Tensor: - return torch.as_tensor([0.25, 0.25, 0.55, 0.75, 0.75]) - - -@pytest.fixture -def targets_binary() -> torch.Tensor: - return torch.as_tensor([0, 0, 1, 1, 1]) - - -@pytest.fixture -def preds_multiclass() -> torch.Tensor: - return torch.as_tensor( - [ - [0.25, 0.20, 0.55], - [0.55, 0.05, 0.40], - [0.10, 0.30, 0.60], - [0.90, 0.05, 0.05], - ] - ) - - -@pytest.fixture -def targets_multiclass() -> torch.Tensor: - return torch.as_tensor([0, 1, 2, 0]) - - class TestCE: """Testing the CE metric class.""" - def test_plot_binary( - self, preds_binary: torch.Tensor, targets_binary: torch.Tensor - ) -> None: + def test_plot_binary(self) -> None: metric = CE(task="binary", n_bins=2, norm="l1") - metric.update(preds_binary, targets_binary) + metric.update( + torch.as_tensor([0.25, 0.25, 0.55, 0.75, 0.75]), + torch.as_tensor([0, 0, 1, 1, 1]), + ) fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) @@ -48,10 +22,20 @@ def test_plot_binary( plt.close(fig) def test_plot_multiclass( - self, preds_multiclass: torch.Tensor, targets_multiclass: torch.Tensor + self, ) -> None: metric = CE(task="multiclass", n_bins=3, norm="l1", num_classes=3) - metric.update(preds_multiclass, targets_multiclass) + metric.update( + torch.as_tensor( + [ + [0.25, 0.20, 0.55], + [0.55, 0.05, 0.40], + [0.10, 0.30, 0.60], + [0.90, 0.05, 0.05], + ] + ), + torch.as_tensor([0, 1, 2, 0]), + ) fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) @@ -59,10 +43,9 @@ def test_plot_multiclass( assert ax.get_ylabel() == "Success Rate (%)" plt.close(fig) - def test_bad_task_argument(self) -> None: + def test_errors(self) -> None: with pytest.raises(ValueError): _ = CE(task="geometric_mean") - def test_bad_num_classes_argument(self) -> None: with pytest.raises(ValueError): _ = CE(task="multiclass", num_classes=1.5) diff --git a/tests/metrics/test_disagreement.py b/tests/metrics/classification/test_disagreement.py similarity index 100% rename from tests/metrics/test_disagreement.py rename to tests/metrics/classification/test_disagreement.py diff --git a/tests/metrics/test_entropy.py b/tests/metrics/classification/test_entropy.py similarity index 98% rename from tests/metrics/test_entropy.py rename to tests/metrics/classification/test_entropy.py index 0239e13c..558184c4 100644 --- a/tests/metrics/test_entropy.py +++ b/tests/metrics/classification/test_entropy.py @@ -83,5 +83,5 @@ def test_compute_3d_to_2d(self, vec3d: torch.Tensor): assert res == math.log(2) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = Entropy("geometric_mean") diff --git a/tests/metrics/classification/test_fpr95.py b/tests/metrics/classification/test_fpr95.py new file mode 100644 index 00000000..e94e785c --- /dev/null +++ b/tests/metrics/classification/test_fpr95.py @@ -0,0 +1,37 @@ +import pytest +import torch + +from torch_uncertainty.metrics.classification.fpr95 import FPR95, FPRx + + +class TestFPR95: + """Testing the Entropy metric class.""" + + def test_compute_zero(self): + metric = FPR95(pos_label=1) + metric.update( + torch.as_tensor([1] * 99 + [0.99]), torch.as_tensor([1] * 99 + [0]) + ) + res = metric.compute() + assert res == 0 + + def test_compute_half(self): + metric = FPR95(pos_label=1) + metric.update( + torch.as_tensor([0.9] * 100 + [0.95] * 50 + [0.85] * 50), + torch.as_tensor([1] * 100 + [0] * 100), + ) + res = metric.compute() + assert res == 0.5 + + def test_compute_one(self): + metric = FPR95(pos_label=1) + metric.update( + torch.as_tensor([0.99] * 99 + [1]), torch.as_tensor([1] * 99 + [0]) + ) + res = metric.compute() + assert res == 1 + + def test_error(self): + with pytest.raises(ValueError): + FPRx(recall_level=1.2, pos_label=1) diff --git a/tests/metrics/test_grouping_loss.py b/tests/metrics/classification/test_grouping_loss.py similarity index 58% rename from tests/metrics/test_grouping_loss.py rename to tests/metrics/classification/test_grouping_loss.py index fe1c4376..ffce34b1 100644 --- a/tests/metrics/test_grouping_loss.py +++ b/tests/metrics/classification/test_grouping_loss.py @@ -4,28 +4,33 @@ from torch_uncertainty.metrics import GroupingLoss -@pytest.fixture() -def disagreement_probas_3() -> torch.Tensor: - return torch.as_tensor([[[0.0, 1.0], [0.0, 1.0], [1.0, 0.0]]]) - - class TestGroupingLoss: """Testing the GroupingLoss metric class.""" def test_compute(self): metric = GroupingLoss() metric.update( - torch.ones((100, 4, 10)) / 10, - torch.arange(100), - torch.rand((100, 4, 10)), + torch.cat([torch.tensor([0, 1, 0, 1]), torch.ones(200) / 10]), + torch.cat( + [torch.tensor([0, 0, 1, 1]), torch.zeros(100), torch.ones(100)] + ).long(), + torch.cat([torch.zeros((104, 10)), torch.ones((100, 10))]), + ) + metric.compute() + + metric = GroupingLoss() + metric.update( + torch.ones((200, 4, 10)), + torch.cat([torch.arange(100), torch.arange(100)]), + torch.cat([torch.zeros((100, 4, 10)), torch.ones((100, 4, 10))]), ) metric.compute() metric.reset() metric.update( - torch.ones((100, 10)) / 10, - torch.nn.functional.one_hot(torch.arange(100)), - torch.rand((100, 10)), + torch.ones((200, 10)) / 10, + torch.nn.functional.one_hot(torch.arange(200)), + torch.cat([torch.zeros((100, 10)), torch.ones((1004, 10))]), ) def test_errors(self): diff --git a/tests/metrics/test_mutual_information.py b/tests/metrics/classification/test_mutual_information.py similarity index 97% rename from tests/metrics/test_mutual_information.py rename to tests/metrics/classification/test_mutual_information.py index 14acf90f..99dde31d 100644 --- a/tests/metrics/test_mutual_information.py +++ b/tests/metrics/classification/test_mutual_information.py @@ -55,5 +55,5 @@ def test_compute_mixed( assert res[1] == pytest.approx(math.log(2), 1e-5) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = MutualInformation("geometric_mean") diff --git a/tests/metrics/test_sparsification.py b/tests/metrics/classification/test_sparsification.py similarity index 100% rename from tests/metrics/test_sparsification.py rename to tests/metrics/classification/test_sparsification.py diff --git a/tests/metrics/test_variation_ratio.py b/tests/metrics/classification/test_variation_ratio.py similarity index 97% rename from tests/metrics/test_variation_ratio.py rename to tests/metrics/classification/test_variation_ratio.py index 263b25f9..10936f2d 100644 --- a/tests/metrics/test_variation_ratio.py +++ b/tests/metrics/classification/test_variation_ratio.py @@ -52,5 +52,5 @@ def test_compute_disagreement( assert res == pytest.approx(0.8, 1e-6) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = VariationRatio(reduction="geometric_mean") diff --git a/tests/metrics/regression/__init__.py b/tests/metrics/regression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/metrics/regression/test_depth_estimation_metrics.py b/tests/metrics/regression/test_depth_estimation_metrics.py new file mode 100644 index 00000000..0c1fbea0 --- /dev/null +++ b/tests/metrics/regression/test_depth_estimation_metrics.py @@ -0,0 +1,112 @@ +import pytest +import torch + +from torch_uncertainty.metrics import ( + Log10, + MeanGTRelativeAbsoluteError, + MeanGTRelativeSquaredError, + MeanSquaredLogError, + SILog, + ThresholdAccuracy, +) + + +class TestLog10: + """Testing the Log10 metric.""" + + def test_main(self): + metric = Log10() + preds = torch.rand((10, 2)).double() + targets = torch.rand((10, 2)).double() + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert torch.mean( + preds.log10().flatten() - targets.log10().flatten() + ) == pytest.approx(metric.compute()) + + +class TestMeanGTRelativeAbsoluteError: + """Testing the MeanGTRelativeAbsoluteError metric.""" + + def test_main(self): + metric = MeanGTRelativeAbsoluteError() + preds = torch.rand((10, 2)) + targets = torch.rand((10, 2)) + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert (torch.abs(preds - targets) / targets).mean() == pytest.approx( + metric.compute() + ) + + +class TestMeanGTRelativeSquaredError: + """Testing the MeanGTRelativeSquaredError metric.""" + + def test_main(self): + metric = MeanGTRelativeSquaredError() + preds = torch.rand((10, 2)) + targets = torch.rand((10, 2)) + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert torch.flatten( + (preds - targets) ** 2 / targets + ).mean() == pytest.approx(metric.compute()) + + +class TestSILog: + """Testing the SILog metric.""" + + def test_main(self): + metric = SILog() + preds = torch.rand((10, 2)).double() + targets = torch.rand((10, 2)).double() + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + mean_log_dists = torch.mean( + targets.flatten().log() - preds.flatten().log() + ) + assert torch.mean( + (preds.flatten().log() - targets.flatten().log() + mean_log_dists) + ** 2 + ) == pytest.approx(metric.compute()) + + +class TestThresholdAccuracy: + """Testing the ThresholdAccuracy metric.""" + + def test_main(self): + metric = ThresholdAccuracy(power=1, lmbda=1.25) + preds = torch.ones((10, 2)) + targets = torch.ones((10, 2)) * 1.3 + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert metric.compute() == 0.0 + + metric = ThresholdAccuracy(power=1, lmbda=1.25) + preds = torch.cat( + [torch.ones((10, 2)) * 1.2, torch.ones((10, 2))], dim=0 + ) + targets = torch.ones((20, 2)) * 1.3 + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert metric.compute() == 0.5 + + def test_error(self): + with pytest.raises(ValueError, match="Power must be"): + ThresholdAccuracy(power=-1) + with pytest.raises(ValueError, match="Lambda must be"): + ThresholdAccuracy(power=1, lmbda=0.5) + + +class TestMeanSquaredLogError: + """Testing the MeanSquaredLogError metric.""" + + def test_main(self): + metric = MeanSquaredLogError() + preds = torch.rand((10, 2)).double() + targets = torch.rand((10, 2)).double() + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert torch.mean( + (preds.log() - targets.log()).flatten() ** 2 + ) == pytest.approx(metric.compute()) diff --git a/tests/metrics/regression/test_nll.py b/tests/metrics/regression/test_nll.py new file mode 100644 index 00000000..316d6c93 --- /dev/null +++ b/tests/metrics/regression/test_nll.py @@ -0,0 +1,56 @@ +import pytest +import torch +from torch.distributions import Normal + +from torch_uncertainty.metrics import CategoricalNLL, DistributionNLL + + +class TestCategoricalNegativeLogLikelihood: + """Testing the CategoricalNLL metric class.""" + + def test_compute_zero(self) -> None: + probs = torch.as_tensor([[1, 0.0], [0.0, 1.0]]) + targets = torch.as_tensor([0, 1]) + + metric = CategoricalNLL() + metric.update(probs, targets) + res = metric.compute() + assert res == 0 + + metric = CategoricalNLL(reduction="none") + metric.update(probs, targets) + res_sum = metric.compute() + assert torch.all(res_sum == torch.zeros(2)) + + metric = CategoricalNLL(reduction="sum") + metric.update(probs, targets) + res_sum = metric.compute() + assert torch.all(res_sum == torch.zeros(1)) + + def test_bad_argument(self) -> None: + with pytest.raises(ValueError): + _ = CategoricalNLL(reduction="geometric_mean") + + +class TestDistributionNLL: + """Testing the TestDistributionNLL metric class.""" + + def test_compute_zero(self) -> None: + metric = DistributionNLL(reduction="mean") + means = torch.as_tensor([1, 10]).float() + stds = torch.as_tensor([1, 2]).float() + targets = torch.as_tensor([1, 10]).float() + dist = Normal(means, stds) + metric.update(dist, targets) + res_mean = metric.compute() + assert res_mean == torch.mean(torch.log(2 * torch.pi * (stds**2)) / 2) + + metric = DistributionNLL(reduction="sum") + metric.update(dist, targets) + res_sum = metric.compute() + assert res_sum == torch.log(2 * torch.pi * (stds**2)).sum() / 2 + + metric = DistributionNLL(reduction="none") + metric.update(dist, targets) + res_all = metric.compute() + assert torch.all(res_all == torch.log(2 * torch.pi * (stds**2)) / 2) diff --git a/tests/metrics/test_fpr95.py b/tests/metrics/test_fpr95.py deleted file mode 100644 index 6015d484..00000000 --- a/tests/metrics/test_fpr95.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest -import torch - -from torch_uncertainty.metrics import FPR95 - - -@pytest.fixture() -def confs_zero() -> torch.Tensor: - return torch.as_tensor([1] * 99 + [0.99]) - - -@pytest.fixture() -def target_zero() -> torch.Tensor: - return torch.as_tensor([1] * 99 + [0]) - - -@pytest.fixture() -def confs_half() -> torch.Tensor: - return torch.as_tensor([0.9] * 100 + [0.95] * 50 + [0.85] * 50) - - -@pytest.fixture() -def target_half() -> torch.Tensor: - return torch.as_tensor([1] * 100 + [0] * 100) - - -@pytest.fixture() -def confs_one() -> torch.Tensor: - return torch.as_tensor([0.99] * 99 + [1]) - - -@pytest.fixture() -def target_one() -> torch.Tensor: - return torch.as_tensor([1] * 99 + [0]) - - -class TestFPR95: - """Testing the Entropy metric class.""" - - def test_compute_zero( - self, confs_zero: torch.Tensor, target_zero: torch.Tensor - ): - metric = FPR95(pos_label=1) - metric.update(confs_zero, target_zero) - res = metric.compute() - assert res == 0 - - def test_compute_half( - self, confs_half: torch.Tensor, target_half: torch.Tensor - ): - metric = FPR95(pos_label=1) - metric.update(confs_half, target_half) - res = metric.compute() - assert res == 0.5 - - def test_compute_one( - self, confs_one: torch.Tensor, target_one: torch.Tensor - ): - metric = FPR95(pos_label=1) - metric.update(confs_one, target_one) - res = metric.compute() - assert res == 1 diff --git a/tests/metrics/test_nll.py b/tests/metrics/test_nll.py deleted file mode 100644 index 2ec65a6b..00000000 --- a/tests/metrics/test_nll.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -import torch - -from torch_uncertainty.metrics import ( - GaussianNegativeLogLikelihood, - NegativeLogLikelihood, -) - - -@pytest.fixture() -def probs_zero() -> torch.Tensor: - return torch.as_tensor([[1, 0.0], [0.0, 1.0]]) - - -@pytest.fixture() -def targets_zero() -> torch.Tensor: - return torch.as_tensor([0, 1]) - - -class TestNegativeLogLikelihood: - """Testing the NegativeLogLikelihood metric class.""" - - def test_compute_zero( - self, probs_zero: torch.Tensor, targets_zero: torch.Tensor - ) -> None: - metric = NegativeLogLikelihood() - metric.update(probs_zero, targets_zero) - res = metric.compute() - assert res == 0 - - metric = NegativeLogLikelihood(reduction="none") - metric.update(probs_zero, targets_zero) - res_sum = metric.compute() - assert torch.all(res_sum == torch.zeros(2)) - - def test_bad_argument(self) -> None: - with pytest.raises(Exception): - _ = NegativeLogLikelihood(reduction="geometric_mean") - - -class TestGaussianNegativeLogLikelihood: - """Testing the NegativeLogLikelihood metric class.""" - - def test_compute_zero(self) -> None: - metric = GaussianNegativeLogLikelihood() - means = torch.as_tensor([1, 10]).float() - variances = torch.as_tensor([1, 2]).float() - targets = torch.as_tensor([1, 10]).float() - metric.update(means, targets, variances) - res_mean = metric.compute() - assert res_mean == torch.log(variances).mean() / 2 - - metric = GaussianNegativeLogLikelihood(reduction="sum") - metric.update(means, targets, variances) - res_sum = metric.compute() - assert res_sum == torch.log(variances).sum() / 2 - - metric = GaussianNegativeLogLikelihood(reduction="none") - metric.update(means, targets, variances) - res_sum = metric.compute() - assert torch.all(res_sum == torch.log(variances) / 2) diff --git a/tests/models/test_deep_ensembles.py b/tests/models/test_deep_ensembles.py index 7d791414..98e45c08 100644 --- a/tests/models/test_deep_ensembles.py +++ b/tests/models/test_deep_ensembles.py @@ -9,30 +9,30 @@ class TestDeepEnsemblesModel: """Testing the deep_ensembles function.""" def test_main(self): - model_1 = dummy_model(1, 10, 1) - model_2 = dummy_model(1, 10, 1) + model_1 = dummy_model(1, 10) + model_2 = dummy_model(1, 10) de = deep_ensembles([model_1, model_2]) # Check B N C assert de(torch.randn(3, 4, 4)).shape == (6, 10) def test_list_and_num_estimators(self): - model_1 = dummy_model(1, 10, 1) - model_2 = dummy_model(1, 10, 1) + model_1 = dummy_model(1, 10) + model_2 = dummy_model(1, 10) with pytest.raises(ValueError): deep_ensembles([model_1, model_2], num_estimators=2) def test_list_singleton(self): - model_1 = dummy_model(1, 10, 1) + model_1 = dummy_model(1, 10) - deep_ensembles([model_1], num_estimators=2) - deep_ensembles(model_1, num_estimators=2) + deep_ensembles([model_1], num_estimators=2, reset_model_parameters=True) + deep_ensembles(model_1, num_estimators=2, reset_model_parameters=False) with pytest.raises(ValueError): deep_ensembles([model_1], num_estimators=1) - def test_model_and_no_num_estimator(self): - model_1 = dummy_model(1, 10, 1) + def test_errors(self): + model_1 = dummy_model(1, 10) with pytest.raises(ValueError): deep_ensembles(model_1, num_estimators=None) @@ -41,3 +41,9 @@ def test_model_and_no_num_estimator(self): with pytest.raises(ValueError): deep_ensembles(model_1, num_estimators=1) + + with pytest.raises(ValueError): + deep_ensembles(model_1, num_estimators=2, task="regression") + + with pytest.raises(ValueError): + deep_ensembles(model_1, num_estimators=2, task="other") diff --git a/tests/models/test_mc_dropout.py b/tests/models/test_mc_dropout.py index 59f084af..b0cd9327 100644 --- a/tests/models/test_mc_dropout.py +++ b/tests/models/test_mc_dropout.py @@ -9,7 +9,7 @@ class TestMCDropout: """Testing the MC Dropout class.""" def test_mc_dropout_train(self): - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) dropout_model = mc_dropout(model, num_estimators=5) dropout_model.train() assert dropout_model.training @@ -21,14 +21,14 @@ def test_mc_dropout_train(self): dropout_model(torch.rand(1, 10)) def test_mc_dropout_eval(self): - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) dropout_model = mc_dropout(model, num_estimators=5) dropout_model.eval() assert not dropout_model.training dropout_model(torch.rand(1, 10)) def test_mc_dropout_errors(self): - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) with pytest.raises(ValueError): _MCDropout(model=model, num_estimators=-1, last_layer=True) @@ -47,10 +47,10 @@ def test_mc_dropout_errors(self): with pytest.raises(ValueError): dropout_model = mc_dropout(model, 5) - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) with pytest.raises(ValueError): dropout_model = mc_dropout(model, None) - model = dummy_model(10, 5, 1, dropout_rate=0) + model = dummy_model(10, 5, dropout_rate=0) with pytest.raises(ValueError): dropout_model = mc_dropout(model, None) diff --git a/tests/models/test_mlps.py b/tests/models/test_mlps.py index afb56d8d..2e7a72e8 100644 --- a/tests/models/test_mlps.py +++ b/tests/models/test_mlps.py @@ -1,11 +1,18 @@ -from torch_uncertainty.models.mlp import bayesian_mlp, packed_mlp +from torch_uncertainty.layers.distributions import NormalLayer +from torch_uncertainty.models.mlp import bayesian_mlp, mlp, packed_mlp class TestMLPModel: """Testing the mlp models.""" - def test_packed(self): + def test_mlps(self): + mlp( + 1, + 1, + hidden_dims=[1, 1, 1], + final_layer=NormalLayer, + final_layer_args={"dim": 1}, + ) + mlp(1, 1, hidden_dims=[]) packed_mlp(1, 1, hidden_dims=[]) - - def test_bayesian(self): bayesian_mlp(1, 1, hidden_dims=[1, 1, 1]) diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py new file mode 100644 index 00000000..f69439d8 --- /dev/null +++ b/tests/models/test_segformer.py @@ -0,0 +1,25 @@ +import torch + +from torch_uncertainty.models.segmentation.segformer import ( + seg_former_b0, + seg_former_b1, + seg_former_b2, + seg_former_b3, + seg_former_b4, + seg_former_b5, +) + + +class TestSegformer: + """Testing the Segformer class.""" + + def test_main(self): + seg_former_b1(10) + seg_former_b2(10) + seg_former_b3(10) + seg_former_b4(10) + seg_former_b5(10) + + model = seg_former_b0(10) + with torch.no_grad(): + model(torch.randn(1, 3, 32, 32)) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 2fc02163..9b22d898 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -1,389 +1,400 @@ -from functools import partial from pathlib import Path import pytest -from cli_test_helpers import ArgvContext from torch import nn from tests._dummies import ( DummyClassificationBaseline, DummyClassificationDataModule, - DummyClassificationDataset, dummy_model, ) -from torch_uncertainty import cli_main, init_args from torch_uncertainty.losses import DECLoss, ELBOLoss -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 -from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, -) - - -class TestClassificationSingle: - """Testing the classification routine with a single model.""" - - def test_cli_main_dummy_binary(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - args.eval_grouping_loss = True - dm = DummyClassificationDataModule( - num_classes=1, num_images=100, **vars(args) - ) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--logits"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_ood(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py", "--fast_dev_run"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - loss = partial( - ELBOLoss, - criterion=nn.CrossEntropyLoss(), - kl_weight=1e-5, - num_samples=2, - ) - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext( - "file.py", - "--eval-ood", - "--entropy", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=DECLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext( - "file.py", - "--eval-ood", - "--entropy", - "--cutmix_alpha", - "0.5", - "--mixtype", - "timm", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=DECLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - with pytest.raises(NotImplementedError): - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_mixup_ts_cv(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext( - "file.py", - "--mixtype", - "kernel_warping", - "--mixup_alpha", - "1.", - "--dist_sim", - "inp", - "--val_temp_scaling", - "--use_cv", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=10, **vars(args)) - dm.dataset = ( - lambda root, - num_channels, - num_classes, - image_size, - transform, - num_images: DummyClassificationDataset( - root, - num_channels=num_channels, - num_classes=num_classes, - image_size=image_size, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [ - DummyClassificationBaseline( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - calibration_set=dm.get_val_set, - **vars(args), - ) - for i in range(len(list_dm)) - ] - - cli_main(list_model, list_dm, root, "logs/dummy", args) - - with ArgvContext( - "file.py", - "--mixtype", - "kernel_warping", - "--mixup_alpha", - "1.", - "--dist_sim", - "emb", - "--val_temp_scaling", - "--use_cv", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=10, **vars(args)) - dm.dataset = ( - lambda root, - num_channels, - num_classes, - image_size, - transform, - num_images: DummyClassificationDataset( - root, - num_channels=num_channels, - num_classes=num_classes, - image_size=image_size, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - DummyClassificationBaseline( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - calibration_set=dm.get_val_set, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "logs/dummy", args) +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TUTrainer + + +class TestClassification: + """Testing the classification routine.""" + + def test_one_estimator_binary(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=1, + num_images=100, + ) + model = DummyClassificationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + loss=nn.BCEWithLogitsLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="msp", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_binary(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=1, + num_images=100, + ) + model = DummyClassificationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + loss=nn.BCEWithLogitsLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="logit", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_timm(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="timm", + mixup_alpha=1.0, + cutmix_alpha=0.5, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_mixup(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="mixup", + mixup_alpha=1.0, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_mixup_io(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="mixup_io", + mixup_alpha=1.0, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_regmixup(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="regmixup", + mixup_alpha=1.0, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_kernel_warping_emb(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="kernel_warping", + mixup_alpha=0.5, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_kernel_warping_inp(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="kernel_warping", + dist_sim="inp", + mixup_alpha=0.5, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_calibrated_with_ood(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=19, # lower than 19 it doesn't work :'( + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="energy", + eval_ood=True, + eval_grouping_loss=True, + calibrate=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_two_classes_mi(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=DECLoss(1, 1e-2), + optim_recipe=optim_cifar10_resnet18, + baseline_type="ensemble", + ood_criterion="mi", + eval_ood=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimator_two_classes_elbo_vr_logs(self): + trainer = TUTrainer( + accelerator="cpu", + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + enable_checkpointing=False, + ) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=ELBOLoss( + None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4 + ), + optim_recipe=optim_cifar10_resnet18, + baseline_type="ensemble", + ood_criterion="vr", + eval_ood=True, + save_in_csv=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) def test_classification_failures(self): + # num_estimators with pytest.raises(ValueError): - ClassificationSingle( - 10, nn.Module(), None, None, use_entropy=True, use_logits=True + ClassificationRoutine( + num_classes=10, model=nn.Module(), loss=None, num_estimators=-1 ) - + # num_classes with pytest.raises(ValueError): - ClassificationSingle(10, nn.Module(), None, None, cutmix_alpha=-1) - + ClassificationRoutine(num_classes=0, model=nn.Module(), loss=None) + # single & MI with pytest.raises(ValueError): - ClassificationSingle( - 10, nn.Module(), None, None, eval_grouping_loss=True + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + num_estimators=1, + ood_criterion="mi", ) - - model = dummy_model(1, 1, 1, 0, with_feats=False, with_linear=True) - - with pytest.raises(ValueError): - ClassificationSingle(10, model, None, None, eval_grouping_loss=True) - - model = dummy_model(1, 1, 1, 0, with_feats=True, with_linear=False) - with pytest.raises(ValueError): - ClassificationSingle(10, model, None, None, eval_grouping_loss=True) - - -class TestClassificationEnsemble: - """Testing the classification routine with an ensemble model.""" - - def test_cli_main_dummy_binary(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - loss = partial( - ELBOLoss, - criterion=nn.CrossEntropyLoss(), - kl_weight=1e-5, - num_samples=1, - ) - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--mutual_information"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_ood(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py", "--logits"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + ood_criterion="other", ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--eval-ood", "--entropy"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule + with pytest.raises(ValueError): + ClassificationRoutine( + num_classes=10, model=nn.Module(), loss=None, cutmix_alpha=-1 ) - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=DECLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), + with pytest.raises(ValueError): + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + eval_grouping_loss=True, ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--eval-ood", "--variation_ratio"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule + with pytest.raises(NotImplementedError): + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + num_estimators=2, + eval_grouping_loss=True, ) - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), + model = dummy_model(1, 1, 0, with_feats=False, with_linear=True) + with pytest.raises(ValueError): + ClassificationRoutine( + num_classes=10, model=model, loss=None, eval_grouping_loss=True ) - cli_main(model, dm, root, "logs/dummy", args) - - def test_classification_failures(self): + model = dummy_model(1, 1, 0, with_feats=True, with_linear=False) with pytest.raises(ValueError): - ClassificationEnsemble( - 10, - nn.Module(), - None, - None, - 2, - use_entropy=True, - use_variation_ratio=True, + ClassificationRoutine( + num_classes=10, model=model, loss=None, eval_grouping_loss=True ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 193e01f5..2c7eb469 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -1,160 +1,153 @@ -from functools import partial from pathlib import Path import pytest -from cli_test_helpers import ArgvContext from torch import nn from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.losses import BetaNLL, NIGLoss -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 - - -class TestRegressionSingle: - """Testing the Regression routine with a single model.""" - - def test_cli_main_dummy_dist(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=2, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - dist_estimation=2, - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_dist_der(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - loss = partial( - NIGLoss, - reg_weight=1e-2, - ) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=4, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - dist_estimation=4, - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy_der", args) - - def test_cli_main_dummy_dist_betanll(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - loss = partial( - BetaNLL, - beta=0.5, - ) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=2, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - dist_estimation=2, - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy_betanll", args) - - def test_cli_main_dummy(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=2, **vars(args)) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=dm.out_features, - loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) +from torch_uncertainty.losses import DistributionNLLLoss +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.routines import RegressionRoutine +from torch_uncertainty.utils import TUTrainer + + +class TestRegression: + """Testing the Regression routine.""" + + def test_one_estimator_one_output(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) + + model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + output_dim=1, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + model = DummyRegressionBaseline( + probabilistic=False, + in_features=dm.in_features, + output_dim=1, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_outputs(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) + + model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + output_dim=2, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + dist_type="laplace", + ) + trainer.fit(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + model = DummyRegressionBaseline( + probabilistic=False, + in_features=dm.in_features, + output_dim=2, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ) + trainer.fit(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_one_output(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) + + model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + output_dim=1, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="ensemble", + dist_type="nig", + ) + trainer.fit(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + model = DummyRegressionBaseline( + probabilistic=False, + in_features=dm.in_features, + output_dim=1, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="ensemble", + ) + trainer.fit(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_two_outputs(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) + + model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + output_dim=2, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="ensemble", + ) + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + model = DummyRegressionBaseline( + probabilistic=False, + in_features=dm.in_features, + output_dim=2, + loss=DistributionNLLLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="ensemble", + ) + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) def test_regression_failures(self): with pytest.raises(ValueError): - DummyRegressionBaseline( - in_features=10, - out_features=3, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - dist_estimation=4, + RegressionRoutine( + True, 1, nn.Identity(), nn.MSELoss, num_estimators=0 ) with pytest.raises(ValueError): - DummyRegressionBaseline( - in_features=10, - out_features=3, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - dist_estimation=-4, - ) - - with pytest.raises(TypeError): - DummyRegressionBaseline( - in_features=10, - out_features=4, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - dist_estimation=4.2, + RegressionRoutine( + True, 0, nn.Identity(), nn.MSELoss, num_estimators=1 ) - - -class TestRegressionEnsemble: - """Testing the Regression routine with an ensemble model.""" - - def test_cli_main_dummy(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=dm.out_features, - loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py new file mode 100644 index 00000000..7e03b673 --- /dev/null +++ b/tests/routines/test_segmentation.py @@ -0,0 +1,67 @@ +from pathlib import Path + +import pytest +from torch import nn + +from tests._dummies import ( + DummySegmentationBaseline, + DummySegmentationDataModule, +) +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.routines import SegmentationRoutine +from torch_uncertainty.utils import TUTrainer + + +class TestSegmentation: + def test_one_estimator_two_classes(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) + + model = DummySegmentationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + image_size=dm.image_size, + loss=nn.CrossEntropyLoss(), + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_two_classes(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) + + model = DummySegmentationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + image_size=dm.image_size, + loss=nn.CrossEntropyLoss(), + baseline_type="ensemble", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_segmentation_failures(self): + with pytest.raises(ValueError): + SegmentationRoutine( + model=nn.Identity(), + num_classes=2, + loss=nn.CrossEntropyLoss(), + num_estimators=0, + ) + with pytest.raises(ValueError): + SegmentationRoutine( + model=nn.Identity(), num_classes=1, loss=nn.CrossEntropyLoss() + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4950af81..edce26d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,401 +1,39 @@ import sys -from pathlib import Path -import pytest -from cli_test_helpers import ArgvContext -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import VGG, ResNet, WideResNet -from torch_uncertainty.baselines.regression import MLP -from torch_uncertainty.datamodules import CIFAR10DataModule, UCIDataModule -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_vgg16, - optim_cifar10_wideresnet, - optim_regression, -) -from torch_uncertainty.utils.misc import csv_writer - -from ._dummies.dataset import DummyClassificationDataset +from torch_uncertainty.baselines.classification import ResNetBaseline +from torch_uncertainty.datamodules import CIFAR10DataModule +from torch_uncertainty.utils.cli import TULightningCLI, TUSaveConfigCallback class TestCLI: - """Testing the CLI function.""" + """Testing torch-uncertainty CLI.""" - def test_cli_main_resnet(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext( + def test_cli_init(self): + """Test CLI initialization.""" + sys.argv = [ "file.py", - "--version", - "mc-dropout", - "--dropout_rate", - "0.2", - "--num_estimators", + "--model.in_channels", + "3", + "--model.num_classes", + "10", + "--model.version", + "std", + "--model.arch", + "18", + "--model.loss.class_path", + "torch.nn.CrossEntropyLoss", + "--data.root", + "./data", + "--data.batch_size", "4", - "--last_layer_dropout", - ): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - - results = cli_main(model, dm, root, "std", args) - results_path = root / "tests" / "logs" - if not results_path.exists(): - results_path.mkdir(parents=True) - for dict_result in results: - csv_writer( - results_path / "results.csv", - dict_result, - ) - # Test if file already exists - for dict_result in results: - csv_writer( - results_path / "results.csv", - dict_result, - ) - - def test_cli_main_other_arguments(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext( - "file.py", - "--seed", - "42", - "--max_epochs", - "1", - "--channels_last", - "--eval-grouping-loss", - ): - print(sys.orig_argv, sys.argv) - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = root / "data" - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - def test_cli_main_wideresnet(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(WideResNet, CIFAR10DataModule) - - # datamodule - args.root = root / "data" - dm = CIFAR10DataModule(**vars(args)) - - args.summary = True - - model = WideResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - def test_cli_main_vgg(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(VGG, CIFAR10DataModule) - - # datamodule - args.root = root / "data" - dm = CIFAR10DataModule(**vars(args)) - - args.summary = True - - model = VGG( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_vgg16, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - def test_cli_main_mlp(self): - root = str(Path(__file__).parent.absolute().parents[0]) - with ArgvContext("file.py"): - args = init_args(MLP, UCIDataModule) - - # datamodule - args.root = root + "/data" - dm = UCIDataModule( - dataset_name="kin8nm", input_shape=(1, 5), **vars(args) - ) - - args.summary = True - - model = MLP( - num_outputs=1, - in_features=5, - hidden_dims=[], - dist_estimation=1, - loss=nn.MSELoss, - optimization_procedure=optim_regression, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - args.test = True - cli_main(model, dm, root, "std", args) - - def test_cli_other_training_task(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(MLP, UCIDataModule) - - # datamodule - args.root = root / "data" - dm = UCIDataModule( - dataset_name="kin8nm", input_shape=(1, 5), **vars(args) - ) - - dm.training_task = "time-series-regression" - - args.summary = True - - model = MLP( - num_outputs=1, - in_features=5, - hidden_dims=[], - dist_estimation=1, - loss=nn.MSELoss, - optimization_procedure=optim_regression, - **vars(args), - ) - with pytest.raises(ValueError): - cli_main(model, dm, root, "std", args) - - def test_cli_cv_ts(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py", "--use_cv", "--channels_last"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [ - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - for i in range(len(list_dm)) - ] - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup_io"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext("file.py", "--use_cv", "--mixtype", "regmixup"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext( - "file.py", "--use_cv", "--mixtype", "kernel_warping" - ): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - def test_init_args_void(self): - with ArgvContext("file.py"): - init_args() + "--trainer.callbacks+=ModelCheckpoint", + "--trainer.callbacks.monitor=cls_val/acc", + "--trainer.callbacks.mode=max", + ] + cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) + assert cli.eval_after_fit_default is False + assert cli.save_config_callback == TUSaveConfigCallback + assert isinstance(cli.trainer.callbacks[0], TUSaveConfigCallback) + cli.trainer.callbacks[0].setup(cli.trainer, cli.model, stage="fit") + cli.trainer.callbacks[0].already_saved = True + cli.trainer.callbacks[0].setup(cli.trainer, cli.model, stage="fit") diff --git a/tests/test_losses.py b/tests/test_losses.py index e907b7ba..f368e6cc 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -3,9 +3,26 @@ import pytest import torch from torch import nn +from torch.distributions import Normal from torch_uncertainty.layers.bayesian import BayesLinear -from torch_uncertainty.losses import BetaNLL, DECLoss, ELBOLoss, NIGLoss +from torch_uncertainty.layers.distributions import NormalInverseGamma +from torch_uncertainty.losses import ( + BetaNLL, + DECLoss, + DERLoss, + DistributionNLLLoss, + ELBOLoss, +) + + +class TestDistributionNLL: + """Testing the DistributionNLLLoss class.""" + + def test_sum(self): + loss = DistributionNLLLoss(reduction="sum") + dist = Normal(0, 1) + loss(dist, torch.tensor([0.0])) class TestELBOLoss: @@ -14,9 +31,14 @@ class TestELBOLoss: def test_main(self): model = BayesLinear(1, 1) criterion = nn.BCEWithLogitsLoss() - loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) + loss(model(torch.randn(1, 1)), torch.randn(1, 1)) + + model = nn.Linear(1, 1) + criterion = nn.BCEWithLogitsLoss() + ELBOLoss(None, criterion, kl_weight=1e-5, num_samples=1) + loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) loss(model(torch.randn(1, 1)), torch.randn(1, 1)) def test_failures(self): @@ -35,53 +57,52 @@ def test_failures(self): with pytest.raises(TypeError): ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1.5) - def test_no_bayes(self): - model = nn.Linear(1, 1) - criterion = nn.BCEWithLogitsLoss() - loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) - loss(model(torch.randn(1, 1)), torch.randn(1, 1)) - - -class TestNIGLoss: - """Testing the NIGLoss class.""" +class TestDERLoss: + """Testing the DERLoss class.""" def test_main(self): - loss = NIGLoss(reg_weight=1e-2) - - inputs = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) + loss = DERLoss(reg_weight=1e-2) + layer = NormalInverseGamma + inputs = layer( + torch.ones(1), torch.ones(1), torch.ones(1), torch.ones(1) + ) targets = torch.tensor([[1.0]], dtype=torch.float32) - assert loss(*inputs.split(1, dim=-1), targets) == pytest.approx( - 2 * math.log(2) - ) + assert loss(inputs, targets) == pytest.approx(2 * math.log(2)) - loss = NIGLoss( + loss = DERLoss( reg_weight=1e-2, reduction="sum", ) + inputs = layer( + torch.ones((2, 1)), + torch.ones((2, 1)), + torch.ones((2, 1)), + torch.ones((2, 1)), + ) assert loss( - *inputs.repeat(2, 1).split(1, dim=-1), - targets.repeat(2, 1), + inputs, + targets, ) == pytest.approx(4 * math.log(2)) - loss = NIGLoss( + loss = DERLoss( reg_weight=1e-2, reduction="none", ) assert loss( - *inputs.repeat(2, 1).split(1, dim=-1), - targets.repeat(2, 1), + inputs, + targets, ) == pytest.approx([2 * math.log(2), 2 * math.log(2)]) def test_failures(self): with pytest.raises(ValueError): - NIGLoss(reg_weight=-1) + DERLoss(reg_weight=-1) with pytest.raises(ValueError): - NIGLoss(reg_weight=1.0, reduction="median") + DERLoss(reg_weight=1.0, reduction="median") class TestDECLoss: diff --git a/tests/test_optimization_procedures.py b/tests/test_optim_recipes.py similarity index 93% rename from tests/test_optimization_procedures.py rename to tests/test_optim_recipes.py index 8a120057..48fcd06f 100644 --- a/tests/test_optimization_procedures.py +++ b/tests/test_optim_recipes.py @@ -4,9 +4,8 @@ from torch_uncertainty.models.resnet import resnet18, resnet34, resnet50 from torch_uncertainty.models.vgg import vgg16 from torch_uncertainty.models.wideresnet import wideresnet28x10 -from torch_uncertainty.optimization_procedures import ( +from torch_uncertainty.optim_recipes import ( get_procedure, - optim_regression, ) @@ -72,10 +71,6 @@ def test_optim_imagenet_resnet50(self): model = resnet50(in_channels=3, num_classes=1000) procedure(model) - def test_optim_regression(self): - model = resnet18(in_channels=3, num_classes=1) - optim_regression(model) - def test_optim_unknown(self): with pytest.raises(NotImplementedError): _ = get_procedure("unknown", "cifar100") diff --git a/tests/test_utils.py b/tests/test_utils.py index 8fc5b131..0ce5c482 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,35 +1,88 @@ from pathlib import Path import pytest +import torch +from huggingface_hub.utils._errors import RepositoryNotFoundError +from torch.distributions import Laplace, Normal -from torch_uncertainty import utils +from torch_uncertainty.utils import ( + csv_writer, + distributions, + get_version, + hub, + plot_hist, +) class TestUtils: """Testing utils methods.""" - def test_getversion_log_success(self): - utils.get_version("tests/testlog", version=42) - utils.get_version(Path("tests/testlog"), version=42) + def test_get_version_log_success(self): + get_version("tests/testlog", version=42) + get_version(Path("tests/testlog"), version=42) - utils.get_version("tests/testlog", version=42, checkpoint=45) + get_version("tests/testlog", version=42, checkpoint=45) def test_getversion_log_failure(self): - with pytest.raises(Exception): - utils.get_version("tests/testlog", version=52) + with pytest.raises(FileNotFoundError): + get_version("tests/testlog", version=52) class TestHub: """Testing hub methods.""" def test_hub_exists(self): - utils.hub.load_hf("test") - utils.hub.load_hf("test", version=1) - utils.hub.load_hf("test", version=2) + hub.load_hf("test") + hub.load_hf("test", version=1) + hub.load_hf("test", version=2) def test_hub_notexists(self): - with pytest.raises(Exception): - utils.hub.load_hf("tests") + with pytest.raises(RepositoryNotFoundError): + hub.load_hf("tests") with pytest.raises(ValueError): - utils.hub.load_hf("test", version=42) + hub.load_hf("test", version=42) + + +class TestMisc: + """Testing misc methods.""" + + def test_csv_writer(self): + root = Path(__file__).parent.resolve() + csv_writer(root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0}) + csv_writer( + root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0, "c": 3.0} + ) + + def test_plot_hist(self): + conf = [torch.rand(20), torch.rand(20)] + plot_hist(conf, bins=10, title="test") + + +class TestDistributions: + """Testing distributions methods.""" + + def test_nig(self): + dist = distributions.NormalInverseGamma( + 0.0, + 1.1, + 1.1, + 1.1, + ) + dist = distributions.NormalInverseGamma( + torch.tensor(0.0), + torch.tensor(1.1), + torch.tensor(1.1), + torch.tensor(1.1), + ) + _ = dist.mean, dist.mean_loc, dist.mean_variance, dist.variance_loc + + def test_errors(self): + with pytest.raises(ValueError): + distributions.cat_dist( + [ + Normal(torch.tensor([0.0]), torch.tensor([1.0])), + Laplace(torch.tensor([0.0]), torch.tensor([1.0])), + ], + dim=0, + ) diff --git a/tests/transforms/test_image.py b/tests/transforms/test_image.py index 56eeb560..872707d9 100644 --- a/tests/transforms/test_image.py +++ b/tests/transforms/test_image.py @@ -2,6 +2,7 @@ import pytest import torch from PIL import Image +from torchvision import tv_tensors from torch_uncertainty.transforms import ( AutoContrast, @@ -11,6 +12,7 @@ Equalize, MIMOBatchFormat, Posterize, + RandomRescale, RepeatTarget, Rotate, Sharpen, @@ -26,6 +28,16 @@ def img_input() -> torch.Tensor: return Image.fromarray(imarray.astype("uint8")).convert("RGB") +@pytest.fixture() +def tv_tensors_input() -> tuple[torch.Tensor, torch.Tensor]: + imarray1 = np.random.rand(3, 28, 28) * 255 + imarray2 = np.random.rand(1, 28, 28) * 255 + return ( + tv_tensors.Image(imarray1.astype("uint8")), + tv_tensors.Mask(imarray2.astype("uint8")), + ) + + @pytest.fixture() def batch_input() -> tuple[torch.Tensor, torch.Tensor]: imgs = torch.rand(2, 3, 28, 28) @@ -210,3 +222,11 @@ def test_failures(self): with pytest.raises(ValueError): _ = MIMOBatchFormat(1, 0, 0) + + +class TestRandomRescale: + """Testing the RandomRescale transform.""" + + def test_tv_tensors(self, tv_tensors_input): + aug = RandomRescale(0.5, 2.0) + _ = aug(tv_tensors_input) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 8316a0e5..e69de29b 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -1,255 +0,0 @@ -from argparse import ArgumentParser, Namespace -from collections import defaultdict -from pathlib import Path -from typing import Any - -import numpy as np -import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks import ( - EarlyStopping, - LearningRateMonitor, - ModelCheckpoint, -) -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from torchinfo import summary - -from .datamodules.abstract import AbstractDataModule -from .utils import get_version - - -def init_args( - network: Any = None, - datamodule: type[pl.LightningDataModule] | None = None, -) -> Namespace: - parser = ArgumentParser("torch-uncertainty") - parser.add_argument( - "--seed", - type=int, - default=None, - help="Random seed to make the training deterministic.", - ) - parser.add_argument( - "--test", - type=int, - default=None, - help="Run in test mode. Set to the checkpoint version number to test.", - ) - parser.add_argument( - "--ckpt", type=int, default=None, help="The number of the checkpoint" - ) - parser.add_argument( - "--summary", - dest="summary", - action="store_true", - help="Print model summary", - ) - parser.add_argument("--log_graph", dest="log_graph", action="store_true") - parser.add_argument( - "--channels_last", - action="store_true", - help="Use channels last memory format", - ) - parser.add_argument( - "--enable_resume", - action="store_true", - help="Allow resuming the training (save optimizer's states)", - ) - parser.add_argument( - "--exp_dir", - type=str, - default="logs/", - help="Directory to store experiment files", - ) - parser.add_argument( - "--exp_name", - type=str, - default="", - help="Name of the experiment folder", - ) - parser.add_argument( - "--opt_temp_scaling", - action="store_true", - default=False, - help="Compute optimal temperature on the test set", - ) - parser.add_argument( - "--val_temp_scaling", - action="store_true", - default=False, - help="Compute temperature on the validation set", - ) - parser = pl.Trainer.add_argparse_args(parser) - if network is not None: - parser = network.add_model_specific_args(parser) - - if datamodule is not None: - parser = datamodule.add_argparse_args(parser) - - return parser.parse_args() - - -def cli_main( - network: pl.LightningModule | list[pl.LightningModule], - datamodule: AbstractDataModule | list[AbstractDataModule], - root: Path | str, - net_name: str, - args: Namespace, -) -> list[dict]: - if isinstance(root, str): - root = Path(root) - - if isinstance(datamodule, list): - training_task = datamodule[0].dm.training_task - else: - training_task = datamodule.training_task - if training_task == "classification": - monitor = "cls_val/acc" - mode = "max" - elif training_task == "regression": - monitor = "reg_val/mse" - mode = "min" - else: - raise ValueError("Unknown problem type.") - - if args.test is None and args.max_epochs is None: - print( - "Setting max_epochs to 1 for testing purposes. Set max_epochs" - " manually to train the model." - ) - args.max_epochs = 1 - - if isinstance(args.seed, int): - pl.seed_everything(args.seed, workers=True) - - if args.channels_last: - if isinstance(network, list): - for i in range(len(network)): - network[i] = network[i].to(memory_format=torch.channels_last) - else: - network = network.to(memory_format=torch.channels_last) - - if hasattr(args, "use_cv") and args.use_cv: - test_values = [] - for i in range(len(datamodule)): - print( - f"Starting fold {i} out of {args.train_over} of a {args.n_splits}-fold CV." - ) - - # logger - tb_logger = TensorBoardLogger( - str(root), - name=net_name, - default_hp_metric=False, - log_graph=args.log_graph, - version=f"fold_{i}", - ) - - # callbacks - save_checkpoints = ModelCheckpoint( - dirpath=tb_logger.log_dir, - monitor=monitor, - mode=mode, - save_last=True, - save_weights_only=not args.enable_resume, - ) - - # Select the best model, monitor the lr and stop if NaN - callbacks = [ - save_checkpoints, - LearningRateMonitor(logging_interval="step"), - EarlyStopping( - monitor=monitor, patience=np.inf, check_finite=True - ), - ] - - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - logger=tb_logger, - deterministic=(args.seed is not None), - inference_mode=not ( - args.opt_temp_scaling or args.val_temp_scaling - ), - ) - if args.summary: - summary( - network[i], - input_size=list(datamodule[i].dm.input_shape).insert(0, 1), - ) - test_values.append({}) - else: - trainer.fit(network[i], datamodule[i]) - test_values.append( - trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] - ) - - all_test_values = defaultdict(list) - for test_value in test_values: - for key in test_value: - all_test_values[key].append(test_value[key]) - - avg_test_values = {} - for key in all_test_values: - avg_test_values[key] = np.mean(all_test_values[key]) - - return [avg_test_values] - - # logger - tb_logger = TensorBoardLogger( - str(root), - name=net_name, - default_hp_metric=False, - log_graph=args.log_graph, - version=args.test, - ) - - # callbacks - save_checkpoints = ModelCheckpoint( - monitor=monitor, - mode=mode, - save_last=True, - save_weights_only=not args.enable_resume, - ) - - # Select the best model, monitor the lr and stop if NaN - callbacks = [ - save_checkpoints, - LearningRateMonitor(logging_interval="step"), - EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), - ] - - # trainer - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - logger=tb_logger, - deterministic=(args.seed is not None), - inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), - ) - if args.summary: - summary( - network, - input_size=list(datamodule.input_shape).insert(0, 1), - ) - test_values = [{}] - elif args.test is not None: - if args.test >= 0: - ckpt_file, _ = get_version( - root=(root / net_name), - version=args.test, - checkpoint=args.ckpt, - ) - test_values = trainer.test( - network, datamodule=datamodule, ckpt_path=str(ckpt_file) - ) - else: - test_values = trainer.test(network, datamodule=datamodule) - else: - # training and testing - trainer.fit(network, datamodule) - if not args.fast_dev_run: - test_values = trainer.test(datamodule=datamodule, ckpt_path="best") - else: - test_values = [{}] - return test_values diff --git a/torch_uncertainty/baselines/__init__.py b/torch_uncertainty/baselines/__init__.py index 71fe93e3..e69de29b 100644 --- a/torch_uncertainty/baselines/__init__.py +++ b/torch_uncertainty/baselines/__init__.py @@ -1,5 +0,0 @@ -# ruff: noqa: F401 -from .classification.resnet import ResNet -from .classification.vgg import VGG -from .classification.wideresnet import WideResNet -from .deep_ensembles import DeepEnsembles diff --git a/torch_uncertainty/baselines/classification/__init__.py b/torch_uncertainty/baselines/classification/__init__.py index 1326c2e3..e080ee4e 100644 --- a/torch_uncertainty/baselines/classification/__init__.py +++ b/torch_uncertainty/baselines/classification/__init__.py @@ -1,4 +1,4 @@ # ruff: noqa: F401 -from .resnet import ResNet -from .vgg import VGG -from .wideresnet import WideResNet +from .resnet import ResNetBaseline +from .vgg import VGGBaseline +from .wideresnet import WideResNetBaseline diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py new file mode 100644 index 00000000..fd6bc6c1 --- /dev/null +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -0,0 +1,62 @@ +from pathlib import Path +from typing import Literal + +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.routines.classification import ClassificationRoutine +from torch_uncertainty.utils import get_version + +from . import ResNetBaseline, VGGBaseline, WideResNetBaseline + + +class DeepEnsemblesBaseline(ClassificationRoutine): + backbones = { + "resnet": ResNetBaseline, + "vgg": VGGBaseline, + "wideresnet": WideResNetBaseline, + } + + def __init__( + self, + num_classes: int, + log_path: str | Path, + checkpoint_ids: list[int], + backbone: Literal["resnet", "vgg", "wideresnet"], + eval_ood: bool = False, + eval_grouping_loss: bool = False, + ood_criterion: Literal[ + "msp", "logits", "energy", "entropy", "mi", "VR" + ] = "msp", + log_plots: bool = False, + calibration_set: Literal["val", "test"] | None = None, + ) -> None: + log_path = Path(log_path) + + backbone_cls = self.backbones[backbone] + + models = [] + for version in checkpoint_ids: # coverage: ignore + ckpt_file, hparams_file = get_version( + root=log_path, version=version + ) + trained_model = backbone_cls.load_from_checkpoint( + checkpoint_path=ckpt_file, + hparams_file=hparams_file, + loss=None, + optim_recipe=None, + ).eval() + models.append(trained_model.model) + + de = deep_ensembles(models=models) + + super().__init__( + num_classes=num_classes, + model=de, + loss=None, + num_estimators=de.num_estimators, + eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, + ood_criterion=ood_criterion, + log_plots=log_plots, + calibration_set=calibration_set, + ) + self.save_hyperparameters() # coverage: ignore diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index e27e57b3..989bb9d8 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -1,23 +1,8 @@ -from argparse import ArgumentParser, BooleanOptionalAction -from pathlib import Path -from typing import Any, Literal +from typing import Literal -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) from torch import nn -from torch_uncertainty.baselines.utils.parser_addons import ( - add_masked_specific_args, - add_mc_dropout_specific_args, - add_mimo_specific_args, - add_packed_specific_args, - add_resnet_specific_args, -) -from torch_uncertainty.models.mc_dropout import mc_dropout +from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.resnet import ( batched_resnet18, batched_resnet20, @@ -50,14 +35,11 @@ resnet101, resnet152, ) -from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, -) +from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget -class ResNet: +class ResNetBaseline(ClassificationRoutine): single = ["std"] ensemble = ["packed", "batched", "masked", "mc-dropout", "mimo"] versions = { @@ -112,12 +94,11 @@ class ResNet: } archs = [18, 20, 34, 50, 101, 152] - def __new__( - cls, + def __init__( + self, num_classes: int, in_channels: int, - loss: type[nn.Module], - optimization_procedure: Any, + loss: nn.Module, version: Literal[ "std", "mc-dropout", @@ -128,22 +109,32 @@ def __new__( ], arch: int, style: str = "imagenet", - num_estimators: int | None = None, + num_estimators: int = 1, dropout_rate: float = 0.0, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, last_layer_dropout: bool = False, groups: int = 1, scale: float | None = None, - alpha: float | None = None, + alpha: int | None = None, gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + eval_ood: bool = False, + eval_grouping_loss: bool = False, pretrained: bool = False, - **kwargs, - ) -> LightningModule: + ) -> None: r"""ResNet backbone baseline for classification providing support for various versions and architectures. @@ -151,7 +142,7 @@ def __new__( num_classes (int): Number of classes to predict. in_channels (int): Number of input channels. loss (nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to + optim_recipe (Any): optimization recipe, corresponds to what expect the `LightningModule.configure_optimizers() `_ method. @@ -180,6 +171,22 @@ def __new__( Only used if :attr:`version` is either ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. + mixtype (str, optional): Mixup type. Defaults to ``"erm"``. + mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. + dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. + kernel_tau_max (float, optional): Maximum value for the kernel tau. + Defaults to ``1.0``. + kernel_tau_std (float, optional): Standard deviation for the kernel + tau. Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults + to ``0``. + cutmix_alpha (float, optional): Alpha parameter for CutMix. + Defaults to ``0``. + groups (int, optional): Number of groups in convolutions. Defaults + to ``1``. + scale (float, optional): Expansion factor affecting the width of + the estimators. Only used if :attr:`version` is ``"masked"``. + Defaults to ``None``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. @@ -197,18 +204,24 @@ def __new__( ``1``. batch_repeat (int, optional): Number of times to repeat the batch. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. + ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. + MSP is the maximum softmax probability, logit is the maximum + logit, entropy is the entropy of the mean prediction, mi is the + mutual information of the ensemble and vr is the variation ratio + of the ensemble. + log_plots (bool, optional): Indicates whether to log the plots or not. + Defaults to ``False``. + save_in_csv (bool, optional): Indicates whether to save the results in + a csv file or not. Defaults to ``False``. + calibration_set (Callable, optional): Calibration set. Defaults to + ``None``. + eval_ood (bool, optional): Indicates whether to evaluate the + OOD detection or not. Defaults to ``False``. + eval_grouping_loss (bool, optional): Indicates whether to evaluate the + grouping loss or not. Defaults to ``False``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. Defaults to ``False``. - **kwargs: Additional arguments. Raises: ValueError: If :attr:`version` is not either ``"std"``, @@ -228,32 +241,29 @@ def __new__( format_batch_fn = nn.Identity() - if version not in cls.versions: + if version not in self.versions: raise ValueError(f"Unknown version: {version}") - if version in cls.ensemble: - params.update( - { - "num_estimators": num_estimators, - } - ) + if version in self.ensemble: + params |= { + "num_estimators": num_estimators, + } + if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": - params.update( - { - "alpha": alpha, - "gamma": gamma, - "pretrained": pretrained, - } - ) + params |= { + "alpha": alpha, + "gamma": gamma, + "pretrained": pretrained, + } + elif version == "masked": - params.update( - { - "scale": scale, - } - ) + params |= { + "scale": scale, + } + elif version == "mimo": format_batch_fn = MIMOBatchFormat( num_estimators=num_estimators, @@ -261,12 +271,10 @@ def __new__( batch_repeat=batch_repeat, ) - # for lightning params - kwargs.update(params | {"version": version, "arch": arch}) - if version == "mc-dropout": # std ResNets don't have `num_estimators` del params["num_estimators"] - model = cls.versions[version][cls.archs.index(arch)](**params) + + model = self.versions[version][self.archs.index(arch)](**params) if version == "mc-dropout": model = mc_dropout( model=model, @@ -274,73 +282,24 @@ def __new__( last_layer=last_layer_dropout, ) - # routine specific parameters - if version in cls.single: - return ClassificationSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - # version in cls.ensemble - return ClassificationEnsemble( + super().__init__( + num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, + num_estimators=num_estimators, format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - **kwargs, - ) - - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) - - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser = add_resnet_specific_args(parser) - parser = add_packed_specific_args(parser) - parser = add_masked_specific_args(parser) - parser = add_mimo_specific_args(parser) - parser = add_mc_dropout_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="std", - help=f"Variation of ResNet. Choose among: {cls.versions.keys()}", - ) - parser.add_argument( - "--pretrained", - dest="pretrained", - action=BooleanOptionalAction, - default=False, + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, + ood_criterion=ood_criterion, + log_plots=log_plots, + save_in_csv=save_in_csv, + calibration_set=calibration_set, ) - return parser + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 24579ed1..9c429ea1 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -1,21 +1,8 @@ -from argparse import ArgumentParser -from pathlib import Path -from typing import Any, Literal - -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) +from typing import Literal + from torch import nn -from torch_uncertainty.baselines.utils.parser_addons import ( - add_mc_dropout_specific_args, - add_packed_specific_args, - add_vgg_specific_args, -) -from torch_uncertainty.models.mc_dropout import mc_dropout +from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.vgg import ( packed_vgg11, packed_vgg13, @@ -26,16 +13,13 @@ vgg16, vgg19, ) -from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, -) +from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget -class VGG: +class VGGBaseline(ClassificationRoutine): single = ["std"] - ensemble = ["packed", "mc-dropout"] + ensemble = ["mc-dropout", "packed"] versions = { "std": [vgg11, vgg13, vgg16, vgg19], "mc-dropout": [vgg11, vgg13, vgg16, vgg19], @@ -48,27 +32,36 @@ class VGG: } archs = [11, 13, 16, 19] - def __new__( - cls, + def __init__( + self, num_classes: int, in_channels: int, - loss: type[nn.Module], - optimization_procedure: Any, + loss: nn.Module, version: Literal["std", "mc-dropout", "packed"], arch: int, - num_estimators: int | None = None, + style: str = "imagenet", + num_estimators: int = 1, dropout_rate: float = 0.0, last_layer_dropout: bool = False, - style: str = "imagenet", + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, groups: int = 1, - alpha: float | None = None, + alpha: int | None = None, gamma: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - **kwargs, - ) -> LightningModule: + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + eval_ood: bool = False, + eval_grouping_loss: bool = False, + ) -> None: r"""VGG backbone baseline for classification providing support for various versions and architectures. @@ -76,10 +69,6 @@ def __new__( num_classes (int): Number of classes to predict. in_channels (int): Number of input channels. loss (nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. version (str): Determines which VGG version to use: @@ -101,6 +90,17 @@ def __new__( Only used if :attr:`version` is either ``"packed"``, ``"batched"`` or ``"masked"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. + mixtype (str, optional): Mixup type. Defaults to ``"erm"``. + mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. + dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. + kernel_tau_max (float, optional): Maximum value for the kernel tau. + Defaults to ``1.0``. + kernel_tau_std (float, optional): Standard deviation for the kernel + tau. Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults + to ``0``. + cutmix_alpha (float, optional): Alpha parameter for CutMix. + Defaults to ``0``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. @@ -110,15 +110,22 @@ def __new__( gamma (int, optional): Number of groups within each estimator. Only used if :attr:`version` is ``"packed"`` and scales with :attr:`groups`. Defaults to ``1s``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. - **kwargs: Additional arguments to be passed to the + ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. + MSP is the maximum softmax probability, logit is the maximum + logit, entropy is the entropy of the mean prediction, mi is the + mutual information of the ensemble and vr is the variation ratio + of the ensemble. + log_plots (bool, optional): Indicates whether to log the plots or not. + Defaults to ``False``. + save_in_csv (bool, optional): Indicates whether to save the results in + a csv file or not. Defaults to ``False``. + calibration_set (Callable, optional): Calibration set. Defaults to + ``None``. + eval_ood (bool, optional): Indicates whether to evaluate the + OOD detection or not. Defaults to ``False``. + eval_grouping_loss (bool, optional): Indicates whether to evaluate the + grouping loss or not. Defaults to ``False``. + Raises: ValueError: If :attr:`version` is not either ``"std"``, ``"packed"``, ``"batched"`` or ``"masked"``. @@ -134,102 +141,64 @@ def __new__( "groups": groups, } - if version not in cls.versions: + if version not in self.versions: raise ValueError(f"Unknown version: {version}") format_batch_fn = nn.Identity() - if version in cls.ensemble: - params.update( - { - "num_estimators": num_estimators, - } - ) + if version == "std": + params |= { + "dropout_rate": dropout_rate, + } + + elif version == "mc-dropout": + params |= { + "dropout_rate": dropout_rate, + "num_estimators": num_estimators, + } + + if version in self.ensemble: + params |= { + "num_estimators": num_estimators, + } if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": - params.update( - { - "alpha": alpha, - "style": style, - "gamma": gamma, - } - ) - - # for lightning params - kwargs.update(params | {"version": version, "arch": arch}) + params |= { + "alpha": alpha, + "style": style, + "gamma": gamma, + } if version == "mc-dropout": # std VGGs don't have `num_estimators` del params["num_estimators"] - model = cls.versions[version][cls.archs.index(arch)](**params) + model = self.versions[version][self.archs.index(arch)](**params) if version == "mc-dropout": model = mc_dropout( model=model, num_estimators=num_estimators, last_layer=last_layer_dropout, ) - - # routine specific parameters - if version in cls.single: - return ClassificationSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - # version in cls.ensemble - return ClassificationEnsemble( + super().__init__( + num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, + num_estimators=num_estimators, format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - **kwargs, - ) - - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) - - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser = add_vgg_specific_args(parser) - parser = add_packed_specific_args(parser) - parser = add_mc_dropout_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="std", - help=f"Variation of VGG. Choose among: {cls.versions.keys()}", + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + eval_ood=eval_ood, + ood_criterion=ood_criterion, + log_plots=log_plots, + save_in_csv=save_in_csv, + calibration_set=calibration_set, + eval_grouping_loss=eval_grouping_loss, ) - return parser + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index fb2d2638..ffda0d48 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -1,23 +1,8 @@ -from argparse import ArgumentParser, BooleanOptionalAction -from pathlib import Path -from typing import Any, Literal +from typing import Literal -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) from torch import nn -from torch_uncertainty.baselines.utils.parser_addons import ( - add_masked_specific_args, - add_mc_dropout_specific_args, - add_mimo_specific_args, - add_packed_specific_args, - add_wideresnet_specific_args, -) -from torch_uncertainty.models.mc_dropout import mc_dropout +from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.wideresnet import ( batched_wideresnet28x10, masked_wideresnet28x10, @@ -26,15 +11,14 @@ wideresnet28x10, ) from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, + ClassificationRoutine, ) from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget -class WideResNet: +class WideResNetBaseline(ClassificationRoutine): single = ["std"] - ensemble = ["packed", "batched", "masked", "mc-dropout", "mimo"] + ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] versions = { "std": [wideresnet28x10], "mc-dropout": [wideresnet28x10], @@ -44,32 +28,40 @@ class WideResNet: "mimo": [mimo_wideresnet28x10], } - def __new__( - cls, + def __init__( + self, num_classes: int, in_channels: int, - loss: type[nn.Module], - optimization_procedure: Any, + loss: nn.Module, version: Literal[ "std", "mc-dropout", "packed", "batched", "masked", "mimo" ], style: str = "imagenet", - num_estimators: int | None = None, + num_estimators: int = 1, dropout_rate: float = 0.0, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, + groups: int = 1, last_layer_dropout: bool = False, - groups: int | None = None, scale: float | None = None, alpha: int | None = None, - gamma: int | None = None, + gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - # pretrained: bool = False, - **kwargs, - ) -> LightningModule: + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + eval_ood: bool = False, + eval_grouping_loss: bool = False, + ) -> None: r"""Wide-ResNet28x10 backbone baseline for classification providing support for various versions. @@ -77,7 +69,7 @@ def __new__( num_classes (int): Number of classes to predict. in_channels (int): Number of input channels. loss (nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to + optim_recipe (Any): optimization recipe, corresponds to what expect the `LightningModule.configure_optimizers() `_ method. @@ -97,6 +89,17 @@ def __new__( Only used if :attr:`version` is either ``"packed"``, ``"batched"`` or ``"masked"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. + mixtype (str, optional): Mixup type. Defaults to ``"erm"``. + mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. + dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. + kernel_tau_max (float, optional): Maximum value for the kernel tau. + Defaults to ``1.0``. + kernel_tau_std (float, optional): Standard deviation for the kernel + tau. Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults + to ``0``. + cutmix_alpha (float, optional): Alpha parameter for CutMix. + Defaults to ``0``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. @@ -114,18 +117,21 @@ def __new__( ``1``. batch_repeat (int, optional): Number of times to repeat the batch. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. - pretrained (bool, optional): Indicates whether to use the pretrained - weights or not. Only used if :attr:`version` is ``"packed"``. + ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. + MSP is the maximum softmax probability, logit is the maximum + logit, entropy is the entropy of the mean prediction, mi is the + mutual information of the ensemble and vr is the variation ratio + of the ensemble. + log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. - **kwargs: Additional arguments. + save_in_csv (bool, optional): Indicates whether to save the results in + a csv file or not. Defaults to ``False``. + calibration_set (Callable, optional): Calibration set. Defaults to + ``None``. + eval_ood (bool, optional): Indicates whether to evaluate the + OOD detection or not. Defaults to ``False``. + eval_grouping_loss (bool, optional): Indicates whether to evaluate the + grouping loss or not. Defaults to ``False``. Raises: ValueError: If :attr:`version` is not either ``"std"``, @@ -146,31 +152,28 @@ def __new__( format_batch_fn = nn.Identity() - if version not in cls.versions: + if version not in self.versions: raise ValueError(f"Unknown version: {version}") - if version in cls.ensemble: - params.update( - { - "num_estimators": num_estimators, - } - ) + if version in self.ensemble: + params |= { + "num_estimators": num_estimators, + } + if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": - params.update( - { - "alpha": alpha, - "gamma": gamma, - } - ) + params |= { + "alpha": alpha, + "gamma": gamma, + } + elif version == "masked": - params.update( - { - "scale": scale, - } - ) + params |= { + "scale": scale, + } + elif version == "mimo": format_batch_fn = MIMOBatchFormat( num_estimators=num_estimators, @@ -178,12 +181,11 @@ def __new__( batch_repeat=batch_repeat, ) - # for lightning params - kwargs.update(params | {"version": version}) - if version == "mc-dropout": # std wideRn don't have `num_estimators` del params["num_estimators"] - model = cls.versions[version][0](**params) + + model = self.versions[version][0](**params) + if version == "mc-dropout": model = mc_dropout( model=model, @@ -191,74 +193,24 @@ def __new__( last_layer=last_layer_dropout, ) - # routine specific parameters - if version in cls.single: - return ClassificationSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - # version in cls.ensemble - return ClassificationEnsemble( + super().__init__( + num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, + num_estimators=num_estimators, format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - **kwargs, + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, + ood_criterion=ood_criterion, + log_plots=log_plots, + save_in_csv=save_in_csv, + calibration_set=calibration_set, ) - - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) - - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser = add_wideresnet_specific_args(parser) - parser = add_packed_specific_args(parser) - parser = add_masked_specific_args(parser) - parser = add_mimo_specific_args(parser) - parser = add_mc_dropout_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="std", - help=f"Variation of WideResNet. Choose among: {cls.versions.keys()}", - ) - parser.add_argument( - "--pretrained", - dest="pretrained", - action=BooleanOptionalAction, - default=False, - ) - - return parser + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/deep_ensembles.py b/torch_uncertainty/baselines/deep_ensembles.py deleted file mode 100644 index 1ccc5d7d..00000000 --- a/torch_uncertainty/baselines/deep_ensembles.py +++ /dev/null @@ -1,109 +0,0 @@ -from argparse import ArgumentParser -from pathlib import Path -from typing import Literal - -from pytorch_lightning import LightningModule - -from torch_uncertainty.models import deep_ensembles -from torch_uncertainty.routines.classification import ClassificationEnsemble -from torch_uncertainty.routines.regression import RegressionEnsemble -from torch_uncertainty.utils import get_version - -from .classification import VGG, ResNet, WideResNet -from .regression import MLP - - -class DeepEnsembles: - backbones = { - "mlp": MLP, - "resnet": ResNet, - "vgg": VGG, - "wideresnet": WideResNet, - } - - def __new__( - cls, - task: Literal["classification", "regression"], - log_path: str | Path, - checkpoint_ids: list[int], - backbone: Literal["mlp", "resnet", "vgg", "wideresnet"], - # num_estimators: int, - in_channels: int | None = None, - num_classes: int | None = None, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - **kwargs, - ) -> LightningModule: - log_path = Path(log_path) - - backbone_cls = cls.backbones[backbone] - - models = [] - for version in checkpoint_ids: # coverage: ignore - ckpt_file, hparams_file = get_version( - root=log_path, version=version - ) - trained_model = backbone_cls.load_from_checkpoint( - checkpoint_path=ckpt_file, - hparams_file=hparams_file, - loss=None, - optimization_procedure=None, - ).eval() - models.append(trained_model.model) - - de = deep_ensembles(models=models) - - if task == "classification": - return ClassificationEnsemble( - in_channels=in_channels, - num_classes=num_classes, - model=de, - loss=None, - optimization_procedure=None, - num_estimators=de.num_estimators, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - ) - # task == "regression": - return RegressionEnsemble( - model=de, - loss=None, - optimization_procedure=None, - dist_estimation=2, - num_estimators=de.num_estimators, - mode="mean", - ) - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser.add_argument( - "--task", - type=str, - choices=["classification", "regression"], - help="Task to be performed", - ) - parser.add_argument( - "--backbone", - type=str, - choices=cls.backbones.keys(), - help="Backbone architecture", - required=True, - ) - parser.add_argument( - "--versions", - type=int, - nargs="+", - help="Versions of the model to be ensembled", - ) - parser.add_argument( - "--log_path", - type=str, - help="Root directory of the models", - required=True, - ) - return parser diff --git a/torch_uncertainty/baselines/regression/__init__.py b/torch_uncertainty/baselines/regression/__init__.py index 9320254f..b4a1391a 100644 --- a/torch_uncertainty/baselines/regression/__init__.py +++ b/torch_uncertainty/baselines/regression/__init__.py @@ -1,2 +1,2 @@ # ruff: noqa: F401 -from .mlp import MLP +from .mlp import MLPBaseline diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 0cf0c6c6..02e3c658 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -1,118 +1,88 @@ -from argparse import ArgumentParser -from pathlib import Path -from typing import Any, Literal +from typing import Literal -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) from torch import nn -from torch_uncertainty.baselines.utils.parser_addons import ( - add_packed_specific_args, +from torch_uncertainty.layers.distributions import ( + LaplaceLayer, + NormalInverseGammaLayer, + NormalLayer, ) from torch_uncertainty.models.mlp import mlp, packed_mlp from torch_uncertainty.routines.regression import ( - RegressionEnsemble, - RegressionSingle, + RegressionRoutine, ) +from torch_uncertainty.transforms.batch import RepeatTarget -class MLP: +class MLPBaseline(RegressionRoutine): single = ["std"] ensemble = ["packed"] versions = {"std": mlp, "packed": packed_mlp} - def __new__( - cls, - num_outputs: int, + def __init__( + self, + output_dim: int, in_features: int, - loss: type[nn.Module], - optimization_procedure: Any, + loss: nn.Module, version: Literal["std", "packed"], hidden_dims: list[int], - dist_estimation: int, num_estimators: int | None = 1, + dropout_rate: float = 0.0, alpha: float | None = None, gamma: int = 1, - **kwargs, - ) -> LightningModule: + distribution: Literal["normal", "laplace", "nig"] | None = None, + ) -> None: r"""MLP baseline for regression providing support for various versions.""" + probabilistic = True params = { + "dropout_rate": dropout_rate, "in_features": in_features, - "num_outputs": num_outputs, + "num_outputs": output_dim, "hidden_dims": hidden_dims, } + if distribution == "normal": + final_layer = NormalLayer + final_layer_args = {"dim": output_dim} + params["num_outputs"] *= 2 + elif distribution == "laplace": + final_layer = LaplaceLayer + final_layer_args = {"dim": output_dim} + params["num_outputs"] *= 2 + elif distribution == "nig": + final_layer = NormalInverseGammaLayer + final_layer_args = {"dim": output_dim} + params["num_outputs"] *= 4 + else: # distribution is None: + probabilistic = False + final_layer = nn.Identity + final_layer_args = {} + + params["final_layer"] = final_layer + params["final_layer_args"] = final_layer_args + + format_batch_fn = nn.Identity() + + if version not in self.versions: + raise ValueError(f"Unknown version: {version}") + if version == "packed": params |= { "alpha": alpha, "num_estimators": num_estimators, "gamma": gamma, } + format_batch_fn = RepeatTarget(num_repeats=num_estimators) - if version not in cls.versions: - raise ValueError(f"Unknown version: {version}") - - model = cls.versions[version](**params) + model = self.versions[version](**params) - kwargs.update(params) - kwargs.update({"version": version}) - # routine specific parameters - if version in cls.single: - return RegressionSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, - **kwargs, - ) - # version in cls.versions.keys(): - return RegressionEnsemble( + # version in self.versions: + super().__init__( + probabilistic=probabilistic, + output_dim=output_dim, model=model, loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, - mode="mean", - **kwargs, - ) - return None - - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) - - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = RegressionEnsemble.add_model_specific_args(parser) - parser = add_packed_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="std", - help=f"Variation of MLP. Choose among: {cls.versions.keys()}", + num_estimators=num_estimators, + format_batch_fn=format_batch_fn, ) - return parser + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/__init__.py b/torch_uncertainty/baselines/segmentation/__init__.py new file mode 100644 index 00000000..fe2488e4 --- /dev/null +++ b/torch_uncertainty/baselines/segmentation/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .segformer import SegFormerBaseline diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py new file mode 100644 index 00000000..1e8185a1 --- /dev/null +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -0,0 +1,80 @@ +from typing import Literal + +from torch import nn + +from torch_uncertainty.models.segmentation.segformer import ( + seg_former_b0, + seg_former_b1, + seg_former_b2, + seg_former_b3, + seg_former_b4, + seg_former_b5, +) +from torch_uncertainty.routines.segmentation import SegmentationRoutine + + +class SegFormerBaseline(SegmentationRoutine): + single = ["std"] + versions = { + "std": [ + seg_former_b0, + seg_former_b1, + seg_former_b2, + seg_former_b3, + seg_former_b4, + seg_former_b5, + ] + } + archs = [0, 1, 2, 3, 4, 5] + + def __init__( + self, + num_classes: int, + loss: nn.Module, + version: Literal["std"], + arch: int, + num_estimators: int = 1, + ) -> None: + r"""SegFormer backbone baseline for segmentation providing support for + various versions and architectures. + + Args: + num_classes (int): Number of classes to predict. + loss (type[Module]): Training loss. + version (str): + Determines which SegFormer version to use. Options are: + + - ``"std"``: original SegFormer + + arch (int): + Determines which architecture to use. Options are: + + - ``0``: SegFormer-B0 + - ``1``: SegFormer-B1 + - ``2``: SegFormer-B2 + - ``3``: SegFormer-B3 + - ``4``: SegFormer-B4 + - ``5``: SegFormer-B5 + + num_estimators (int, optional): Number of estimators in the + ensemble. Defaults to 1 (single model). + """ + params = { + "num_classes": num_classes, + } + + format_batch_fn = nn.Identity() + + if version not in self.versions: + raise ValueError(f"Unknown version {version}") + + model = self.versions[version][self.archs.index(arch)](**params) + + super().__init__( + num_classes=num_classes, + model=model, + loss=loss, + num_estimators=num_estimators, + format_batch_fn=format_batch_fn, + ) + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/utils/parser_addons.py b/torch_uncertainty/baselines/utils/parser_addons.py deleted file mode 100644 index 763aa908..00000000 --- a/torch_uncertainty/baselines/utils/parser_addons.py +++ /dev/null @@ -1,139 +0,0 @@ -from argparse import ArgumentParser - - -def add_resnet_specific_args(parser: ArgumentParser) -> ArgumentParser: - """Add ResNet specific arguments to parser. - - Args: - parser (ArgumentParser): Argument parser. - - Adds the following arguments: - --arch (int): Architecture of ResNet. Choose among: [18, 34, 50, 101, 152] - --dropout_rate (float): Dropout rate. - --groups (int): Number of groups. - """ - # style_choices = ["cifar", "imagenet", "robust"] - archs = [18, 20, 34, 50, 101, 152] - parser.add_argument( - "--arch", - type=int, - choices=archs, - default=18, - help=f"Architecture of ResNet. Choose among: {archs}", - ) - parser.add_argument( - "--dropout_rate", - type=float, - default=0.0, - help="Dropout rate", - ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups", - ) - return parser - - -def add_vgg_specific_args(parser: ArgumentParser) -> ArgumentParser: - # style_choices = ["cifar", "imagenet", "robust"] - archs = [11, 13, 16, 19] - parser.add_argument( - "--arch", - type=int, - choices=archs, - default=11, - help=f"Architecture of VGG. Choose among: {archs}", - ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups", - ) - parser.add_argument( - "--dropout_rate", - type=float, - default=0.1, - help="Dropout rate", - ) - return parser - - -def add_wideresnet_specific_args(parser: ArgumentParser) -> ArgumentParser: - # style_choices = ["cifar", "imagenet"] - parser.add_argument( - "--dropout_rate", - type=float, - default=0.3, - help="Dropout rate", - ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups", - ) - return parser - - -def add_mlp_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--dropout_rate", - type=float, - default=0.1, - help="Dropout rate", - ) - return parser - - -def add_packed_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--alpha", - type=int, - default=None, - help="Alpha for Packed-Ensembles", - ) - parser.add_argument( - "--gamma", - type=int, - default=1, - help="Gamma for Packed-Ensembles", - ) - return parser - - -def add_masked_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--scale", - type=float, - default=None, - help="Scale for Masksembles", - ) - return parser - - -def add_mimo_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--rho", - type=float, - default=0.0, - help="Rho for MIMO", - ) - parser.add_argument( - "--batch_repeat", - type=int, - default=1, - help="Batch repeat for MIMO", - ) - return parser - - -def add_mc_dropout_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--last_layer_dropout", - action="store_true", - help="Whether to apply dropout to the last layer only", - ) - return parser diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index 90b9f0eb..24859701 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -1,7 +1,8 @@ # ruff: noqa: F401 -from .cifar10 import CIFAR10DataModule -from .cifar100 import CIFAR100DataModule -from .imagenet import ImageNetDataModule -from .mnist import MNISTDataModule -from .tiny_imagenet import TinyImageNetDataModule +from .classification.cifar10 import CIFAR10DataModule +from .classification.cifar100 import CIFAR100DataModule +from .classification.imagenet import ImageNetDataModule +from .classification.mnist import MNISTDataModule +from .classification.tiny_imagenet import TinyImageNetDataModule +from .segmentation import CamVidDataModule, CityscapesDataModule from .uci_regression import UCIDataModule diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index ecda2634..1da19ced 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -1,9 +1,8 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal +from lightning.pytorch.core import LightningDataModule from numpy.typing import ArrayLike -from pytorch_lightning import LightningDataModule from sklearn.model_selection import StratifiedKFold from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler @@ -19,10 +18,10 @@ def __init__( self, root: str | Path, batch_size: int, - num_workers: int = 1, - pin_memory: bool = True, - persistent_workers: bool = True, - **kwargs, + val_split: float, + num_workers: int, + pin_memory: bool, + persistent_workers: bool, ) -> None: """Abstract DataModule class. @@ -33,17 +32,16 @@ def __init__( Args: root (str): Root directory of the datasets. batch_size (int): Number of samples per batch. - num_workers (int): Number of workers to use for data loading. Defaults - to ``1``. - pin_memory (bool): Whether to pin memory. Defaults to ``True``. - persistent_workers (bool): Whether to use persistent workers. Defaults - to ``True``. - kwargs (Any): Other arguments. + val_split (float): Share of samples to use for validation. + num_workers (int): Number of workers to use for data loading. + pin_memory (bool): Whether to pin memory. + persistent_workers (bool): Whether to use persistent workers. """ super().__init__() self.root = Path(root) self.batch_size = batch_size + self.val_split = val_split self.num_workers = num_workers self.pin_memory = pin_memory @@ -140,6 +138,7 @@ def make_cross_val_splits( val_idx=val_idx, datamodule=self, batch_size=self.batch_size, + val_split=self.val_split, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -148,22 +147,6 @@ def make_cross_val_splits( return cv_dm - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=None) - p.add_argument("--num_workers", type=int, default=4) - p.add_argument("--use_cv", action="store_true") - p.add_argument("--n_splits", type=int, default=10) - p.add_argument("--train_over", type=int, default=4) - return parent_parser - class CrossValDataModule(AbstractDataModule): def __init__( @@ -173,18 +156,18 @@ def __init__( val_idx: ArrayLike, datamodule: AbstractDataModule, batch_size: int, - num_workers: int = 1, - pin_memory: bool = True, - persistent_workers: bool = True, - **kwargs, + val_split: float, + num_workers: int, + pin_memory: bool, + persistent_workers: bool, ) -> None: super().__init__( - root, - batch_size, - num_workers, - pin_memory, - persistent_workers, - **kwargs, + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, ) self.train_idx = train_idx @@ -197,6 +180,8 @@ def setup(self, stage: str | None = None) -> None: self.val = self.dm.val elif stage == "test": self.test = self.val + else: + raise ValueError(f"Stage {stage} not supported.") def _data_loader(self, dataset: Dataset, idx: ArrayLike) -> DataLoader: return DataLoader( diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py similarity index 86% rename from torch_uncertainty/datamodules/cifar10.py rename to torch_uncertainty/datamodules/classification/cifar10.py index 7c7bf6dc..45452115 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -1,20 +1,19 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import numpy as np import torchvision.transforms as T from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10, SVHN +from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H from torch_uncertainty.transforms import Cutout - -from .abstract import AbstractDataModule +from torch_uncertainty.utils import create_train_val_split class CIFAR10DataModule(AbstractDataModule): @@ -26,9 +25,9 @@ class CIFAR10DataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, - val_split: float = 0.0, + eval_ood: bool = False, + val_split: float | None = None, num_workers: int = 1, cutout: int | None = None, auto_augment: str | None = None, @@ -37,7 +36,6 @@ def __init__( num_dataloaders: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for CIFAR10. @@ -60,11 +58,11 @@ def __init__( pin_memory (bool): Whether to pin memory. Defaults to ``True``. persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. - kwargs: Additional arguments. """ super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -151,13 +149,12 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, self.val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) + else: self.train = full self.val = self.dataset( @@ -166,7 +163,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: if self.test_alt is None: self.test = self.dataset( self.root, @@ -187,7 +184,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") def train_dataloader(self) -> DataLoader: @@ -224,21 +221,3 @@ def _get_train_targets(self) -> ArrayLike: if self.val_split: return np.array(self.train.dataset.targets)[self.train.indices] return np.array(self.train.targets) - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for CIFAR10 - p.add_argument("--cutout", type=int, default=0) - p.add_argument("--auto_augment", type=str) - p.add_argument("--test_alt", choices=["c", "h"], default=None) - p.add_argument( - "--severity", dest="corruption_severity", type=int, default=None - ) - p.add_argument("--eval-ood", action="store_true") - return parent_parser diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py similarity index 85% rename from torch_uncertainty/datamodules/cifar100.py rename to torch_uncertainty/datamodules/classification/cifar100.py index 63574ce1..bc5a3691 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -1,6 +1,5 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import numpy as np import torch @@ -8,14 +7,14 @@ from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN +from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C from torch_uncertainty.transforms import Cutout - -from .abstract import AbstractDataModule +from torch_uncertainty.utils import create_train_val_split class CIFAR100DataModule(AbstractDataModule): @@ -27,9 +26,9 @@ class CIFAR100DataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, - val_split: float = 0.0, + eval_ood: bool = False, + val_split: float | None = None, cutout: int | None = None, randaugment: bool = False, auto_augment: str | None = None, @@ -39,7 +38,6 @@ def __init__( num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for CIFAR100. @@ -63,18 +61,17 @@ def __init__( pin_memory (bool): Whether to pin memory. Defaults to ``True``. persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. - kwargs: Additional arguments. """ super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) self.eval_ood = eval_ood - self.val_split = val_split self.num_dataloaders = num_dataloaders if test_alt == "c": @@ -152,12 +149,10 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, self.val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) else: self.train = full @@ -167,7 +162,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: if self.test_alt is None: self.test = self.dataset( self.root, @@ -188,7 +183,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") def train_dataloader(self) -> DataLoader: @@ -225,22 +220,3 @@ def _get_train_targets(self) -> ArrayLike: if self.val_split: return np.array(self.train.dataset.targets)[self.train.indices] return np.array(self.train.targets) - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for CIFAR100 - p.add_argument("--cutout", type=int, default=0) - p.add_argument("--randaugment", dest="randaugment", action="store_true") - p.add_argument("--auto_augment", type=str) - p.add_argument("--test_alt", choices=["c"], default=None) - p.add_argument( - "--severity", dest="corruption_severity", type=int, default=1 - ) - p.add_argument("--eval-ood", action="store_true") - return parent_parser diff --git a/torch_uncertainty/datamodules/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py similarity index 88% rename from torch_uncertainty/datamodules/imagenet.py rename to torch_uncertainty/datamodules/classification/imagenet.py index f32774ee..8f89a23b 100644 --- a/torch_uncertainty/datamodules/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -1,14 +1,13 @@ import copy -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import torchvision.transforms as T import yaml from timm.data.auto_augment import rand_augment_transform from timm.data.mixup import Mixup from torch import nn -from torch.utils.data import DataLoader, Subset, random_split +from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist from torch_uncertainty.datamodules.abstract import AbstractDataModule @@ -18,6 +17,7 @@ ImageNetR, OpenImageO, ) +from torch_uncertainty.utils.misc import create_train_val_split class ImageNetDataModule(AbstractDataModule): @@ -38,8 +38,8 @@ class ImageNetDataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, + eval_ood: bool = False, val_split: float | Path | None = None, ood_ds: str = "openimage-o", test_alt: str | None = None, @@ -49,7 +49,6 @@ def __init__( num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for ImageNet. @@ -77,6 +76,7 @@ def __init__( super().__init__( root=Path(root), batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -203,12 +203,10 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split and isinstance(self.val_split, float): - self.train, self.val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) elif isinstance(self.val_split, Path): self.train = Subset(full, self.train_indices) @@ -222,13 +220,13 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: split="val", transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: self.test = self.dataset( self.root, split="val", transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: @@ -257,25 +255,6 @@ def test_dataloader(self) -> list[DataLoader]: dataloader.append(self._data_loader(self.ood)) return dataloader - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for ImageNet - p.add_argument("--eval-ood", action="store_true") - p.add_argument("--ood_ds", choices=cls.ood_datasets, default="svhn") - p.add_argument("--test_alt", choices=cls.test_datasets, default=None) - p.add_argument("--procedure", choices=["ViT", "A3"], default=None) - p.add_argument("--train_size", type=int, default=224) - p.add_argument( - "--rand_augment", dest="rand_augment_opt", type=str, default=None - ) - return parent_parser - def read_indices(path: Path) -> list[str]: # coverage: ignore """Read a file and return its lines as a list. diff --git a/torch_uncertainty/datamodules/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py similarity index 84% rename from torch_uncertainty/datamodules/mnist.py rename to torch_uncertainty/datamodules/classification/mnist.py index f60606a0..77a6f4f5 100644 --- a/torch_uncertainty/datamodules/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -1,15 +1,15 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import torchvision.transforms as T from torch import nn -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST from torch_uncertainty.transforms import Cutout +from torch_uncertainty.utils import create_train_val_split class MNISTDataModule(AbstractDataModule): @@ -22,16 +22,15 @@ class MNISTDataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, + eval_ood: bool = False, ood_ds: Literal["fashion", "not"] = "fashion", - val_split: float = 0.0, + val_split: float | None = None, num_workers: int = 1, cutout: int | None = None, test_alt: Literal["c"] | None = None, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for MNIST. @@ -56,6 +55,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -63,7 +63,6 @@ def __init__( self.eval_ood = eval_ood self.batch_size = batch_size - self.val_split = val_split if test_alt == "c": self.dataset = MNISTC @@ -114,12 +113,10 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, self.val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) else: self.train = full @@ -129,14 +126,14 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: self.test = self.dataset( self.root, train=False, download=False, transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: @@ -158,17 +155,3 @@ def test_dataloader(self) -> list[DataLoader]: if self.eval_ood: dataloader.append(self._data_loader(self.ood)) return dataloader - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for MNIST - p.add_argument("--eval-ood", action="store_true") - p.add_argument("--ood_ds", choices=cls.ood_datasets, default="fashion") - p.add_argument("--test_alt", choices=["c"], default=None) - return parent_parser diff --git a/torch_uncertainty/datamodules/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py similarity index 85% rename from torch_uncertainty/datamodules/tiny_imagenet.py rename to torch_uncertainty/datamodules/classification/tiny_imagenet.py index 1a0568ac..5430264d 100644 --- a/torch_uncertainty/datamodules/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -1,7 +1,7 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal +import numpy as np import torchvision.transforms as T from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform @@ -9,9 +9,9 @@ from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN +from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet - -from .abstract import AbstractDataModule +from torch_uncertainty.utils import create_train_val_split class TinyImageNetDataModule(AbstractDataModule): @@ -22,18 +22,19 @@ class TinyImageNetDataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, + eval_ood: bool = False, + val_split: float | None = None, ood_ds: str = "svhn", rand_augment_opt: str | None = None, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -121,23 +122,31 @@ def prepare_data(self) -> None: # coverage: ignore def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: - self.train = self.dataset( + full = self.dataset( self.root, split="train", transform=self.train_transform, ) - self.val = self.dataset( - self.root, - split="val", - transform=self.test_transform, - ) - elif stage == "test": + if self.val_split: + self.train, self.val = create_train_val_split( + full, + self.val_split, + self.test_transform, + ) + else: + self.train = full + self.val = self.dataset( + self.root, + split="val", + transform=self.test_transform, + ) + if stage == "test" or stage is None: self.test = self.dataset( self.root, split="val", transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: @@ -200,22 +209,11 @@ def test_dataloader(self) -> list[DataLoader]: return dataloader def _get_train_data(self) -> ArrayLike: + if self.val_split: + return self.train.dataset.samples[self.train.indices] return self.train.samples def _get_train_targets(self) -> ArrayLike: - return self.train.label_data - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for Tiny Imagenet - p.add_argument( - "--rand_augment", dest="rand_augment_opt", type=str, default=None - ) - p.add_argument("--eval-ood", action="store_true") - return parent_parser + if self.val_split: + return np.array(self.train.dataset.label_data)[self.train.indices] + return np.array(self.train.label_data) diff --git a/torch_uncertainty/datamodules/depth_estimation/__init__.py b/torch_uncertainty/datamodules/depth_estimation/__init__.py new file mode 100644 index 00000000..dc94a8cb --- /dev/null +++ b/torch_uncertainty/datamodules/depth_estimation/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .muad import MUADDataModule diff --git a/torch_uncertainty/datamodules/depth_estimation/muad.py b/torch_uncertainty/datamodules/depth_estimation/muad.py new file mode 100644 index 00000000..a751926e --- /dev/null +++ b/torch_uncertainty/datamodules/depth_estimation/muad.py @@ -0,0 +1,144 @@ +from pathlib import Path + +import torch +from torch.nn.common_types import _size_2_t +from torch.nn.modules.utils import _pair +from torchvision import tv_tensors +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets import MUAD +from torch_uncertainty.transforms import RandomRescale +from torch_uncertainty.utils.misc import create_train_val_split + + +class MUADDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + crop_size: _size_2_t = 1024, + inference_size: _size_2_t = (1024, 2048), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""Segmentation DataModule for the MUAD dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + """ + super().__init__( + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.dataset = MUAD + self.crop_size = _pair(crop_size) + self.inference_size = _pair(inference_size) + + self.train_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: -float("inf")}, + ), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.test_transform = v2.Compose( + [ + v2.Resize(size=self.inference_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset( + root=self.root, split="train", target_type="depth", download=True + ) + self.dataset( + root=self.root, split="val", target_type="depth", download=True + ) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + full = self.dataset( + root=self.root, + split="train", + target_type="depth", + transforms=self.train_transform, + ) + + if self.val_split is not None: + self.train, self.val = create_train_val_split( + full, + self.val_split, + self.test_transform, + ) + else: + self.train = full + self.val = self.dataset( + root=self.root, + split="val", + target_type="depth", + transforms=self.test_transform, + ) + + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="val", + target_type="depth", + transforms=self.test_transform, + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datamodules/segmentation/__init__.py b/torch_uncertainty/datamodules/segmentation/__init__.py new file mode 100644 index 00000000..b4f55984 --- /dev/null +++ b/torch_uncertainty/datamodules/segmentation/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 +from .camvid import CamVidDataModule +from .cityscapes import CityscapesDataModule +from .muad import MUADDataModule diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py new file mode 100644 index 00000000..4a4aee65 --- /dev/null +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -0,0 +1,123 @@ +from pathlib import Path + +import torch +from torchvision import tv_tensors +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets.segmentation import CamVid + + +class CamVidDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + val_split: float | None = None, # FIXME: not used for now + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""DataModule for the CamVid dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose( + [ + v2.Resize((360, 480)), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + ] + ) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ + super().__init__( + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + self.dataset = CamVid + + self.train_transform = v2.Compose( + [ + v2.Resize((360, 480)), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + ] + ) + self.test_transform = v2.Compose( + [ + v2.Resize((360, 480)), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + ] + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset(root=self.root, download=True) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + self.train = self.dataset( + root=self.root, + split="train", + download=False, + transforms=self.train_transform, + ) + self.val = self.dataset( + root=self.root, + split="val", + download=False, + transforms=self.test_transform, + ) + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="test", + download=False, + transforms=self.test_transform, + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py new file mode 100644 index 00000000..f35bd65d --- /dev/null +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -0,0 +1,195 @@ +from pathlib import Path + +import torch +from torch.nn.common_types import _size_2_t +from torch.nn.modules.utils import _pair +from torchvision import tv_tensors +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets.segmentation import Cityscapes +from torch_uncertainty.transforms import RandomRescale +from torch_uncertainty.utils.misc import create_train_val_split + + +class CityscapesDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + crop_size: _size_2_t = 1024, + inference_size: _size_2_t = (1024, 2048), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""DataModule for the Cityscapes dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + Training transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + Validation/Test transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + v2.Resize(size=inference_size, antialias=True), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ + super().__init__( + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.dataset = Cityscapes + self.mode = "fine" + self.crop_size = _pair(crop_size) + self.inference_size = _pair(inference_size) + + self.train_transform = v2.Compose( + [ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(size=self.inference_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset(root=self.root, split="train", mode=self.mode) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + full = self.dataset( + root=self.root, + split="train", + mode=self.mode, + target_type="semantic", + transforms=self.train_transform, + ) + + if self.val_split is not None: + self.train, self.val = create_train_val_split( + full, + self.val_split, + self.test_transform, + ) + else: + self.train = full + self.val = self.dataset( + root=self.root, + split="val", + mode=self.mode, + target_type="semantic", + transforms=self.test_transform, + ) + + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="val", + mode=self.mode, + target_type="semantic", + transforms=self.test_transform, + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py new file mode 100644 index 00000000..c126b05e --- /dev/null +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -0,0 +1,194 @@ +from pathlib import Path + +import torch +from torch.nn.common_types import _size_2_t +from torch.nn.modules.utils import _pair +from torchvision import tv_tensors +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets import MUAD +from torch_uncertainty.transforms import RandomRescale +from torch_uncertainty.utils.misc import create_train_val_split + + +class MUADDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + crop_size: _size_2_t = 1024, + inference_size: _size_2_t = (1024, 2048), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""Segmentation DataModule for the MUAD dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + Training transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + Validation/Test transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + v2.Resize(size=inference_size, antialias=True), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ + super().__init__( + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.dataset = MUAD + self.crop_size = _pair(crop_size) + self.inference_size = _pair(inference_size) + + self.train_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.test_transform = v2.Compose( + [ + v2.Resize(size=self.inference_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset( + root=self.root, split="train", target_type="semantic", download=True + ) + self.dataset( + root=self.root, split="val", target_type="semantic", download=True + ) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + full = self.dataset( + root=self.root, + split="train", + target_type="semantic", + transforms=self.train_transform, + ) + + if self.val_split is not None: + self.train, self.val = create_train_val_split( + full, + self.val_split, + self.test_transform, + ) + else: + self.train = full + self.val = self.dataset( + root=self.root, + split="val", + target_type="semantic", + transforms=self.test_transform, + ) + + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="val", + target_type="semantic", + transforms=self.test_transform, + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 1fe028c7..66571959 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -1,7 +1,5 @@ -from argparse import ArgumentParser from functools import partial from pathlib import Path -from typing import Any from torch import Generator from torch.utils.data import random_split @@ -25,7 +23,6 @@ def __init__( persistent_workers: bool = True, input_shape: tuple[int, ...] | None = None, split_seed: int = 42, - **kwargs, ) -> None: """The UCI regression datasets. @@ -48,18 +45,16 @@ def __init__( ``None``. split_seed (int, optional): The seed to use for splitting the dataset. Defaults to ``42``. - **kwargs: Additional arguments. """ super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) - self.val_split = val_split - self.dataset = partial( UCIRegression, dataset_name=dataset_name, seed=split_seed ) @@ -70,6 +65,7 @@ def prepare_data(self) -> None: """Download the dataset.""" self.dataset(root=self.root, download=True) + # ruff: noqa: ARG002 def setup(self, stage: str | None = None) -> None: """Split the datasets into train, val, and test.""" full = self.dataset( @@ -87,22 +83,3 @@ def setup(self, stage: str | None = None) -> None: ) if self.val_split == 0: self.val = self.test - - # Change by default test_dataloader -> List[DataLoader] - # def test_dataloader(self) -> DataLoader: - # """Get the test dataloader for UCI Regression. - - # Return: - # DataLoader: UCI Regression test dataloader. - # """ - # return self._data_loader(self.test) - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - super().add_argparse_args(parent_parser) - - return parent_parser diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index 5aa7ef67..732334a0 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,3 +1,4 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset from .frost import FrostImages +from .muad import MUAD diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index c98563e8..10f9f230 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -1,8 +1,8 @@ from collections.abc import Callable from pathlib import Path -from typing import Any import numpy as np +from torch import Tensor from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, @@ -169,7 +169,7 @@ def __len__(self) -> int: """The number of samples in the dataset.""" return self.labels.shape[0] - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index: int) -> tuple[np.ndarray | Tensor, int]: """Get the samples and targets of the dataset. Args: diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index 4c0ae7f6..168f8571 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -70,9 +70,7 @@ def __init__( def _check_specific_integrity(self) -> bool: filename, md5 = self.h_test_list fpath = self.root / filename - if not check_integrity(fpath, md5): - return False - return True + return check_integrity(fpath, md5) def download_h(self) -> None: download_url( diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_n.py b/torch_uncertainty/datasets/classification/cifar/cifar_n.py index 56a74704..069a081a 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_n.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_n.py @@ -70,9 +70,7 @@ def __init__( def _check_specific_integrity(self) -> bool: filename, md5 = self.n_test_list fpath = self.root / filename - if not check_integrity(fpath, md5): - return False - return True + return check_integrity(fpath, md5) def download_n(self) -> None: download_and_extract_archive( @@ -124,9 +122,7 @@ def __init__( def _check_specific_integrity(self) -> bool: filename, md5 = self.n_test_list fpath = self.root / filename - if not check_integrity(fpath, md5): - return False - return True + return check_integrity(fpath, md5) def download_n(self) -> None: download_and_extract_archive( diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index a33ae285..891bfb9a 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -49,6 +49,7 @@ def __init__( self.download() self.root = Path(root) + self.split = split if not self._check_integrity(): raise RuntimeError( diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py index 5dd08ffe..553fbd1b 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py @@ -51,7 +51,7 @@ def make_dataset(self) -> None: self.samples = samples self.label_data = torch.as_tensor(labels).long() - def _add_channels(self, img): + def _add_channels(self, img: np.ndarray) -> np.ndarray: while len(img.shape) < 3: # third axis is the channels img = np.expand_dims(img, axis=-1) while (img.shape[-1]) < 3: @@ -78,7 +78,7 @@ def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: return sample, target - def _make_paths(self): + def _make_paths(self) -> list[tuple[Path, int]]: self.ids = [] with self.wnids_path.open() as idf: for nid in idf: @@ -103,7 +103,7 @@ def _make_paths(self): label_id = self.ids.index(nid) with anno_path.open() as annof: for line in annof: - fname, x0, y0, x1, y1 = line.split() + fname, _, _, _, _ = line.split() fname = imgs_path / fname paths.append((fname, label_id)) @@ -111,7 +111,7 @@ def _make_paths(self): val_path = self.root / "val" with (val_path / "val_annotations.txt").open() as valf: for line in valf: - fname, nid, x0, y0, x1, y1 = line.split() + fname, nid, _, _, _, _ = line.split() fname = val_path / "images" / fname label_id = self.ids.index(nid) paths.append((fname, label_id)) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index 3a72e404..65febcf9 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -163,9 +163,7 @@ def __getitem__(self, index: int) -> Any: def _check_integrity(self) -> bool: """Check the integrity of the dataset.""" fpath = self.root / self.filename - if not check_integrity(fpath, self.zip_md5): - return False - return True + return check_integrity(fpath, self.zip_md5) def download(self) -> None: """Download the dataset.""" diff --git a/torch_uncertainty/datasets/classification/openimage_o.py b/torch_uncertainty/datasets/classification/openimage_o.py index 4ec40f50..2cc9104a 100644 --- a/torch_uncertainty/datasets/classification/openimage_o.py +++ b/torch_uncertainty/datasets/classification/openimage_o.py @@ -46,6 +46,7 @@ def __init__( Wang H., et al. In CVPR 2022. """ self.root = Path(root) + self.split = split self.transform = transform self.target_transform = target_transform @@ -78,4 +79,3 @@ def download(self) -> None: filename=self.filename, md5=self.md5sum, ) - print(f"Downloaded {self.filename} to {self.root}") diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index b7babdb6..ffe842e8 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -1,18 +1,23 @@ import json +import os +import shutil from collections.abc import Callable from pathlib import Path -from typing import Literal +from typing import Any, Literal -import cv2 as cv -import matplotlib.pyplot as plt +import cv2 import numpy as np +import torch +from einops import rearrange from PIL import Image +from torchvision import tv_tensors from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, download_and_extract_archive, download_url, ) +from torchvision.transforms.v2 import functional as F class MUAD(VisionDataset): @@ -26,25 +31,30 @@ class MUAD(VisionDataset): "val": "957af9c1c36f0a85c33279e06b6cf8d8", "val_depth": "0282030d281aeffee3335f713ba12373", } - samples: list[Path] = [] + _num_samples = { + "train": 3420, + "val": 492, + "test": ..., + } targets: list[Path] = [] - # TODO: Add depth regression mode def __init__( self, root: str | Path, - split: Literal["train", "val", "train_depth", "val_depth"], - transform: Callable | None = None, + split: Literal["train", "val"], + target_type: Literal["semantic", "depth"] = "semantic", + transforms: Callable | None = None, download: bool = False, ) -> None: """The MUAD Dataset. Args: root (str): Root directory of dataset where directory 'leftImg8bit' - and 'leftLabel' are located. - split (str, optional): The image split to use, 'train', 'val', - 'train_depth' or 'val_depth'. - transform (callable, optional): A function/transform that takes in + and 'leftLabel' or 'leftDepth' are located. + split (str, optional): The image split to use, 'train' or 'val'. + target_type (str, optional): The type of target to use, 'semantic' + or 'depth'. + transforms (callable, optional): A function/transform that takes in a tuple of PIL images and returns a transformed version. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already @@ -63,23 +73,56 @@ def __init__( ) super().__init__( root=Path(root) / "MUAD", - transform=transform, + transforms=transforms, ) - if split not in ["train", "val", "train_depth", "val_depth"]: + if split not in ["train", "val"]: raise ValueError( - "split must be one of ['train', 'val', 'train_depth', " - f"'val_depth']. Got {split}." + f"split must be one of ['train', 'val']. Got {split}." ) self.split = split - - split_path = self.root / (split + ".zip") - if not check_integrity(split_path, self.zip_md5[split]) and download: - self._download() + self.target_type = target_type + + if not self.check_split_integrity("leftImg8bit"): + if download: + self._download(split=split) + else: + raise FileNotFoundError( + f"MUAD {split} split not found or incomplete. Set download=True to download it." + ) + + if ( + not self.check_split_integrity("leftLabel") + and target_type == "semantic" + ): + if download: + self._download(split=split) + else: + raise FileNotFoundError( + f"MUAD {split} split not found or incomplete. Set download=True to download it." + ) + + if ( + not self.check_split_integrity("leftDepth") + and target_type == "depth" + ): + if download: + self._download(split=f"{split}_depth") + # FIXME: Depth target for train are in a different folder + # thus we move them to the correct folder + if split == "train": + shutil.move( + self.root / f"{split}_depth", + self.root / split / "leftDepth", + ) + else: + raise FileNotFoundError( + f"MUAD {split} split not found or incomplete. Set download=True to download it." + ) # Load classes metadata cls_path = self.root / "classes.json" - if not check_integrity(cls_path, self.classes_md5) and download: + if (not check_integrity(cls_path, self.classes_md5)) and download: download_url( self.classes_url, self.root, @@ -100,37 +143,76 @@ def __init__( self._make_dataset(self.root / split) + def encode_target(self, target: Image.Image) -> Image.Image: + """Encode target image to tensor. + + Args: + target (Image.Image): Target PIL image. + + Returns: + torch.Tensor: Encoded target. + """ + target = F.pil_to_tensor(target) + target = rearrange(target, "c h w -> h w c") + out = torch.zeros_like(target[..., :1]) + # convert target color to index + for muad_class in self.classes: + out[ + ( + target == torch.tensor(muad_class["id"], dtype=target.dtype) + ).all(dim=-1) + ] = muad_class["train_id"] + + return F.to_pil_image(rearrange(out, "h w c -> c h w")) + def decode_target(self, target: Image.Image) -> np.ndarray: target[target == 255] = 19 return self.train_id_to_color[target] - def __getitem__(self, index: int) -> tuple[Image.Image, Image.Image]: - """Get the image and its segmentation target.""" - img_path = self.samples[index] - seg_path = self.targets[index] - - image = cv.imread(img_path) - image = cv.cvtColor(image, cv.COLOR_BGR2RGB) + def __getitem__(self, index: int) -> tuple[Any, Any]: + """Get the sample at the given index. - segm = plt.imread(seg_path) * 255.0 - target = np.zeros((segm.shape[0], segm.shape[1])) + 255.0 + Args: + index (int): Index - for c in self.classes: - upper = np.array(c["train_id"]) - mask = cv.inRange(segm, upper, upper) - target[mask == 255] = c["train_id"] - target = target.astype(np.uint8) - target = Image.fromarray(target) + Returns: + tuple: (image, target) where target is either a segmentation mask + or a depth map. + """ + image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) + if self.target_type == "semantic": + target = tv_tensors.Mask( + self.encode_target(Image.open(self.targets[index])) + ) + else: + os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + target = Image.fromarray( + cv2.imread( + str(self.targets[index]), + cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, + ) + ) + # TODO: in the long tun it would be better to use a custom + # tv_tensor for depth maps (e.g. tv_tensors.DepthMap) + target = np.asarray(target, np.float32) + target = tv_tensors.Mask(400 * (1 - target)) # convert to meters - image = Image.fromarray(image) + if self.transforms is not None: + image, target = self.transforms(image, target) - if self.transform: - image, target = self.transform(image, target) return image, target + def check_split_integrity(self, folder: str) -> bool: + split_path = self.root / self.split + return ( + split_path.is_dir() + and len(list((split_path / folder).glob("**/*"))) + == self._num_samples[self.split] + ) + def __len__(self) -> int: """The number of samples in the dataset.""" - return len(self.samples) + return self._num_samples[self.split] def _make_dataset(self, path: Path) -> None: """Create a list of samples and targets. @@ -143,12 +225,19 @@ def _make_dataset(self, path: Path) -> None: "Depth regression mode is not implemented yet. Raise an issue " "if you need it." ) - self.samples = list((path / "leftImg8bit/").glob("**/*")) - self.targets = list((path / "leftLabel/").glob("**/*")) + self.samples = sorted((path / "leftImg8bit/").glob("**/*")) + if self.target_type == "semantic": + self.targets = sorted((path / "leftLabel/").glob("**/*")) + elif self.target_type == "depth": + self.targets = sorted((path / "leftDepth/").glob("**/*")) + else: + raise ValueError( + f"target_type must be one of ['semantic', 'depth']. Got {self.target_type}." + ) - def _download(self): + def _download(self, split: str) -> None: """Download and extract the chosen split of the dataset.""" - split_url = self.base_url + self.split + ".zip" + split_url = self.base_url + split + ".zip" download_and_extract_archive( - split_url, self.root, md5=self.zip_md5[self.split] + split_url, self.root, md5=self.zip_md5[split] ) diff --git a/torch_uncertainty/datasets/regression/toy.py b/torch_uncertainty/datasets/regression/toy.py index 161ddb8f..d5c3cdfb 100644 --- a/torch_uncertainty/datasets/regression/toy.py +++ b/torch_uncertainty/datasets/regression/toy.py @@ -28,6 +28,6 @@ def __init__( samples = torch.linspace( lower_bound, upper_bound, num_samples - ).unsqueeze(1) + ).unsqueeze(-1) targets = samples**3 + torch.normal(*noise, size=samples.size()) - super().__init__(samples, targets) + super().__init__(samples, targets.squeeze(-1)) diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index ab3fd160..59722abc 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -174,11 +174,11 @@ def _check_integrity(self) -> bool: self.md5, ) - def _standardize(self): + def _standardize(self) -> None: self.data = (self.data - self.data_mean) / self.data_std self.targets = (self.targets - self.target_mean) / self.target_std - def _compute_statistics(self): + def _compute_statistics(self) -> None: self.data_mean = self.data.mean(axis=0) self.data_std = self.data.std(axis=0) self.data_std[self.data_std == 0] = 1 @@ -253,8 +253,6 @@ def _make_dataset(self) -> None: ) # convert Ex to 10^x and remove second target array = df.apply(pd.to_numeric, errors="coerce").to_numpy()[:, :-1] - # elif self.dataset_name == "power-plant": - # array = pd.read_excel(path / "Folds5x2_pp.xlsx").to_numpy() elif self.dataset_name == "protein": array = pd.read_csv( path / "CASP.csv", diff --git a/torch_uncertainty/datasets/segmentation/__init__.py b/torch_uncertainty/datasets/segmentation/__init__.py new file mode 100644 index 00000000..11d4f9fd --- /dev/null +++ b/torch_uncertainty/datasets/segmentation/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa: F401 +from .camvid import CamVid +from .cityscapes import Cityscapes diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py new file mode 100644 index 00000000..5a25c821 --- /dev/null +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -0,0 +1,248 @@ +import json +import shutil +from collections.abc import Callable +from pathlib import Path +from typing import Literal, NamedTuple + +import torch +from einops import rearrange, repeat +from PIL import Image +from torchvision import tv_tensors +from torchvision.datasets import VisionDataset +from torchvision.datasets.utils import ( + download_and_extract_archive, + download_url, +) +from torchvision.transforms.v2 import functional as F + + +class CamVidClass(NamedTuple): + name: str + index: int + color: tuple[int, int, int] + + +class CamVid(VisionDataset): + # Notes: some classes are not used here + classes = [ + CamVidClass("sky", 0, (128, 128, 128)), + CamVidClass("building", 1, (128, 0, 0)), + CamVidClass("pole", 2, (192, 192, 128)), + CamVidClass("road_marking", 3, (255, 69, 0)), + CamVidClass("road", 4, (128, 64, 128)), + CamVidClass("pavement", 5, (60, 40, 222)), + CamVidClass("tree", 6, (128, 128, 0)), + CamVidClass("sign_symbol", 7, (192, 128, 128)), + CamVidClass("fence", 8, (64, 64, 128)), + CamVidClass("car", 9, (64, 0, 128)), + CamVidClass("pedestrian", 10, (64, 64, 0)), + CamVidClass("bicyclist", 11, (0, 128, 192)), + CamVidClass("unlabelled", 12, (0, 0, 0)), + ] + + urls = { + "raw": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip", + "label": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip", + "splits": "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/segmentation/camvid/splits.json", + } + + splits_md5 = "db45289aaa83c60201391b11e78c6382" + + filenames = { + "raw": "701_StillsRaw_full.zip", + "label": "LabeledApproved_full.zip", + } + base_folder = "camvid" + num_samples = { + "train": 367, + "val": 101, + "test": 233, + "all": 701, + } + + def __init__( + self, + root: str, + split: Literal["train", "val", "test"] | None = None, + transforms: Callable | None = None, + download: bool = False, + ) -> None: + """`CamVid `_ Dataset. + + Args: + root (str): Root directory of dataset where ``camvid/`` exists or + will be saved to if download is set to ``True``. + split (str, optional): The dataset split, supports ``train``, + ``val`` and ``test``. Default: ``None``. + transforms (callable, optional): A function/transform that takes + input sample and its target as entry and returns a transformed + version. Default: ``None``. + download (bool, optional): If true, downloads the dataset from the + internet and puts it in root directory. If dataset is already + downloaded, it is not downloaded again. + """ + if split not in ["train", "val", "test", None]: + raise ValueError( + f"Unknown split '{split}'. " + "Supported splits are ['train', 'val', 'test', None]" + ) + + super().__init__(root, transforms, None, None) + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. " + "You can use download=True to download it" + ) + + # get filenames for split + if split is None: + self.images = sorted( + (Path(self.root) / "camvid" / "raw").glob("*.png") + ) + self.targets = sorted( + (Path(self.root) / "camvid" / "label").glob("*.png") + ) + else: + with (Path(self.root) / "camvid" / "splits.json").open() as f: + filenames = json.load(f)[split] + + self.images = sorted( + [ + path + for path in (Path(self.root) / "camvid" / "raw").glob( + "*.png" + ) + if path.stem in filenames + ] + ) + self.targets = sorted( + [ + path + for path in (Path(self.root) / "camvid" / "label").glob( + "*.png" + ) + if path.stem[:-2] in filenames + ] + ) + + self.split = split if split is not None else "all" + + def encode_target(self, target: Image.Image) -> torch.Tensor: + """Encode target image to tensor. + + Args: + target (Image.Image): Target PIL image. + + Returns: + torch.Tensor: Encoded target. + """ + colored_target = F.pil_to_tensor(target) + colored_target = rearrange(colored_target, "c h w -> h w c") + target = torch.zeros_like(colored_target[..., :1]) + # convert target color to index + for camvid_class in self.classes: + index = camvid_class.index if camvid_class.index != 12 else 255 + target[ + ( + colored_target + == torch.tensor(camvid_class.color, dtype=target.dtype) + ).all(dim=-1) + ] = index + + return rearrange(target, "h w c -> c h w") + + def decode_target(self, target: torch.Tensor) -> Image.Image: + """Decode target tensor to image. + + Args: + target (torch.Tensor): Target tensor. + + Returns: + Image.Image: Decoded target as a PIL.Image. + """ + colored_target = repeat(target.clone(), "h w -> h w 3", c=3) + + for camvid_class in self.classes: + colored_target[ + ( + target + == torch.tensor(camvid_class.index, dtype=target.dtype) + ).all(dim=0) + ] = torch.tensor(camvid_class.color, dtype=target.dtype) + + return F.to_pil_image(rearrange(colored_target, "h w c -> c h w")) + + def __getitem__( + self, index: int + ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: + """Get the image and target at the given index. + + Args: + index (int): Sample index. + + Returns: + tuple[tv_tensors.Image, tv_tensors.Mask]: Image and target. + """ + image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) + target = tv_tensors.Mask( + self.encode_target(Image.open(self.targets[index])) + ) + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + """Return the number of samples.""" + return self.num_samples[self.split] + + def _check_integrity(self) -> bool: + """Check if the dataset exists.""" + if ( + len(list((Path(self.root) / "camvid" / "raw").glob("*.png"))) + != self.num_samples["all"] + ): + return False + if ( + len(list((Path(self.root) / "camvid" / "label").glob("*.png"))) + != self.num_samples["all"] + ): + return False + + return (Path(self.root) / "camvid" / "splits.json").exists() + + def download(self) -> None: + """Download the CamVid data if it doesn't exist already.""" + if self._check_integrity(): + print("Files already downloaded and verified") + return + + if (Path(self.root) / self.base_folder).exists(): + shutil.rmtree(Path(self.root) / self.base_folder) + + download_and_extract_archive( + self.urls["raw"], + self.root, + extract_root=Path(self.root) / "camvid", + filename=self.filenames["raw"], + ) + (Path(self.root) / "camvid" / "701_StillsRaw_full").replace( + Path(self.root) / "camvid" / "raw" + ) + download_and_extract_archive( + self.urls["label"], + self.root, + extract_root=Path(self.root) / "camvid" / "label", + filename=self.filenames["label"], + ) + download_url( + self.urls["splits"], + Path(self.root) / "camvid", + filename="splits.json", + md5=self.splits_md5, + ) diff --git a/torch_uncertainty/datasets/segmentation/cityscapes.py b/torch_uncertainty/datasets/segmentation/cityscapes.py new file mode 100644 index 00000000..234a6ee5 --- /dev/null +++ b/torch_uncertainty/datasets/segmentation/cityscapes.py @@ -0,0 +1,81 @@ +from typing import Any + +import torch +from einops import rearrange +from PIL import Image +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +from torchvision import tv_tensors +from torchvision.datasets import Cityscapes as OriginalCityscapes +from torchvision.transforms.v2 import functional as F + + +class Cityscapes(OriginalCityscapes): + def encode_target(self, target: Image.Image) -> Image.Image: + """Encode target image to tensor. + + Args: + target (Image.Image): Target PIL image. + + Returns: + torch.Tensor: Encoded target. + """ + colored_target = F.pil_to_tensor(target) + colored_target = rearrange(colored_target, "c h w -> h w c") + target = torch.zeros_like(colored_target[..., :1]) + # convert target color to index + for cityscapes_class in self.classes: + target[ + ( + colored_target + == torch.tensor(cityscapes_class.id, dtype=target.dtype) + ).all(dim=-1) + ] = cityscapes_class.train_id + + return F.to_pil_image(rearrange(target, "h w c -> c h w")) + + def __getitem__(self, index: int) -> tuple[Any, Any]: + """Get the sample at the given index. + + Args: + index (int): Index + Returns: + tuple: (image, target) where target is a tuple of all target types + if ``target_type`` is a list with more + than one item. Otherwise, target is a json object if + ``target_type="polygon"``, else the image segmentation. + """ + image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) + + targets: Any = [] + for i, t in enumerate(self.target_type): + if t == "polygon": + target = self._load_json(self.targets[index][i]) + elif t == "semantic": + target = tv_tensors.Mask( + self.encode_target(Image.open(self.targets[index][i])) + ) + else: + target = Image.open(self.targets[index][i]) + + targets.append(target) + + target = tuple(targets) if len(targets) > 1 else targets[0] + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def plot_sample( + self, index: int, ax: _AX_TYPE | None = None + ) -> _PLOT_OUT_TYPE: + """Plot a sample from the dataset. + + Args: + index: The index of the sample to plot. + ax: Optional matplotlib axis to plot on. + + Returns: + The axis on which the sample was plotted. + """ + raise NotImplementedError("This method is not implemented yet.") diff --git a/torch_uncertainty/layers/__init__.py b/torch_uncertainty/layers/__init__.py index 64fd39fa..f91746bd 100644 --- a/torch_uncertainty/layers/__init__.py +++ b/torch_uncertainty/layers/__init__.py @@ -2,4 +2,5 @@ from .batch_ensemble import BatchConv2d, BatchLinear from .bayesian import BayesConv1d, BayesConv2d, BayesConv3d, BayesLinear from .masksembles import MaskedConv2d, MaskedLinear +from .modules import Identity from .packed import PackedConv1d, PackedConv2d, PackedConv3d, PackedLinear diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index 10b83a62..6022f40b 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -25,8 +25,13 @@ def __init__( device=None, dtype=None, ) -> None: - r"""Applies a linear transformation using BatchEnsemble method to the - incoming data: :math:`y=(x\circ \hat{r_{group}})W^{T}\circ \hat{s_{group}} + \hat{b}`. + r"""BatchEnsemble-style Linear layer. + + Applies a linear transformation using BatchEnsemble method to the incoming + data. + + .. math:: + y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b} Args: in_features (int): size of each input sample. @@ -70,9 +75,9 @@ def __init__( Shape: - Input: :math:`(N, H_{in})` where :math:`N` is the batch size and - :math:`H_{in} = \text{in_features}`. + :math:`H_{in} = \text{in_features}`. - Output: :math:`(N, H_{out})` where - :math:`H_{out} = \text{out_features}`. + :math:`H_{out} = \text{out_features}`. Warning: Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`. @@ -110,10 +115,6 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see - # https://github.com/pytorch/pytorch/issues/57109 - # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) if self.bias is not None: @@ -204,7 +205,9 @@ def __init__( device=None, dtype=None, ) -> None: - r"""Applies a 2d convolution over an input signal composed of several input + r"""BatchEnsemble-style Conv2d layer. + + Applies a 2d convolution over an input signal composed of several input planes using BatchEnsemble method to the incoming data. In the simplest case, the output value of the layer with input size @@ -212,12 +215,12 @@ def __init__( :math:`(N, C_{out}, H_{out}, W_{out})` can be precisely described as: .. math:: - \text{out}(N_i, C_{\text{out}_j})=\ - &\hat{b}(N_i,C_{\text{out}_j}) - +\hat{s_group}(N_{i},C_{\text{out}_j}) \\ - &\times \sum_{k = 0}^{C_{\text{in}} - 1} - \text{weight}(C_{\text{out}_j}, k)\star (\text{input}(N_i, k) - \times \hat{r_group}(N_i, k)) + \text{out}(N_i, C_{\text{out}_j})=\ + &\widehat{b}(N_i,C_{\text{out}_j}) + +\widehat{s_{group}}(N_{i},C_{\text{out}_j}) \\ + &\times \sum_{k = 0}^{C_{\text{in}} - 1} + \text{weight}(C_{\text{out}_j}, k)\star (\text{input}(N_i, k) + \times \widehat{r_{group}}(N_i, k)) Reference: Introduced by the paper `BatchEnsemble: An Alternative Approach to @@ -335,12 +338,6 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(k), 1/sqrt(k)), where - # k = weight.size(1) * prod(*kernel_size) - # For more details see: - # https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 - # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) if self.bias is not None: @@ -408,7 +405,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: bias if bias is not None else 0 ) - def extra_repr(self): + def extra_repr(self) -> str: s = ( "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", num_estimators={num_estimators}, stride={stride}" diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index 450d9f63..d9fc4df4 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -84,9 +84,7 @@ def __init__( valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} if padding_mode not in valid_padding_modes: raise ValueError( - "padding_mode must be one of {}, but got '{}'".format( - valid_padding_modes, padding_mode - ) + f"padding_mode must be one of {valid_padding_modes}, but got '{padding_mode}'" ) if transposed: @@ -180,7 +178,7 @@ def sample(self) -> tuple[Tensor, Tensor | None]: bias = self.bias_sampler.sample() if self.bias_mu is not None else None return weight, bias - def extra_repr(self): # coverage: ignore + def extra_repr(self) -> str: # coverage: ignore s = ( "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", stride={stride}" @@ -199,7 +197,7 @@ def extra_repr(self): # coverage: ignore s += ", padding_mode={padding_mode}" return s.format(**self.__dict__) - def __setstate__(self, state): + def __setstate__(self, state) -> None: super().__setstate__(state) if not hasattr(self, "padding_mode"): # coverage: ignore self.padding_mode = "zeros" diff --git a/torch_uncertainty/layers/bayesian/bayes_linear.py b/torch_uncertainty/layers/bayesian/bayes_linear.py index 9dd1d06e..074f8554 100644 --- a/torch_uncertainty/layers/bayesian/bayes_linear.py +++ b/torch_uncertainty/layers/bayesian/bayes_linear.py @@ -113,7 +113,7 @@ def forward(self, inputs: Tensor) -> Tensor: return self._frozen_forward(inputs) return self._forward(inputs) - def _frozen_forward(self, inputs): + def _frozen_forward(self, inputs) -> Tensor: return F.linear(inputs, self.weight_mu, self.bias_mu) def _forward(self, inputs: Tensor) -> Tensor: @@ -146,6 +146,4 @@ def sample(self) -> tuple[Tensor, Tensor | None]: return weight, bias def extra_repr(self) -> str: - return "in_features={}, out_features={}, bias={}".format( - self.in_features, self.out_features, self.bias_mu is not None - ) + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias_mu is not None}" diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py new file mode 100644 index 00000000..108cc8c4 --- /dev/null +++ b/torch_uncertainty/layers/distributions.py @@ -0,0 +1,107 @@ +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributions import Distribution, Laplace, Normal + +from torch_uncertainty.utils.distributions import NormalInverseGamma + + +class _AbstractDist(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + if dim < 1: + raise ValueError(f"dim must be positive, got {dim}.") + self.dim = dim + + def forward(self, x: Tensor) -> Distribution: + raise NotImplementedError + + +class NormalLayer(_AbstractDist): + """Normal distribution layer. + + Converts model outputs to Independent Normal distributions. + + Args: + dim (int): The number of independent dimensions for each prediction. + eps (float): The minimal value of the :attr:`scale` parameter. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__(dim) + if eps <= 0: + raise ValueError(f"eps must be positive, got {eps}.") + self.eps = eps + + def forward(self, x: Tensor) -> Normal: + r"""Forward pass of the Normal distribution layer. + + Args: + x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`2). + + Returns: + Normal: The output normal distribution. + """ + loc = x[:, : self.dim] + scale = F.softplus(x[:, self.dim :]) + self.eps + return Normal(loc, scale) + + +class LaplaceLayer(_AbstractDist): + """Laplace distribution layer. + + Converts model outputs to Independent Laplace distributions. + + Args: + dim (int): The number of independent dimensions for each prediction. + eps (float): The minimal value of the :attr:`scale` parameter. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__(dim) + if eps <= 0: + raise ValueError(f"eps must be positive, got {eps}.") + self.eps = eps + + def forward(self, x: Tensor) -> Laplace: + r"""Forward pass of the Laplace distribution layer. + + Args: + x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`2). + + Returns: + Laplace: The output Laplace distribution. + """ + loc = x[:, : self.dim] + scale = F.softplus(x[:, self.dim :]) + self.eps + return Laplace(loc, scale) + + +class NormalInverseGammaLayer(_AbstractDist): + """Normal-Inverse-Gamma distribution layer. + + Converts model outputs to Independent Normal-Inverse-Gamma distributions. + + Args: + dim (int): The number of independent dimensions for each prediction. + eps (float): The minimal values of the :attr:`lmbda`, :attr:`alpha`-1 + and :attr:`beta` parameters. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__(dim) + self.eps = eps + + def forward(self, x: Tensor) -> NormalInverseGamma: + r"""Forward pass of the NormalInverseGamma distribution layer. + + Args: + x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`4). + + Returns: + NormalInverseGamma: The output NormalInverseGamma distribution. + """ + loc = x[:, : self.dim] + lmbda = F.softplus(x[:, self.dim : 2 * self.dim]) + self.eps + alpha = 1 + F.softplus(x[:, 2 * self.dim : 3 * self.dim]) + self.eps + beta = F.softplus(x[:, 3 * self.dim :]) + self.eps + return NormalInverseGamma(loc, lmbda, alpha, beta) diff --git a/torch_uncertainty/layers/filter_response_norm.py b/torch_uncertainty/layers/filter_response_norm.py index ade98582..8c9f3aee 100644 --- a/torch_uncertainty/layers/filter_response_norm.py +++ b/torch_uncertainty/layers/filter_response_norm.py @@ -2,7 +2,7 @@ from torch import Tensor, nn -class FilterResponseNormNd(nn.Module): +class _FilterResponseNormNd(nn.Module): def __init__( self, dimension: int, @@ -49,7 +49,7 @@ def forward(self, x: Tensor) -> Tensor: return torch.max(y, self.tau) -class FilterResponseNorm1d(FilterResponseNormNd): +class FilterResponseNorm1d(_FilterResponseNormNd): def __init__( self, num_channels: int, eps: float = 1e-6, device=None, dtype=None ) -> None: @@ -70,7 +70,7 @@ def __init__( ) -class FilterResponseNorm2d(FilterResponseNormNd): +class FilterResponseNorm2d(_FilterResponseNormNd): def __init__( self, num_channels: int, eps: float = 1e-6, device=None, dtype=None ) -> None: @@ -91,7 +91,7 @@ def __init__( ) -class FilterResponseNorm3d(FilterResponseNormNd): +class FilterResponseNorm3d(_FilterResponseNormNd): def __init__( self, num_channels: int, eps: float = 1e-6, device=None, dtype=None ) -> None: diff --git a/torch_uncertainty/layers/mc_batch_norm.py b/torch_uncertainty/layers/mc_batch_norm.py index 1dd5907a..9a68e633 100644 --- a/torch_uncertainty/layers/mc_batch_norm.py +++ b/torch_uncertainty/layers/mc_batch_norm.py @@ -101,7 +101,7 @@ class MCBatchNorm1d(_MCBatchNorm): Check MCBatchNorm in torch_uncertainty/post_processing/. """ - def _check_input_dim(self, inputs): + def _check_input_dim(self, inputs) -> None: if inputs.dim() != 2 and inputs.dim() != 3: raise ValueError( f"expected 2D or 3D input (got {inputs.dim()}D input)" @@ -127,7 +127,7 @@ class MCBatchNorm2d(_MCBatchNorm): Check MCBatchNorm in torch_uncertainty/post_processing/. """ - def _check_input_dim(self, inputs): + def _check_input_dim(self, inputs) -> None: if inputs.dim() != 3 and inputs.dim() != 4: raise ValueError( f"expected 3D or 4D input (got {inputs.dim()}D input)" @@ -153,7 +153,7 @@ class MCBatchNorm3d(_MCBatchNorm): Check MCBatchNorm in torch_uncertainty/post_processing/. """ - def _check_input_dim(self, inputs): + def _check_input_dim(self, inputs) -> None: if inputs.dim() != 4 and inputs.dim() != 5: raise ValueError( f"expected 4D or 5D input (got {inputs.dim()}D input)" diff --git a/torch_uncertainty/layers/modules.py b/torch_uncertainty/layers/modules.py new file mode 100644 index 00000000..a5b9b18e --- /dev/null +++ b/torch_uncertainty/layers/modules.py @@ -0,0 +1,12 @@ +from typing import Any + +from torch import nn + + +class Identity(nn.Module): + # ruff: noqa: ARG002 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, *args) -> Any: + return args diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 38bcb7db..336c5576 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -8,6 +8,13 @@ def check_packed_parameters_consistency( alpha: float, num_estimators: int, gamma: int ) -> None: + """Check the consistency of the parameters of the Packed-Ensembles layers. + + Args: + alpha (float): The width multiplier of the layer. + num_estimators (int): The number of estimators in the ensemble. + gamma (int): The number of groups in the ensemble. + """ if alpha is None: raise ValueError("You must specify the value of the arg. `alpha`") @@ -92,7 +99,7 @@ def __init__( this constraint. Note: - The input should be of size (`batch_size`, :attr:`in_features`, 1, + The input should be of shape (`batch_size`, :attr:`in_features`, 1, 1). The (often) necessary rearrange operation is executed by default. """ diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index 8afe9661..55aeb91a 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -1,8 +1,40 @@ +from typing import Literal + import torch from torch import Tensor, nn +from torch.distributions import Distribution from torch.nn import functional as F -from .layers.bayesian import bayesian_modules +from torch_uncertainty.layers.bayesian import bayesian_modules +from torch_uncertainty.utils.distributions import NormalInverseGamma + + +class DistributionNLLLoss(nn.Module): + def __init__( + self, reduction: Literal["mean", "sum"] | None = "mean" + ) -> None: + """Negative Log-Likelihood loss using given distributions as inputs. + + Args: + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. Defaults to "mean". + """ + super().__init__() + self.reduction = reduction + + def forward(self, dist: Distribution, targets: Tensor) -> Tensor: + """Compute the NLL of the targets given predicted distributions. + + Args: + dist (Distribution): The predicted distributions + targets (Tensor): The target values + """ + loss = -dist.log_prob(targets) + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + return loss class KLDiv(nn.Module): @@ -34,57 +66,40 @@ def _kl_div(self) -> Tensor: class ELBOLoss(nn.Module): def __init__( self, - model: nn.Module, - criterion: nn.Module, + model: nn.Module | None, + inner_loss: nn.Module, kl_weight: float, num_samples: int, ) -> None: """The Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks. ELBO loss for Bayesian Neural Networks. Use this loss function with the - objective that you seek to minimize as :attr:`criterion`. + objective that you seek to minimize as :attr:`inner_loss`. Args: model (nn.Module): The Bayesian Neural Network to compute the loss for - criterion (nn.Module): The loss function to use during training + inner_loss (nn.Module): The loss function to use during training kl_weight (float): The weight of the KL divergence term num_samples (int): The number of samples to use for the ELBO loss + + Note: + Set the model to None if you use the ELBOLoss within + the ClassificationRoutine. It will get filled automatically. """ super().__init__() - self.model = model - self._kl_div = KLDiv(model) - - if isinstance(criterion, type): - raise TypeError( - "The criterion should be an instance of a class." - f"Got {criterion}." - ) - self.criterion = criterion + _elbo_loss_checks(inner_loss, kl_weight, num_samples) + self.set_model(model) - if kl_weight < 0: - raise ValueError( - f"The KL weight should be non-negative. Got {kl_weight}." - ) + self.inner_loss = inner_loss self.kl_weight = kl_weight - - if num_samples < 1: - raise ValueError( - "The number of samples should not be lower than 1." - f"Got {num_samples}." - ) - if not isinstance(num_samples, int): - raise TypeError( - "The number of samples should be an integer. " - f"Got {type(num_samples)}." - ) self.num_samples = num_samples def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: - """Gather the kl divergence from the bayesian modules and aggregate + """Gather the KL divergence from the bayesian modules and aggregate the ELBO loss for a given network. Args: - inputs (Tensor): The *inputs* of the Bayesian Neural Network + inputs (Tensor): The inputs of the Bayesian Neural Network targets (Tensor): The target values Returns: @@ -93,16 +108,50 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: aggregated_elbo = torch.zeros(1, device=inputs.device) for _ in range(self.num_samples): logits = self.model(inputs) - aggregated_elbo += self.criterion(logits, targets) + aggregated_elbo += self.inner_loss(logits, targets) aggregated_elbo += self.kl_weight * self._kl_div() return aggregated_elbo / self.num_samples + def set_model(self, model: nn.Module) -> None: + self.model = model + if model is not None: + self._kl_div = KLDiv(model) + + +def _elbo_loss_checks( + inner_loss: nn.Module, kl_weight: float, num_samples: int +) -> None: + if isinstance(inner_loss, type): + raise TypeError( + "The inner_loss should be an instance of a class." + f"Got {inner_loss}." + ) + + if kl_weight < 0: + raise ValueError( + f"The KL weight should be non-negative. Got {kl_weight}." + ) -class NIGLoss(nn.Module): + if num_samples < 1: + raise ValueError( + "The number of samples should not be lower than 1." + f"Got {num_samples}." + ) + if not isinstance(num_samples, int): + raise TypeError( + "The number of samples should be an integer. " + f"Got {type(num_samples)}." + ) + + +class DERLoss(DistributionNLLLoss): def __init__( self, reg_weight: float, reduction: str | None = "mean" ) -> None: - """The Normal Inverse-Gamma loss. + """The Deep Evidential loss. + + This loss combines the negative log-likelihood loss of the normal + inverse gamma distribution and a weighted regularization term. Args: reg_weight (float): The weight of the regularization term. @@ -113,7 +162,11 @@ def __init__( Amini, A., Schwarting, W., Soleimany, A., & Rus, D. (2019). Deep evidential regression. https://arxiv.org/abs/1910.02600. """ - super().__init__() + super().__init__(reduction=None) + + if reduction not in (None, "none", "mean", "sum"): + raise ValueError(f"{reduction} is not a valid value for reduction.") + self.final_reduction = reduction if reg_weight < 0: raise ValueError( @@ -121,49 +174,24 @@ def __init__( f"{reg_weight}." ) self.reg_weight = reg_weight - if reduction not in ("none", "mean", "sum"): - raise ValueError(f"{reduction} is not a valid value for reduction.") - self.reduction = reduction - - def _nig_nll( - self, - gamma: Tensor, - v: Tensor, - alpha: Tensor, - beta: Tensor, - targets: Tensor, - ) -> Tensor: - gam = 2 * beta * (1 + v) - return ( - 0.5 * torch.log(torch.pi / v) - - alpha * gam.log() - + (alpha + 0.5) * torch.log(gam + v * (targets - gamma) ** 2) - + torch.lgamma(alpha) - - torch.lgamma(alpha + 0.5) - ) - def _nig_reg( - self, gamma: Tensor, v: Tensor, alpha: Tensor, targets: Tensor - ) -> Tensor: - return torch.norm(targets - gamma, 1, dim=1, keepdim=True) * ( - 2 * v + alpha + def _reg(self, dist: NormalInverseGamma, targets: Tensor) -> Tensor: + return torch.norm(targets - dist.loc, 1, dim=1, keepdim=True) * ( + 2 * dist.lmbda + dist.alpha ) def forward( self, - gamma: Tensor, - v: Tensor, - alpha: Tensor, - beta: Tensor, + dist: NormalInverseGamma, targets: Tensor, ) -> Tensor: - loss_nll = self._nig_nll(gamma, v, alpha, beta, targets) - loss_reg = self._nig_reg(gamma, v, alpha, targets) + loss_nll = super().forward(dist, targets) + loss_reg = self._reg(dist, targets) loss = loss_nll + self.reg_weight * loss_reg - if self.reduction == "mean": + if self.final_reduction == "mean": return loss.mean() - if self.reduction == "sum": + if self.final_reduction == "sum": return loss.sum() return loss @@ -347,7 +375,7 @@ def forward( raise NotImplementedError( "DECLoss does not yet support mixup/cutmix." ) - # else: # TODO: handle binary + # TODO: handle binary targets = F.one_hot(targets, num_classes=evidence.size()[-1]) if self.loss_type == "mse": diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 572e5432..207d0c9b 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -1,11 +1,23 @@ # ruff: noqa: F401 -from .brier_score import BrierScore -from .calibration import CE -from .disagreement import Disagreement -from .entropy import Entropy -from .fpr95 import FPR95 -from .grouping_loss import GroupingLoss -from .mutual_information import MutualInformation -from .nll import GaussianNegativeLogLikelihood, NegativeLogLikelihood -from .sparsification import AUSE -from .variation_ratio import VariationRatio +from .classification import ( + AUSE, + CE, + FPR95, + BrierScore, + CategoricalNLL, + Disagreement, + Entropy, + GroupingLoss, + MeanIntersectionOverUnion, + MutualInformation, + VariationRatio, +) +from .regression import ( + DistributionNLL, + Log10, + MeanGTRelativeAbsoluteError, + MeanGTRelativeSquaredError, + MeanSquaredLogError, + SILog, + ThresholdAccuracy, +) diff --git a/torch_uncertainty/metrics/classification/__init__.py b/torch_uncertainty/metrics/classification/__init__.py new file mode 100644 index 00000000..df6078c9 --- /dev/null +++ b/torch_uncertainty/metrics/classification/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa: F401 +from .brier_score import BrierScore +from .calibration import CE +from .disagreement import Disagreement +from .entropy import Entropy +from .fpr95 import FPR95 +from .grouping_loss import GroupingLoss +from .mean_iou import MeanIntersectionOverUnion +from .mutual_information import MutualInformation +from .nll import CategoricalNLL +from .sparsification import AUSE +from .variation_ratio import VariationRatio diff --git a/torch_uncertainty/metrics/brier_score.py b/torch_uncertainty/metrics/classification/brier_score.py similarity index 99% rename from torch_uncertainty/metrics/brier_score.py rename to torch_uncertainty/metrics/classification/brier_score.py index 4d67450a..43b12f2c 100644 --- a/torch_uncertainty/metrics/brier_score.py +++ b/torch_uncertainty/metrics/classification/brier_score.py @@ -8,7 +8,7 @@ class BrierScore(Metric): - is_differentiable: bool = False + is_differentiable: bool = True higher_is_better: bool | None = False full_state_update: bool = False diff --git a/torch_uncertainty/metrics/calibration.py b/torch_uncertainty/metrics/classification/calibration.py similarity index 96% rename from torch_uncertainty/metrics/calibration.py rename to torch_uncertainty/metrics/classification/calibration.py index 55fc443c..c32787f4 100644 --- a/torch_uncertainty/metrics/calibration.py +++ b/torch_uncertainty/metrics/classification/calibration.py @@ -76,8 +76,8 @@ class MulticlassCE(MulticlassCalibrationError): # noqa: N818 def plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: fig, ax = plt.subplots() if ax is None else (None, ax) - conf = dim_zero_cat(self.confidences) - acc = dim_zero_cat(self.accuracies) + conf = dim_zero_cat(self.confidences).cpu() + acc = dim_zero_cat(self.accuracies).cpu() bin_width = 1 / self.n_bins @@ -98,9 +98,9 @@ def plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: acc.unsqueeze(1) * torch.nn.functional.one_hot(inverse).float(), 0, ) - / (val_oh.T @ counts + 1e-6).float() + / (val_oh.T.float() @ counts.float() + 1e-6) ) - counts_all = (val_oh.T @ counts).float() + counts_all = val_oh.T.float() @ counts.float() total = torch.sum(counts) plt.rc("axes", axisbelow=True) diff --git a/torch_uncertainty/metrics/disagreement.py b/torch_uncertainty/metrics/classification/disagreement.py similarity index 100% rename from torch_uncertainty/metrics/disagreement.py rename to torch_uncertainty/metrics/classification/disagreement.py diff --git a/torch_uncertainty/metrics/entropy.py b/torch_uncertainty/metrics/classification/entropy.py similarity index 98% rename from torch_uncertainty/metrics/entropy.py rename to torch_uncertainty/metrics/classification/entropy.py index dabb4cb4..a7eae7f6 100644 --- a/torch_uncertainty/metrics/entropy.py +++ b/torch_uncertainty/metrics/classification/entropy.py @@ -16,7 +16,7 @@ def __init__( **kwargs: Any, ) -> None: """The Shannon Entropy Metric to estimate the confidence of a single model - or the mean confidence accross estimators. + or the mean confidence across estimators. Args: reduction (str, optional): Determines how to reduce over the diff --git a/torch_uncertainty/metrics/fpr95.py b/torch_uncertainty/metrics/classification/fpr95.py similarity index 73% rename from torch_uncertainty/metrics/fpr95.py rename to torch_uncertainty/metrics/classification/fpr95.py index 1f351aed..87a1b93a 100644 --- a/torch_uncertainty/metrics/fpr95.py +++ b/torch_uncertainty/metrics/classification/fpr95.py @@ -38,7 +38,7 @@ def stable_cumsum(arr: ArrayLike, rtol: float = 1e-05, atol: float = 1e-08): return out -class FPR95(Metric): +class FPRx(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False @@ -46,29 +46,46 @@ class FPR95(Metric): conf: list[Tensor] targets: list[Tensor] - def __init__(self, pos_label: int, **kwargs) -> None: - """The False Positive Rate at 95% Recall metric.""" + def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: + """The False Positive Rate at x% Recall metric. + + Args: + recall_level (float): The recall level at which to compute the FPR. + pos_label (int): The positive label. + kwargs: Additional arguments to pass to the metric class. + """ super().__init__(**kwargs) + if recall_level < 0 or recall_level > 1: + raise ValueError( + f"Recall level must be between 0 and 1. Got {recall_level}." + ) + self.recall_level = recall_level self.pos_label = pos_label self.add_state("conf", [], dist_reduce_fx="cat") self.add_state("targets", [], dist_reduce_fx="cat") rank_zero_warn( - "Metric `FPR95` will save all targets and predictions" + f"Metric `FPR{int(recall_level*100)}` will save all targets and predictions" " in buffer. For large datasets this may lead to large memory" " footprint." ) def update(self, conf: Tensor, target: Tensor) -> None: + """Update the metric state. + + Args: + conf (Tensor): The confidence scores. + target (Tensor): The target labels. + """ self.conf.append(conf) self.targets.append(target) def compute(self) -> Tensor: - r"""Compute the actual False Positive Rate at 95% Recall. + r"""Compute the actual False Positive Rate at x% Recall. Returns: - Tensor: The value of the FPR95. + Tensor: The value of the FPRx. Reference: Inpired by https://github.com/hendrycks/anomaly-seg. @@ -82,7 +99,6 @@ def compute(self) -> Tensor: in_scores = conf[np.logical_not(out_labels)] out_scores = conf[out_labels] - # pos = OOD neg = np.array(in_scores[:]).reshape((-1, 1)) pos = np.array(out_scores[:]).reshape((-1, 1)) examples = np.squeeze(np.vstack((pos, neg))) @@ -120,8 +136,20 @@ def compute(self) -> Tensor: thresholds[sl], ) - cutoff = np.argmin(np.abs(recall - 0.95)) + cutoff = np.argmin(np.abs(recall - self.recall_level)) return torch.tensor( fps[cutoff] / (np.sum(np.logical_not(labels))), dtype=torch.float32 ) + + +class FPR95(FPRx): + def __init__(self, pos_label: int, **kwargs) -> None: + """The False Positive Rate at 95% Recall metric. + + Args: + recall_level (float): The recall level at which to compute the FPR. + pos_label (int): The positive label. + kwargs: Additional arguments to pass to the metric class. + """ + super().__init__(recall_level=0.95, pos_label=pos_label, **kwargs) diff --git a/torch_uncertainty/metrics/grouping_loss.py b/torch_uncertainty/metrics/classification/grouping_loss.py similarity index 94% rename from torch_uncertainty/metrics/grouping_loss.py rename to torch_uncertainty/metrics/classification/grouping_loss.py index cfa23367..da9eab41 100644 --- a/torch_uncertainty/metrics/grouping_loss.py +++ b/torch_uncertainty/metrics/classification/grouping_loss.py @@ -7,7 +7,9 @@ class GLEstimator(GLEstimatorBase): - def fit(self, probs: Tensor, targets: Tensor, features: Tensor): + def fit( + self, probs: Tensor, targets: Tensor, features: Tensor + ) -> "GLEstimator": probs = probs.detach().cpu().numpy() features = features.detach().cpu().numpy() targets = (targets * 1).detach().cpu().numpy() @@ -33,15 +35,11 @@ def __init__( Inputs: - :attr:`probs`: :math:`(B, C)` or :math:`(B, N, C)` - :attr:`target`: :math:`(B)` or :math:`(B, C)` + - :attr:`features`: :math:`(B, F)` or :math:`(B, N, F)` where :math:`B` is the batch size, :math:`C` is the number of classes and :math:`N` is the number of estimators. - Note: - If :attr:`probs` is a 3d tensor, then the metric computes the mean of - the Brier score over the estimators ie. :math:`t = \frac{1}{N} - \sum_{i=0}^{N-1} BrierScore(probs[:,i,:], target)`. - Warning: Make sure that the probabilities in :attr:`probs` are normalized to sum to one. diff --git a/torch_uncertainty/metrics/classification/mean_iou.py b/torch_uncertainty/metrics/classification/mean_iou.py new file mode 100644 index 00000000..95c5b8a0 --- /dev/null +++ b/torch_uncertainty/metrics/classification/mean_iou.py @@ -0,0 +1,16 @@ +from torch import Tensor +from torchmetrics.classification.stat_scores import MulticlassStatScores +from torchmetrics.utilities.compute import _safe_divide + + +class MeanIntersectionOverUnion(MulticlassStatScores): + """Compute the MeanIntersection over Union (IoU) score.""" + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def compute(self) -> Tensor: + """Compute the Means Intersection over Union (MIoU) based on saved inputs.""" + tp, fp, _, fn = self._final_state() + return _safe_divide(tp, tp + fp + fn).mean() diff --git a/torch_uncertainty/metrics/mutual_information.py b/torch_uncertainty/metrics/classification/mutual_information.py similarity index 100% rename from torch_uncertainty/metrics/mutual_information.py rename to torch_uncertainty/metrics/classification/mutual_information.py diff --git a/torch_uncertainty/metrics/nll.py b/torch_uncertainty/metrics/classification/nll.py similarity index 57% rename from torch_uncertainty/metrics/nll.py rename to torch_uncertainty/metrics/classification/nll.py index 08c98bb9..6a08f6d2 100644 --- a/torch_uncertainty/metrics/nll.py +++ b/torch_uncertainty/metrics/classification/nll.py @@ -2,14 +2,15 @@ import torch import torch.nn.functional as F +from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities.data import dim_zero_cat -class NegativeLogLikelihood(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = False - full_state_update: bool = False +class CategoricalNLL(Metric): + is_differentiable = False + higher_is_better = False + full_state_update = False def __init__( self, @@ -66,12 +67,12 @@ def __init__( self.add_state("values", default=[], dist_reduce_fx="cat") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def update(self, probs: torch.Tensor, target: torch.Tensor) -> None: + def update(self, probs: Tensor, target: Tensor) -> None: """Update state with prediction probabilities and targets. Args: - probs (torch.Tensor): Probabilities from the model. - target (torch.Tensor): Ground truth labels. + probs (Tensor): Probabilities from the model. + target (Tensor): Ground truth labels. """ if self.reduction is None or self.reduction == "none": self.values.append( @@ -81,7 +82,7 @@ def update(self, probs: torch.Tensor, target: torch.Tensor) -> None: self.values += F.nll_loss(torch.log(probs), target, reduction="sum") self.total += target.size(0) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """Computes NLL based on inputs passed in to ``update`` previously.""" values = dim_zero_cat(self.values) @@ -91,52 +92,3 @@ def compute(self) -> torch.Tensor: return values.sum(dim=-1) / self.total # reduction is None or "none" return values - - -class GaussianNegativeLogLikelihood(NegativeLogLikelihood): - """The Gaussian Negative Log Likelihood Metric. - - Args: - reduction (str, optional): Determines how to reduce over the - :math:`B`/batch dimension: - - - ``'mean'`` [default]: Averages score across samples - - ``'sum'``: Sum score across samples - - ``'none'`` or ``None``: Returns score per sample - - kwargs: Additional keyword arguments, see `Advanced metric settings - `_. - - Inputs: - - :attr:`mean`: :math:`(B, D)` - - :attr:`target`: :math:`(B, D)` - - :attr:`var`: :math:`(B, D)` - - where :math:`B` is the batch size and :math:`D` is the number of - dimensions. :math:`D` is optional. - - Raises: - ValueError: - If :attr:`reduction` is not one of ``'mean'``, ``'sum'``, - ``'none'`` or ``None``. - """ - - def update( - self, mean: torch.Tensor, target: torch.Tensor, var: torch.Tensor - ) -> None: - """Update state with prediction mean, targets, and prediction varoance. - - Args: - mean (torch.Tensor): Probabilities from the model. - target (torch.Tensor): Ground truth labels. - var (torch.Tensor): Predicted variance from the model. - """ - if self.reduction is None or self.reduction == "none": - self.values.append( - F.gaussian_nll_loss(mean, target, var, reduction="none") - ) - else: - self.values += F.gaussian_nll_loss( - mean, target, var, reduction="sum" - ) - self.total += target.size(0) diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/classification/sparsification.py similarity index 98% rename from torch_uncertainty/metrics/sparsification.py rename to torch_uncertainty/metrics/classification/sparsification.py index a647f182..c843f442 100644 --- a/torch_uncertainty/metrics/sparsification.py +++ b/torch_uncertainty/metrics/classification/sparsification.py @@ -19,7 +19,7 @@ class AUSE(Metric): scores: list[Tensor] errors: list[Tensor] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: """The Area Under the Sparsification Error curve (AUSE) metric to estimate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors. @@ -35,7 +35,7 @@ def __init__(self, **kwargs): Inputs: - :attr:`scores`: Uncertainty scores of shape :math:`(B,)`. A higher - score means a higher uncertainty. + score means a higher uncertainty. - :attr:`errors`: Binary errors of shape :math:`(B,)`, where :math:`B` is the batch size. diff --git a/torch_uncertainty/metrics/variation_ratio.py b/torch_uncertainty/metrics/classification/variation_ratio.py similarity index 96% rename from torch_uncertainty/metrics/variation_ratio.py rename to torch_uncertainty/metrics/classification/variation_ratio.py index faaddce3..cdf05c89 100644 --- a/torch_uncertainty/metrics/variation_ratio.py +++ b/torch_uncertainty/metrics/classification/variation_ratio.py @@ -52,7 +52,7 @@ def compute(self) -> Tensor: n_estimators = probs_per_est.shape[1] probs = probs_per_est.mean(dim=1) - # best class for exemple + # best class for example max_classes = probs.argmax(dim=-1) if self.probabilistic: @@ -61,7 +61,7 @@ def compute(self) -> Tensor: torch.arange(probs_per_est.size(0)), max_classes ].mean(dim=1) else: - # best class for (exemple, estimator) + # best class for (example, estimator) max_classes_per_est = probs_per_est.argmax(dim=-1) variation_ratio = ( 1 diff --git a/torch_uncertainty/metrics/regression/__init__.py b/torch_uncertainty/metrics/regression/__init__.py new file mode 100644 index 00000000..50f26c74 --- /dev/null +++ b/torch_uncertainty/metrics/regression/__init__.py @@ -0,0 +1,10 @@ +# ruff: noqa: F401 +from .log10 import Log10 +from .mse_log import MeanSquaredLogError +from .nll import DistributionNLL +from .relative_error import ( + MeanGTRelativeAbsoluteError, + MeanGTRelativeSquaredError, +) +from .silog import SILog +from .threshold_accuracy import ThresholdAccuracy diff --git a/torch_uncertainty/metrics/regression/log10.py b/torch_uncertainty/metrics/regression/log10.py new file mode 100644 index 00000000..acd9a0e1 --- /dev/null +++ b/torch_uncertainty/metrics/regression/log10.py @@ -0,0 +1,36 @@ +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class Log10(Metric): + def __init__(self, **kwargs) -> None: + r"""The Log10 metric. + + .. math:: \text{Log10} = \frac{1}{N} \sum_{i=1}^{N} \log_{10}(y_i) - \log_{10}(\hat{y_i}) + + where :math:`N` is the batch size, :math:`y_i` is a tensor of target values and :math:`\hat{y_i}` is a tensor of prediction. + + Inputs: + - :attr:`preds`: :math:`(N)` + - :attr:`target`: :math:`(N)` + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + """ + super().__init__(**kwargs) + self.add_state( + "values", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.values += torch.sum(pred.log10() - target.log10()) + self.total += target.size(0) + + def compute(self) -> Tensor: + """Compute the Log10 metric.""" + values = dim_zero_cat(self.values) + return values / self.total diff --git a/torch_uncertainty/metrics/regression/mse_log.py b/torch_uncertainty/metrics/regression/mse_log.py new file mode 100644 index 00000000..caae3186 --- /dev/null +++ b/torch_uncertainty/metrics/regression/mse_log.py @@ -0,0 +1,34 @@ +from torch import Tensor +from torchmetrics import MeanSquaredError + + +class MeanSquaredLogError(MeanSquaredError): + def __init__(self, squared: bool = True, **kwargs) -> None: + r"""`Compute MeanSquaredLogError`_ (MSELog). + + .. math:: \text{MSELog} = \frac{1}{N}\sum_i^N (\log \hat{y_i} - \log y_i)^2 + + where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``mse_log`` (:class:`~torch.Tensor`): A tensor with the + relative mean absolute error over the state + + Args: + squared: If True returns MSELog value, if False returns EMSELog value. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Reference: + As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + """ + super().__init__(squared, **kwargs) + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + return super().update(pred.log(), target.log()) diff --git a/torch_uncertainty/metrics/regression/nll.py b/torch_uncertainty/metrics/regression/nll.py new file mode 100644 index 00000000..9b2f9c3a --- /dev/null +++ b/torch_uncertainty/metrics/regression/nll.py @@ -0,0 +1,30 @@ +from torch import Tensor, distributions +from torchmetrics.utilities.data import dim_zero_cat + +from torch_uncertainty.metrics import CategoricalNLL + + +class DistributionNLL(CategoricalNLL): + def update(self, dist: distributions.Distribution, target: Tensor) -> None: + """Update state with the predicted distributions and the targets. + + Args: + dist (torch.distributions.Distribution): Predicted distributions. + target (Tensor): Ground truth labels. + """ + if self.reduction is None or self.reduction == "none": + self.values.append(-dist.log_prob(target)) + else: + self.values += -dist.log_prob(target).sum() + self.total += target.size(0) + + def compute(self) -> Tensor: + """Computes NLL based on inputs passed in to ``update`` previously.""" + values = dim_zero_cat(self.values) + + if self.reduction == "sum": + return values.sum(dim=-1) + if self.reduction == "mean": + return values.sum(dim=-1) / self.total + # reduction is None or "none" + return values diff --git a/torch_uncertainty/metrics/regression/relative_error.py b/torch_uncertainty/metrics/regression/relative_error.py new file mode 100644 index 00000000..27ac1eb4 --- /dev/null +++ b/torch_uncertainty/metrics/regression/relative_error.py @@ -0,0 +1,68 @@ +import torch +from torch import Tensor +from torchmetrics import MeanAbsoluteError, MeanSquaredError + + +class MeanGTRelativeAbsoluteError(MeanAbsoluteError): + def __init__(self, **kwargs) -> None: + r"""`Compute Mean Absolute Error relative to the Ground Truth`_ (MAErel or ARE). + + .. math:: \text{MAErel} = \frac{1}{N}\sum_i^N \frac{| y_i - \hat{y_i} |}{y_i} + + where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rel_mean_absolute_error`` (:class:`~torch.Tensor`): A tensor with the + relative mean absolute error over the state + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Reference: + As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + """ + super().__init__(**kwargs) + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + return super().update(pred / target, torch.ones_like(target)) + + +class MeanGTRelativeSquaredError(MeanSquaredError): + def __init__( + self, squared: bool = True, num_outputs: int = 1, **kwargs + ) -> None: + r"""Compute `mean squared error relative to the Ground Truth`_ (MSErel or SRE). + + .. math:: \text{MSErel} = \frac{1}{N}\sum_i^N \frac{(y_i - \hat{y_i})^2}{y_i} + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rel_mean_squared_error`` (:class:`~torch.Tensor`): A tensor with the relative mean squared error + + Args: + squared: If True returns MSErel value, if False returns RMSErel value. + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Reference: + As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + """ + super().__init__(squared, num_outputs, **kwargs) + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + return super().update(pred / torch.sqrt(target), torch.sqrt(target)) diff --git a/torch_uncertainty/metrics/regression/silog.py b/torch_uncertainty/metrics/regression/silog.py new file mode 100644 index 00000000..370b7036 --- /dev/null +++ b/torch_uncertainty/metrics/regression/silog.py @@ -0,0 +1,45 @@ +from typing import Any + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class SILog(Metric): + def __init__(self, lmbda: float = 1, **kwargs: Any) -> None: + r"""The Scale-Invariant Logarithmic Loss metric. + + .. math:: \text{SILog} = \frac{1}{N} \sum_{i=1}^{N} \left(\log(y_i) - \log(\hat{y_i})\right)^2 - \left(\frac{1}{N} \sum_{i=1}^{N} \log(y_i) \right)^2 + + where :math:`N` is the batch size, :math:`y_i` is a tensor of target values and :math:`\hat{y_i}` is a tensor of prediction. + + Inputs: + - :attr:`pred`: :math:`(N)` + - :attr:`target`: :math:`(N)` + + Args: + lmbda: The regularization parameter on the variance of error (default 1). + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Reference: + Depth Map Prediction from a Single Image using a Multi-Scale Deep Network. + David Eigen, Christian Puhrsch, Rob Fergus. NeurIPS 2014. + From Big to Small: Multi-Scale Local Planar Guidance for Monocular Depth Estimation. + Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh. For the lambda parameter. + """ + super().__init__(**kwargs) + self.lmbda = lmbda + self.add_state("log_dists", default=[], dist_reduce_fx="cat") + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.log_dists.append(torch.flatten(pred.log() - target.log())) + + def compute(self) -> Tensor: + """Compute the Scale-Invariant Logarithmic Loss.""" + log_dists = dim_zero_cat(self.log_dists) + num_samples = log_dists.size(0) + return torch.mean(log_dists**2) - self.lmbda * torch.sum( + log_dists + ) ** 2 / (num_samples * num_samples) diff --git a/torch_uncertainty/metrics/regression/threshold_accuracy.py b/torch_uncertainty/metrics/regression/threshold_accuracy.py new file mode 100644 index 00000000..68068ad8 --- /dev/null +++ b/torch_uncertainty/metrics/regression/threshold_accuracy.py @@ -0,0 +1,43 @@ +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class ThresholdAccuracy(Metric): + def __init__(self, power: int, lmbda: float = 1.25, **kwargs) -> None: + r"""The Threshold Accuracy metric, a.k.a. d1, d2, d3. + + Args: + power: The power to raise the threshold to. Often in [1, 2, 3]. + lmbda: The threshold to compare the max of ratio of predictions + to targets and its inverse to. Defaults to 1.25. + kwargs: Additional arguments to pass to the metric class. + """ + super().__init__(**kwargs) + if power < 0: + raise ValueError( + f"Power must be greater than or equal to 0. Got {power}." + ) + self.power = power + if lmbda < 1: + raise ValueError( + f"Lambda must be greater than or equal to 1. Got {lmbda}." + ) + self.lmbda = lmbda + self.add_state( + "values", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.values += torch.sum( + torch.max(preds / target, target / preds) < self.lmbda**self.power + ) + self.total += target.size(0) + + def compute(self) -> Tensor: + """Compute the Threshold Accuracy.""" + values = dim_zero_cat(self.values) + return values / self.total diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index afc7480f..08dfc824 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -1,2 +1,3 @@ # ruff: noqa: F401 from .deep_ensembles import deep_ensembles +from .mc_dropout import mc_dropout diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index d82152df..637e43d6 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -1,7 +1,11 @@ import copy +from typing import Literal import torch from torch import nn +from torch.distributions import Distribution + +from torch_uncertainty.utils.distributions import cat_dist class _DeepEnsembles(nn.Module): @@ -9,9 +13,8 @@ def __init__( self, models: list[nn.Module], ) -> None: - """Create a deep ensembles from a list of models.""" + """Create a classification deep ensembles from a list of models.""" super().__init__() - self.models = nn.ModuleList(models) self.num_estimators = len(models) @@ -29,18 +32,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat([model.forward(x) for model in self.models], dim=0) +class _RegDeepEnsembles(_DeepEnsembles): + def __init__( + self, + probabilistic: bool, + models: list[nn.Module], + ) -> None: + """Create a regression deep ensembles from a list of models.""" + super().__init__(models) + self.probabilistic = probabilistic + + def forward(self, x: torch.Tensor) -> Distribution: + r"""Return the logits of the ensemble. + + Args: + x (Tensor): The input of the model. + + Returns: + Distribution: + """ + if self.probabilistic: + return cat_dist([model.forward(x) for model in self.models], dim=0) + return super().forward(x) + + def deep_ensembles( models: list[nn.Module] | nn.Module, num_estimators: int | None = None, -) -> nn.Module: + task: Literal[ + "classification", "regression", "segmentation" + ] = "classification", + probabilistic: bool | None = None, + reset_model_parameters: bool = False, +) -> _DeepEnsembles: """Build a Deep Ensembles out of the original models. Args: models (list[nn.Module] | nn.Module): The model to be ensembled. num_estimators (int | None): The number of estimators in the ensemble. + task (Literal["classification", "regression"]): The model task. + probabilistic (bool): Whether the regression model is probabilistic. + reset_model_parameters (bool): Whether to reset the model parameters + when :attr:models is a module or a list of length 1. Returns: - nn.Module: The ensembled model. + _DeepEnsembles: The ensembled model. Raises: ValueError: If :attr:num_estimators is not specified and :attr:models @@ -55,10 +91,8 @@ def deep_ensembles( Simple and scalable predictive uncertainty estimation using deep ensembles. In NeurIPS, 2017. """ - if ( - isinstance(models, list) - and len(models) == 1 - or isinstance(models, nn.Module) + if (isinstance(models, list) and len(models) == 1) or isinstance( + models, nn.Module ): if num_estimators is None: raise ValueError( @@ -73,6 +107,13 @@ def deep_ensembles( models = models[0] models = [copy.deepcopy(models) for _ in range(num_estimators)] + + if reset_model_parameters: + for model in models: + for layer in model.children(): + if hasattr(layer, "reset_parameters"): + layer.reset_parameters() + elif ( isinstance(models, list) and len(models) > 1 @@ -82,4 +123,12 @@ def deep_ensembles( "num_estimators must be None if you provided a non-singleton list." ) - return _DeepEnsembles(models=models) + if task in ("classification", "segmentation"): + return _DeepEnsembles(models=models) + if task == "regression": + if probabilistic is None: + raise ValueError( + "probabilistic must be specified for regression models." + ) + return _RegDeepEnsembles(probabilistic=probabilistic, models=models) + raise ValueError(f"Unknown task: {task}.") diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 61832ce7..fcb9663e 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -10,7 +10,7 @@ from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear from torch_uncertainty.models.utils import stochastic_model -__all__ = ["lenet", "packed_lenet", "bayesian_lenet"] +__all__ = ["bayesian_lenet", "lenet", "packed_lenet"] class _LeNet(nn.Module): diff --git a/torch_uncertainty/models/mc_dropout.py b/torch_uncertainty/models/mc_dropout.py index d3ac77af..355fe43e 100644 --- a/torch_uncertainty/models/mc_dropout.py +++ b/torch_uncertainty/models/mc_dropout.py @@ -13,7 +13,8 @@ def __init__( last_layer (bool): whether to apply dropout to the last layer only. Warning: - The underlying models must have a `dropout_rate` attribute. + The underlying models must have a non-zero :attr:`dropout_rate` + attribute. Warning: For the `last-layer` option to work properly, the model must @@ -70,7 +71,7 @@ def train(self, mode: bool = True) -> nn.Module: def forward( self, x: Tensor, - ) -> tuple[Tensor, Tensor]: + ) -> Tensor: if not self.training: x = x.repeat(self.num_estimators, 1, 1, 1) return self.model(x) @@ -85,7 +86,7 @@ def mc_dropout( model (nn.Module): model to wrap num_estimators (int): number of estimators to use last_layer (bool, optional): whether to apply dropout to the last - layer. Defaults to False. + layer only. Defaults to False. """ return _MCDropout( model=model, num_estimators=num_estimators, last_layer=last_layer diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 8e0cefbc..a822343d 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -7,7 +7,7 @@ from torch_uncertainty.layers.packed import PackedLinear from torch_uncertainty.models.utils import stochastic_model -__all__ = ["mlp", "packed_mlp", "bayesian_mlp"] +__all__ = ["bayesian_mlp", "mlp", "packed_mlp"] class _MLP(nn.Module): @@ -19,7 +19,9 @@ def __init__( layer: type[nn.Module], activation: Callable, layer_args: dict, - dropout: float, + final_layer: nn.Module, + final_layer_args: dict, + dropout_rate: float, ) -> None: """Multi-layer perceptron class. @@ -30,11 +32,13 @@ def __init__( layer (nn.Module): Layer class. activation (Callable): Activation function. layer_args (Dict): Arguments for the layer class. - dropout (float): Dropout probability. + final_layer (nn.Module): Final layer class for distribution regression. + final_layer_args (Dict): Arguments for the final layer class. + dropout_rate (float): Dropout probability. """ super().__init__() self.activation = activation - self.dropout = dropout + self.dropout_rate = dropout_rate layers = nn.ModuleList() @@ -70,14 +74,14 @@ def __init__( ) else: layers.append(layer(hidden_dims[-1], num_outputs, **layer_args)) - self.layers = layers + self.final_layer = final_layer(**final_layer_args) def forward(self, x: Tensor) -> Tensor: for layer in self.layers[:-1]: - x = F.dropout(layer(x), p=self.dropout, training=self.training) + x = F.dropout(layer(x), p=self.dropout_rate, training=self.training) x = self.activation(x) - return self.layers[-1](x) + return self.final_layer(self.layers[-1](x)) @stochastic_model @@ -93,10 +97,14 @@ def _mlp( layer_args: dict | None = None, layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _MLP | _StochasticMLP: if layer_args is None: layer_args = {} + if final_layer_args is None: + final_layer_args = {} model = _MLP if not stochastic else _StochasticMLP return model( in_features=in_features, @@ -105,7 +113,9 @@ def _mlp( layer_args=layer_args, layer=layer, activation=activation, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) @@ -115,7 +125,9 @@ def mlp( hidden_dims: list[int], layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _MLP: """Multi-layer perceptron. @@ -126,7 +138,10 @@ def mlp( layer (nn.Module, optional): Layer type. Defaults to nn.Linear. activation (Callable, optional): Activation function. Defaults to F.relu. - dropout (float, optional): Dropout probability. Defaults to 0.0. + final_layer (nn.Module, optional): Final layer class for distribution + regression. Defaults to nn.Identity. + final_layer_args (Dict, optional): Arguments for the final layer class. + dropout_rate (float, optional): Dropout probability. Defaults to 0.0. Returns: _MLP: A Multi-Layer-Perceptron model. @@ -138,7 +153,9 @@ def mlp( hidden_dims=hidden_dims, layer=layer, activation=activation, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) @@ -150,7 +167,9 @@ def packed_mlp( alpha: float = 2, gamma: float = 1, activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _MLP: layer_args = { "num_estimators": num_estimators, @@ -165,7 +184,9 @@ def packed_mlp( layer=PackedLinear, activation=activation, layer_args=layer_args, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) @@ -174,7 +195,9 @@ def bayesian_mlp( num_outputs: int, hidden_dims: list[int], activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _StochasticMLP: return _mlp( stochastic=True, @@ -183,5 +206,7 @@ def bayesian_mlp( hidden_dims=hidden_dims, layer=BayesLinear, activation=activation, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index fcdc1a5b..52b3fc1f 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -300,7 +300,7 @@ def _make_layer( self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: out = x.repeat(self.num_estimators, 1, 1, 1) out = F.relu(self.bn1(self.conv1(out))) out = self.optional_pool(out) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 8c8ce3c3..ea61606d 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -105,7 +105,7 @@ def __init__( num_estimators=num_estimators, scale=scale, groups=groups, - bias=False, + bias=conv_bias, ) self.bn1 = normalization_layer(planes) self.conv2 = MaskedConv2d( @@ -117,7 +117,7 @@ def __init__( stride=stride, padding=1, groups=groups, - bias=False, + bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) self.bn2 = normalization_layer(planes) @@ -128,7 +128,7 @@ def __init__( num_estimators=num_estimators, scale=scale, groups=groups, - bias=False, + bias=conv_bias, ) self.bn3 = normalization_layer(self.expansion * planes) @@ -143,7 +143,7 @@ def __init__( scale=scale, stride=stride, groups=groups, - bias=False, + bias=conv_bias, ), normalization_layer(self.expansion * planes), ) diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index 148740d1..0eeea7ba 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -143,6 +143,7 @@ def forward(self, x: Tensor) -> Tensor: return self.activation_fn(out) +# ruff: noqa: ERA001 # class Robust_Bottleneck(nn.Module): # """Robust _Bottleneck from "Can CNNs be more robust than transformers?" # This corresponds to ResNet-Up-Inverted-DW in the paper. diff --git a/torch_uncertainty/models/segmentation/__init__.py b/torch_uncertainty/models/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/models/segmentation/segformer/__init__.py b/torch_uncertainty/models/segmentation/segformer/__init__.py new file mode 100644 index 00000000..dc3fb2ee --- /dev/null +++ b/torch_uncertainty/models/segmentation/segformer/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401, F403 +from .std import * diff --git a/torch_uncertainty/models/segmentation/segformer/std.py b/torch_uncertainty/models/segmentation/segformer/std.py new file mode 100644 index 00000000..3881e055 --- /dev/null +++ b/torch_uncertainty/models/segmentation/segformer/std.py @@ -0,0 +1,806 @@ +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- + +import math +import warnings +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, h, w): + b, _, c = x.shape + x = x.transpose(1, 2).view(b, c, h, w) + x = self.dwconv(x) + return x.flatten(2).transpose(1, 2) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, h, w): + x = self.fc1(x) + x = self.dwconv(x, h, w) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, h, w): + b, n, c = x.shape + q = ( + self.q(x) + .reshape(b, n, self.num_heads, c // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(b, c, h, w) + x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = ( + self.kv(x_) + .reshape(b, -1, 2, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + else: + kv = ( + self.kv(x) + .reshape(b, -1, 2, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + return self.proj_drop(x) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better + # than dropout here + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, h, w): + x = x + self.drop_path(self.attn(self.norm1(x), h, w)) + return x + self.drop_path(self.mlp(self.norm2(x), h, w)) + + +class OverlapPatchEmbed(nn.Module): + """Image to Patch Embedding.""" + + def __init__( + self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.h, self.w = ( + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ) + self.num_patches = self.h * self.w + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, h, w = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, h, w + + +class MixVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=None, + num_heads=None, + mlp_ratios=None, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=None, + sr_ratios=None, + ): + if sr_ratios is None: + sr_ratios = [8, 4, 2, 1] + if depths is None: + depths = [3, 4, 6, 3] + if mlp_ratios is None: + mlp_ratios = [4, 4, 4, 4] + if num_heads is None: + num_heads = [1, 2, 4, 8] + if embed_dims is None: + embed_dims = [64, 128, 256, 512] + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed( + img_size=img_size, + patch_size=7, + stride=4, + in_chans=in_chans, + embed_dim=embed_dims[0], + ) + self.patch_embed2 = OverlapPatchEmbed( + img_size=img_size // 4, + patch_size=3, + stride=2, + in_chans=embed_dims[0], + embed_dim=embed_dims[1], + ) + self.patch_embed3 = OverlapPatchEmbed( + img_size=img_size // 8, + patch_size=3, + stride=2, + in_chans=embed_dims[1], + embed_dim=embed_dims[2], + ) + self.patch_embed4 = OverlapPatchEmbed( + img_size=img_size // 16, + patch_size=3, + stride=2, + in_chans=embed_dims[2], + embed_dim=embed_dims[3], + ) + + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList( + [ + Block( + dim=embed_dims[0], + num_heads=num_heads[0], + mlp_ratio=mlp_ratios[0], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[0], + ) + for i in range(depths[0]) + ] + ) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList( + [ + Block( + dim=embed_dims[1], + num_heads=num_heads[1], + mlp_ratio=mlp_ratios[1], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[1], + ) + for i in range(depths[1]) + ] + ) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList( + [ + Block( + dim=embed_dims[2], + num_heads=num_heads[2], + mlp_ratio=mlp_ratios[2], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[2], + ) + for i in range(depths[2]) + ] + ) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList( + [ + Block( + dim=embed_dims[3], + num_heads=num_heads[3], + mlp_ratio=mlp_ratios[3], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[3], + ) + for i in range(depths[3]) + ] + ) + self.norm4 = norm_layer(embed_dims[3]) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward_features(self, x): + b = x.shape[0] + outs = [] + + # stage 1 + x, h, w = self.patch_embed1(x) + for _i, blk in enumerate(self.block1): + x = blk(x, h, w) + x = self.norm1(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, h, w = self.patch_embed2(x) + for _i, blk in enumerate(self.block2): + x = blk(x, h, w) + x = self.norm2(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, h, w = self.patch_embed3(x) + for _i, blk in enumerate(self.block3): + x = blk(x, h, w) + x = self.norm3(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, h, w = self.patch_embed4(x) + for _i, blk in enumerate(self.block4): + x = blk(x, h, w) + x = self.norm4(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs + + def forward(self, x): + return self.forward_features(x) + + +class MitB0(MixVisionTransformer): + def __init__(self): + super().__init__( + patch_size=4, + embed_dims=[32, 64, 160, 256], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB1(MixVisionTransformer): + def __init__(self): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB2(MixVisionTransformer): + def __init__(self): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB3(MixVisionTransformer): + def __init__(self): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 18, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB4(MixVisionTransformer): + def __init__(self): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 8, 27, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB5(MixVisionTransformer): + def __init__(self): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 6, 40, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MLPHead(nn.Module): + """Linear Embedding.""" + + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + return self.proj(x) + + +def resize( + inputs, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + warning=True, +): + if warning and size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in inputs.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if (output_h > input_h or output_w > output_h) and ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`", + stacklevel=2, + ) + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) + return F.interpolate(inputs, size, scale_factor, mode, align_corners) + + +class SegFormerHead(nn.Module): + """SegFormer: Simple and Efficient Design for Semantic Segmentation with + Transformers. + """ + + def __init__( + self, + in_channels, + feature_strides, + decoder_params, + num_classes, + dropout_ratio=0.1, + ): + super().__init__() + self.in_channels = in_channels + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + self.num_classes = num_classes + # --- self in_index [0, 1, 2, 3] + + ( + c1_in_channels, + c2_in_channels, + c3_in_channels, + c4_in_channels, + ) = self.in_channels + + embedding_dim = decoder_params["embed_dim"] + + self.linear_c4 = MLPHead( + input_dim=c4_in_channels, embed_dim=embedding_dim + ) + self.linear_c3 = MLPHead( + input_dim=c3_in_channels, embed_dim=embedding_dim + ) + self.linear_c2 = MLPHead( + input_dim=c2_in_channels, embed_dim=embedding_dim + ) + self.linear_c1 = MLPHead( + input_dim=c1_in_channels, embed_dim=embedding_dim + ) + + self.fuse = nn.Sequential( + nn.Conv2d( + embedding_dim * 4, embedding_dim, kernel_size=1, bias=False + ), + nn.ReLU(), + nn.BatchNorm2d(embedding_dim), + ) + + self.linear_pred = nn.Conv2d( + embedding_dim, self.num_classes, kernel_size=1 + ) + + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + + def forward(self, inputs): + # x [inputs[i] for i in self.in_index] # len=4, 1/4,1/8,1/16,1/32 + c1, c2, c3, c4 = inputs[0], inputs[1], inputs[2], inputs[3] + + n, _, _, _ = c4.shape + + _c4 = ( + self.linear_c4(c4) + .permute(0, 2, 1) + .reshape(n, -1, c4.shape[2], c4.shape[3]) + ) + _c4 = resize( + _c4, size=c1.size()[2:], mode="bilinear", align_corners=False + ) + + _c3 = ( + self.linear_c3(c3) + .permute(0, 2, 1) + .reshape(n, -1, c3.shape[2], c3.shape[3]) + ) + _c3 = resize( + _c3, size=c1.size()[2:], mode="bilinear", align_corners=False + ) + + _c2 = ( + self.linear_c2(c2) + .permute(0, 2, 1) + .reshape(n, -1, c2.shape[2], c2.shape[3]) + ) + _c2 = resize( + _c2, size=c1.size()[2:], mode="bilinear", align_corners=False + ) + + _c1 = ( + self.linear_c1(c1) + .permute(0, 2, 1) + .reshape(n, -1, c1.shape[2], c1.shape[3]) + ) + + _c = self.fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + + x = self.dropout(_c) + return self.linear_pred(x) + + +class _SegFormer(nn.Module): + def __init__( + self, + in_channels, + feature_strides, + decoder_params, + num_classes, + dropout_ratio, + mit: nn.Module, + ): + super().__init__() + + self.encoder = mit() + self.head = SegFormerHead( + in_channels, + feature_strides, + decoder_params, + num_classes, + dropout_ratio, + ) + + def forward(self, x): + features = self.encoder(x) + return self.head(features) + + +def seg_former_b0(num_classes: int): + return _SegFormer( + in_channels=[32, 64, 160, 256], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 256}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB0, + ) + + +def seg_former_b1(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB1, + ) + + +def seg_former_b2(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB2, + ) + + +def seg_former_b3(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB3, + ) + + +def seg_former_b4(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB4, + ) + + +def seg_former_b5(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB5, + ) diff --git a/torch_uncertainty/models/wideresnet/mimo.py b/torch_uncertainty/models/wideresnet/mimo.py index 9adca113..c3a25e0a 100644 --- a/torch_uncertainty/models/wideresnet/mimo.py +++ b/torch_uncertainty/models/wideresnet/mimo.py @@ -39,7 +39,6 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.training: x = x.repeat(self.num_estimators, 1, 1, 1) - out = rearrange(x, "(m b) c h w -> b (m c) h w", m=self.num_estimators) out = super().forward(out) return rearrange(out, "b (m d) -> (m b) d", m=self.num_estimators) diff --git a/torch_uncertainty/models/wideresnet/std.py b/torch_uncertainty/models/wideresnet/std.py index 3c943eaa..3e14b2c8 100644 --- a/torch_uncertainty/models/wideresnet/std.py +++ b/torch_uncertainty/models/wideresnet/std.py @@ -163,7 +163,7 @@ def _wide_layer( num_blocks: int, dropout_rate: float, stride: int, - groups, + groups: int, ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] diff --git a/torch_uncertainty/optimization_procedures.py b/torch_uncertainty/optim_recipes.py similarity index 93% rename from torch_uncertainty/optimization_procedures.py rename to torch_uncertainty/optim_recipes.py index 07d7f21d..a413b02c 100644 --- a/torch_uncertainty/optimization_procedures.py +++ b/torch_uncertainty/optim_recipes.py @@ -8,17 +8,16 @@ __all__ = [ "optim_cifar10_resnet18", + "optim_cifar10_resnet34", "optim_cifar10_resnet50", - "optim_cifar10_wideresnet", "optim_cifar10_vgg16", + "optim_cifar10_wideresnet", "optim_cifar100_resnet18", + "optim_cifar100_resnet34", "optim_cifar100_resnet50", "optim_cifar100_vgg16", "optim_imagenet_resnet50", "optim_imagenet_resnet50_a3", - "optim_regression", - "optim_cifar10_resnet34", - "optim_cifar100_resnet34", "optim_tinyimagenet_resnet34", "optim_tinyimagenet_resnet50", ] @@ -270,7 +269,7 @@ def optim_cifar100_resnet34( def optim_tinyimagenet_resnet34( model: nn.Module, ) -> dict[str, Optimizer | LRScheduler]: - """Optimization procedure from 'The Devil is in the Margin: Margin-based + """Optimization recipe from 'The Devil is in the Margin: Margin-based Label Smoothing for Network Calibration', (CVPR 2022, https://arxiv.org/abs/2111.15430): 'We train for 100 epochs with a learning rate of 0.1 for the first @@ -295,7 +294,7 @@ def optim_tinyimagenet_resnet34( def optim_tinyimagenet_resnet50( model: nn.Module, ) -> dict[str, Optimizer | LRScheduler]: - """Optimization procedure from 'The Devil is in the Margin: Margin-based + """Optimization recipe from 'The Devil is in the Margin: Margin-based Label Smoothing for Network Calibration', (CVPR 2022, https://arxiv.org/abs/2111.15430): 'We train for 100 epochs with a learning rate of 0.1 for the first @@ -317,25 +316,8 @@ def optim_tinyimagenet_resnet50( return {"optimizer": optimizer, "lr_scheduler": scheduler} -def optim_regression( - model: nn.Module, - learning_rate: float = 1e-2, -) -> dict: - optimizer = optim.SGD( - model.parameters(), - lr=learning_rate, - weight_decay=0, - ) - return { - "optimizer": optimizer, - "monitor": "reg/val_nll", - } - - -def batch_ensemble_wrapper( - model: nn.Module, optimization_procedure: Callable -) -> dict: - procedure = optimization_procedure(model) +def batch_ensemble_wrapper(model: nn.Module, optim_recipe: Callable) -> dict: + procedure = optim_recipe(model) param_optimizer = procedure["optimizer"] scheduler = procedure["lr_scheduler"] @@ -379,7 +361,7 @@ def get_procedure( method: str = "", imagenet_recipe: str | None = None, ) -> Callable: - """Get the optimization procedure for a given architecture and dataset. + """Get the optimization recipe for a given architecture and dataset. Args: arch_name (str): The name of the architecture. @@ -389,7 +371,7 @@ def get_procedure( ImageNet. Defaults to None. Returns: - callable: The optimization procedure. + callable: The optimization recipe. """ if arch_name in ["resnet18", "resnet20"]: if ds_name == "cifar10": @@ -437,8 +419,6 @@ def get_procedure( raise NotImplementedError(f"No recipe for architecture: {arch_name}.") if method == "batched": - procedure = partial( - batch_ensemble_wrapper, optimization_procedure=procedure - ) + procedure = partial(batch_ensemble_wrapper, optim_recipe=procedure) return procedure diff --git a/torch_uncertainty/plotting_utils.py b/torch_uncertainty/plotting_utils.py deleted file mode 100644 index c9bdb337..00000000 --- a/torch_uncertainty/plotting_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -import matplotlib.pyplot as plt -import torch -from matplotlib.axes import Axes -from matplotlib.figure import Figure - - -def plot_hist( - conf: list[torch.Tensor], - bins: int = 20, - title: str = "Histogram with 'auto' bins", - dpi: int = 60, -) -> tuple[Figure, Axes]: - """Plot a confidence histogram. - - Args: - conf (Any): The confidence values. - bins (int, optional): The number of bins. Defaults to 20. - title (str, optional): The title of the plot. Defaults to "Histogram - with 'auto' bins". - dpi (int, optional): The dpi of the plot. Defaults to 60. - - Returns: - Tuple[Figure, Axes]: The figure and axes of the plot. - """ - plt.rc("axes", axisbelow=True) - fig, ax = plt.subplots(1, figsize=(7, 5), dpi=dpi) - for i in [1, 0]: - ax.hist( - conf[i], - bins=bins, - density=True, - label=["In-distribution", "Out-of-Distribution"][i], - alpha=0.4, - linewidth=1, - edgecolor=["#0d559f", "#d45f00"][i], - color=["#1f77b4", "#ff7f0e"][i], - ) - - ax.set_title(title) - plt.grid(True, linestyle="--", alpha=0.7, zorder=0) - plt.legend() - fig.tight_layout() - return fig, ax diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index c2d0c368..f4141214 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -96,7 +96,7 @@ def calib_eval() -> float: def forward(self, inputs: Tensor) -> Tensor: if not self.trained: print( - "TemperatureScaler has not been trained yet. Returning a " + "TemperatureScaler has not been trained yet. Returning " "manually tempered inputs." ) return self._scale(self.model(inputs)) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index ec2276dd..d99fdd7c 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -62,7 +62,7 @@ def __init__( "model does not contain any MCBatchNorm2d after conversion." ) - def fit(self, dataset: Dataset): + def fit(self, dataset: Dataset) -> None: """Fit the model on the dataset. Args: diff --git a/torch_uncertainty/routines/__init__.py b/torch_uncertainty/routines/__init__.py index e69de29b..41b7ea80 100644 --- a/torch_uncertainty/routines/__init__.py +++ b/torch_uncertainty/routines/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 +from .classification import ClassificationRoutine +from .regression import RegressionRoutine +from .segmentation import SegmentationRoutine diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 21bda294..29019873 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,48 +1,49 @@ -from argparse import ArgumentParser from collections.abc import Callable -from functools import partial -from typing import Any +from pathlib import Path +from typing import Literal -import pytorch_lightning as pl import torch import torch.nn.functional as F from einops import rearrange -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities.memory import get_model_size_mb -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from lightning.pytorch import LightningModule +from lightning.pytorch.loggers import Logger +from lightning.pytorch.utilities.types import STEP_OUTPUT from timm.data import Mixup as timm_Mixup from torch import Tensor, nn +from torch.optim import Optimizer from torchmetrics import Accuracy, MetricCollection from torchmetrics.classification import ( BinaryAUROC, BinaryAveragePrecision, ) +from torch_uncertainty.layers import Identity from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.metrics import ( CE, FPR95, BrierScore, + CategoricalNLL, Disagreement, Entropy, GroupingLoss, MutualInformation, - NegativeLogLikelihood, VariationRatio, ) -from torch_uncertainty.plotting_utils import plot_hist from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup +from torch_uncertainty.utils import csv_writer, plot_hist -class ClassificationSingle(pl.LightningModule): +class ClassificationRoutine(LightningModule): def __init__( self, - num_classes: int, model: nn.Module, - loss: type[nn.Module], - optimization_procedure: Any, + num_classes: int, + loss: nn.Module, + num_estimators: int = 1, format_batch_fn: nn.Module | None = None, + optim_recipe: dict | Optimizer | None = None, mixtype: str = "erm", mixmode: str = "elem", dist_sim: str = "emb", @@ -52,121 +53,112 @@ def __init__( cutmix_alpha: float = 0, eval_ood: bool = False, eval_grouping_loss: bool = False, - use_entropy: bool = False, - use_logits: bool = False, + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", log_plots: bool = False, - calibration_set: Callable | None = None, - **kwargs, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, ) -> None: - """Classification routine for single models. + r"""Routine for efficient training and testing on **classification tasks** + using LightningModule. Args: + model (torch.nn.Module): Model to train. num_classes (int): Number of classes. - model (nn.Module): Model to train. - loss (type[nn.Module]): Loss function. - optimization_procedure (Any): Optimization procedure. - format_batch_fn (nn.Module, optional): Function to format the batch. + loss (torch.nn.Module): Loss function to optimize the :attr:`model`. + num_estimators (int, optional): Number of estimators for the + ensemble. Defaults to ``1`` (single model). + format_batch_fn (torch.nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. + optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. mixtype (str, optional): Mixup type. Defaults to ``"erm"``. mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to 1.0. + Defaults to ``1.0``. kernel_tau_std (float, optional): Standard deviation for the kernel tau. - Defaults to 0.5. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. + Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to ``0``. cutmix_alpha (float, optional): Alpha parameter for Cutmix. - Defaults to 0. + Defaults to ``0``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection performance or not. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. + ood_criterion (str, optional): OOD criterion. Available options are + + - ``"msp"`` (default): Maximum softmax probability. + - ``"logit"``: Maximum logit. + - ``"energy"``: Logsumexp of the mean logits. + - ``"entropy"``: Entropy of the mean prediction. + - ``"mi"``: Mutual information of the ensemble. + - ``"vr"``: Variation ratio of the ensemble. + log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. - calibration_set (Callable, optional): Function to get the calibration - set. Defaults to ``None``. - kwargs (Any): Additional arguments. - - Note: - The default OOD criterion is the softmax confidence score. + save_in_csv(bool, optional): Save the results in csv. Defaults to + ``False``. + calibration_set (str, optional): The calibration dataset to use for + scaling. If not ``None``, it uses either the validation set when + set to ``"val"`` or the test set when set to ``"test"``. + Defaults to ``None``. Warning: - Make sure at most only one of :attr:`use_entropy` and :attr:`use_logits` - attributes is set to ``True``. Otherwise a :class:`ValueError()` will - be raised. + You must define :attr:`optim_recipe` if you do not use the CLI. + + Note: + :attr:`optim_recipe` can be anything that can be returned by + :meth:`LightningModule.configure_optimizers()`. Find more details + `here `_. """ super().__init__() + _classification_routine_checks( + model=model, + num_classes=num_classes, + num_estimators=num_estimators, + ood_criterion=ood_criterion, + eval_grouping_loss=eval_grouping_loss, + ) if format_batch_fn is None: format_batch_fn = nn.Identity() - self.save_hyperparameters( - ignore=[ - "model", - "loss", - "optimization_procedure", - "format_batch_fn", - "calibration_set", - ] - ) - - if (use_logits + use_entropy) > 1: - raise ValueError("You cannot choose more than one OOD criterion.") - - if eval_grouping_loss and not hasattr(model, "feats_forward"): - raise ValueError( - "Your model must have a `feats_forward` method to compute the " - "grouping loss." - ) - - if eval_grouping_loss and not ( - hasattr(model, "classification_head") or hasattr(model, "linear") - ): - raise ValueError( - "Your model must have a `classification_head` or `linear` " - "attribute to compute the grouping loss." - ) - self.num_classes = num_classes + self.num_estimators = num_estimators self.eval_ood = eval_ood self.eval_grouping_loss = eval_grouping_loss - self.use_logits = use_logits - self.use_entropy = use_entropy + self.ood_criterion = ood_criterion self.log_plots = log_plots - + self.save_in_csv = save_in_csv self.calibration_set = calibration_set - self.binary_cls = num_classes == 1 self.model = model self.loss = loss - self.optimization_procedure = optimization_procedure - # batch format self.format_batch_fn = format_batch_fn + self.optim_recipe = optim_recipe # metrics if self.binary_cls: cls_metrics = MetricCollection( { - "acc": Accuracy(task="binary"), - "ece": CE(task="binary"), - "brier": BrierScore(num_classes=1), + "Acc": Accuracy(task="binary"), + "ECE": CE(task="binary"), + "Brier": BrierScore(num_classes=1), }, compute_groups=False, ) else: cls_metrics = MetricCollection( { - "nll": NegativeLogLikelihood(), - "acc": Accuracy( + "NLL": CategoricalNLL(), + "Acc": Accuracy( task="multiclass", num_classes=self.num_classes ), - "ece": CE(task="multiclass", num_classes=self.num_classes), - "brier": BrierScore(num_classes=self.num_classes), + "ECE": CE(task="multiclass", num_classes=self.num_classes), + "Brier": BrierScore(num_classes=self.num_classes), }, compute_groups=False, ) @@ -175,263 +167,67 @@ def __init__( self.test_cls_metrics = cls_metrics.clone(prefix="cls_test/") if self.calibration_set is not None: - self.ts_cls_metrics = cls_metrics.clone(prefix="ts_") + self.ts_cls_metrics = cls_metrics.clone(prefix="cls_test/ts_") self.test_entropy_id = Entropy() if self.eval_ood: ood_metrics = MetricCollection( { - "fpr95": FPR95(pos_label=1), - "auroc": BinaryAUROC(), - "aupr": BinaryAveragePrecision(), + "FPR95": FPR95(pos_label=1), + "AUROC": BinaryAUROC(), + "AUPR": BinaryAveragePrecision(), }, - compute_groups=[["auroc", "aupr"], ["fpr95"]], + compute_groups=[["AUROC", "AUPR"], ["FPR95"]], ) self.test_ood_metrics = ood_metrics.clone(prefix="ood/") self.test_entropy_ood = Entropy() - if self.eval_grouping_loss: - grouping_loss = MetricCollection({"grouping_loss": GroupingLoss()}) - self.val_grouping_loss = grouping_loss.clone(prefix="gpl/val_") - self.test_grouping_loss = grouping_loss.clone(prefix="gpl/test_") - - if mixup_alpha < 0 or cutmix_alpha < 0: - raise ValueError( - "Cutmix alpha and Mixup alpha must be positive." - f"Got {mixup_alpha} and {cutmix_alpha}." - ) - self.mixtype = mixtype self.mixmode = mixmode self.dist_sim = dist_sim + if num_estimators == 1: + if mixup_alpha < 0 or cutmix_alpha < 0: + raise ValueError( + "Cutmix alpha and Mixup alpha must be positive." + f"Got {mixup_alpha} and {cutmix_alpha}." + ) - self.mixup = self.init_mixup( - mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std - ) - - # Handle ELBO special cases - self.is_elbo = ( - isinstance(self.loss, partial) and self.loss.func == ELBOLoss - ) - - # Deep Evidential Classification - self.is_dec = self.loss == DECLoss or ( - isinstance(self.loss, partial) and self.loss.func == DECLoss - ) - - def configure_optimizers(self) -> Any: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - if self.is_elbo: - self.loss = partial(self.loss, model=self.model) - return self.loss() - - def forward(self, inputs: Tensor, return_features: bool = False) -> Tensor: - """Forward pass of the model. - - Args: - inputs (Tensor): Input tensor. - return_features (bool, optional): Whether to store the features or - not. Defaults to ``False``. - - Note: - The features are stored in the :attr:`features` attribute. - """ - if return_features: - self.features = self.model.feats_forward(inputs) - if hasattr(self.model, "classification_head"): # coverage: ignore - logits = self.model.classification_head(self.features) - else: - logits = self.model.linear(self.features) - else: - self.features = None - logits = self.model(inputs) - return logits - - def on_train_start(self) -> None: - # hyperparameters for performances - param = {} - param["storage"] = f"{get_model_size_mb(self)} MB" - - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> STEP_OUTPUT: - if self.mixtype == "kernel_warping": - if self.dist_sim == "emb": - with torch.no_grad(): - feats = self.model.feats_forward(batch[0]).detach() - - batch = self.mixup(*batch, feats) - elif self.dist_sim == "inp": - batch = self.mixup(*batch, batch[0]) - else: - batch = self.mixup(*batch) - - inputs, targets = self.format_batch_fn(batch) - - if self.is_elbo: - loss = self.criterion(inputs, targets) - else: - logits = self.forward(inputs) - # BCEWithLogitsLoss expects float targets - if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: - logits = logits.squeeze(-1) - targets = targets.float() - - if not self.is_dec: - loss = self.criterion(logits, targets) - else: - loss = self.criterion(logits, targets, self.current_epoch) - self.log("train_loss", loss) - return loss - - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> None: - inputs, targets = batch - - logits = self.forward(inputs, return_features=self.eval_grouping_loss) - - if self.binary_cls: - probs = torch.sigmoid(logits).squeeze(-1) - else: - probs = F.softmax(logits, dim=-1) - - self.val_cls_metrics.update(probs, targets) - - if self.eval_grouping_loss: - self.val_grouping_loss.update(probs, targets, self.features) - - def validation_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: - self.log_dict(self.val_cls_metrics.compute()) - self.val_cls_metrics.reset() - - if self.eval_grouping_loss: - self.log_dict(self.val_grouping_loss.compute()) - self.val_grouping_loss.reset() - - def on_test_start(self) -> None: - if self.calibration_set is not None: - self.cal_model = TemperatureScaler( - model=self.model, device=self.device - ).fit(calibration_set=self.calibration_set()) - else: - self.cal_model = None - - def test_step( - self, - batch: tuple[Tensor, Tensor], - batch_idx: int, - dataloader_idx: int | None = 0, - ) -> Tensor: - inputs, targets = batch - logits = self.forward(inputs, return_features=self.eval_grouping_loss) - - if self.binary_cls: - probs = torch.sigmoid(logits).squeeze(-1) - else: - probs = F.softmax(logits, dim=-1) - - # self.cal_plot.update(probs, targets) - confs = probs.max(dim=-1)[0] - - if self.use_logits: - ood_scores = -logits.max(dim=-1)[0] - elif self.use_entropy: - ood_scores = torch.special.entr(probs).sum(dim=-1) - else: - ood_scores = -confs - - if self.calibration_set is not None and self.cal_model is not None: - cal_logits = self.cal_model(inputs) - cal_probs = F.softmax(cal_logits, dim=-1) - self.ts_cls_metrics.update(cal_probs, targets) + self.mixup = self.init_mixup( + mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std + ) - if dataloader_idx == 0: - self.test_cls_metrics.update(probs, targets) if self.eval_grouping_loss: - self.test_grouping_loss.update(probs, targets, self.features) - self.test_entropy_id(probs) - self.log( - "cls_test/entropy", - self.test_entropy_id, - on_epoch=True, - add_dataloader_idx=False, - ) - if self.eval_ood: - self.test_ood_metrics.update( - ood_scores, torch.zeros_like(targets) + grouping_loss = MetricCollection( + {"grouping_loss": GroupingLoss()} + ) + self.val_grouping_loss = grouping_loss.clone(prefix="gpl/val_") + self.test_grouping_loss = grouping_loss.clone( + prefix="gpl/test_" ) - elif self.eval_ood and dataloader_idx == 1: - self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) - self.test_entropy_ood(probs) - self.log( - "ood/entropy", - self.test_entropy_ood, - on_epoch=True, - add_dataloader_idx=False, - ) - return logits - - def test_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: - self.log_dict( - self.test_cls_metrics.compute(), - ) - if self.eval_grouping_loss: - self.log_dict( - self.test_grouping_loss.compute(), - ) - if self.calibration_set is not None and self.cal_model is not None: - self.log_dict(self.ts_cls_metrics.compute()) - self.ts_cls_metrics.reset() + self.is_elbo = isinstance(self.loss, ELBOLoss) + if self.is_elbo: + self.loss.set_model(self.model) + self.is_dec = isinstance(self.loss, DECLoss) - if self.eval_ood: - self.log_dict( - self.test_ood_metrics.compute(), + # metrics for ensembles only + if self.num_estimators > 1: + ens_metrics = MetricCollection( + { + "Disagreement": Disagreement(), + "MI": MutualInformation(), + "Entropy": Entropy(), + } ) - self.test_ood_metrics.reset() - if isinstance(self.logger, TensorBoardLogger) and self.log_plots: - self.logger.experiment.add_figure( - "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] - ) + self.test_id_ens_metrics = ens_metrics.clone(prefix="cls_test/ens_") if self.eval_ood: - id_logits = torch.cat(outputs[0], 0).float().cpu() - ood_logits = torch.cat(outputs[1], 0).float().cpu() + self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") - id_probs = F.softmax(id_logits, dim=-1) - ood_probs = F.softmax(ood_logits, dim=-1) - - logits_fig = plot_hist( - [id_logits.max(-1).values, ood_logits.max(-1).values], - 20, - "Histogram of the logits", - )[0] - probs_fig = plot_hist( - [id_probs.max(-1).values, ood_probs.max(-1).values], - 20, - "Histogram of the likelihoods", - )[0] - self.logger.experiment.add_figure("Logit Histogram", logits_fig) - self.logger.experiment.add_figure( - "Likelihood Histogram", probs_fig - ) - - self.test_cls_metrics.reset() - if self.eval_grouping_loss: - self.test_grouping_loss.reset() - - def identity(self, x: float, y: float): - return x, y + self.id_logit_storage = None + self.ood_logit_storage = None def init_mixup( self, @@ -474,217 +270,95 @@ def init_mixup( tau_max=kernel_tau_max, tau_std=kernel_tau_std, ) - return self.identity + return Identity() - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the routine's attributes via command-line options. + def configure_optimizers(self) -> Optimizer | dict: + return self.optim_recipe - Args: - parent_parser (ArgumentParser): Parent parser to be completed. - - Adds: - - ``--entropy``: sets :attr:`use_entropy` to ``True``. - - ``--logits``: sets :attr:`use_logits` to ``True``. - - ``--eval-grouping-loss``: sets :attr:`eval_grouping-loss` to - ``True``. - - ``--mixup_alpha``: sets :attr:`mixup_alpha` for Mixup - - ``--cutmix_alpha``: sets :attr:`cutmix_alpha` for Cutmix - - ``--mixtype``: sets :attr:`mixtype` for Mixup - - ``--mixmode``: sets :attr:`mixmode` for Mixup - - ``--dist_sim``: sets :attr:`dist_sim` for Mixup - - ``--kernel_tau_max``: sets :attr:`kernel_tau_max` for Mixup - - ``--kernel_tau_std``: sets :attr:`kernel_tau_std` for Mixup - """ - parent_parser.add_argument( - "--entropy", dest="use_entropy", action="store_true" - ) - parent_parser.add_argument( - "--logits", dest="use_logits", action="store_true" - ) - parent_parser.add_argument( - "--eval-grouping-loss", - action="store_true", - help="Whether to evaluate the grouping loss or not", - ) + def on_train_start(self) -> None: + init_metrics = dict.fromkeys(self.val_cls_metrics, 0) + init_metrics.update(dict.fromkeys(self.test_cls_metrics, 0)) - # Mixup args - parent_parser.add_argument( - "--mixup_alpha", dest="mixup_alpha", type=float, default=0 - ) - parent_parser.add_argument( - "--cutmix_alpha", dest="cutmix_alpha", type=float, default=0 - ) - parent_parser.add_argument( - "--mixtype", dest="mixtype", type=str, default="erm" - ) - parent_parser.add_argument( - "--mixmode", dest="mixmode", type=str, default="elem" - ) - parent_parser.add_argument( - "--dist_sim", dest="dist_sim", type=str, default="emb" - ) - parent_parser.add_argument( - "--kernel_tau_max", dest="kernel_tau_max", type=float, default=1.0 - ) - parent_parser.add_argument( - "--kernel_tau_std", dest="kernel_tau_std", type=float, default=0.5 - ) - return parent_parser + if self.logger is not None: # coverage: ignore + self.logger.log_hyperparams( + self.hparams, + init_metrics, + ) + def on_test_start(self) -> None: + if isinstance(self.calibration_set, str) and self.calibration_set in [ + "val", + "test", + ]: + dataset = ( + self.trainer.datamodule.val_dataloader().dataset + if self.calibration_set == "val" + else self.trainer.datamodule.test_dataloader().dataset + ) + with torch.inference_mode(False): + self.cal_model = TemperatureScaler( + model=self.model, device=self.device + ).fit(calibration_set=dataset) + else: + self.cal_model = None -class ClassificationEnsemble(ClassificationSingle): - def __init__( - self, - num_classes: int, - model: nn.Module, - loss: type[nn.Module], - optimization_procedure: Any, - num_estimators: int, - format_batch_fn: nn.Module | None = None, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, - eval_ood: bool = False, - eval_grouping_loss: bool = False, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - log_plots: bool = False, - **kwargs, - ) -> None: - """Classification routine for ensemble models. + if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): + self.id_logit_storage = [] + self.ood_logit_storage = [] + + def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: + """Forward pass of the model. Args: - num_classes (int): Number of classes. - model (nn.Module): Model to train. - loss (type[nn.Module]): Loss function. - optimization_procedure (Any): Optimization procedure. - num_estimators (int): Number of estimators in the ensemble. - format_batch_fn (nn.Module, optional): Function to format the batch. - Defaults to :class:`torch.nn.Identity()`. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to 1.0. - kernel_tau_std (float, optional): Standard deviation for the kernel tau. - Defaults to 0.5. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. - cutmix_alpha (float, optional): Alpha parameter for Cutmix. - Defaults to 0. - eval_ood (bool, optional): Indicates whether to evaluate the OOD - detection performance or not. Defaults to ``False``. - eval_grouping_loss (bool, optional): Indicates whether to evaluate the - grouping loss or not. Defaults to ``False``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. - log_plots (bool, optional): Indicates whether to log plots from - metrics. Defaults to ``False``. - calibration_set (Callable, optional): Function to get the calibration - set. Defaults to ``None``. - kwargs (Any): Additional arguments. + inputs (Tensor): Input tensor. + save_feats (bool, optional): Whether to store the features or + not. Defaults to ``False``. Note: - The default OOD criterion is the averaged softmax confidence score. - - Warning: - Make sure at most only one of :attr:`use_entropy`, :attr:`use_logits` - , :attr:`use_mi`, and :attr:`use_variation_ratio` attributes is set to - ``True``. Otherwise a :class:`ValueError()` will be raised. + The features are stored in the :attr:`self.features` attribute. """ - super().__init__( - num_classes=num_classes, - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, - eval_ood=eval_ood, - eval_grouping_loss=eval_grouping_loss, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - - self.num_estimators = num_estimators - - self.use_mi = use_mi - self.use_variation_ratio = use_variation_ratio - self.log_plots = log_plots - - if ( - self.use_logits - + self.use_entropy - + self.use_mi - + self.use_variation_ratio - ) > 1: - raise ValueError("You cannot choose more than one OOD criterion.") - - # metrics for ensembles only - ens_metrics = MetricCollection( - { - "disagreement": Disagreement(), - "mi": MutualInformation(), - "entropy": Entropy(), - } - ) - self.test_id_ens_metrics = ens_metrics.clone(prefix="ens_id/test_") - - if self.eval_ood: - self.test_ood_ens_metrics = ens_metrics.clone( - prefix="ens_ood/test_" - ) - - if self.eval_grouping_loss: - raise NotImplementedError( - "Grouping loss not implemented for ensembles. Raise an issue" - " if you need it." - ) - - def on_train_start(self) -> None: - param = {} - param["storage"] = f"{get_model_size_mb(self)} MB" + if save_feats: + self.features = self.model.feats_forward(inputs) + if hasattr(self.model, "classification_head"): # coverage: ignore + logits = self.model.classification_head(self.features) + else: + logits = self.model.linear(self.features) + else: + self.features = None + logits = self.model(inputs) + return logits def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: - batch = self.mixup(*batch) - # eventual input repeat is done in the model + # Mixup only for single models + if self.num_estimators == 1: + if self.mixtype == "kernel_warping": + if self.dist_sim == "emb": + with torch.no_grad(): + feats = self.model.feats_forward(batch[0]).detach() + + batch = self.mixup(*batch, feats) + elif self.dist_sim == "inp": + batch = self.mixup(*batch, batch[0]) + else: + batch = self.mixup(*batch) + inputs, targets = self.format_batch_fn(batch) if self.is_elbo: - loss = self.criterion(inputs, targets) + loss = self.loss(inputs, targets) else: logits = self.forward(inputs) # BCEWithLogitsLoss expects float targets - if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: + if self.binary_cls and isinstance(self.loss, nn.BCEWithLogitsLoss): logits = logits.squeeze(-1) targets = targets.float() if not self.is_dec: - loss = self.criterion(logits, targets) + loss = self.loss(logits, targets) else: - loss = self.criterion(logits, targets, self.current_epoch) + loss = self.loss(logits, targets, self.current_epoch) self.log("train_loss", loss) return loss @@ -693,8 +367,11 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch - logits = self.forward(inputs) + logits = self.forward( + inputs, save_feats=self.eval_grouping_loss + ) # (m*b, c) logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) + if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) else: @@ -703,14 +380,19 @@ def validation_step( probs = probs_per_est.mean(dim=1) self.val_cls_metrics.update(probs, targets) + if self.eval_grouping_loss: + self.val_grouping_loss.update(probs, targets, self.features) + def test_step( self, batch: tuple[Tensor, Tensor], batch_idx: int, - dataloader_idx: int | None = 0, - ) -> Tensor: + dataloader_idx: int = 0, + ) -> None: inputs, targets = batch - logits = self.forward(inputs) + logits = self.forward( + inputs, save_feats=self.eval_grouping_loss + ) # (m*b, c) if logits.size(0) % self.num_estimators != 0: # coverage: ignore raise ValueError( f"The number of predicted samples {logits.size(0)} is not " @@ -726,33 +408,49 @@ def test_step( probs_per_est = F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) - # self.cal_plot.update(probs, targets) + confs = probs.max(-1)[0] - if self.use_logits: - ood_scores = -logits.mean(dim=1).max(dim=-1)[0] - elif self.use_entropy: + if self.ood_criterion == "logit": + ood_scores = -logits.mean(dim=1).max(dim=-1).values + elif self.ood_criterion == "energy": + ood_scores = -logits.mean(dim=1).logsumexp(dim=-1) + elif self.ood_criterion == "entropy": ood_scores = ( torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) ) - elif self.use_mi: + elif self.ood_criterion == "mi": mi_metric = MutualInformation(reduction="none") ood_scores = mi_metric(probs_per_est) - elif self.use_variation_ratio: + elif self.ood_criterion == "vr": vr_metric = VariationRatio(reduction="none", probabilistic=False) ood_scores = vr_metric(probs_per_est.transpose(0, 1)) else: ood_scores = -confs + # Scaling for single models + if ( + self.num_estimators == 1 + and self.calibration_set is not None + and self.cal_model is not None + ): + cal_logits = self.cal_model(inputs) + cal_probs = F.softmax(cal_logits, dim=-1) + self.ts_cls_metrics.update(cal_probs, targets) + if dataloader_idx == 0: # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( probs.squeeze(-1) if self.binary_cls else probs, targets, ) - self.test_entropy_id(probs) + if self.eval_grouping_loss: + self.test_grouping_loss.update(probs, targets, self.features) - self.test_id_ens_metrics.update(probs_per_est) + self.log_dict( + self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False + ) + self.test_entropy_id(probs) self.log( "cls_test/entropy", self.test_entropy_id, @@ -760,52 +458,92 @@ def test_step( add_dataloader_idx=False, ) + if self.num_estimators > 1: + self.test_id_ens_metrics.update(probs_per_est) + if self.eval_ood: self.test_ood_metrics.update( ood_scores, torch.zeros_like(targets) ) + + if self.id_logit_storage is not None: + self.id_logit_storage.append(logits.detach().cpu()) + elif self.eval_ood and dataloader_idx == 1: self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_entropy_ood(probs) - self.test_ood_ens_metrics.update(probs_per_est) self.log( "ood/entropy", self.test_entropy_ood, on_epoch=True, add_dataloader_idx=False, ) - return logits + if self.num_estimators > 1: + self.test_ood_ens_metrics.update(probs_per_est) - def test_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: - self.log_dict( - self.test_cls_metrics.compute(), - ) + if self.ood_logit_storage is not None: + self.ood_logit_storage.append(logits.detach().cpu()) - self.log_dict( - self.test_id_ens_metrics.compute(), - ) + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_cls_metrics.compute()) + self.val_cls_metrics.reset() - if self.eval_ood: - self.log_dict( - self.test_ood_metrics.compute(), - ) + if self.eval_grouping_loss: + self.log_dict(self.val_grouping_loss.compute()) + self.val_grouping_loss.reset() + + def on_test_epoch_end(self) -> None: + # already logged + result_dict = self.test_cls_metrics.compute() + + # already logged + result_dict.update({"cls_test/entropy": self.test_entropy_id.compute()}) + + if ( + self.num_estimators == 1 + and self.calibration_set is not None + and self.cal_model is not None + ): + tmp_metrics = self.ts_cls_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) + self.ts_cls_metrics.reset() + + if self.eval_grouping_loss: self.log_dict( - self.test_ood_ens_metrics.compute(), + self.test_grouping_loss.compute(), ) + if self.num_estimators > 1: + tmp_metrics = self.test_id_ens_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) + self.test_id_ens_metrics.reset() + + if self.eval_ood: + tmp_metrics = self.test_ood_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) self.test_ood_metrics.reset() - self.test_ood_ens_metrics.reset() - if isinstance(self.logger, TensorBoardLogger) and self.log_plots: + # already logged + result_dict.update({"ood/entropy": self.test_entropy_ood.compute()}) + + if self.num_estimators > 1: + tmp_metrics = self.test_ood_ens_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) + self.test_ood_ens_metrics.reset() + + if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( - "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] + "Calibration Plot", self.test_cls_metrics["ECE"].plot()[0] ) + # plot histograms of logits and likelihoods if self.eval_ood: - id_logits = torch.cat(outputs[0], 0).float().cpu() - ood_logits = torch.cat(outputs[1], 0).float().cpu() + id_logits = torch.cat(self.id_logit_storage, dim=0) + ood_logits = torch.cat(self.ood_logit_storage, dim=0) id_probs = F.softmax(id_logits, dim=-1) ood_probs = F.softmax(ood_logits, dim=-1) @@ -831,45 +569,70 @@ def test_epoch_end( "Likelihood Histogram", probs_fig ) - self.test_cls_metrics.reset() - self.test_id_ens_metrics.reset() - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the routine's attributes via command-line options. - - Adds: - - ``--entropy``: sets :attr:`use_entropy` to ``True``. - - ``--logits``: sets :attr:`use_logits` to ``True``. - - ``--mutual_information``: sets :attr:`use_mi` to ``True``. - - ``--variation_ratio``: sets :attr:`use_variation_ratio` to ``True``. - - ``--num_estimators``: sets :attr:`num_estimators`. - """ - parent_parser = ClassificationSingle.add_model_specific_args( - parent_parser + if self.save_in_csv: + self.save_results_to_csv(result_dict) + + def save_results_to_csv(self, results: dict[str, float]) -> None: + if self.logger is not None: + csv_writer( + Path(self.logger.log_dir) / "results.csv", + results, + ) + + +def _classification_routine_checks( + model: nn.Module, + num_classes: int, + num_estimators: int, + ood_criterion: str, + eval_grouping_loss: bool, +) -> None: + if not isinstance(num_estimators, int) or num_estimators < 1: + raise ValueError( + "The number of estimators must be a positive integer >= 1." + f"Got {num_estimators}." + ) + + if ood_criterion not in [ + "msp", + "logit", + "energy", + "entropy", + "mi", + "vr", + ]: + raise ValueError( + "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," + f" 'mi' or 'vr'. Got {ood_criterion}." + ) + + if num_estimators == 1 and ood_criterion in ["mi", "vr"]: + raise ValueError( + "You cannot use mutual information or variation ratio with a single" + " model." ) - # FIXME: should be a str to choose among the available OOD criteria - # rather than a boolean, but it is not possible since - # ClassificationSingle and ClassificationEnsemble have different OOD - # criteria. - parent_parser.add_argument( - "--mutual_information", - dest="use_mi", - action="store_true", - default=False, + + if num_estimators != 1 and eval_grouping_loss: + raise NotImplementedError( + "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." ) - parent_parser.add_argument( - "--variation_ratio", - dest="use_variation_ratio", - action="store_true", - default=False, + + if num_classes < 1: + raise ValueError( + "The number of classes must be a positive integer >= 1." + f"Got {num_classes}." ) - parent_parser.add_argument( - "--num_estimators", - type=int, - default=None, - help="Number of estimators for ensemble", + + if eval_grouping_loss and not hasattr(model, "feats_forward"): + raise ValueError( + "Your model must have a `feats_forward` method to compute the " + "grouping loss." + ) + + if eval_grouping_loss and not ( + hasattr(model, "classification_head") or hasattr(model, "linear") + ): + raise ValueError( + "Your model must have a `classification_head` or `linear` " + "attribute to compute the grouping loss." ) - return parent_parser diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 16522831..3124856d 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -1,268 +1,192 @@ -from argparse import ArgumentParser -from typing import Any, Literal - -import pytorch_lightning as pl import torch -import torch.nn.functional as F from einops import rearrange -from pytorch_lightning.utilities.memory import get_model_size_mb -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT -from torch import nn -from torchmetrics import MeanSquaredError, MetricCollection - -from torch_uncertainty.metrics.nll import GaussianNegativeLogLikelihood - - -class RegressionSingle(pl.LightningModule): +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import Tensor, nn +from torch.distributions import ( + Categorical, + Independent, + MixtureSameFamily, +) +from torch.optim import Optimizer +from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection + +from torch_uncertainty.metrics.regression.nll import DistributionNLL +from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist + + +class RegressionRoutine(LightningModule): def __init__( self, model: nn.Module, - loss: type[nn.Module], - optimization_procedure: Any, - dist_estimation: int, - **kwargs, + output_dim: int, + probabilistic: bool, + loss: nn.Module, + num_estimators: int = 1, + optim_recipe: dict | Optimizer | None = None, + format_batch_fn: nn.Module | None = None, ) -> None: + r"""Routine for efficient training and testing on **regression tasks** + using LightningModule. + + Args: + model (torch.nn.Module): Model to train. + output_dim (int): Number of outputs of the model. + probabilistic (bool): Whether the model is probabilistic, i.e., + outputs a PyTorch distribution. + loss (torch.nn.Module): Loss function to optimize the :attr:`model`. + num_estimators (int, optional): The number of estimators for the + ensemble. Defaults to ``1`` (single model). + optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. + format_batch_fn (torch.nn.Module, optional): The function to format the + batch. Defaults to ``None``. + + Warning: + If :attr:`probabilistic` is True, the model must output a `PyTorch + distribution `_. + + Warning: + You must define :attr:`optim_recipe` if you do not use + the CLI. + + Note: + :attr:`optim_recipe` can be anything that can be returned by + :meth:`LightningModule.configure_optimizers()`. Find more details + `here `_. + """ super().__init__() - - self.save_hyperparameters( - ignore=[ - "model", - "loss", - "optimization_procedure", - ] - ) + _regression_routine_checks(num_estimators, output_dim) self.model = model + self.probabilistic = probabilistic + self.output_dim = output_dim self.loss = loss - self.optimization_procedure = optimization_procedure - - # metrics - if isinstance(dist_estimation, int): - if dist_estimation <= 0: - raise ValueError( - "Expected the argument ``dist_estimation`` to be integer " - f" larger than 0, but got {dist_estimation}." - ) - else: - raise TypeError( - "Expected the argument ``dist_estimation`` to be integer, but " - f"got {type(dist_estimation)}" - ) + self.num_estimators = num_estimators - out_features = list(self.model.parameters())[-1].size(0) - if dist_estimation > out_features: - raise ValueError( - "Expected argument ``dist_estimation`` to be an int lower or " - f"equal than the size of the output layer, but got " - f"{dist_estimation} and {out_features}." - ) + if format_batch_fn is None: + format_batch_fn = nn.Identity() - self.dist_estimation = dist_estimation + self.optim_recipe = optim_recipe + self.format_batch_fn = format_batch_fn - if dist_estimation in (4, 2): - reg_metrics = MetricCollection( - { - "mse": MeanSquaredError(squared=True), - "gnll": GaussianNegativeLogLikelihood(), - }, - compute_groups=False, - ) - else: - reg_metrics = MetricCollection( - { - "mse": MeanSquaredError(squared=True), - }, - compute_groups=False, - ) + reg_metrics = MetricCollection( + { + "MAE": MeanAbsoluteError(), + "MSE": MeanSquaredError(squared=True), + "RMSE": MeanSquaredError(squared=False), + }, + compute_groups=True, + ) self.val_metrics = reg_metrics.clone(prefix="reg_val/") self.test_metrics = reg_metrics.clone(prefix="reg_test/") - def configure_optimizers(self) -> Any: - return self.optimization_procedure(self) + if self.probabilistic: + reg_prob_metrics = MetricCollection( + {"NLL": DistributionNLL(reduction="mean")} + ) + self.val_prob_metrics = reg_prob_metrics.clone(prefix="reg_val/") + self.test_prob_metrics = reg_prob_metrics.clone(prefix="reg_test/") - @property - def criterion(self) -> nn.Module: - return self.loss() + self.one_dim_regression = output_dim == 1 - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - return self.model.forward(inputs) + def configure_optimizers(self) -> Optimizer | dict: + return self.optim_recipe def on_train_start(self) -> None: - # hyperparameters for performances - param = {} - param["storage"] = f"{get_model_size_mb(self)} MB" - - def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int - ) -> STEP_OUTPUT: - inputs, targets = batch - logits = self.forward(inputs) - - if self.dist_estimation == 4: - means, v, alpha, beta = logits.split(1, dim=-1) - v = F.softplus(v) - alpha = 1 + F.softplus(alpha) - beta = F.softplus(beta) - loss = self.criterion(means, v, alpha, beta, targets) - elif self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - loss = self.criterion(means, targets, variances) - else: - loss = self.criterion(logits, targets) - - self.log("train_loss", loss) - return loss + init_metrics = dict.fromkeys(self.val_metrics, 0) + init_metrics.update(dict.fromkeys(self.test_metrics, 0)) + if self.probabilistic: + init_metrics.update(dict.fromkeys(self.val_prob_metrics, 0)) + init_metrics.update(dict.fromkeys(self.test_prob_metrics, 0)) + + if self.logger is not None: # coverage: ignore + self.logger.log_hyperparams( + self.hparams, + init_metrics, + ) - def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int - ) -> None: - inputs, targets = batch - logits = self.forward(inputs) - if self.dist_estimation == 4: - means = logits[..., 0] - alpha = 1 + F.softplus(logits[..., 2]) - beta = F.softplus(logits[..., 3]) - variances = beta / (alpha - 1) - self.val_metrics.gnll.update(means, targets, variances) - - targets = targets.view(means.size()) - elif self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.val_metrics.gnll.update(means, targets, variances) - - if means.ndim == 1: - means = means.unsqueeze(-1) - else: - means = logits.squeeze(-1) + def forward(self, inputs: Tensor) -> Tensor: + """Forward pass of the routine. - self.val_metrics.mse.update(means, targets) + The forward pass automatically squeezes the output if the regression + is one-dimensional and if the routine contains a single model. - def validation_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() + Args: + inputs (Tensor): The input tensor. - def test_step( - self, - batch: tuple[torch.Tensor, torch.Tensor], - batch_idx: int, - ) -> None: - inputs, targets = batch - logits = self.forward(inputs) - - if self.dist_estimation == 4: - means = logits[..., 0] - alpha = 1 + F.softplus(logits[..., 2]) - beta = F.softplus(logits[..., 3]) - variances = beta / (alpha - 1) - self.test_metrics.gnll.update(means, targets, variances) - - targets = targets.view(means.size()) - elif self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.test_metrics.gnll.update(means, targets, variances) - - if means.ndim == 1: - means = means.unsqueeze(-1) + Returns: + Tensor: The output tensor. + """ + pred = self.model(inputs) + if self.probabilistic: + if self.one_dim_regression: + pred = squeeze_dist(pred, -1) + if self.num_estimators == 1: + pred = squeeze_dist(pred, -1) else: - means = logits.squeeze(-1) - - self.test_metrics.mse.update(means, targets) - - def test_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: - self.log_dict( - self.test_metrics.compute(), - ) - self.test_metrics.reset() - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - return parent_parser - - -class RegressionEnsemble(RegressionSingle): - def __init__( - self, - model: nn.Module, - loss: type[nn.Module], - optimization_procedure: Any, - dist_estimation: int, - num_estimators: int, - mode: Literal["mean", "mixture"], - out_features: int | None = 1, - **kwargs, - ) -> None: - super().__init__( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, - **kwargs, - ) - - if mode == "mixture": - raise NotImplementedError( - "Mixture of gaussians not implemented yet. Raise an issue if " - "needed." - ) - - self.mode = mode - self.num_estimators = num_estimators - self.out_features = out_features + if self.one_dim_regression: + pred = pred.squeeze(-1) + if self.num_estimators == 1: + pred = pred.squeeze(-1) + return pred def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: - inputs, targets = batch + inputs, targets = self.format_batch_fn(batch) + dists = self.model(inputs) - # eventual input repeat is done in the model - targets = targets.repeat((self.num_estimators, 1)) - return super().training_step((inputs, targets), batch_idx) + if self.one_dim_regression: + targets = targets.unsqueeze(-1) + + loss = self.loss(dists, targets) + self.log("train_loss", loss) + return loss def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch - logits = self.forward(inputs) - - if self.out_features == 1: - logits = rearrange( - logits, "(m b) dist -> b m dist", m=self.num_estimators + if self.one_dim_regression: + targets = targets.unsqueeze(-1) + preds = self.model(inputs) + + if self.probabilistic: + ens_dist = Independent( + dist_rearrange( + preds, "(m b) c -> b m c", m=self.num_estimators + ), + 1, ) - else: - logits = rearrange( - logits, - "(m b) (f dist) -> b f m dist", - m=self.num_estimators, - f=self.out_features, + mix = Categorical( + torch.ones(self.num_estimators, device=self.device) ) - - if self.mode == "mean": - logits = logits.mean(dim=1) - - if self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.val_metrics.gnll.update(means, targets, variances) + mixture = MixtureSameFamily(mix, ens_dist) + preds = mixture.mean else: - means = logits + preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = preds.mean(dim=1) - self.val_metrics.mse.update(means, targets) + self.val_metrics.update(preds, targets) + if self.probabilistic: + self.val_prob_metrics.update(mixture, targets) + + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_metrics.compute()) + self.val_metrics.reset() + if self.probabilistic: + self.log_dict( + self.val_prob_metrics.compute(), + ) + self.val_prob_metrics.reset() def test_step( self, - batch: tuple[torch.Tensor, torch.Tensor], + batch: tuple[Tensor, Tensor], batch_idx: int, - dataloader_idx: int | None = 0, + dataloader_idx: int = 0, ) -> None: if dataloader_idx != 0: raise NotImplementedError( @@ -271,46 +195,48 @@ def test_step( ) inputs, targets = batch - logits = self.forward(inputs) - - if self.out_features == 1: - logits = rearrange( - logits, "(m b) dist -> b m dist", m=self.num_estimators + if self.one_dim_regression: + targets = targets.unsqueeze(-1) + preds = self.model(inputs) + + if self.probabilistic: + ens_dist = Independent( + dist_rearrange( + preds, "(m b) c -> b m c", m=self.num_estimators + ), + 1, ) - else: - logits = rearrange( - logits, - "(m b) (f dist) -> b f m dist", - m=self.num_estimators, - f=self.out_features, + mix = Categorical( + torch.ones(self.num_estimators, device=self.device) ) + mixture = MixtureSameFamily(mix, ens_dist) + preds = mixture.mean + else: + preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = preds.mean(dim=1) - if self.mode == "mean": - logits = logits.mean(dim=1) + self.test_metrics.update(preds, targets) + if self.probabilistic: + self.test_prob_metrics.update(mixture, targets) - if self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.test_metrics.gnll.update(means, targets, variances) - else: - means = logits + def on_test_epoch_end(self) -> None: + self.log_dict( + self.test_metrics.compute(), + ) + self.test_metrics.reset() - self.test_metrics.mse.update(means, targets) + if self.probabilistic: + self.log_dict( + self.test_prob_metrics.compute(), + ) + self.test_prob_metrics.reset() - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the routine's attributes via command-line options. - Adds: - - ``--num_estimators``: sets :attr:`num_estimators`. - """ - parent_parser = RegressionSingle.add_model_specific_args(parent_parser) - parent_parser.add_argument( - "--num_estimators", - type=int, - default=None, - help="Number of estimators for ensemble", +def _regression_routine_checks(num_estimators: int, output_dim: int) -> None: + if num_estimators < 1: + raise ValueError( + f"num_estimators must be positive, got {num_estimators}." ) - return parent_parser + + if output_dim < 1: + raise ValueError(f"output_dim must be positive, got {output_dim}.") diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py new file mode 100644 index 00000000..a3227dcf --- /dev/null +++ b/torch_uncertainty/routines/segmentation.py @@ -0,0 +1,155 @@ +from einops import rearrange +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import Tensor, nn +from torch.optim import Optimizer +from torchmetrics import Accuracy, MetricCollection +from torchvision.transforms.v2 import functional as F + +from torch_uncertainty.metrics import ( + CE, + BrierScore, + CategoricalNLL, + MeanIntersectionOverUnion, +) + + +class SegmentationRoutine(LightningModule): + def __init__( + self, + model: nn.Module, + num_classes: int, + loss: nn.Module, + num_estimators: int = 1, + optim_recipe: dict | Optimizer | None = None, + format_batch_fn: nn.Module | None = None, + ) -> None: + """Routine for efficient training and testing on **segmentation tasks** + using LightningModule. + + Args: + model (torch.nn.Module): Model to train. + num_classes (int): Number of classes in the segmentation task. + loss (torch.nn.Module): Loss function to optimize the :attr:`model`. + num_estimators (int, optional): The number of estimators for the + ensemble. Defaults to ̀`1̀` (single model). + optim_recipe (dict or Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. + format_batch_fn (torch.nn.Module, optional): The function to format the + batch. Defaults to ``None``. + + Warning: + You must define :attr:`optim_recipe` if you do not use + the CLI. + + Note: + :attr:`optim_recipe` can be anything that can be returned by + :meth:`LightningModule.configure_optimizers()`. Find more details + `here `_. + """ + super().__init__() + _segmentation_routine_checks(num_estimators, num_classes) + + self.model = model + self.num_classes = num_classes + self.loss = loss + self.num_estimators = num_estimators + + if format_batch_fn is None: + format_batch_fn = nn.Identity() + + self.optim_recipe = optim_recipe + self.format_batch_fn = format_batch_fn + + # metrics + seg_metrics = MetricCollection( + { + "Acc": Accuracy(task="multiclass", num_classes=num_classes), + "ECE": CE(task="multiclass", num_classes=num_classes), + "mIoU": MeanIntersectionOverUnion(num_classes=num_classes), + "Brier": BrierScore(num_classes=num_classes), + "NLL": CategoricalNLL(), + }, + compute_groups=[["Acc", "mIoU"], ["ECE"], ["Brier"], ["NLL"]], + ) + + self.val_seg_metrics = seg_metrics.clone(prefix="seg_val/") + self.test_seg_metrics = seg_metrics.clone(prefix="seg_test/") + + def configure_optimizers(self) -> Optimizer | dict: + return self.optim_recipe + + def forward(self, img: Tensor) -> Tensor: + return self.model(img) + + def on_train_start(self) -> None: + init_metrics = dict.fromkeys(self.val_seg_metrics, 0) + init_metrics.update(dict.fromkeys(self.test_seg_metrics, 0)) + + self.logger.log_hyperparams(self.hparams, init_metrics) + + def training_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> STEP_OUTPUT: + img, target = batch + img, target = self.format_batch_fn((img, target)) + logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) + logits = rearrange(logits, "b c h w -> (b h w) c") + target = target.flatten() + valid_mask = target != 255 + loss = self.loss(logits[valid_mask], target[valid_mask]) + self.log("train_loss", loss) + return loss + + def validation_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> None: + img, target = batch + logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) + logits = rearrange( + logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + ) + probs_per_est = logits.softmax(dim=-1) + probs = probs_per_est.mean(dim=1) + target = target.flatten() + valid_mask = target != 255 + self.val_seg_metrics.update(probs[valid_mask], target[valid_mask]) + + def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + img, target = batch + logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) + logits = rearrange( + logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + ) + probs_per_est = logits.softmax(dim=-1) + probs = probs_per_est.mean(dim=1) + target = target.flatten() + valid_mask = target != 255 + self.test_seg_metrics.update(probs[valid_mask], target[valid_mask]) + + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_seg_metrics.compute()) + self.val_seg_metrics.reset() + + def on_test_epoch_end(self) -> None: + self.log_dict(self.test_seg_metrics.compute()) + self.test_seg_metrics.reset() + + +def _segmentation_routine_checks(num_estimators: int, num_classes: int) -> None: + if num_estimators < 1: + raise ValueError( + f"num_estimators must be positive, got {num_estimators}." + ) + + if num_classes < 2: + raise ValueError(f"num_classes must be at least 2, got {num_classes}.") diff --git a/torch_uncertainty/transforms/__init__.py b/torch_uncertainty/transforms/__init__.py index b2339f3c..d3aae6ec 100644 --- a/torch_uncertainty/transforms/__init__.py +++ b/torch_uncertainty/transforms/__init__.py @@ -8,6 +8,7 @@ Contrast, Equalize, Posterize, + RandomRescale, Rotate, Sharpen, Shear, diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index 29e035e8..600cea3d 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -25,7 +25,7 @@ def __init__(self, num_repeats: int) -> None: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat(self.num_repeats) + return inputs, targets.repeat_interleave(self.num_repeats, dim=0) class MIMOBatchFormat(nn.Module): diff --git a/torch_uncertainty/transforms/corruptions.py b/torch_uncertainty/transforms/corruptions.py index 193e5898..5235c6fe 100644 --- a/torch_uncertainty/transforms/corruptions.py +++ b/torch_uncertainty/transforms/corruptions.py @@ -24,16 +24,16 @@ from torch_uncertainty.datasets import FrostImages __all__ = [ - "GaussianNoise", - "ShotNoise", - "ImpulseNoise", - "SpeckleNoise", + "DefocusBlur", + "Frost", "GaussianBlur", + "GaussianNoise", "GlassBlur", - "DefocusBlur", + "ImpulseNoise", "JPEGCompression", "Pixelate", - "Frost", + "ShotNoise", + "SpeckleNoise", ] diff --git a/torch_uncertainty/transforms/image.py b/torch_uncertainty/transforms/image.py index 09c72903..2e8707eb 100644 --- a/torch_uncertainty/transforms/image.py +++ b/torch_uncertainty/transforms/image.py @@ -1,7 +1,11 @@ +from typing import Any + import torch -import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as F from PIL import Image, ImageEnhance from torch import Tensor, nn +from torchvision.transforms.v2 import InterpolationMode, Transform +from torchvision.transforms.v2._utils import query_size class AutoContrast(nn.Module): @@ -242,3 +246,79 @@ def forward( if isinstance(img, Tensor): img: Image.Image = F.to_pil_image(img) return ImageEnhance.Color(img).enhance(level) + + +class RandomRescale(Transform): + """Randomly rescale the input. + + This transformation can be used together with ``RandomCrop`` as data augmentations to train + models on image segmentation task. + + Output spatial size is randomly sampled from the interval ``[min_size, max_size]``: + + .. code-block:: python + + scale = uniform_sample(min_scale, max_scale) + output_width = input_width * scale + output_height = input_height * scale + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + min_scale (int): Minimum scale for random sampling + max_scale (int): Maximum scale for random sampling + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + def __init__( + self, + min_scale: int, + max_scale: int, + interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, + antialias: bool | None = True, + ) -> None: + super().__init__() + self.min_scale = min_scale + self.max_scale = max_scale + self.interpolation = interpolation + self.antialias = antialias + + def _get_params(self, flat_inputs: list[Any]) -> dict[str, Any]: + height, width = query_size(flat_inputs) + scale = torch.rand(1) + scale = self.min_scale + scale * (self.max_scale - self.min_scale) + return {"size": (int(height * scale), int(width * scale))} + + def _transform(self, inpt: Any, params: dict[str, Any]) -> Any: + return self._call_kernel( + F.resize, + inpt, + params["size"], + interpolation=self.interpolation, + antialias=self.antialias, + ) diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 4447b2bc..ec51ae34 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -16,11 +16,12 @@ def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: return 1 / (dist_rate + 1e-12) +# ruff: noqa: ERA001 # def tensor_linspace(start: Tensor, stop: Tensor, num: int): # """ # Creates a tensor of shape [num, *start.shape] whose values are evenly # spaced from start to end, inclusive. -# Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. +# Replicates but the multi-dimensional behaviour of numpy.linspace in PyTorch. # """ # # create a tensor of 'num' steps from 0 to 1 # steps = torch.arange(num, dtype=torch.float32, device=start.device) / ( @@ -31,7 +32,7 @@ def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: # # to allow for broadcastings # # using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here # # but torchscript -# # "cannot statically infer the expected size of a list in this contex", +# # "cannot statically infer the expected size of a list in this context", # # hence the code below # for i in range(start.ndim): # steps = steps.unsqueeze(-1) diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index e6b70312..de0547c7 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,4 +1,6 @@ # ruff: noqa: F401 from .checkpoints import get_version +from .cli import TULightningCLI from .hub import load_hf -from .misc import csv_writer +from .misc import create_train_val_split, csv_writer, plot_hist +from .trainer import TUTrainer diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py new file mode 100644 index 00000000..8b8659c4 --- /dev/null +++ b/torch_uncertainty/utils/cli.py @@ -0,0 +1,132 @@ +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.cli import ( + ArgsType, + LightningArgumentParser, + LightningCLI, + SaveConfigCallback, +) +from typing_extensions import override + +from torch_uncertainty.utils.trainer import TUTrainer + + +class TUSaveConfigCallback(SaveConfigCallback): + @override + def setup( + self, trainer: Trainer, pl_module: LightningModule, stage: str + ) -> None: + if self.already_saved: + return + + if self.save_to_log_dir and stage == "fit": + log_dir = trainer.log_dir # this broadcasts the directory + assert log_dir is not None + config_path = Path(log_dir) / self.config_filename + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = ( + fs.isfile(config_path) if trainer.is_global_zero else False + ) + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: # coverage: ignore + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' + ) + + if trainer.is_global_zero: + fs.makedirs(log_dir, exist_ok=True) + self.parser.save( + self.config, + config_path, + skip_none=False, + overwrite=self.overwrite, + multifile=self.multifile, + ) + if trainer.is_global_zero: + self.save_config(trainer, pl_module, stage) + self.already_saved = True + + self.already_saved = trainer.strategy.broadcast(self.already_saved) + + +class TULightningCLI(LightningCLI): + def __init__( + self, + model_class: ( + type[LightningModule] | Callable[..., LightningModule] | None + ) = None, + datamodule_class: ( + type[LightningDataModule] + | Callable[..., LightningDataModule] + | None + ) = None, + save_config_callback: type[SaveConfigCallback] + | None = TUSaveConfigCallback, + save_config_kwargs: dict[str, Any] | None = None, + trainer_class: type[Trainer] | Callable[..., Trainer] = TUTrainer, + trainer_defaults: dict[str, Any] | None = None, + seed_everything_default: bool | int = True, + parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, + subclass_mode_model: bool = False, + subclass_mode_data: bool = False, + args: ArgsType = None, + run: bool = True, + auto_configure_optimizers: bool = True, + eval_after_fit_default: bool = False, + ) -> None: + """Custom LightningCLI for torch-uncertainty. + + Args: + model_class (type[LightningModule] | Callable[..., LightningModule] | None, optional): _description_. Defaults to None. + datamodule_class (type[LightningDataModule] | Callable[..., LightningDataModule] | None, optional): _description_. Defaults to None. + save_config_callback (type[SaveConfigCallback] | None, optional): _description_. Defaults to TUSaveConfigCallback. + save_config_kwargs (dict[str, Any] | None, optional): _description_. Defaults to None. + trainer_class (type[Trainer] | Callable[..., Trainer], optional): _description_. Defaults to Trainer. + trainer_defaults (dict[str, Any] | None, optional): _description_. Defaults to None. + seed_everything_default (bool | int, optional): _description_. Defaults to True. + parser_kwargs (dict[str, Any] | dict[str, dict[str, Any]] | None, optional): _description_. Defaults to None. + subclass_mode_model (bool, optional): _description_. Defaults to False. + subclass_mode_data (bool, optional): _description_. Defaults to False. + args (ArgsType, optional): _description_. Defaults to None. + run (bool, optional): _description_. Defaults to True. + auto_configure_optimizers (bool, optional): _description_. Defaults to True. + eval_after_fit_default (bool, optional): _description_. Defaults to False. + """ + self.eval_after_fit_default = eval_after_fit_default + super().__init__( + model_class, + datamodule_class, + save_config_callback, + save_config_kwargs, + trainer_class, + trainer_defaults, + seed_everything_default, + parser_kwargs, + subclass_mode_model, + subclass_mode_data, + args, + run, + auto_configure_optimizers, + ) + + def add_default_arguments_to_parser( + self, parser: LightningArgumentParser + ) -> None: + """Adds default arguments to the parser.""" + parser.add_argument( + "--eval_after_fit", + action="store_true", + default=self.eval_after_fit_default, + ) + super().add_default_arguments_to_parser(parser) diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py new file mode 100644 index 00000000..ab2336bc --- /dev/null +++ b/torch_uncertainty/utils/distributions.py @@ -0,0 +1,180 @@ +from numbers import Number + +import torch +from einops import rearrange +from torch import Tensor +from torch.distributions import Distribution, Laplace, Normal, constraints +from torch.distributions.utils import broadcast_all + + +def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: + r"""Concatenate a list of distributions into a single distribution. + + Args: + distributions (list[Distribution]): The list of distributions. + dim (int): The dimension to concatenate. + + Returns: + Distribution: The concatenated distributions. + """ + dist_type = type(distributions[0]) + if not all( + isinstance(distribution, dist_type) for distribution in distributions + ): + raise ValueError("All distributions must have the same type.") + + if isinstance(distributions[0], Normal | Laplace): + locs = torch.cat( + [distribution.loc for distribution in distributions], dim=dim + ) + scales = torch.cat( + [distribution.scale for distribution in distributions], dim=dim + ) + return dist_type(loc=locs, scale=scales) + if isinstance(distributions[0], NormalInverseGamma): + locs = torch.cat( + [distribution.loc for distribution in distributions], dim=dim + ) + lmbdas = torch.cat( + [distribution.lmbda for distribution in distributions], dim=dim + ) + alphas = torch.cat( + [distribution.alpha for distribution in distributions], dim=dim + ) + betas = torch.cat( + [distribution.beta for distribution in distributions], dim=dim + ) + return dist_type(loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas) + raise NotImplementedError( + f"Concatenation of {dist_type} distributions is not supported." + "Raise an issue if needed." + ) + + +def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: + """Squeeze the distribution along a given dimension. + + Args: + distribution (Distribution): The distribution to squeeze. + dim (int): The dimension to squeeze. + + Returns: + Distribution: The squeezed distribution. + """ + dist_type = type(distribution) + if isinstance(distribution, Normal | Laplace): + loc = distribution.loc.squeeze(dim) + scale = distribution.scale.squeeze(dim) + return dist_type(loc=loc, scale=scale) + if isinstance(distribution, NormalInverseGamma): + loc = distribution.loc.squeeze(dim) + lmbda = distribution.lmbda.squeeze(dim) + alpha = distribution.alpha.squeeze(dim) + beta = distribution.beta.squeeze(dim) + return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) + raise NotImplementedError( + f"Squeezing of {dist_type} distributions is not supported." + "Raise an issue if needed." + ) + + +def dist_rearrange( + distribution: Distribution, pattern: str, **axes_lengths: int +) -> Distribution: + dist_type = type(distribution) + if isinstance(distribution, Normal | Laplace): + loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) + scale = rearrange(distribution.scale, pattern=pattern, **axes_lengths) + return dist_type(loc=loc, scale=scale) + if isinstance(distribution, NormalInverseGamma): + loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) + lmbda = rearrange(distribution.lmbda, pattern=pattern, **axes_lengths) + alpha = rearrange(distribution.alpha, pattern=pattern, **axes_lengths) + beta = rearrange(distribution.beta, pattern=pattern, **axes_lengths) + return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) + raise NotImplementedError( + f"Ensemble distribution of {dist_type} is not supported." + "Raise an issue if needed." + ) + + +class NormalInverseGamma(Distribution): + arg_constraints = { + "loc": constraints.real, + "lmbda": constraints.positive, + "alpha": constraints.greater_than_eq(1), + "beta": constraints.positive, + } + support = constraints.real + has_rsample = False + + def __init__( + self, + loc: Number | Tensor, + lmbda: Number | Tensor, + alpha: Number | Tensor, + beta: Number | Tensor, + validate_args: bool | None = None, + ) -> None: + self.loc, self.lmbda, self.alpha, self.beta = broadcast_all( + loc, lmbda, alpha, beta + ) + if ( + isinstance(loc, Number) + and isinstance(lmbda, Number) + and isinstance(alpha, Number) + and isinstance(beta, Number) + ): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + @property + def mean(self) -> Tensor: + """Impromper mean of the NormalInverseGamma distribution. + + This value is necessary to perform point-wise predictions in the + regression routine. + """ + return self.loc + + def mode(self) -> None: + raise NotImplementedError( + "Mode is not meaningful for the NormalInverseGamma distribution" + ) + + def stddev(self) -> None: + raise NotImplementedError( + "Standard deviation is not meaningful for the NormalInverseGamma distribution" + ) + + def variance(self) -> None: + raise NotImplementedError( + "Variance is not meaningful for the NormalInverseGamma distribution" + ) + + @property + def mean_loc(self) -> Tensor: + return self.loc + + @property + def mean_variance(self) -> Tensor: + return self.beta / (self.alpha - 1) + + @property + def variance_loc(self) -> Tensor: + return self.beta / (self.alpha - 1) / self.lmbda + + def log_prob(self, value: Tensor) -> Tensor: + if self._validate_args: # coverage: ignore + self._validate_sample(value) + gam: Tensor = 2 * self.beta * (1 + self.lmbda) + return ( + -0.5 * torch.log(torch.pi / self.lmbda) + + self.alpha * gam.log() + - (self.alpha + 0.5) + * torch.log(gam + self.lmbda * (value - self.loc) ** 2) + - torch.lgamma(self.alpha) + + torch.lgamma(self.alpha + 0.5) + ) diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py new file mode 100644 index 00000000..ac11c02a --- /dev/null +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -0,0 +1,126 @@ +import os +import shutil +import sys +from typing import Any + +from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE +from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop +from lightning.pytorch.trainer.connectors.logger_connector.result import ( + _OUT_DICT, +) +from lightning_utilities.core.apply_func import apply_to_collection +from torch import Tensor + + +class TUEvaluationLoop(_EvaluationLoop): + @staticmethod + def _print_results(results: list[_OUT_DICT], stage: str) -> None: + # remove the dl idx suffix + results = [ + {k.split("/dataloader_idx_")[0]: v for k, v in result.items()} + for result in results + ] + metrics_paths = { + k + for keys in apply_to_collection( + results, dict, _EvaluationLoop._get_keys + ) + for k in keys + } + if not metrics_paths: + return + + metrics_strs = [":".join(metric) for metric in metrics_paths] + # sort both lists based on metrics_strs + metrics_strs, metrics_paths = zip( + *sorted(zip(metrics_strs, metrics_paths, strict=False)), + strict=False, + ) + + if len(results) == 2: + headers = ["In-Distribution", "Out-of-Distribution"] + else: + headers = [f"DataLoader {i}" for i in range(len(results))] + + # fallback is useful for testing of printed output + term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 + max_length = int( + min( + max( + len(max(metrics_strs, key=len)), + len(max(headers, key=len)), + 25, + ), + term_size / 2, + ) + ) + + rows: list[list[Any]] = [[] for _ in metrics_paths] + + for result in results: + for metric, row in zip(metrics_paths, rows, strict=False): + val = _EvaluationLoop._find_value(result, metric) + if val is not None: + if isinstance(val, Tensor): + val = val.item() if val.numel() == 1 else val.tolist() + row.append(f"{val:.5f}") + else: + row.append(" ") + + # keep one column with max length for metrics + num_cols = int((term_size - max_length) / max_length) + + for i in range(0, len(headers), num_cols): + table_headers = headers[i : (i + num_cols)] + table_rows = [row[i : (i + num_cols)] for row in rows] + + table_headers.insert(0, f"{stage} Metric".capitalize()) + + if _RICH_AVAILABLE: + from rich import get_console + from rich.table import Column, Table + + columns = [ + Column( + h, justify="center", style="magenta", width=max_length + ) + for h in table_headers + ] + columns[0].style = "cyan" + + table = Table(*columns) + for metric, row in zip(metrics_strs, table_rows, strict=False): + row.insert(0, metric) + table.add_row(*row) + + console = get_console() + console.print(table) + else: # coverage: ignore + row_format = f"{{:^{max_length}}}" * len(table_headers) + half_term_size = int(term_size / 2) + + try: + # some terminals do not support this character + if sys.stdout.encoding is not None: + "─".encode(sys.stdout.encoding) + except UnicodeEncodeError: + bar_character = "-" + else: + bar_character = "─" + bar = bar_character * term_size + + lines = [bar, row_format.format(*table_headers).rstrip(), bar] + for metric, row in zip(metrics_strs, table_rows, strict=False): + # deal with column overflow + if len(metric) > half_term_size: + while len(metric) > half_term_size: + row_metric = metric[:half_term_size] + metric = metric[half_term_size:] + lines.append( + row_format.format(row_metric, *row).rstrip() + ) + lines.append(row_format.format(metric, " ").rstrip()) + else: + lines.append(row_format.format(metric, *row).rstrip()) + lines.append(bar) + print(os.linesep.join(lines)) diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index 33a43c92..ab5d697d 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -1,6 +1,14 @@ +import copy import csv +from collections.abc import Callable from pathlib import Path +import matplotlib.pyplot as plt +import torch +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from torch.utils.data import Dataset, random_split + def csv_writer(path: Path, dic: dict) -> None: """Write a dictionary to a csv file. @@ -15,7 +23,6 @@ def csv_writer(path: Path, dic: dict) -> None: else: append_mode = False rw_mode = "w" - # Write dic with path.open(rw_mode) as csvfile: writer = csv.writer(csvfile, delimiter=",") @@ -23,3 +30,53 @@ def csv_writer(path: Path, dic: dict) -> None: if append_mode is False: writer.writerow(dic.keys()) writer.writerow([f"{elem:.4f}" for elem in dic.values()]) + + +def plot_hist( + conf: list[torch.Tensor], + bins: int = 20, + title: str = "Histogram with 'auto' bins", + dpi: int = 60, +) -> tuple[Figure, Axes]: + """Plot a confidence histogram. + + Args: + conf (Any): The confidence values. + bins (int, optional): The number of bins. Defaults to 20. + title (str, optional): The title of the plot. Defaults to "Histogram + with 'auto' bins". + dpi (int, optional): The dpi of the plot. Defaults to 60. + + Returns: + Tuple[Figure, Axes]: The figure and axes of the plot. + """ + plt.rc("axes", axisbelow=True) + fig, ax = plt.subplots(1, figsize=(7, 5), dpi=dpi) + for i in [1, 0]: + ax.hist( + conf[i], + bins=bins, + density=True, + label=["In-distribution", "Out-of-Distribution"][i], + alpha=0.4, + linewidth=1, + edgecolor=["#0d559f", "#d45f00"][i], + color=["#1f77b4", "#ff7f0e"][i], + ) + + ax.set_title(title) + plt.grid(True, linestyle="--", alpha=0.7, zorder=0) + plt.legend() + fig.tight_layout() + return fig, ax + + +def create_train_val_split( + dataset: Dataset, + val_split_rate: float, + val_transforms: Callable | None = None, +) -> tuple[Dataset, Dataset]: + train, val = random_split(dataset, [1 - val_split_rate, val_split_rate]) + val = copy.deepcopy(val) + val.dataset.transform = val_transforms + return train, val diff --git a/torch_uncertainty/utils/trainer.py b/torch_uncertainty/utils/trainer.py new file mode 100644 index 00000000..e1b09f7d --- /dev/null +++ b/torch_uncertainty/utils/trainer.py @@ -0,0 +1,19 @@ +from lightning.pytorch import Trainer +from lightning.pytorch.trainer.states import ( + RunningStage, + TrainerFn, +) + +from torch_uncertainty.utils.evaluation_loop import TUEvaluationLoop + + +class TUTrainer(Trainer): + def __init__(self, inference_mode: bool = True, **kwargs): + super().__init__(inference_mode=inference_mode, **kwargs) + + self.test_loop = TUEvaluationLoop( + self, + TrainerFn.TESTING, + RunningStage.TESTING, + inference_mode=inference_mode, + )