From 5501556b094c174fee28ab7a8667c74267fa36f9 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 27 Jun 2024 17:01:08 +0200 Subject: [PATCH 01/14] :shirt: Fix emojis --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a66c6b85..52b0ed8e 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ This package provides a multi-level API, including: - easy-to-use :zap: lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. - ready-to-train baselines on research datasets, such as ImageNet and CIFAR -- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (:construction: work in progress :construction:). +- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR ( :construction: work in progress :construction: ). - **layers**, **models**, **metrics**, & **losses** available for use in your networks - scikit-learn style post-processing methods such as Temperature Scaling. @@ -59,7 +59,7 @@ We also provide the following methods: ### Baselines -To date, the following deep learning baselines have been implemented. **Click on the methods for tutorials**: +To date, the following deep learning baselines have been implemented. **Click** :inbox_tray: **on the methods for tutorials**: - [Deep Ensembles](https://torch-uncertainty.github.io/auto_tutorials/tutorial_from_de_to_pe.html), BatchEnsemble, Masksembles, & MIMO - [MC-Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) From 8d66a34795897abae9c924d25654fe049dce67ed Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 27 Jun 2024 17:12:57 +0200 Subject: [PATCH 02/14] :shirt: Fix typos in tutorials --- auto_tutorials_source/tutorial_bayesian.py | 2 -- auto_tutorials_source/tutorial_corruption.py | 1 - .../tutorial_evidential_classification.py | 15 +++----- .../tutorial_from_de_to_pe.py | 35 ++++++++++--------- 4 files changed, 23 insertions(+), 30 deletions(-) diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 68628d2a..d50c7bf7 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -37,8 +37,6 @@ neural network utils from torch.nn, as well as the partial util to provide the modified default arguments for the ELBO loss. """ - -# %% from pathlib import Path from lightning.pytorch import Trainer diff --git a/auto_tutorials_source/tutorial_corruption.py b/auto_tutorials_source/tutorial_corruption.py index 9e4e7a10..01b6a2d9 100644 --- a/auto_tutorials_source/tutorial_corruption.py +++ b/auto_tutorials_source/tutorial_corruption.py @@ -11,7 +11,6 @@ torch_uncertainty.transforms.corruptions. We also need to load utilities from torchvision and matplotlib. """ - from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, ToTensor, Resize diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index 1b780361..dccda568 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -22,12 +22,8 @@ - the evidential objective: the DECLoss from torch_uncertainty.losses - the datamodule that handles dataloaders & transforms: MNISTDataModule from torch_uncertainty.datamodules -We also need to define an optimizer using torch.optim, the neural network utils within torch.nn, as well as the partial util to provide -the modified default arguments for the DEC loss. +We also need to define an optimizer using torch.optim, the neural network utils within torch.nn. """ - -# %% -from functools import partial from pathlib import Path import torch @@ -73,11 +69,10 @@ def optim_lenet(model: nn.Module) -> dict: # %% # 4. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# Next, we need to define the loss to be used during training. To do this, we -# redefine the default parameters for the DEC loss using the partial -# function from functools. After that, we define the training routine using +# Next, we need to define the loss to be used during training. +# After that, we define the training routine using # the single classification model training routine from -# torch_uncertainty.routines.classification.ClassificationSingle. +# torch_uncertainty.routines.ClassificationRoutine. # In this routine, we provide the model, the DEC loss, the optimizer, # and all the default arguments. @@ -152,4 +147,4 @@ def rotated_mnist(angle: int) -> None: # References # ---------- # -# - **Deep Evidential Classification:** Murat Sensoy, Lance Kaplan, & Melih Kandemir (2018). Evidential Deep Learning to Quantify Classification Uncertainty `NeurIPS 2018 `_ +# - **Deep Evidential Classification:** Murat Sensoy, Lance Kaplan, & Melih Kandemir (2018). Evidential Deep Learning to Quantify Classification Uncertainty `NeurIPS 2018 `_. diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py index 2290a024..24933566 100644 --- a/auto_tutorials_source/tutorial_from_de_to_pe.py +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -1,11 +1,12 @@ -"""Improved Ensemble parameter-efficiency with Packed-Ensembles +""" +Improved Ensemble parameter-efficiency with Packed-Ensembles ============================================================ -*This tutorial is adapted from a notebook part of a lecture given at the [Helmholtz AI Conference](https://haicon24.de/) by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.* +*This tutorial is adapted from a notebook part of a lecture given at the `Helmholtz AI Conference `_ by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.* In this notebook will work on the MNIST dataset that was introduced by Corinna Cortes, Christopher J.C. Burges, and later modified by Yann LeCun in the foundational paper: -- [Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE.](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) +- `Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE. `_. The MNIST dataset consists of 70 000 images of handwritten digits from 0 to 9. The images are grayscale and 28x28-pixel sized. The task is to classify the images into their respective digits. The dataset can be automatically downloaded using the `torchvision` library. @@ -15,19 +16,19 @@ - Calibration error: a measure of the calibration of the predicted probabilities, - Negative Log-Likelihood: the value of the loss on the test set. -Throughout this notebook, we abstract the training and evaluation process using [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) -and [TorchUncertainty](https://torch-uncertainty.github.io/). +Throughout this notebook, we abstract the training and evaluation process using `PyTorch Lightning `_ +and `TorchUncertainty `_. Similarly to keras for tensorflow, PyTorch Lightning is a high-level interface for PyTorch that simplifies the training and evaluation process using a Trainer. TorchUncertainty is partly built on top of PyTorch Lightning and provides tools to train and evaluate models with uncertainty quantification. TorchUncertainty includes datamodules that handle the data loading and preprocessing. We don't use them here for tutorial purposes. -""" -# 1. Download, instantiate and visualize the datasets -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# -# The dataset is automatically downloaded using torchvision. We then visualize a few images to see a bit what we are working with. +1. Download, instantiate and visualize the datasets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The dataset is automatically downloaded using torchvision. We then visualize a few images to see a bit what we are working with. +""" # Create the transforms for the images import torch import torchvision.transforms as T @@ -138,7 +139,7 @@ def optim_recipe(model, lr_mult: float = 1.0): # %% -# To train the model, we use [TorchUncertainty](https://torch-uncertainty.github.io/), a library that we have developed to ease +# To train the model, we use `TorchUncertainty `_, a library that we have developed to ease # the training and evaluation of models with uncertainty. # # **Note:** To train supervised classification models we most often use the cross-entropy loss. @@ -240,7 +241,7 @@ def optim_recipe(model, lr_mult: float = 1.0): # We have put the pre-trained models on Hugging Face that you can download with the utility function # "hf_hub_download" imported just below. These models are trained for 75 epochs and are therefore not # comparable to the all the other models trained in this notebook. The pretrained models can be seen -# [here](https://huggingface.co/ENSTA-U2IS/tutorial-models) and TorchUncertainty's are [here](https://huggingface.co/torch-uncertainty). +# `here `_ and TorchUncertainty's are `here `_. from torch_uncertainty.utils.hub import hf_hub_download @@ -289,15 +290,15 @@ def optim_recipe(model, lr_mult: float = 1.0): # 4. From Deep Ensembles to Packed-Ensembles # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In the paper [Packed-Ensembles for Efficient Uncertainty Quantification](https://arxiv.org/abs/2210.09184) +# In the paper `Packed-Ensembles for Efficient Uncertainty Quantification `_ # published at the International Conference on Learning Representations (ICLR) in 2023, we introduced a # modification of Deep Ensembles to make it more computationally-efficient. The idea is to pack the ensemble # members into a single model, which allows us to train the ensemble in a single forward pass. # This modification is particularly useful when the ensemble size is large, as it is often the case in practice. # # We will need to update the model and replace the layers with their Packed equivalents. You can find the -# documentation of the Packed-Linear layer [here](https://torch-uncertainty.github.io/generated/torch_uncertainty.layers.PackedLinear.html), -# and the Packed-Conv2D, [here](https://torch-uncertainty.github.io/generated/torch_uncertainty.layers.PackedLinear.html). +# documentation of the Packed-Linear layer `here `_, +# and the Packed-Conv2D, `here `_. import torch import torch.nn as nn @@ -387,7 +388,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # %% # The training time should be approximately similar to the one of the single model that you trained before. However, please note that we are working with very small models, hence completely underusing your GPU. As such, the training time is not representative of what you would observe with larger models. # -# You can read more on Packed-Ensembles in the [paper](https://arxiv.org/abs/2210.09184) or the [Medium](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873) post. +# You can read more on Packed-Ensembles in the `paper `_ or the `Medium `_ post. # # To Go Further & More Concepts of Uncertainty in ML # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -413,4 +414,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Grouping Loss # ^^^^^^^^^^^^^ # -# The grouping loss is a measure of uncertainty orthogonal to calibration. Have a look at [this paper](https://arxiv.org/abs/2210.16315) to learn about it. Check out their small library [GLest](https://github.com/aperezlebel/glest). TorchUncertainty includes a wrapper of the library to compute the grouping loss with eval_grouping_loss parameter. +# The grouping loss is a measure of uncertainty orthogonal to calibration. Have a look at `this paper `_ to learn about it. Check out their small library `GLest `_. TorchUncertainty includes a wrapper of the library to compute the grouping loss with eval_grouping_loss parameter. From 23ee3b3eef1d7e08f0c80c60c819410f87f50ad6 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 28 Jun 2024 09:20:24 +0200 Subject: [PATCH 03/14] :shirt: Add NPY rule --- pyproject.toml | 1 + tests/_dummies/dataset.py | 21 +++++++-------- tests/transforms/test_image.py | 8 +++--- torch_uncertainty/layers/masksembles.py | 3 ++- torch_uncertainty/transforms/corruptions.py | 3 ++- torch_uncertainty/transforms/cutout.py | 5 ++-- torch_uncertainty/transforms/mixup.py | 16 ++++++----- torch_uncertainty/transforms/pixmix.py | 30 ++++++++++++--------- torch_uncertainty/utils/trainer.py | 2 +- 9 files changed, 51 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e214522..37f4068f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ lint.extend-select = [ "ISC", "ICN", "N", + "NPY", "PERF", "PIE", "PTH", diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 662e4f9f..9f39ce05 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -53,7 +53,7 @@ def __init__( else: shape = (num_images, num_channels, image_size, image_size) - self.data = np.random.randint( + self.data = np.random.default_rng().integers( low=0, high=255, size=shape, @@ -187,14 +187,15 @@ def __init__( smnt_shape = (num_images, 1, image_size, image_size) - self.data = np.random.randint( + rng = np.random.default_rng() + self.data = rng.integers( low=0, high=255, size=img_shape, dtype=np.uint8, ) - self.targets = np.random.randint( + self.targets = rng.integers( low=0, high=num_classes, size=smnt_shape, @@ -245,20 +246,18 @@ def __init__( else: smnt_shape = (num_images, output_dim, image_size, image_size) - self.data = np.random.randint( + rng = np.random.default_rng() + self.data = rng.integers( low=0, high=255, size=img_shape, dtype=np.uint8, ) - self.targets = ( - np.random.uniform( - low=0, - high=1, - size=smnt_shape, - ) - * 100 + self.targets = rng.uniform( + low=0, + high=100, + size=smnt_shape, ) def __getitem__(self, index: int) -> tuple[Any, Any]: diff --git a/tests/transforms/test_image.py b/tests/transforms/test_image.py index 872707d9..c79e5210 100644 --- a/tests/transforms/test_image.py +++ b/tests/transforms/test_image.py @@ -24,14 +24,16 @@ @pytest.fixture() def img_input() -> torch.Tensor: - imarray = np.random.rand(28, 28, 3) * 255 + rng = np.random.default_rng() + imarray = rng.uniform(low=0, high=255, size=(28, 28, 3)) return Image.fromarray(imarray.astype("uint8")).convert("RGB") @pytest.fixture() def tv_tensors_input() -> tuple[torch.Tensor, torch.Tensor]: - imarray1 = np.random.rand(3, 28, 28) * 255 - imarray2 = np.random.rand(1, 28, 28) * 255 + rng = np.random.default_rng() + imarray1 = rng.uniform(low=0, high=255, size=(3, 28, 28)) + imarray2 = rng.uniform(low=0, high=255, size=(1, 28, 28)) return ( tv_tensors.Image(imarray1.astype("uint8")), tv_tensors.Mask(imarray2.astype("uint8")), diff --git a/torch_uncertainty/layers/masksembles.py b/torch_uncertainty/layers/masksembles.py index 9444cef6..d8493c9e 100644 --- a/torch_uncertainty/layers/masksembles.py +++ b/torch_uncertainty/layers/masksembles.py @@ -23,12 +23,13 @@ def _generate_masks(m: int, n: int, s: float) -> np.ndarray: Returns: np.ndarray: Matrix of binary vectors. """ + rng = np.random.default_rng() total_positions = int(m * s) masks = [] for _ in range(n): new_vector = np.zeros([total_positions]) - idx = np.random.choice(range(total_positions), m, replace=False) + idx = rng.choice(range(total_positions), m, replace=False) new_vector[idx] = 1 masks.append(new_vector) diff --git a/torch_uncertainty/transforms/corruptions.py b/torch_uncertainty/transforms/corruptions.py index 5235c6fe..8b3d674c 100644 --- a/torch_uncertainty/transforms/corruptions.py +++ b/torch_uncertainty/transforms/corruptions.py @@ -283,6 +283,7 @@ def __repr__(self) -> str: class Frost(nn.Module): def __init__(self, severity: int) -> None: super().__init__() + self.rng = np.random.default_rng() if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") if not isinstance(severity, int): @@ -300,7 +301,7 @@ def forward(self, img: Tensor) -> Tensor: return img _, height, width = img.shape frost_img = RandomResizedCrop((height, width))( - self.frost_ds[np.random.randint(5)] + self.frost_ds[self.rng.integers(low=0, high=4)] ) return torch.clip(self.mix[0] * img + self.mix[1] * frost_img, 0, 1) diff --git a/torch_uncertainty/transforms/cutout.py b/torch_uncertainty/transforms/cutout.py index 4d83e474..9a91215f 100644 --- a/torch_uncertainty/transforms/cutout.py +++ b/torch_uncertainty/transforms/cutout.py @@ -12,6 +12,7 @@ def __init__(self, length: int, value: int = 0) -> None: value (int): Pixel value to be filled in the cutout square. """ super().__init__() + self.rng = np.random.default_rng() if length <= 0: raise ValueError("Cutout length must be positive.") @@ -26,8 +27,8 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: img = img.unsqueeze(0) h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) - y = np.random.randint(h) - x = np.random.randint(w) + y = self.rng.integers(low=0, high=h - 1) + x = self.rng.integers(low=0, high=w - 1) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index ec51ae34..bd26534d 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -2,7 +2,7 @@ import scipy import torch import torch.nn.functional as F -from torch import Tensor +from torch import Tensor, nn def beta_warping(x, alpha_cdf: float = 1.0, eps: float = 1e-12) -> float: @@ -79,20 +79,23 @@ def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: # return 1 / (dist_rate + 1e-12) -class AbstractMixup: +# TODO: Should be a torchvision transform +class AbstractMixup(nn.Module): def __init__( self, alpha: float = 1.0, mode: str = "batch", num_classes: int = 1000 ) -> None: + super().__init__() + self.rng = np.random.default_rng() self.alpha = alpha self.num_classes = num_classes self.mode = mode def _get_params(self, batch_size: int, device: torch.device): if self.mode == "batch": - lam = np.random.beta(self.alpha, self.alpha) + lam = self.rng.beta(a=self.alpha, b=self.alpha) else: lam = torch.as_tensor( - np.random.beta(self.alpha, self.alpha, batch_size), + self.rng.beta(a=self.alpha, b=self.alpha, size=batch_size), device=device, ) index = torch.randperm(batch_size, device=device) @@ -106,7 +109,6 @@ def _linear_mixing( ) -> Tensor: if isinstance(lam, Tensor): lam = lam.view(-1, *[1 for _ in range(inp.ndim - 1)]).float() - return lam * inp + (1 - lam) * inp[index, :] def _mix_target( @@ -202,9 +204,9 @@ def __init__( def _get_params(self, batch_size: int, device: torch.device): if self.mode == "batch": - lam = np.random.beta(self.alpha, self.alpha) + lam = self.rng.beta(a=self.alpha, b=self.alpha) else: - lam = np.random.beta(self.alpha, self.alpha, batch_size) + lam = self.rng.beta(a=self.alpha, b=self.alpha, size=batch_size) index = torch.randperm(batch_size, device=device) return lam, index diff --git a/torch_uncertainty/transforms/pixmix.py b/torch_uncertainty/transforms/pixmix.py index 84c19cce..48f5e893 100644 --- a/torch_uncertainty/transforms/pixmix.py +++ b/torch_uncertainty/transforms/pixmix.py @@ -8,12 +8,13 @@ def get_ab(beta: float) -> tuple[float, float]: - if np.random.random() < 0.5: - a = np.float32(np.random.beta(beta, 1)) - b = np.float32(np.random.beta(1, beta)) + rng = np.random.default_rng() + if rng.uniform(low=0, high=1) < 0.5: + a = np.float32(rng.beta(a=beta, b=1)) + b = np.float32(rng.beta(a=1, b=beta)) else: - a = 1 + np.float32(np.random.beta(1, beta)) - b = -np.float32(np.random.beta(1, beta)) + a = 1 + np.float32(rng.beta(a=1, b=beta)) + b = -np.float32(rng.beta(a=1, b=beta)) return a, b @@ -43,6 +44,7 @@ def __init__( augmentation_severity: float = 3, mixing_severity: float = 3, all_ops: bool = True, + seed: int = 12345, ) -> None: """PixMix augmentation class. @@ -53,11 +55,13 @@ def __init__( mixing_severity (float): Severity of mixing. all_ops (bool): Whether to use augmentations included in ImageNet-C. Defaults to True. + seed (int): Seed for random number generator. Defaults to 12345. Note: Default arguments are set to follow original guidelines. """ super().__init__() + self.rng = np.random.default_rng(seed) self.mixing_set = mixing_set self.num_mixing_images = len(mixing_set) self.mixing_iterations = mixing_iterations @@ -80,15 +84,17 @@ def __init__( self.aug_instances.append(aug()) def __call__(self, img: Image.Image) -> np.ndarray: - mixed = self.augment_input(img) if np.random.random() < 0.5 else img + # TODO: Fix + mixed = self.augment_input(img) if self.rng.random() < 0.5 else img - for _ in range(np.random.randint(self.mixing_iterations + 1)): - if np.random.random() < 0.5: + for _ in range(self.rng.integers(low=0, high=self.mixing_iterations)): + if self.rng.random() < 0.5: aug_image_copy = self._augment(img) else: - aug_image_copy = np.random.choice(self.num_mixing_images) + aug_image_copy = self.rng.choice(self.num_mixing_images) - mixed_op = np.random.choice(mixings) + # TODO: Fix + mixed_op = self.rng.choice(mixings) mixed = mixed_op( np.array(mixed), np.array(aug_image_copy), self.mixing_severity ) @@ -96,7 +102,7 @@ def __call__(self, img: Image.Image) -> np.ndarray: return mixed def _augment(self, image: Image.Image) -> np.ndarray: - op = np.random.choice(self.aug_instances) + op = self.rng.choice(self.aug_instances, 1) if op.level_type is int: aug_level = self._sample_int(op.pixmix_max_level) else: @@ -104,7 +110,7 @@ def _augment(self, image: Image.Image) -> np.ndarray: return op(image.copy(), aug_level) def _sample_level(self) -> float: - return np.random.uniform(low=0.1, high=self.augmentation_severity) + return self.rng.uniform(low=0.1, high=self.augmentation_severity) def _sample_int(self, maxval: int) -> int: """Helper method to scale `level` between 0 and maxval. diff --git a/torch_uncertainty/utils/trainer.py b/torch_uncertainty/utils/trainer.py index e1b09f7d..45ef8fc3 100644 --- a/torch_uncertainty/utils/trainer.py +++ b/torch_uncertainty/utils/trainer.py @@ -8,7 +8,7 @@ class TUTrainer(Trainer): - def __init__(self, inference_mode: bool = True, **kwargs): + def __init__(self, inference_mode: bool = True, **kwargs) -> None: super().__init__(inference_mode=inference_mode, **kwargs) self.test_loop = TUEvaluationLoop( From ac4919f1e8c1a59d5a5623611508d7863dd1188a Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 28 Jun 2024 13:52:16 +0200 Subject: [PATCH 04/14] :shirt: Restore ruff's conciseness on commits --- .github/workflows/run-tests.yml | 2 +- .pre-commit-config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index a870c00e..86564ecc 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -70,7 +70,7 @@ jobs: - name: Check style & format if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m ruff check torch_uncertainty --no-fix + python3 -m ruff check torch_uncertainty --no-fix --statistics python3 -m ruff format torch_uncertainty --check - name: Test with pytest and compute coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e63bb16f..39a2aff4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: language: python types_or: [python, pyi] require_serial: true - args: [--force-exclude, --fix, --exit-non-zero-on-fix] + args: [--force-exclude, --fix, --exit-non-zero-on-fix, --output-format, concise] exclude: ^auto_tutorials_source/ - id: ruff-format name: ruff-format From 657341abc17addb2c34ec4cd962badc4ad9cb024 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 8 Jul 2024 15:46:06 +0200 Subject: [PATCH 05/14] :shirt: Improve SWAG docs. --- torch_uncertainty/models/wrappers/swag.py | 26 +++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index fb12588c..e96bbe87 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -30,6 +30,11 @@ def __init__( posterior after each update. Uses the SWAG posterior estimation only at test time. Otherwise, uses the base model for training. + Call :meth:`update_wrapper` at the end of each epoch. It will update + the SWAG posterior if the current epoch number minus :attr:`cycle_start` + is a multiple of :attr:`cycle_length`. Call :meth:`bn_update` to update + the batchnorm statistics of the current SWAG samples. + Args: model (nn.Module): PyTorch model to be trained. cycle_start (int): Begininning of the first SWAG averaging cycle. @@ -65,13 +70,19 @@ def __init__( self.fit = False self.samples = [] - def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + def eval_forward(self, x: Tensor) -> Tensor: + """Forward pass of the SWAG model when in eval mode.""" if not self.fit: return self.core_model.forward(x) return torch.cat([mod.to(device=x.device)(x) for mod in self.samples]) def initialize_stats(self) -> None: - """Initialize the SWAG dictionary of statistics.""" + """Initialize the SWAG dictionary of statistics. + + For each parameter, we create a mean, squared mean, and covariance + square root. The covariance square root is only used when + `diag_covariance` is False. + """ self.swag_stats = {} for name_p, param in self.core_model.named_parameters(): mean, squared_mean = ( @@ -140,7 +151,7 @@ def update_wrapper(self, epoch: int) -> None: self.need_bn_update = True self.fit = True - def bn_update(self, loader: DataLoader, device) -> None: + def bn_update(self, loader: DataLoader, device: torch.device) -> None: """Update the bachnorm statistics of the current SWAG samples. Args: @@ -214,18 +225,21 @@ def _fullrank_sample( param.data = sample.to(device="cpu", dtype=param.dtype) return new_sample - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_state_dict(self, destination, prefix: str, keep_vars: bool): + """Add the SWAG statistics to the destination dict.""" super()._save_to_state_dict(destination, prefix, keep_vars) destination |= self.swag_stats def state_dict( self, *args, destination=None, prefix="", keep_vars=False - ) -> dict[str, Tensor]: + ) -> Mapping: + """Add the SWAG statistics to the state dict.""" return self.swag_stats | super().state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars ) - def _load_swag_stats(self, state_dict): + def _load_swag_stats(self, state_dict: Mapping): + """Load the SWAG statistics from the state dict.""" self.swag_stats = { k: v for k, v in state_dict.items() if k in self.swag_stats } From 6d3f248823837866c92ea1b4db551b2fd3ff28a3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 15 Jul 2024 12:22:56 +0200 Subject: [PATCH 06/14] :sparkles: Improve dropout wrapper --- auto_tutorials_source/tutorial_mc_dropout.py | 42 +++++----- tests/models/wrappers/test_mc_dropout.py | 31 ++++---- torch_uncertainty/models/lenet.py | 21 ++--- .../models/wrappers/mc_dropout.py | 79 +++++++++++-------- 4 files changed, 92 insertions(+), 81 deletions(-) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index e19eee61..de8e829c 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -19,9 +19,9 @@ First, we have to load the following utilities from TorchUncertainty: -- the Trainer from Lightning +- the TUTrainer from TorchUncertainty utils - the datamodule handling dataloaders: MNISTDataModule from torch_uncertainty.datamodules -- the model: LeNet, which lies in torch_uncertainty.models +- the model: lenet from torch_uncertainty.models - the MC Dropout wrapper: mc_dropout, from torch_uncertainty.models.wrappers - the classification training & evaluation routine in the torch_uncertainty.routines - an optimization recipe in the torch_uncertainty.optim_recipes module. @@ -29,10 +29,9 @@ We also need import the neural network utils within `torch.nn`. """ -# %% from pathlib import Path -from lightning.pytorch import Trainer +from torch_uncertainty.utils import TUTrainer from torch import nn from torch_uncertainty.datamodules import MNISTDataModule @@ -42,18 +41,17 @@ from torch_uncertainty.routines import ClassificationRoutine # %% -# 2. Creating the necessary variables -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 2. Defining the Model and the Trainer +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In the following, we will need to define the root of the datasets and the -# logs, and to fake-parse the arguments needed for using the PyTorch Lightning -# Trainer. We also create the datamodule that handles the MNIST dataset, +# In the following, we first create the trainer and instantiate +# the datamodule that handles the MNIST dataset, # dataloaders and transforms. We create the model using the -# blueprint from torch_uncertainty.models and we wrap it into mc_dropout. -# -# It is important to add a ``dropout_rate`` argument in your model to use Monte Carlo dropout. +# blueprint from torch_uncertainty.models and we wrap it into an mc_dropout. +# To use the mc_dropout wrapper, **make sure that you use dropout modules** and +# not functionals. Moreover, **they have to be** instantiated in the __init__ method. -trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) +trainer = TUTrainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule root = Path("data") @@ -63,7 +61,7 @@ model = lenet( in_channels=datamodule.num_channels, num_classes=datamodule.num_classes, - dropout_rate=0.5, + dropout_rate=0.4, ) mc_model = mc_dropout(model, num_estimators=16, last_layer=False) @@ -71,10 +69,10 @@ # %% # 3. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# This is a classification problem, and we use CrossEntropyLoss as the likelihood. +# This is a classification problem, and we use CrossEntropyLoss as the (negative-log-)likelihood. # We define the training routine using the classification training routine from -# torch_uncertainty.routines.classification. We provide the number of classes -# and channels, the optimizer wrapper, and the dropout rate. +# torch_uncertainty.routines. We provide the number of classes +# the optimization recipe, and tell the routine that our model is an ensemble at evalutation time. routine = ClassificationRoutine( num_classes=datamodule.num_classes, @@ -87,15 +85,19 @@ # %% # 4. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now train the model using the trainer. We pass the routine and the datamodule +# to the fit and test methods of the trainer. It will automatically evaluate some uncertainty +# metrics that you will find in the table below. trainer.fit(model=routine, datamodule=datamodule) -trainer.test(model=routine, datamodule=datamodule) +results = trainer.test(model=routine, datamodule=datamodule) # %% # 5. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # Now that the model is trained, let's test it on MNIST. Don't forget to call -# .eval() to enable dropout at inference. +# .eval() to enable dropout at evaluation and get multiple (here 16) predictions. import matplotlib.pyplot as plt import numpy as np @@ -127,7 +129,7 @@ def imshow(img): for j in range(6): values, predicted = torch.max(probs[:, j], 1) print( - f"Predicted digits for the image {j+1}: ", + f"MC-Dropout predictions for the image {j+1}: ", " ".join([str(image_id.item()) for image_id in predicted]), ) diff --git a/tests/models/wrappers/test_mc_dropout.py b/tests/models/wrappers/test_mc_dropout.py index 23a70c6a..bf63d5b1 100644 --- a/tests/models/wrappers/test_mc_dropout.py +++ b/tests/models/wrappers/test_mc_dropout.py @@ -35,31 +35,32 @@ def test_mc_dropout_eval(self): def test_mc_dropout_errors(self): model = dummy_model(10, 5, 0.1) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="`num_estimators` must be strictly positive" + ): MCDropout( model=model, num_estimators=-1, last_layer=True, on_batch=True ) - with pytest.raises(ValueError): - MCDropout( - model=model, num_estimators=0, last_layer=False, on_batch=False - ) - dropout_model = mc_dropout(model, 5) - with pytest.raises(TypeError): + with pytest.raises( + TypeError, match="Training mode is expected to be boolean" + ): dropout_model.train(mode=1) - with pytest.raises(TypeError): + with pytest.raises( + TypeError, match="Training mode is expected to be boolean" + ): dropout_model.train(mode=None) - del model.dropout_rate - with pytest.raises(ValueError): + model = dummy_model(10, 5, 0.0) + with pytest.raises( + ValueError, + match="At least one dropout module must have a dropout rate", + ): dropout_model = mc_dropout(model, 5) - model = dummy_model(10, 5, 0.1) - with pytest.raises(ValueError): - dropout_model = mc_dropout(model, None) - model = dummy_model(10, 5, dropout_rate=0) + del model.dropout with pytest.raises(ValueError): - dropout_model = mc_dropout(model, None) + dropout_model = mc_dropout(model, 5) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 6804c6f9..36b175c0 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -51,35 +51,28 @@ def __init__( ) if batchnorm: self.norm1 = norm(6) + self.conv_dropout = nn.Dropout2d(p=dropout_rate) self.conv2 = conv2d_layer(6, 16, (5, 5), groups=groups, **layer_args) if batchnorm: self.norm2 = norm(16) self.pooling = nn.AdaptiveAvgPool2d((4, 4)) self.fc1 = linear_layer(256, 120, **layer_args) + self.fc_dropout = nn.Dropout(p=dropout_rate) self.fc2 = linear_layer(120, 84, **layer_args) + self.last_fc_dropout = nn.Dropout(p=dropout_rate) self.fc3 = linear_layer(84, num_classes, **layer_args) def forward(self, x: torch.Tensor) -> torch.Tensor: - out = F.dropout( - self.activation(self.norm1(self.conv1(x))), - p=self.dropout_rate, - ) + out = self.conv_dropout(self.activation(self.norm1(self.conv1(x)))) out = F.max_pool2d(out, 2) - out = F.dropout( - self.activation(self.norm2(self.conv2(out))), - p=self.dropout_rate, - ) + out = self.conv_dropout(self.activation(self.norm2(self.conv2(out)))) out = F.max_pool2d(out, 2) out = self.pooling(out) out = torch.flatten(out, 1) - out = F.dropout( + out = self.fc_dropout( self.activation(self.fc1(out)), - p=self.dropout_rate, - ) - out = F.dropout( - self.activation(self.fc2(out)), - p=self.dropout_rate, ) + out = self.last_fc_dropout(self.activation(self.fc2(out))) return self.fc3(out) diff --git a/torch_uncertainty/models/wrappers/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py index 986d23a8..d56a6548 100644 --- a/torch_uncertainty/models/wrappers/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -1,6 +1,8 @@ import torch from torch import Tensor, nn +DROPOUT_MODULES = (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d) + class MCDropout(nn.Module): def __init__( @@ -14,50 +16,49 @@ def __init__( Args: model (nn.Module): model to wrap - num_estimators (int): number of estimators to use + num_estimators (int): number of estimators to use during the + evaluation last_layer (bool): whether to apply dropout to the last layer only. - on_batch (bool): Increase the batch_size to perform MC-Dropout. - Otherwise in a for loop. - - Warning: - Apply dropout using modules and not functional for this wrapper to - work as intended. + on_batch (bool): Perform the MC-Dropout on the batch-size. + Otherwise in a for loop. Useful when constrained in memory. Warning: - The underlying models must have a non-zero :attr:`dropout_rate` - attribute. + This module will work only if you apply dropout through modules + declared in the constructor (__init__). Warning: - For the `last-layer` option to work properly, the model must - declare the last dropout at the end of the initialization - (i.e. after all the other dropout layers). + The `last-layer` option disables the lastly initialized dropout + during evaluation: make sure that the last dropout is either + functional or a module of its own. """ super().__init__() - _dropout_checks(model, num_estimators) - self.last_layer = last_layer - self.on_batch = on_batch - self.core_model = model - self.num_estimators = num_estimators - - self.filtered_modules = list( + filtered_modules = list( filter( - lambda m: isinstance(m, nn.Dropout | nn.Dropout2d), + lambda m: isinstance(m, DROPOUT_MODULES), model.modules(), ) ) if last_layer: - self.filtered_modules = self.filtered_modules[-1:] + filtered_modules = filtered_modules[-1:] + + _dropout_checks(filtered_modules, num_estimators) + self.last_layer = last_layer + self.on_batch = on_batch + self.core_model = model + self.num_estimators = num_estimators + self.filtered_modules = filtered_modules def train(self, mode: bool = True) -> nn.Module: """Override the default train method to set the training mode of - each submodule to be the same as the module itself. + each submodule to be the same as the module itself except for the + selected dropout modules. Args: mode (bool, optional): whether to set the module to training mode. Defaults to True. """ - if not isinstance(mode, bool): # coverage: ignore - raise TypeError("training mode is expected to be boolean") + if not isinstance(mode, bool): + raise TypeError("Training mode is expected to be boolean") self.training = mode for module in self.children(): module.train(mode) @@ -69,6 +70,19 @@ def forward( self, x: Tensor, ) -> Tensor: + """Forward pass of the model. + + During training, the forward pass is the same as of the core model. + During evaluation, the forward pass is repeated `num_estimators` times + either on the batch size or in a for loop depending on + :attr:`last_layer`. + + Args: + x (Tensor): input tensor of shape (B, ...) + + Returns: + Tensor: output tensor of shape (:attr:`num_estimators` * B, ...) + """ if self.training: return self.core_model(x) if self.on_batch: @@ -96,7 +110,6 @@ def mc_dropout( on_batch (bool): Increase the batch_size to perform MC-Dropout. Otherwise in a for loop to reduce memory footprint. Defaults to true. - """ return MCDropout( model=model, @@ -106,17 +119,19 @@ def mc_dropout( ) -def _dropout_checks(model: nn.Module, num_estimators: int) -> None: - if not hasattr(model, "dropout_rate"): +def _dropout_checks( + filtered_modules: list[nn.Module], num_estimators: int +) -> None: + if not filtered_modules: raise ValueError( - "`dropout_rate` must be set in the model to use MC Dropout." + "No dropout module found in the model. " + "Please use `nn.Dropout`-like modules to apply dropout." ) - if model.dropout_rate <= 0.0: + # Check that at least one module has > 0.0 dropout rate + if not any(mod.p > 0.0 for mod in filtered_modules): raise ValueError( - "`dropout_rate` must be strictly positive to use MC Dropout." + "At least one dropout module must have a dropout rate > 0.0." ) - if num_estimators is None: - raise ValueError("`num_estimators` must be set to use MC Dropout.") if num_estimators <= 0: raise ValueError( "`num_estimators` must be strictly positive to use MC Dropout." From 7a077bf423cb574be4520f48eb9d9ba5e5321b66 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 16 Jul 2024 18:35:33 +0200 Subject: [PATCH 07/14] :bug: Fix TU issues with bare install --- torch_uncertainty/datasets/nyu.py | 13 ++++++++++++- .../datasets/regression/uci_regression.py | 7 +++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/datasets/nyu.py b/torch_uncertainty/datasets/nyu.py index 90c5736e..3c142a69 100644 --- a/torch_uncertainty/datasets/nyu.py +++ b/torch_uncertainty/datasets/nyu.py @@ -1,9 +1,9 @@ from collections.abc import Callable +from importlib import util from pathlib import Path from typing import Literal import cv2 -import h5py import numpy as np from PIL import Image from torchvision import tv_tensors @@ -14,6 +14,11 @@ download_url, ) +if util.find_spec("h5py"): + import h5py + + h5py_installed = True + class NYUv2(VisionDataset): root: Path @@ -48,6 +53,12 @@ def __init__( max_depth (float): Maximum depth value. Defaults to 10. download (bool): Download dataset if not found. Defaults to False. """ + if not h5py_installed: # coverage: ignore + raise ImportError( + "The h5py library is not installed. Please install" + "torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) super().__init__(Path(root) / "NYUv2", transforms=transforms) self.min_depth = min_depth self.max_depth = max_depth diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index 59722abc..d6c2fbee 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -5,6 +5,8 @@ if util.find_spec("pandas"): import pandas as pd + pandas_installed = True + import torch import torch.nn.functional as F from torch.utils.data import Dataset @@ -225,9 +227,10 @@ def download(self) -> None: def _make_dataset(self) -> None: """Create dataset from extracted files.""" - if not util.find_spec("pandas"): + if not pandas_installed: raise ImportError( - "Please install pandas manually to use the UCI datasets." + "Please install torch_uncertainty with the tabular option:" + """pip install -U "torch_uncertainty[tabular]".""" ) path = self.root / self.root_appendix / self.dataset_name if self.dataset_name == "boston": From cc4a04d7761e19bf96882286a5561031b0ded85c Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 17 Jul 2024 11:17:45 +0200 Subject: [PATCH 08/14] :bug: Fix AURC in multi-GPU settings --- .../baselines/classification/resnet.py | 4 +++ .../metrics/classification/risk_coverage.py | 30 +++++++++++-------- torch_uncertainty/models/resnet/std.py | 4 ++- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index d184cda0..47391ff6 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -56,6 +56,7 @@ def __init__( dropout_rate: float = 0.0, mixup_params: dict | None = None, last_layer_dropout: bool = False, + width_multiplier: float = 1.0, groups: int = 1, scale: float | None = None, alpha: int | None = None, @@ -113,6 +114,8 @@ def __init__( mixmode, dist_sim, kernel_tau_max, kernel_tau_std, mixup_alpha, and cutmix_alpha. If None, no augmentations. Defaults to ``None``. + width_multiplier (float, optional): Expansion factor affecting the width + of the estimators. Defaults to ``1.0`` groups (int, optional): Number of groups in convolutions. Defaults to ``1``. scale (float, optional): Expansion factor affecting the width of @@ -168,6 +171,7 @@ def __init__( "conv_bias": False, "dropout_rate": dropout_rate, "groups": groups, + "width_multiplier": width_multiplier, "in_channels": in_channels, "num_classes": num_classes, "style": style, diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py index 46c40c6d..264d363f 100644 --- a/torch_uncertainty/metrics/classification/risk_coverage.py +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -3,9 +3,9 @@ import matplotlib.pyplot as plt import numpy as np import torch -from sklearn.metrics import auc from torch import Tensor from torchmetrics.metric import Metric +from torchmetrics.utilities.compute import _auc_compute from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.plot import _AX_TYPE @@ -83,12 +83,15 @@ def compute(self) -> Tensor: Returns: Tensor: The AURC. """ - error_rates = self.partial_compute().cpu() + error_rates = self.partial_compute() num_samples = error_rates.size(0) - x = torch.arange(1, num_samples + 1, device="cpu") / num_samples - return torch.tensor([auc(x, error_rates)], device=self.device) / ( - 1 - 1 / num_samples + if num_samples < 2: + return torch.tensor([float("nan")], device=error_rates.device) + x = ( + torch.arange(1, num_samples + 1, device=error_rates.device) + / num_samples ) + return _auc_compute(x, error_rates) / (1 - 1 / num_samples) def plot( self, @@ -114,15 +117,15 @@ def plot( # Computation of AUSEC error_rates = self.partial_compute().cpu().flip(0) num_samples = error_rates.size(0) - rejection_rates = (np.arange(num_samples) / num_samples) * 100 - x = np.arange(num_samples) / num_samples - aurc = auc(x, error_rates) + x = torch.arange(num_samples) / num_samples + aurc = _auc_compute(x, error_rates).cpu().item() # reduce plot size plot_xs = np.arange(0.01, 100 + 0.01, 0.01) - xs = np.arange(start=1, stop=num_samples + 1, step=1) / num_samples - rejection_rates = np.interp(plot_xs, xs, rejection_rates) + xs = np.arange(start=1, stop=num_samples + 1) / num_samples + + rejection_rates = np.interp(plot_xs, xs, x * 100) error_rates = np.interp(plot_xs, xs, error_rates) # plot @@ -136,7 +139,7 @@ def plot( ax.text( 0.02, 0.95, - f"AUSEC={aurc:.3%}", + f"AUSEC={aurc:.2%}", color="black", ha="left", va="bottom", @@ -219,13 +222,16 @@ def compute(self) -> Tensor: scores = dim_zero_cat(self.scores) errors = dim_zero_cat(self.errors) num_samples = scores.size(0) + if num_samples < 1: + return torch.tensor([float("nan")], device=scores.device) error_rates = _aurc_rejection_rate_compute(scores, errors) admissible_risks = (error_rates > self.risk_threshold) * 1 max_cov_at_risk = admissible_risks.flip(0).argmin() + # check if max_cov_at_risk is really admissible, if not return nan risk = admissible_risks[max_cov_at_risk] if risk > self.risk_threshold: - return torch.tensor([float("nan")]) + return torch.tensor([float("nan")], device=scores.device) return 1 - max_cov_at_risk / num_samples diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index 0e643da7..cdf1303d 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -352,6 +352,7 @@ def resnet( arch: int, conv_bias: bool = True, dropout_rate: float = 0.0, + width_multiplier: float = 1.0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, @@ -368,6 +369,7 @@ def resnet( conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float): Dropout rate. Defaults to 0. + width_multiplier (float): Width multiplier. Defaults to 1. groups (int): Number of groups in convolutions. Defaults to 1. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. @@ -387,7 +389,7 @@ def resnet( dropout_rate=dropout_rate, groups=groups, style=style, - in_planes=64, + in_planes=int(64 * width_multiplier), activation_fn=activation_fn, normalization_layer=normalization_layer, ) From bfb96f86320aaba3237cc7fa497ee0c49e5ac615 Mon Sep 17 00:00:00 2001 From: Adrien Lafage Date: Wed, 17 Jul 2024 11:26:23 +0200 Subject: [PATCH 09/14] :shirt: Update tutorial_mc_dropout.py --- auto_tutorials_source/tutorial_mc_dropout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index de8e829c..4bd8373e 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -72,7 +72,7 @@ # This is a classification problem, and we use CrossEntropyLoss as the (negative-log-)likelihood. # We define the training routine using the classification training routine from # torch_uncertainty.routines. We provide the number of classes -# the optimization recipe, and tell the routine that our model is an ensemble at evalutation time. +# the optimization recipe and tell the routine that our model is an ensemble at evaluation time. routine = ClassificationRoutine( num_classes=datamodule.num_classes, From 1fc442e9d6b0f2ff642ca079fb88be739cfb1218 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 17 Jul 2024 11:40:27 +0200 Subject: [PATCH 10/14] :ok_hand: Fix tests & review --- torch_uncertainty/models/resnet/batched.py | 4 +++- torch_uncertainty/models/resnet/lpbnn.py | 2 ++ torch_uncertainty/models/resnet/masked.py | 4 +++- torch_uncertainty/models/resnet/mimo.py | 3 ++- torch_uncertainty/models/resnet/packed.py | 4 +++- torch_uncertainty/models/wrappers/mc_dropout.py | 5 ++--- 6 files changed, 15 insertions(+), 7 deletions(-) diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index 4b32d3f1..795cd13e 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -308,6 +308,7 @@ def batched_resnet( num_estimators: int, conv_bias: bool = True, dropout_rate: float = 0, + width_multiplier: float = 1.0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", normalization_layer: type[nn.Module] = nn.BatchNorm2d, @@ -322,6 +323,7 @@ def batched_resnet( conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float): Dropout rate. Defaults to 0. + width_multiplier (float): Width multiplier. Defaults to 1. groups (int): Number of groups within each estimator. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. @@ -341,6 +343,6 @@ def batched_resnet( dropout_rate=dropout_rate, groups=groups, style=style, - in_planes=64, + in_planes=int(64 * width_multiplier), normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/resnet/lpbnn.py b/torch_uncertainty/models/resnet/lpbnn.py index 83f22f58..36c4d103 100644 --- a/torch_uncertainty/models/resnet/lpbnn.py +++ b/torch_uncertainty/models/resnet/lpbnn.py @@ -322,6 +322,7 @@ def lpbnn_resnet( num_estimators: int, dropout_rate: float = 0, conv_bias: bool = True, + width_multiplier: float = 1.0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", ) -> _LPBNNResNet: @@ -336,4 +337,5 @@ def lpbnn_resnet( conv_bias=conv_bias, groups=groups, style=style, + in_planes=int(64 * width_multiplier), ) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 6af599df..45398891 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -325,6 +325,7 @@ def masked_resnet( arch: int, num_estimators: int, scale: float, + width_multiplier: float = 1.0, groups: int = 1, conv_bias: bool = True, dropout_rate: float = 0, @@ -339,6 +340,7 @@ def masked_resnet( arch (int): The architecture of the ResNet. num_estimators (int): Number of estimators in the ensemble. scale (float): The scale of the mask. + width_multiplier (float): Width multiplier. Defaults to 1. groups (int): Number of groups within each estimator. Defaults to 1. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. @@ -361,6 +363,6 @@ def masked_resnet( conv_bias=conv_bias, dropout_rate=dropout_rate, style=style, - in_planes=64, + in_planes=int(64 * width_multiplier), normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/resnet/mimo.py b/torch_uncertainty/models/resnet/mimo.py index 05a25e14..bf16a933 100644 --- a/torch_uncertainty/models/resnet/mimo.py +++ b/torch_uncertainty/models/resnet/mimo.py @@ -57,6 +57,7 @@ def mimo_resnet( num_estimators: int, conv_bias: bool = True, dropout_rate: float = 0.0, + width_multiplier: float = 1.0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", normalization_layer: type[nn.Module] = nn.BatchNorm2d, @@ -72,6 +73,6 @@ def mimo_resnet( dropout_rate=dropout_rate, groups=groups, style=style, - in_planes=64, + in_planes=int(64 * width_multiplier), normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index fc9787d5..1cdd2f98 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -395,6 +395,7 @@ def packed_resnet( alpha: int, gamma: int, conv_bias: bool = True, + width_multiplier: float = 1.0, groups: int = 1, dropout_rate: float = 0, style: Literal["imagenet", "cifar"] = "imagenet", @@ -413,6 +414,7 @@ def packed_resnet( num_estimators (int): Number of estimators in the ensemble. alpha (int): Expansion factor affecting the width of the estimators. gamma (int): Number of groups within each estimator. + width_multiplier (float): Width multiplier. Defaults to 1. groups (int): Number of groups within each estimator. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. @@ -436,7 +438,7 @@ def packed_resnet( groups=groups, num_classes=num_classes, style=style, - in_planes=64, + in_planes=int(64 * width_multiplier), normalization_layer=normalization_layer, ) if pretrained: # coverage: ignore diff --git a/torch_uncertainty/models/wrappers/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py index d56a6548..6bd92ac0 100644 --- a/torch_uncertainty/models/wrappers/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -1,7 +1,6 @@ import torch from torch import Tensor, nn - -DROPOUT_MODULES = (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d) +from torch.nn.modules.dropout import _DropoutNd class MCDropout(nn.Module): @@ -34,7 +33,7 @@ def __init__( super().__init__() filtered_modules = list( filter( - lambda m: isinstance(m, DROPOUT_MODULES), + lambda m: isinstance(m, _DropoutNd), model.modules(), ) ) From 607b73e528500b270f486629c4dd9f54d7d7bcca Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 17 Jul 2024 11:52:10 +0200 Subject: [PATCH 11/14] :ok_hand: Finish fixing review & update ruff --- pyproject.toml | 2 +- torch_uncertainty/datasets/nyu.py | 2 ++ .../datasets/regression/uci_regression.py | 5 +++- torch_uncertainty/post_processing/laplace.py | 2 ++ torch_uncertainty/transforms/corruptions.py | 30 +++++++++++++++++++ 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 37f4068f..c7333169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ image = ["scikit-image", "h5py"] tabular = ["pandas"] dev = [ "torch_uncertainty[image]", - "ruff==0.3.4", + "ruff==0.5.2", "pytest-cov", "pre-commit", "pre-commit-hooks", diff --git a/torch_uncertainty/datasets/nyu.py b/torch_uncertainty/datasets/nyu.py index 3c142a69..91987541 100644 --- a/torch_uncertainty/datasets/nyu.py +++ b/torch_uncertainty/datasets/nyu.py @@ -18,6 +18,8 @@ import h5py h5py_installed = True +else: + h5py_installed = False class NYUv2(VisionDataset): diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index d6c2fbee..e10ff6b5 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -6,6 +6,9 @@ import pandas as pd pandas_installed = True +else: + pandas_installed = False + import torch import torch.nn.functional as F @@ -227,7 +230,7 @@ def download(self) -> None: def _make_dataset(self) -> None: """Create dataset from extracted files.""" - if not pandas_installed: + if not pandas_installed: # coverage: ignore raise ImportError( "Please install torch_uncertainty with the tabular option:" """pip install -U "torch_uncertainty[tabular]".""" diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index e0b1203b..9d7e07d6 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -8,6 +8,8 @@ from laplace import Laplace laplace_installed = True +else: + laplace_installed = False class LaplaceApprox(nn.Module): diff --git a/torch_uncertainty/transforms/corruptions.py b/torch_uncertainty/transforms/corruptions.py index 8b3d674c..4b63c44b 100644 --- a/torch_uncertainty/transforms/corruptions.py +++ b/torch_uncertainty/transforms/corruptions.py @@ -5,6 +5,11 @@ if util.find_spec("cv2"): # coverage: ignore import cv2 + + cv2_installed = True +else: + cv2_installed = False + import numpy as np import torch from PIL import Image @@ -12,6 +17,11 @@ if util.find_spec("skimage"): # coverage: ignore from skimage.filters import gaussian from skimage.util import random_noise + + skimage_installed = True +else: + skimage_installed = False + from torch import Tensor, nn from torchvision.transforms import ( InterpolationMode, @@ -80,6 +90,11 @@ def __repr__(self) -> str: class ImpulseNoise(nn.Module): def __init__(self, severity: int) -> None: super().__init__() + if not skimage_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") if not isinstance(severity, int): @@ -128,6 +143,11 @@ def __repr__(self) -> str: class GaussianBlur(nn.Module): def __init__(self, severity: int) -> None: super().__init__() + if not skimage_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") if not isinstance(severity, int): @@ -152,6 +172,11 @@ def __repr__(self) -> str: class GlassBlur(nn.Module): # TODO: batch def __init__(self, severity: int) -> None: super().__init__() + if not skimage_installed or not cv2_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") if not isinstance(severity, int): @@ -203,6 +228,11 @@ def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): class DefocusBlur(nn.Module): def __init__(self, severity: int) -> None: super().__init__() + if not cv2_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") if not isinstance(severity, int): From 0003cb6ee9b2aad2bcb7c220c97337494dd6b4e7 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 17 Jul 2024 11:59:56 +0200 Subject: [PATCH 12/14] :white_check_mark: Add RC tests --- tests/metrics/classification/test_risk_coverage.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/metrics/classification/test_risk_coverage.py b/tests/metrics/classification/test_risk_coverage.py index 3506479f..868443d6 100644 --- a/tests/metrics/classification/test_risk_coverage.py +++ b/tests/metrics/classification/test_risk_coverage.py @@ -52,6 +52,9 @@ def test_plot(self) -> None: assert ax.get_ylabel() == "Risk - Error Rate (%)" plt.close(fig) + metric = AURC() + assert metric(torch.zeros(1), torch.zeros(1)).isnan() + metric = AURC() metric.update(scores, values) fig, ax = metric.plot(plot_value=False) @@ -91,6 +94,9 @@ def test_compute_zero(self) -> None: metric = CovAtxRisk(risk_threshold=0.5) assert metric(probs, targets) == 1 + metric = CovAtxRisk(risk_threshold=0.5) + assert metric(torch.zeros(0), torch.zeros(0)).isnan() + def test_errors(self): with pytest.raises( TypeError, match="Expected threshold to be of type float" From d48384d85ed8ba3c267a69a813671335de404245 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 17 Jul 2024 12:02:14 +0200 Subject: [PATCH 13/14] :shirt: Fix cov. ignores --- torch_uncertainty/datasets/nyu.py | 2 +- torch_uncertainty/datasets/regression/uci_regression.py | 2 +- torch_uncertainty/post_processing/laplace.py | 2 +- torch_uncertainty/transforms/corruptions.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_uncertainty/datasets/nyu.py b/torch_uncertainty/datasets/nyu.py index 91987541..c67a944e 100644 --- a/torch_uncertainty/datasets/nyu.py +++ b/torch_uncertainty/datasets/nyu.py @@ -18,7 +18,7 @@ import h5py h5py_installed = True -else: +else: # coverage: ignore h5py_installed = False diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index e10ff6b5..0f4be30c 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -6,7 +6,7 @@ import pandas as pd pandas_installed = True -else: +else: # coverage: ignore pandas_installed = False diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 9d7e07d6..fc1d2894 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -8,7 +8,7 @@ from laplace import Laplace laplace_installed = True -else: +else: # coverage: ignore laplace_installed = False diff --git a/torch_uncertainty/transforms/corruptions.py b/torch_uncertainty/transforms/corruptions.py index 4b63c44b..72bf6700 100644 --- a/torch_uncertainty/transforms/corruptions.py +++ b/torch_uncertainty/transforms/corruptions.py @@ -3,23 +3,23 @@ from importlib import util from io import BytesIO -if util.find_spec("cv2"): # coverage: ignore +if util.find_spec("cv2"): import cv2 cv2_installed = True -else: +else: # coverage: ignore cv2_installed = False import numpy as np import torch from PIL import Image -if util.find_spec("skimage"): # coverage: ignore +if util.find_spec("skimage"): from skimage.filters import gaussian from skimage.util import random_noise skimage_installed = True -else: +else: # coverage: ignore skimage_installed = False from torch import Tensor, nn From 3a709e6ea9c5efd507fd0c63563d5b455da9610f Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 17 Jul 2024 12:14:37 +0200 Subject: [PATCH 14/14] :zap: Bump version --- docs/source/conf.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 107ce618..a9a685d5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.2.1" +release = "0.2.1.post0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index c7333169..94e8415e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.2.1" +version = "0.2.1.post0" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" },