diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index f28dcb77..084adb66 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -40,7 +40,7 @@ jobs: - name: Install dependencies run: | - python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[image,dev,docs] - name: Sphinx build diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 86564ecc..6c6ccc62 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -64,7 +64,7 @@ jobs: - name: Install dependencies if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[all] - name: Check style & format diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index d50c7bf7..939e83c1 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -20,7 +20,7 @@ Training a Bayesian LeNet using TorchUncertainty models and Lightning --------------------------------------------------------------------- -In this part, we train a bayesian LeNet, based on the model and routines already implemented in TU. +In this part, we train a Bayesian LeNet, based on the model and routines already implemented in TU. 1. Loading the utilities ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -30,13 +30,13 @@ - the Trainer from Lightning - the model: bayesian_lenet, which lies in the torch_uncertainty.model - the classification training routine from torch_uncertainty.routines -- the bayesian objective: the ELBOLoss, which lies in the torch_uncertainty.losses file +- the Bayesian objective: the ELBOLoss, which lies in the torch_uncertainty.losses file - the datamodule that handles dataloaders: MNISTDataModule from torch_uncertainty.datamodules -We will also need to define an optimizer using torch.optim, the -neural network utils from torch.nn, as well as the partial util to provide -the modified default arguments for the ELBO loss. +We will also need to define an optimizer using torch.optim and Pytorch's +neural network utils from torch.nn. """ +# %% from pathlib import Path from lightning.pytorch import Trainer @@ -94,7 +94,7 @@ def optim_lenet(model: nn.Module): loss = ELBOLoss( model=model, inner_loss=nn.CrossEntropyLoss(), - kl_weight=1 / 50000, + kl_weight=1 / 10000, num_samples=3, ) diff --git a/auto_tutorials_source/tutorial_corruption.py b/auto_tutorials_source/tutorial_corruption.py index 01b6a2d9..4eecd6ff 100644 --- a/auto_tutorials_source/tutorial_corruption.py +++ b/auto_tutorials_source/tutorial_corruption.py @@ -2,15 +2,16 @@ Corrupting Images with TorchUncertainty to Benchmark Robustness =============================================================== -This tutorial shows the impact of the different corruptions available in the -TorchUncertainty library. These corruptions were first proposed in the paper +This tutorial shows the impact of the different corruption transforms available in the +TorchUncertainty library. These corruption transforms were first proposed in the paper Benchmarking Neural Network Robustness to Common Corruptions and Perturbations by Dan Hendrycks and Thomas Dietterich. For this tutorial, we will only load the corruption transforms available in -torch_uncertainty.transforms.corruptions. We also need to load utilities from +torch_uncertainty.transforms.corruption. We also need to load utilities from torchvision and matplotlib. """ +# %% from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, ToTensor, Resize @@ -60,7 +61,7 @@ def show_images(transforms): # %% # 1. Noise Corruptions # ~~~~~~~~~~~~~~~~~~~~ -from torch_uncertainty.transforms.corruptions import ( +from torch_uncertainty.transforms.corruption import ( GaussianNoise, ShotNoise, ImpulseNoise, @@ -79,7 +80,7 @@ def show_images(transforms): # %% # 2. Blur Corruptions # ~~~~~~~~~~~~~~~~~~~~ -from torch_uncertainty.transforms.corruptions import ( +from torch_uncertainty.transforms.corruption import ( GaussianBlur, GlassBlur, DefocusBlur, @@ -96,7 +97,7 @@ def show_images(transforms): # %% # 3. Other Corruptions # ~~~~~~~~~~~~~~~~~~~~ -from torch_uncertainty.transforms.corruptions import ( +from torch_uncertainty.transforms.corruption import ( JPEGCompression, Pixelate, Frost, diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 96d72375..a30b49d5 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -29,6 +29,7 @@ We also need to define an optimizer using torch.optim and the neural network utils within torch.nn. """ +# %% import torch from lightning.pytorch import Trainer from lightning import LightningDataModule diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index dccda568..cd124f5d 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -24,6 +24,7 @@ We also need to define an optimizer using torch.optim, the neural network utils within torch.nn. """ +# %% from pathlib import Path import torch diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py index 24933566..55de3735 100644 --- a/auto_tutorials_source/tutorial_from_de_to_pe.py +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -30,6 +30,7 @@ The dataset is automatically downloaded using torchvision. We then visualize a few images to see a bit what we are working with. """ # Create the transforms for the images +# %% import torch import torchvision.transforms as T @@ -241,7 +242,7 @@ def optim_recipe(model, lr_mult: float = 1.0): # We have put the pre-trained models on Hugging Face that you can download with the utility function # "hf_hub_download" imported just below. These models are trained for 75 epochs and are therefore not # comparable to the all the other models trained in this notebook. The pretrained models can be seen -# `here `_ and TorchUncertainty's are `here `_. +# on `HuggingFace `_ and TorchUncertainty's are `here `_. from torch_uncertainty.utils.hub import hf_hub_download @@ -297,7 +298,7 @@ def optim_recipe(model, lr_mult: float = 1.0): # This modification is particularly useful when the ensemble size is large, as it is often the case in practice. # # We will need to update the model and replace the layers with their Packed equivalents. You can find the -# documentation of the Packed-Linear layer `here `_, +# documentation of the Packed-Linear layer using this `link `_, # and the Packed-Conv2D, `here `_. import torch diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index a8bed883..886ed9cf 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -22,6 +22,7 @@ We also need import the neural network utils within `torch.nn`. """ +# %% from pathlib import Path from lightning import Trainer @@ -98,7 +99,7 @@ # 6. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # Now that the model is trained, let's test it on MNIST. Don't forget to call -# .eval() to enable Monte Carlo batch normalization at inference. +# .eval() to enable Monte Carlo batch normalization at evaluation (sometimes called inference). # In this tutorial, we plot the most uncertain images, i.e. the images for which # the variance of the predictions is the highest. # Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 4bd8373e..b8f01fb0 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -28,7 +28,7 @@ We also need import the neural network utils within `torch.nn`. """ - +# %% from pathlib import Path from torch_uncertainty.utils import TUTrainer diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index fdbfc469..ceaaa036 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -25,6 +25,7 @@ If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the `log_plots` flag. """ +# %% from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.metrics import CalibrationError from torch_uncertainty.models.resnet import resnet diff --git a/docs/source/api.rst b/docs/source/api.rst index ed5e07ce..20d1a0ea 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -156,8 +156,6 @@ Models Wrappers ^^^^^^^^ - - Functions """"""""" @@ -188,30 +186,82 @@ Metrics Classification ^^^^^^^^^^^^^^ - .. currentmodule:: torch_uncertainty.metrics.classification +Proper Scores +""""""""""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + BrierScore + CategoricalNLL + +Out-of-Distribution Detection +""""""""""""""""""""""""""""" + .. autosummary:: :toctree: generated/ :nosignatures: :template: class.rst AURC - AUSE + FPRx FPR95 + + +Selective Classification +"""""""""""""""""""""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + AUGRC + RiskAtxCov + RiskAt80Cov + CovAtxRisk + CovAt5Risk + +Calibration +""""""""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + AdaptiveCalibrationError - BrierScore CalibrationError - CategoricalNLL - CovAt5Risk + +Diversity +""""""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + Disagreement Entropy - GroupingLoss - MeanIntersectionOverUnion MutualInformation - RiskAt80Cov VariationRatio + +Others +"""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + AUSE + GroupingLoss + Regression ^^^^^^^^^^ @@ -232,6 +282,18 @@ Regression SILog ThresholdAccuracy +Segmentation +^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.metrics.classification + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + MeanIntersectionOverUnion + Losses ------ diff --git a/docs/source/cli_guide.rst b/docs/source/cli_guide.rst index 24129fe6..0b888ea9 100644 --- a/docs/source/cli_guide.rst +++ b/docs/source/cli_guide.rst @@ -89,7 +89,7 @@ This command will display the available subcommands of the CLI tool. fit Runs the full optimization routine. validate Perform one evaluation epoch over the validation set. test Perform one evaluation epoch over the test set. - predict Run inference on your data. + predict Run evaluation on your data. You can execute whichever subcommand you like and set up all your hyperparameters directly using the command line diff --git a/docs/source/conf.py b/docs/source/conf.py index a9a685d5..418b398a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.2.1.post0" +release = "0.2.2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/references.rst b/docs/source/references.rst index 89829c16..490e080a 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -243,6 +243,19 @@ For Laplace Approximation, consider citing: * Authors: *Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig* * Paper: `NeurIPS 2021 `__. +Losses +------ + +Conflictual Loss +^^^^^^^^^^^^^^^^ + +For the conflictual loss, consider citing: + +**On the Calibration of Epistemic Uncertainty: Principles, Paradoxes and Conflictual Loss** + +* Authors: *Mohammed Fellaji, Frédéric Pennerath, Brieuc Conan-Guez, and Miguel Couceiro* +* Paper: `ArXiv 2024 `__. + +Area Under the Generalized Risk-Coverage curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For the area under the generalized risk-coverage curve, consider citing: + +**Overcoming Common Flaws in the Evaluation of Selective Classification Systems** + +* Authors: *Jeremias Traub, Till J. Bungert, Carsten T. Lüth, Michael Baumgartner, Klaus H. Maier-Hein, Lena Maier-Hein, and Paul F Jaeger* +* Paper: `ArXiv `__. + + Grouping Loss ^^^^^^^^^^^^^ diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index fa3eb77d..2274bdb5 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -57,7 +57,7 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - class_path: torch_uncertainty.optim_recipes.FullSWALR + class_path: torch_uncertainty.optim_recipes.CosineSWALR init_args: milestone: 20 swa_lr: 0.01 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index 292b49f0..ddff0067 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -57,7 +57,7 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - class_path: torch_uncertainty.optim_recipes.FullSWALR + class_path: torch_uncertainty.optim_recipes.CosineSWALR init_args: milestone: 10 swa_lr: 0.01 diff --git a/pyproject.toml b/pyproject.toml index 94e8415e..60ee19af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.2.1.post0" +version = "0.2.2" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, @@ -32,7 +32,7 @@ classifiers = [ ] dependencies = [ "timm", - "lightning[pytorch-extra]", + "lightning[pytorch-extra]>=2.0", "torchvision>=0.16", "tensorboard", "einops", diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index df12f214..64944684 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -13,7 +13,7 @@ def test_cifar10_main(self): dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR10 - assert isinstance(dm.train_transform.transforms[2], Cutout) + assert isinstance(dm.train_transform.transforms[1], Cutout) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset diff --git a/tests/datamodules/classification/test_cifar100.py b/tests/datamodules/classification/test_cifar100.py index e24af243..f2f00aa2 100644 --- a/tests/datamodules/classification/test_cifar100.py +++ b/tests/datamodules/classification/test_cifar100.py @@ -13,7 +13,7 @@ def test_cifar100(self): dm = CIFAR100DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR100 - assert isinstance(dm.train_transform.transforms[2], Cutout) + assert isinstance(dm.train_transform.transforms[1], Cutout) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index f52c9abf..ba30fad3 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -20,7 +20,7 @@ def test_mnist_cutout(self): ) assert dm.dataset == MNIST - assert isinstance(dm.train_transform.transforms[0], Cutout) + assert isinstance(dm.train_transform.transforms[1], Cutout) dm = MNISTDataModule( root="./data/", @@ -29,7 +29,7 @@ def test_mnist_cutout(self): cutout=0, val_split=0, ) - assert isinstance(dm.train_transform.transforms[0], nn.Identity) + assert isinstance(dm.train_transform.transforms[1], nn.Identity) with pytest.raises(ValueError): MNISTDataModule(root="./data/", batch_size=128, ood_ds="other") diff --git a/tests/layers/test_filter_response_norm.py b/tests/layers/test_norm.py similarity index 83% rename from tests/layers/test_filter_response_norm.py rename to tests/layers/test_norm.py index e1f58eb1..89fe23eb 100644 --- a/tests/layers/test_filter_response_norm.py +++ b/tests/layers/test_norm.py @@ -1,6 +1,7 @@ import pytest import torch +from torch_uncertainty.layers.channel_layer_norm import ChannelLayerNorm from torch_uncertainty.layers.filter_response_norm import ( FilterResponseNorm1d, FilterResponseNorm2d, @@ -62,3 +63,14 @@ def test_errors(self): layer2d(torch.randn(1, 1, 1, 1, 20)) with pytest.raises(ValueError): layer3d(torch.randn(1, 1, 1, 1, 1, 20)) + + +class TestChannelLayerNorm: + """Testing the FRN2d layer.""" + + def test_main(self): + """Test initialization.""" + cln = ChannelLayerNorm(1) + cln(torch.randn(1, 1, 4, 4)) + cln = ChannelLayerNorm(18) + cln(torch.randn(1, 18, 2, 3)) diff --git a/tests/losses/__init__.py b/tests/losses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/losses/test_bayesian.py b/tests/losses/test_bayesian.py new file mode 100644 index 00000000..9e43c1f3 --- /dev/null +++ b/tests/losses/test_bayesian.py @@ -0,0 +1,70 @@ +import pytest +import torch +from torch import nn, optim + +from torch_uncertainty.layers.bayesian import BayesLinear +from torch_uncertainty.losses import ELBOLoss +from torch_uncertainty.routines import RegressionRoutine + + +class TestELBOLoss: + """Testing the ELBOLoss class.""" + + def test_main(self): + model = BayesLinear(1, 1) + criterion = nn.BCEWithLogitsLoss() + loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) + loss(model(torch.randn(1, 1)), torch.randn(1, 1)) + + model = nn.Linear(1, 1) + criterion = nn.BCEWithLogitsLoss() + + ELBOLoss(None, criterion, kl_weight=1e-5, num_samples=1) + loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) + loss(model(torch.randn(1, 1)), torch.randn(1, 1)) + + def test_training_step(self): + model = BayesLinear(10, 4) + criterion = nn.MSELoss() + loss = ELBOLoss(model, criterion, kl_weight=1 / 50000, num_samples=3) + + routine = RegressionRoutine( + probabilistic=False, + output_dim=4, + model=model, + loss=loss, + optim_recipe=optim.Adam( + model.parameters(), + lr=5e-4, + weight_decay=0, + ), + ) + + inputs = torch.randn(1, 10) + targets = torch.randn(1, 4) + routine.training_step((inputs, targets), 0) + + def test_failures(self): + model = BayesLinear(1, 1) + criterion = nn.BCEWithLogitsLoss() + + with pytest.raises( + TypeError, match="The inner_loss should be an instance of a class." + ): + ELBOLoss(model, nn.BCEWithLogitsLoss, kl_weight=1, num_samples=1) + + with pytest.raises( + ValueError, match="The KL weight should be non-negative. Got " + ): + ELBOLoss(model, criterion, kl_weight=-1, num_samples=1) + + with pytest.raises( + ValueError, + match="The number of samples should not be lower than 1.", + ): + ELBOLoss(model, criterion, kl_weight=1, num_samples=-1) + + with pytest.raises( + TypeError, match="The number of samples should be an integer. " + ): + ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1.5) diff --git a/tests/losses/test_classification.py b/tests/losses/test_classification.py new file mode 100644 index 00000000..f5bb2400 --- /dev/null +++ b/tests/losses/test_classification.py @@ -0,0 +1,108 @@ +import pytest +import torch + +from torch_uncertainty.losses import ( + ConfidencePenaltyLoss, + ConflictualLoss, + DECLoss, +) + + +class TestDECLoss: + """Testing the DECLoss class.""" + + def test_main(self): + loss = DECLoss( + loss_type="mse", reg_weight=1e-2, annealing_step=1, reduction="sum" + ) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0]), current_epoch=1) + loss = DECLoss(loss_type="mse", reg_weight=1e-2, annealing_step=1) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0]), current_epoch=0) + loss = DECLoss(loss_type="log", reg_weight=1e-2, reduction="none") + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + loss = DECLoss(loss_type="digamma") + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + + def test_failures(self): + with pytest.raises( + ValueError, + match="The regularization weight should be non-negative, but got", + ): + DECLoss(reg_weight=-1) + + with pytest.raises( + ValueError, match="The annealing step should be positive, but got " + ): + DECLoss(annealing_step=0) + + loss = DECLoss(annealing_step=10) + with pytest.raises(ValueError): + loss( + torch.tensor([[0.0, 0.0]]), + torch.tensor([0]), + current_epoch=None, + ) + + with pytest.raises( + ValueError, match=" is not a valid value for reduction." + ): + DECLoss(reduction="median") + + with pytest.raises( + ValueError, match="is not a valid value for mse/log/digamma loss." + ): + DECLoss(loss_type="regression") + + +class TestConfidencePenaltyLoss: + """Testing the ConfidencePenaltyLoss class.""" + + def test_main(self): + loss = ConfidencePenaltyLoss(reg_weight=1e-2, reduction="sum") + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + loss = ConfidencePenaltyLoss(reg_weight=1e-2) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + loss = ConfidencePenaltyLoss(reg_weight=1e-2, reduction=None) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + + def test_failures(self): + with pytest.raises( + ValueError, + match="The regularization weight should be non-negative, but got", + ): + ConfidencePenaltyLoss(reg_weight=-1) + + with pytest.raises( + ValueError, match="is not a valid value for reduction." + ): + ConfidencePenaltyLoss(reduction="median") + + with pytest.raises( + ValueError, + match="The epsilon value should be non-negative, but got", + ): + ConfidencePenaltyLoss(eps=-1) + + +class TestConflictualLoss: + """Testing the ConflictualLoss class.""" + + def test_main(self): + loss = ConflictualLoss(reg_weight=1e-2, reduction="sum") + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + loss = ConflictualLoss(reg_weight=1e-2) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + loss = ConflictualLoss(reg_weight=1e-2, reduction=None) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + + def test_failures(self): + with pytest.raises( + ValueError, + match="The regularization weight should be non-negative, but got", + ): + ConflictualLoss(reg_weight=-1) + + with pytest.raises( + ValueError, match="is not a valid value for reduction." + ): + ConflictualLoss(reduction="median") diff --git a/tests/losses/test_regression.py b/tests/losses/test_regression.py new file mode 100644 index 00000000..41f413a1 --- /dev/null +++ b/tests/losses/test_regression.py @@ -0,0 +1,119 @@ +import math + +import pytest +import torch +from torch.distributions import Normal + +from torch_uncertainty.layers.distributions import NormalInverseGamma +from torch_uncertainty.losses import ( + BetaNLL, + DERLoss, + DistributionNLLLoss, +) + + +class TestDistributionNLL: + """Testing the DistributionNLLLoss class.""" + + def test_sum(self): + loss = DistributionNLLLoss(reduction="sum") + dist = Normal(0, 1) + loss(dist, torch.tensor([0.0])) + + +class TestDERLoss: + """Testing the DERLoss class.""" + + def test_main(self): + loss = DERLoss(reg_weight=1e-2) + layer = NormalInverseGamma + inputs = layer( + torch.ones(1), torch.ones(1), torch.ones(1), torch.ones(1) + ) + targets = torch.tensor([[1.0]], dtype=torch.float32) + + assert loss(inputs, targets) == pytest.approx(2 * math.log(2)) + + loss = DERLoss( + reg_weight=1e-2, + reduction="sum", + ) + inputs = layer( + torch.ones((2, 1)), + torch.ones((2, 1)), + torch.ones((2, 1)), + torch.ones((2, 1)), + ) + + assert loss( + inputs, + targets, + ) == pytest.approx(4 * math.log(2)) + + loss = DERLoss( + reg_weight=1e-2, + reduction="none", + ) + + assert loss( + inputs, + targets, + ) == pytest.approx([2 * math.log(2), 2 * math.log(2)]) + + def test_failures(self): + with pytest.raises( + ValueError, + match="The regularization weight should be non-negative, but got ", + ): + DERLoss(reg_weight=-1) + + with pytest.raises( + ValueError, match="is not a valid value for reduction." + ): + DERLoss(reg_weight=1.0, reduction="median") + + +class TestBetaNLL: + """Testing the BetaNLL class.""" + + def test_main(self): + loss = BetaNLL(beta=0.5) + + inputs = torch.tensor([[1.0, 1.0]], dtype=torch.float32) + targets = torch.tensor([[1.0]], dtype=torch.float32) + + assert loss(*inputs.split(1, dim=-1), targets) == 0 + + loss = BetaNLL( + beta=0.5, + reduction="sum", + ) + + assert ( + loss( + *inputs.repeat(2, 1).split(1, dim=-1), + targets.repeat(2, 1), + ) + == 0 + ) + + loss = BetaNLL( + beta=0.5, + reduction="none", + ) + + assert loss( + *inputs.repeat(2, 1).split(1, dim=-1), + targets.repeat(2, 1), + ) == pytest.approx([0.0, 0.0]) + + def test_failures(self): + with pytest.raises( + ValueError, match="The beta parameter should be in range " + ): + BetaNLL(beta=-1) + + with pytest.raises( + ValueError, match="is not a valid value for reduction." + ): + BetaNLL(beta=1.0, reduction="median") diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index 3ad5e3f3..cf94da78 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -58,6 +58,7 @@ def test_main(self) -> None: ace = AdaptiveCalibrationError( task="binary", num_bins=2, norm="l1", validate_args=True ) + ace = AdaptiveCalibrationError( task="binary", num_bins=2, norm="l1", validate_args=False ) @@ -112,7 +113,20 @@ def test_main(self) -> None: ), torch.as_tensor([0, 0, 0, 1, 1]), ) - assert ace.compute().item() ** 2 == pytest.approx((0.8 - 0.5) ** 2) + assert ace.compute().item() == pytest.approx(0.8 - 0.5) + + ace = AdaptiveCalibrationError(task="binary", num_bins=3, norm="l2") + ece = CalibrationError(task="binary", num_bins=3, norm="l2") + + ace.update( + torch.as_tensor([0.12, 0.26, 0.70, 0.71, 0.91, 0.92]), + torch.as_tensor([0, 1, 0, 0, 1, 1]), + ) + ece.update( + torch.as_tensor([0.12, 0.26, 0.70, 0.71, 0.91, 0.92]), + torch.as_tensor([0, 1, 0, 0, 1, 1]), + ) + assert ace.compute().item() > ece.compute().item() def test_errors(self) -> None: with pytest.raises(TypeError, match="is expected to be `int`"): diff --git a/tests/metrics/classification/test_fpr95.py b/tests/metrics/classification/test_fpr95.py index 99bb0dc3..3c10fd01 100644 --- a/tests/metrics/classification/test_fpr95.py +++ b/tests/metrics/classification/test_fpr95.py @@ -32,6 +32,14 @@ def test_compute_one(self): res = metric.compute() assert res == 1 + def test_compute_nan(self): + metric = FPR95(pos_label=1) + metric.update( + torch.as_tensor([0.1] * 50 + [0.4] * 50), torch.as_tensor([0] * 100) + ) + res = metric.compute() + assert torch.isnan(res).all() + def test_error(self): with pytest.raises(ValueError): FPRx(recall_level=1.2, pos_label=1) diff --git a/tests/metrics/classification/test_risk_coverage.py b/tests/metrics/classification/test_risk_coverage.py index 868443d6..63e82f43 100644 --- a/tests/metrics/classification/test_risk_coverage.py +++ b/tests/metrics/classification/test_risk_coverage.py @@ -40,6 +40,12 @@ def test_compute_multiclass(self) -> None: value = (0 * 0.4 + 0.25 * 0.2 / 2 + 0.25 * 0.2 + 0.15 * 0.2 / 2) / 0.8 assert metric(probs, targets).item() == pytest.approx(value) + def test_compute_nan(self) -> None: + probs = torch.as_tensor([[0.1, 0.9]]) + targets = torch.as_tensor([1]).long() + metric = AURC() + assert torch.isnan(metric(probs, targets)).all() + def test_plot(self) -> None: scores = torch.as_tensor([0.2, 0.1, 0.5, 0.3, 0.4]) values = torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.5]) diff --git a/tests/models/test_resnets.py b/tests/models/test_resnets.py index 44c2cd3c..31677f51 100644 --- a/tests/models/test_resnets.py +++ b/tests/models/test_resnets.py @@ -9,6 +9,7 @@ packed_resnet, resnet, ) +from torch_uncertainty.models.resnet.utils import get_resnet_num_blocks class TestResnet: @@ -21,6 +22,11 @@ def test_main(self): model(torch.randn(1, 1, 32, 32)) model.feats_forward(torch.randn(1, 1, 32, 32)) + get_resnet_num_blocks(44) + get_resnet_num_blocks(56) + get_resnet_num_blocks(110) + get_resnet_num_blocks(1202) + def test_mc_dropout(self): resnet(1, 10, arch=20, conv_bias=False, style="cifar") model = resnet(1, 10, arch=50).eval() diff --git a/tests/test_losses.py b/tests/test_losses.py deleted file mode 100644 index f368e6cc..00000000 --- a/tests/test_losses.py +++ /dev/null @@ -1,184 +0,0 @@ -import math - -import pytest -import torch -from torch import nn -from torch.distributions import Normal - -from torch_uncertainty.layers.bayesian import BayesLinear -from torch_uncertainty.layers.distributions import NormalInverseGamma -from torch_uncertainty.losses import ( - BetaNLL, - DECLoss, - DERLoss, - DistributionNLLLoss, - ELBOLoss, -) - - -class TestDistributionNLL: - """Testing the DistributionNLLLoss class.""" - - def test_sum(self): - loss = DistributionNLLLoss(reduction="sum") - dist = Normal(0, 1) - loss(dist, torch.tensor([0.0])) - - -class TestELBOLoss: - """Testing the ELBOLoss class.""" - - def test_main(self): - model = BayesLinear(1, 1) - criterion = nn.BCEWithLogitsLoss() - loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) - loss(model(torch.randn(1, 1)), torch.randn(1, 1)) - - model = nn.Linear(1, 1) - criterion = nn.BCEWithLogitsLoss() - - ELBOLoss(None, criterion, kl_weight=1e-5, num_samples=1) - loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) - loss(model(torch.randn(1, 1)), torch.randn(1, 1)) - - def test_failures(self): - model = BayesLinear(1, 1) - criterion = nn.BCEWithLogitsLoss() - - with pytest.raises(TypeError): - ELBOLoss(model, nn.BCEWithLogitsLoss, kl_weight=1, num_samples=1) - - with pytest.raises(ValueError): - ELBOLoss(model, criterion, kl_weight=-1, num_samples=1) - - with pytest.raises(ValueError): - ELBOLoss(model, criterion, kl_weight=1, num_samples=-1) - - with pytest.raises(TypeError): - ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1.5) - - -class TestDERLoss: - """Testing the DERLoss class.""" - - def test_main(self): - loss = DERLoss(reg_weight=1e-2) - layer = NormalInverseGamma - inputs = layer( - torch.ones(1), torch.ones(1), torch.ones(1), torch.ones(1) - ) - targets = torch.tensor([[1.0]], dtype=torch.float32) - - assert loss(inputs, targets) == pytest.approx(2 * math.log(2)) - - loss = DERLoss( - reg_weight=1e-2, - reduction="sum", - ) - inputs = layer( - torch.ones((2, 1)), - torch.ones((2, 1)), - torch.ones((2, 1)), - torch.ones((2, 1)), - ) - - assert loss( - inputs, - targets, - ) == pytest.approx(4 * math.log(2)) - - loss = DERLoss( - reg_weight=1e-2, - reduction="none", - ) - - assert loss( - inputs, - targets, - ) == pytest.approx([2 * math.log(2), 2 * math.log(2)]) - - def test_failures(self): - with pytest.raises(ValueError): - DERLoss(reg_weight=-1) - - with pytest.raises(ValueError): - DERLoss(reg_weight=1.0, reduction="median") - - -class TestDECLoss: - """Testing the DECLoss class.""" - - def test_main(self): - loss = DECLoss( - loss_type="mse", reg_weight=1e-2, annealing_step=1, reduction="sum" - ) - loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0]), current_epoch=1) - loss = DECLoss(loss_type="mse", reg_weight=1e-2, annealing_step=1) - loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0]), current_epoch=0) - loss = DECLoss(loss_type="log", reg_weight=1e-2, reduction="none") - loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) - loss = DECLoss(loss_type="digamma") - loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) - - def test_failures(self): - with pytest.raises(ValueError): - DECLoss(reg_weight=-1) - - with pytest.raises(ValueError): - DECLoss(annealing_step=0) - - loss = DECLoss(annealing_step=10) - with pytest.raises(ValueError): - loss( - torch.tensor([[0.0, 0.0]]), - torch.tensor([0]), - current_epoch=None, - ) - - with pytest.raises(ValueError): - DECLoss(reduction="median") - - with pytest.raises(ValueError): - DECLoss(loss_type="regression") - - -class TestBetaNLL: - """Testing the BetaNLL class.""" - - def test_main(self): - loss = BetaNLL(beta=0.5) - - inputs = torch.tensor([[1.0, 1.0]], dtype=torch.float32) - targets = torch.tensor([[1.0]], dtype=torch.float32) - - assert loss(*inputs.split(1, dim=-1), targets) == 0 - - loss = BetaNLL( - beta=0.5, - reduction="sum", - ) - - assert ( - loss( - *inputs.repeat(2, 1).split(1, dim=-1), - targets.repeat(2, 1), - ) - == 0 - ) - - loss = BetaNLL( - beta=0.5, - reduction="none", - ) - - assert loss( - *inputs.repeat(2, 1).split(1, dim=-1), - targets.repeat(2, 1), - ) == pytest.approx([0.0, 0.0]) - - def test_failures(self): - with pytest.raises(ValueError): - BetaNLL(beta=-1) - - with pytest.raises(ValueError): - BetaNLL(beta=1.0, reduction="median") diff --git a/tests/test_optim_recipes.py b/tests/test_optim_recipes.py index b6d15863..ab4b455d 100644 --- a/tests/test_optim_recipes.py +++ b/tests/test_optim_recipes.py @@ -2,12 +2,28 @@ import pytest import torch -from torch_uncertainty.optim_recipes import FullSWALR, get_procedure, optim_abnn +from torch_uncertainty.optim_recipes import ( + CosineAnnealingWarmup, + CosineSWALR, + get_procedure, + optim_abnn, +) -class TestFullSWALR: +class TestCosineAnnealingWarmup: + def test_full_cosine_annealing_warmup(self): + CosineAnnealingWarmup( + torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), lr=1e-3), + warmup_start_factor=0.1, + warmup_epochs=5, + max_epochs=100, + eta_min=1e-5, + ) + + +class TestCosineSWALR: def test_full_swa_lr(self): - FullSWALR( + CosineSWALR( torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), lr=1e-3), swa_lr=1, milestone=12, diff --git a/tests/transforms/test_corruptions.py b/tests/transforms/test_corruption.py similarity index 86% rename from tests/transforms/test_corruptions.py rename to tests/transforms/test_corruption.py index 46b07ce3..4d979f89 100644 --- a/tests/transforms/test_corruptions.py +++ b/tests/transforms/test_corruption.py @@ -1,7 +1,8 @@ import pytest import torch +from requests.exceptions import HTTPError -from torch_uncertainty.transforms.corruptions import ( +from torch_uncertainty.transforms.corruption import ( DefocusBlur, Frost, GaussianBlur, @@ -127,13 +128,19 @@ def test_pixelate(self): print(transform) def test_frost(self): - with pytest.raises(ValueError): - _ = Frost(-1) - with pytest.raises(TypeError): - _ = Frost(0.1) - inputs = torch.rand(3, 32, 32) - transform = Frost(1) - transform(inputs) - transform = Frost(0) - transform(inputs) - print(transform) + try: + Frost(1) + frost_ok = True + except HTTPError: + frost_ok = False + if frost_ok: + with pytest.raises(ValueError): + _ = Frost(-1) + with pytest.raises(TypeError): + _ = Frost(0.1) + inputs = torch.rand(3, 32, 32) + transform = Frost(1) + transform(inputs) + transform = Frost(0) + transform(inputs) + print(transform) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 47391ff6..00ea94ce 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -34,7 +34,7 @@ class ResNetBaseline(ClassificationRoutine): "mimo": mimo_resnet, "mc-dropout": resnet, } - archs = [18, 20, 34, 50, 101, 152] + archs = [18, 20, 34, 44, 50, 56, 101, 110, 152, 1202] def __init__( self, @@ -52,6 +52,7 @@ def __init__( ], arch: int, style: str = "imagenet", + normalization_layer: type[nn.Module] = nn.BatchNorm2d, num_estimators: int = 1, dropout_rate: float = 0.0, mixup_params: dict | None = None, @@ -106,6 +107,8 @@ def __init__( style (str, optional): Which ResNet style to use. Defaults to ``imagenet``. + normalization_layer (type[nn.Module], optional): Normalization layer + to use. Defaults to ``nn.BatchNorm2d``. num_estimators (int, optional): Number of estimators in the ensemble. Only used if :attr:`version` is either ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"`` Defaults to ``None``. @@ -175,6 +178,7 @@ def __init__( "in_channels": in_channels, "num_classes": num_classes, "style": style, + "normalization_layer": normalization_layer, } format_batch_fn = nn.Identity() diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 1e5eda4a..a87d61c0 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -29,6 +29,7 @@ def __init__( eval_ood: bool = False, val_split: float | None = None, num_workers: int = 1, + basic_augment: bool = True, cutout: int | None = None, auto_augment: str | None = None, test_alt: Literal["c", "h"] | None = None, @@ -47,6 +48,8 @@ def __init__( to ``0.0``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. + basic_augment (bool): Whether to apply base augmentations. Defaults to + ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. randaugment (bool): Whether to apply RandAugment. Defaults to ``False``. @@ -89,6 +92,16 @@ def __init__( "GitHub issue if needed." ) + if basic_augment: + basic_transform = T.Compose( + [ + T.RandomCrop(32, padding=4), + T.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + if cutout: main_transform = Cutout(cutout) elif auto_augment: @@ -98,8 +111,7 @@ def __init__( self.train_transform = T.Compose( [ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), + basic_transform, main_transform, T.ToTensor(), T.Normalize( diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 373430bd..fa759853 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -29,6 +29,7 @@ def __init__( batch_size: int, eval_ood: bool = False, val_split: float | None = None, + basic_augment: bool = True, cutout: int | None = None, randaugment: bool = False, auto_augment: str | None = None, @@ -48,6 +49,8 @@ def __init__( batch_size (int): Number of samples per batch. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + basic_augment (bool): Whether to apply base augmentations. Defaults to + ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. randaugment (bool): Whether to apply RandAugment. Defaults to ``False``. @@ -93,6 +96,16 @@ def __init__( "GitHub issue if needed." ) + if basic_augment: + basic_transform = T.Compose( + [ + T.RandomCrop(32, padding=4), + T.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + if cutout: main_transform = Cutout(cutout) elif randaugment: @@ -104,8 +117,7 @@ def __init__( self.train_transform = T.Compose( [ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), + basic_transform, main_transform, T.ToTensor(), T.ConvertImageDtype(torch.float32), diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 6d35303c..1e19ed4a 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -49,6 +49,7 @@ def __init__( procedure: str | None = None, train_size: int = 224, interpolation: str = "bilinear", + basic_augment: bool = True, rand_augment_opt: str | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -71,6 +72,8 @@ def __init__( train_size (int): Size of training images. Defaults to ``224``. interpolation (str): Interpolation method for the Resize Crops. Defaults to ``"bilinear"``. + basic_augment (bool): Whether to apply base augmentations. Defaults to + ``True``. rand_augment_opt (str): Which RandAugment to use. Defaults to ``None``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. @@ -123,6 +126,18 @@ def __init__( self.procedure = procedure + if basic_augment: + basic_transform = T.Compose( + [ + T.RandomResizedCrop( + train_size, interpolation=self.interpolation + ), + T.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + if self.procedure is None: if rand_augment_opt is not None: main_transform = rand_augment_transform(rand_augment_opt, {}) @@ -144,10 +159,7 @@ def __init__( self.train_transform = T.Compose( [ - T.RandomResizedCrop( - train_size, interpolation=self.interpolation - ), - T.RandomHorizontalFlip(), + basic_transform, main_transform, T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index b411f502..9be45ea6 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -27,6 +27,7 @@ def __init__( ood_ds: Literal["fashion", "notMNIST"] = "fashion", val_split: float | None = None, num_workers: int = 1, + basic_augment: bool = True, cutout: int | None = None, test_alt: Literal["c"] | None = None, pin_memory: bool = True, @@ -45,6 +46,8 @@ def __init__( to ``0.0``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. + basic_augment (bool): Whether to apply base augmentations. Defaults to + ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. test_alt (str): Which test set to use. Defaults to ``None``. pin_memory (bool): Whether to pin memory. Defaults to ``True``. @@ -78,13 +81,18 @@ def __init__( f"`ood_ds` should be in {self.ood_datasets}. Got {ood_ds}." ) + if basic_augment: + basic_transform = T.RandomCrop(28, padding=4) + else: + basic_transform = nn.Identity() + main_transform = Cutout(cutout) if cutout else nn.Identity() self.train_transform = T.Compose( [ + basic_transform, main_transform, T.ToTensor(), - T.RandomCrop(28, padding=4), T.Normalize((0.1307,), (0.3081,)), ] ) diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 49506d48..bec3025d 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -30,6 +30,7 @@ def __init__( val_split: float | None = None, ood_ds: str = "svhn", interpolation: str = "bilinear", + basic_augment: bool = True, rand_augment_opt: str | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -44,7 +45,6 @@ def __init__( persistent_workers=persistent_workers, ) # TODO: COMPUTE STATS - self.eval_ood = eval_ood self.ood_ds = ood_ds self.interpolation = interpolation_modes_from_str(interpolation) @@ -62,6 +62,16 @@ def __init__( f"OOD dataset {ood_ds} not supported for TinyImageNet." ) + if basic_augment: + basic_transform = T.Compose( + [ + T.RandomCrop(64, padding=4), + T.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + if rand_augment_opt is not None: main_transform = rand_augment_transform(rand_augment_opt, {}) else: @@ -69,8 +79,7 @@ def __init__( self.train_transform = T.Compose( [ - T.RandomCrop(64, padding=4), - T.RandomHorizontalFlip(), + basic_transform, main_transform, T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 34c69c89..df067d97 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -43,7 +43,7 @@ def __init__( of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. eval_size (sequence or int, optional): Desired input image and - depth mask sizes during inference. If size is an int, + depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. diff --git a/torch_uncertainty/datamodules/depth/kitti.py b/torch_uncertainty/datamodules/depth/kitti.py index c5035893..69227769 100644 --- a/torch_uncertainty/datamodules/depth/kitti.py +++ b/torch_uncertainty/datamodules/depth/kitti.py @@ -37,7 +37,7 @@ def __init__( of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(375, 1242)``. eval_size (sequence or int, optional): Desired input image and - depth mask sizes during inference. If size is an int, + depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py index cf4f6cde..032a4292 100644 --- a/torch_uncertainty/datamodules/depth/muad.py +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -37,7 +37,7 @@ def __init__( of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. eval_size (sequence or int, optional): Desired input image and - depth mask sizes during inference. If size is an int, + depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. diff --git a/torch_uncertainty/datamodules/depth/nyu.py b/torch_uncertainty/datamodules/depth/nyu.py index ec925ffa..077badff 100644 --- a/torch_uncertainty/datamodules/depth/nyu.py +++ b/torch_uncertainty/datamodules/depth/nyu.py @@ -37,7 +37,7 @@ def __init__( of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(416, 544)``. eval_size (sequence or int, optional): Desired input image and - depth mask sizes during inference. If size is an int, + depth mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index baee3d4b..ea4bea8e 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -36,7 +36,7 @@ def __init__( of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. eval_size (sequence or int, optional): Desired input image and - segmentation mask sizes during inference. If size is an int, + segmentation mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index 10f9f230..b10fa0b9 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path @@ -197,7 +198,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: """Download the dataset.""" if self._check_integrity(): - print("Files already downloaded and verified.") + logging.info("Files already downloaded and verified") return download_and_extract_archive( self.url, self.root, filename=self.filename, md5=self.tgz_md5 diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index 7d69d0f9..c5229df7 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Callable from pathlib import Path @@ -85,7 +86,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: """Download and extract dataset.""" if self._check_integrity(): - print("Files already downloaded and verified") + logging.info("Files already downloaded and verified") return if isinstance(self.filename, str): download_and_extract_archive( diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py index 553fbd1b..0e42331e 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py @@ -24,8 +24,11 @@ def __init__( ) -> None: self.root = Path(root) / "tiny-imagenet-200" + if split not in ["train", "val", "test"]: + raise ValueError(f"Split {split} is not supported.") + self.split = split - self.label_idx = 1 # from [image, id, nid, box] + self.label_idx = 1 self.transform = transform self.target_transform = target_transform diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py index 1f0bcc38..762ff346 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path @@ -155,7 +156,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: """Download the dataset.""" if self._check_integrity(): - print("Files already downloaded and verified.") + logging.info("Files already downloaded and verified") return for filename, md5 in list( zip(self.filename, self.tgz_md5, strict=True) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index 65febcf9..ae1bf563 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path from typing import Any, Literal @@ -168,7 +169,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: """Download the dataset.""" if self._check_integrity(): - print("Files already downloaded and verified.") + logging.info("Files already downloaded and verified") return download_and_extract_archive( self.url, self.root, filename=self.filename, md5=self.zip_md5 diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index 9bd27f8c..8fa77b4c 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path from typing import Any, Literal @@ -80,7 +81,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print("Files already downloaded and verified") + logging.info("Files already downloaded and verified") return download_and_extract_archive( @@ -89,7 +90,7 @@ def download(self) -> None: filename=self.filename, md5=self.tgz_md5, ) - print(f"Downloaded {self.filename} to {self.root}") + logging.info("Downloaded %s to %s.", self.filename, self.root) def __getitem__(self, index: int) -> tuple[Any, Any]: """Get the samples and targets of the dataset. diff --git a/torch_uncertainty/datasets/classification/openimage_o.py b/torch_uncertainty/datasets/classification/openimage_o.py index 2cc9104a..14c839de 100644 --- a/torch_uncertainty/datasets/classification/openimage_o.py +++ b/torch_uncertainty/datasets/classification/openimage_o.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path @@ -69,7 +70,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print("Files already downloaded and verified") + logging.info("Files already downloaded and verified") return download_and_extract_archive( diff --git a/torch_uncertainty/datasets/fractals.py b/torch_uncertainty/datasets/fractals.py index c609dd9d..d46358b5 100644 --- a/torch_uncertainty/datasets/fractals.py +++ b/torch_uncertainty/datasets/fractals.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path from typing import Any @@ -56,7 +57,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print("Files already downloaded and verified") + logging.info("Files already downloaded and verified") return download_file_from_google_drive( diff --git a/torch_uncertainty/datasets/frost.py b/torch_uncertainty/datasets/frost.py index 9cdc533e..6e391b93 100644 --- a/torch_uncertainty/datasets/frost.py +++ b/torch_uncertainty/datasets/frost.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from pathlib import Path from typing import Any @@ -63,7 +64,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print("Files already downloaded and verified") + logging.info("Files already downloaded and verified") return download_and_extract_archive( @@ -72,7 +73,7 @@ def download(self) -> None: filename=self.filename, md5=self.zip_md5, ) - print(f"Downloaded {self.filename} to {self.root}") + logging.info("Downloaded %s to %s.", self.filename, self.root) def __getitem__(self, index: int) -> Any: """Get the samples of the dataset. diff --git a/torch_uncertainty/datasets/kitti.py b/torch_uncertainty/datasets/kitti.py index f2b2a35f..5256f385 100644 --- a/torch_uncertainty/datasets/kitti.py +++ b/torch_uncertainty/datasets/kitti.py @@ -1,4 +1,5 @@ import json +import logging import shutil from collections.abc import Callable from pathlib import Path @@ -38,7 +39,7 @@ def __init__( download: bool = False, remove_unused: bool = False, ) -> None: - print( + logging.info( "KITTIDepth is copyrighted by the Karlsruhe Institute of Technology " "(KIT) and the Toyota Technological Institute at Chicago (TTIC). " "By using KITTIDepth, you agree to the terms and conditions of the " @@ -135,7 +136,7 @@ def _download_depth(self) -> None: md5=self.depth_md5, ) - print("Re-structuring the depth annotations...") + logging.info("Re-structuring the depth annotations...") if (self.root / "train" / "leftDepth").exists(): shutil.rmtree(self.root / "train" / "leftDepth") @@ -143,7 +144,7 @@ def _download_depth(self) -> None: (self.root / "train" / "leftDepth").mkdir(parents=True, exist_ok=False) depth_files = list((self.root).glob("**/tmp/train/**/image_02/*.png")) - print("Train files:") + logging.info("Train files...") for file in tqdm(depth_files): exp_code = file.parents[3].name.split("_") filecode = "_".join( @@ -157,7 +158,7 @@ def _download_depth(self) -> None: (self.root / "val" / "leftDepth").mkdir(parents=True, exist_ok=False) depth_files = list((self.root).glob("**/tmp/val/**/image_02/*.png")) - print("Validation files:") + logging.info("Validation files...") for file in tqdm(depth_files): exp_code = file.parents[3].name.split("_") filecode = "_".join( @@ -179,7 +180,7 @@ def _download_raw(self, remove_unused: bool) -> None: raw_filenames = json.load(file) for filename in tqdm(raw_filenames): - print(self.raw_url + filename) + logging.info("%s", self.raw_url + filename) download_and_extract_archive( self.raw_url + filename, download_root=self.root, @@ -187,7 +188,7 @@ def _download_raw(self, remove_unused: bool) -> None: md5=None, ) - print("Re-structuring the raw data...") + logging.info("Re-structuring the raw data...") samples_to_keep = list( (self.root / "train" / "leftDepth").glob("*.png") @@ -200,7 +201,7 @@ def _download_raw(self, remove_unused: bool) -> None: parents=True, exist_ok=False ) - print("Train files:") + logging.info("Train files...") for sample in tqdm(samples_to_keep): filecode = sample.name.split("_") first_level = "_".join([filecode[0], filecode[1], filecode[2]]) @@ -234,7 +235,7 @@ def _download_raw(self, remove_unused: bool) -> None: (self.root / "val" / "leftImg8bit").mkdir(parents=True, exist_ok=False) - print("Validation files:") + logging.info("Validation files...") for sample in tqdm(samples_to_keep): filecode = sample.name.split("_") first_level = "_".join([filecode[0], filecode[1], filecode[2]]) diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 9cde371a..f27ffe57 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -1,4 +1,5 @@ import json +import logging import os import shutil from collections.abc import Callable @@ -73,7 +74,7 @@ def __init__( MUAD cannot be used for commercial purposes. Read MUAD's license carefully before using it and verify that you can comply. """ - print( + logging.info( "MUAD is restricted to non-commercial use. By using MUAD, you " "agree to the terms and conditions." ) diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index 0f4be30c..3a23ae8d 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from importlib import util from pathlib import Path @@ -193,7 +194,7 @@ def _compute_statistics(self) -> None: def download(self) -> None: """Download and extract dataset.""" if self._check_integrity(): - print("Files already downloaded and verified") + logging.info("Files already downloaded and verified") return if self.url is None: raise ValueError( diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 5a25c821..5bf3c7fa 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -1,4 +1,5 @@ import json +import logging import shutil from collections.abc import Callable from pathlib import Path @@ -219,7 +220,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: """Download the CamVid data if it doesn't exist already.""" if self._check_integrity(): - print("Files already downloaded and verified") + logging.info("Files already downloaded and verified") return if (Path(self.root) / self.base_folder).exists(): diff --git a/torch_uncertainty/datasets/segmentation/cityscapes.py b/torch_uncertainty/datasets/segmentation/cityscapes.py index 234a6ee5..97e48ef0 100644 --- a/torch_uncertainty/datasets/segmentation/cityscapes.py +++ b/torch_uncertainty/datasets/segmentation/cityscapes.py @@ -5,11 +5,11 @@ from PIL import Image from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchvision import tv_tensors -from torchvision.datasets import Cityscapes as OriginalCityscapes +from torchvision.datasets import Cityscapes as TVCityscapes from torchvision.transforms.v2 import functional as F -class Cityscapes(OriginalCityscapes): +class Cityscapes(TVCityscapes): def encode_target(self, target: Image.Image) -> Image.Image: """Encode target image to tensor. diff --git a/torch_uncertainty/layers/__init__.py b/torch_uncertainty/layers/__init__.py index f91746bd..210e0bea 100644 --- a/torch_uncertainty/layers/__init__.py +++ b/torch_uncertainty/layers/__init__.py @@ -1,6 +1,7 @@ # ruff: noqa: F401 from .batch_ensemble import BatchConv2d, BatchLinear from .bayesian import BayesConv1d, BayesConv2d, BayesConv3d, BayesLinear +from .channel_layer_norm import ChannelLayerNorm from .masksembles import MaskedConv2d, MaskedLinear from .modules import Identity from .packed import PackedConv1d, PackedConv2d, PackedConv3d, PackedLinear diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index 3584ba77..d6122bb5 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -173,7 +173,7 @@ def unfreeze(self) -> None: self.frozen = False def sample(self) -> tuple[Tensor, Tensor | None]: - """Sample the bayesian layer's posterior.""" + """Sample the Bayesian layer's posterior.""" weight = self.weight_sampler.sample() bias = self.bias_sampler.sample() if self.bias_mu is not None else None return weight, bias diff --git a/torch_uncertainty/layers/bayesian/bayes_linear.py b/torch_uncertainty/layers/bayesian/bayes_linear.py index 2c9f15c4..ff2247d2 100644 --- a/torch_uncertainty/layers/bayesian/bayes_linear.py +++ b/torch_uncertainty/layers/bayesian/bayes_linear.py @@ -140,7 +140,7 @@ def unfreeze(self) -> None: self.frozen = False def sample(self) -> tuple[Tensor, Tensor | None]: - """Sample the bayesian layer's posterior.""" + """Sample the Bayesian layer's posterior.""" weight = self.weight_sampler.sample() bias = self.bias_sampler.sample() if self.bias_mu is not None else None return weight, bias diff --git a/torch_uncertainty/layers/channel_layer_norm.py b/torch_uncertainty/layers/channel_layer_norm.py new file mode 100644 index 00000000..69999324 --- /dev/null +++ b/torch_uncertainty/layers/channel_layer_norm.py @@ -0,0 +1,59 @@ +import torch +from torch import Tensor +from torch.nn import LayerNorm + +from .utils import ChannelBack, ChannelFront + + +class ChannelLayerNorm(LayerNorm): + def __init__( + self, + normalized_shape: int | list[int], + eps: float = 0.00001, + elementwise_affine: bool = True, + bias: bool = True, + device: torch.device | str | None = None, + dtype: torch.dtype | str | None = None, + ) -> None: + r"""Layer normalization over the channel dimension. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the channel dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine (bool): a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias (bool): If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`elementwise_affine` is ``True``). Default: ``True``. + device (torch.device or str or None): the desired device of the module. + dtype (torch.dtype or str or None): the desired floating point type of the module. + + Attributes: + weight: the learnable weights of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 1. + bias: the learnable bias of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 0. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + """ + super().__init__( + normalized_shape, eps, elementwise_affine, bias, device, dtype + ) + self.cback = ChannelBack() + self.cfront = ChannelFront() + + def forward(self, inputs: Tensor) -> Tensor: + return self.cfront(super().forward(self.cback(inputs))) diff --git a/torch_uncertainty/layers/utils.py b/torch_uncertainty/layers/utils.py new file mode 100644 index 00000000..050d56d9 --- /dev/null +++ b/torch_uncertainty/layers/utils.py @@ -0,0 +1,12 @@ +from einops import rearrange +from torch import Tensor, nn + + +class ChannelBack(nn.Module): + def forward(self, x: Tensor) -> Tensor: + return rearrange(x, "b c h w -> b h w c") + + +class ChannelFront(nn.Module): + def forward(self, x: Tensor) -> Tensor: + return rearrange(x, "b h w c -> b c h w") diff --git a/torch_uncertainty/losses/__init__.py b/torch_uncertainty/losses/__init__.py new file mode 100644 index 00000000..318295e1 --- /dev/null +++ b/torch_uncertainty/losses/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 +from .bayesian import ELBOLoss, KLDiv +from .classification import ConfidencePenaltyLoss, ConflictualLoss, DECLoss +from .regression import BetaNLL, DERLoss, DistributionNLLLoss diff --git a/torch_uncertainty/losses/bayesian.py b/torch_uncertainty/losses/bayesian.py new file mode 100644 index 00000000..3621a8f2 --- /dev/null +++ b/torch_uncertainty/losses/bayesian.py @@ -0,0 +1,114 @@ +import torch +from torch import Tensor, nn + +from torch_uncertainty.layers.bayesian import bayesian_modules + + +class KLDiv(nn.Module): + def __init__(self, model: nn.Module) -> None: + """KL divergence loss for Bayesian Neural Networks. Gathers the KL from the + modules computed in the forward passes. + + Args: + model (nn.Module): Bayesian Neural Network + """ + super().__init__() + self.model = model + + def forward(self) -> Tensor: + return self._kl_div() + + def _kl_div(self) -> Tensor: + """Gathers pre-computed KL-Divergences from :attr:`model`.""" + kl_divergence = torch.zeros(1) + count = 0 + for module in self.model.modules(): + if isinstance(module, bayesian_modules): + kl_divergence = kl_divergence.to( + device=module.lvposterior.device + ) + kl_divergence += module.lvposterior - module.lprior + count += 1 + return kl_divergence / count + + +class ELBOLoss(nn.Module): + def __init__( + self, + model: nn.Module | None, + inner_loss: nn.Module, + kl_weight: float, + num_samples: int, + ) -> None: + """The Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks. + + ELBO loss for Bayesian Neural Networks. Use this loss function with the + objective that you seek to minimize as :attr:`inner_loss`. + + Args: + model (nn.Module): The Bayesian Neural Network to compute the loss for + inner_loss (nn.Module): The loss function to use during training + kl_weight (float): The weight of the KL divergence term + num_samples (int): The number of samples to use for the ELBO loss + + Note: + Set the model to None if you use the ELBOLoss within + the ClassificationRoutine. It will get filled automatically. + """ + super().__init__() + _elbo_loss_checks(inner_loss, kl_weight, num_samples) + self.set_model(model) + + self.inner_loss = inner_loss + self.kl_weight = kl_weight + self.num_samples = num_samples + + def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: + """Gather the KL divergence from the Bayesian modules and aggregate + the ELBO loss for a given network. + + Args: + inputs (Tensor): The inputs of the Bayesian Neural Network + targets (Tensor): The target values + + Returns: + Tensor: The aggregated ELBO loss + """ + aggregated_elbo = torch.zeros(1, device=inputs.device) + for _ in range(self.num_samples): + logits = self.model(inputs) + aggregated_elbo += self.inner_loss(logits, targets) + # TODO: This shouldn't be necessary + aggregated_elbo += self.kl_weight * self._kl_div().to(inputs.device) + return aggregated_elbo / self.num_samples + + def set_model(self, model: nn.Module | None) -> None: + self.model = model + if model is not None: + self._kl_div = KLDiv(model) + + +def _elbo_loss_checks( + inner_loss: nn.Module, kl_weight: float, num_samples: int +) -> None: + if isinstance(inner_loss, type): + raise TypeError( + "The inner_loss should be an instance of a class." + f"Got {inner_loss}." + ) + + if kl_weight < 0: + raise ValueError( + f"The KL weight should be non-negative. Got {kl_weight}." + ) + + if num_samples < 1: + raise ValueError( + "The number of samples should not be lower than 1." + f"Got {num_samples}." + ) + if not isinstance(num_samples, int): + raise TypeError( + "The number of samples should be an integer. " + f"Got {type(num_samples)}." + ) diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses/classification.py similarity index 50% rename from torch_uncertainty/losses.py rename to torch_uncertainty/losses/classification.py index c82ab210..0b9230b9 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses/classification.py @@ -1,258 +1,7 @@ -from typing import Literal - import torch from torch import Tensor, nn -from torch.distributions import Distribution from torch.nn import functional as F -from torch_uncertainty.layers.bayesian import bayesian_modules -from torch_uncertainty.utils.distributions import NormalInverseGamma - - -class DistributionNLLLoss(nn.Module): - def __init__( - self, reduction: Literal["mean", "sum"] | None = "mean" - ) -> None: - """Negative Log-Likelihood loss using given distributions as inputs. - - Args: - reduction (str, optional): specifies the reduction to apply to the - output:``'none'`` | ``'mean'`` | ``'sum'``. Defaults to "mean". - """ - super().__init__() - self.reduction = reduction - - def forward( - self, - dist: Distribution, - targets: Tensor, - padding_mask: Tensor | None = None, - ) -> Tensor: - """Compute the NLL of the targets given predicted distributions. - - Args: - dist (Distribution): The predicted distributions - targets (Tensor): The target values - padding_mask (Tensor, optional): The padding mask. Defaults to None. - Sets the loss to 0 for padded values. - """ - loss = -dist.log_prob(targets) - if padding_mask is not None: - loss = loss.masked_fill(padding_mask, 0.0) - - if self.reduction == "mean": - loss = loss.mean() - elif self.reduction == "sum": - loss = loss.sum() - return loss - - -class KLDiv(nn.Module): - def __init__(self, model: nn.Module) -> None: - """KL divergence loss for Bayesian Neural Networks. Gathers the KL from the - modules computed in the forward passes. - - Args: - model (nn.Module): Bayesian Neural Network - """ - super().__init__() - self.model = model - - def forward(self) -> Tensor: - return self._kl_div() - - def _kl_div(self) -> Tensor: - """Gathers pre-computed KL-Divergences from :attr:`model`.""" - kl_divergence = torch.zeros(1) - count = 0 - for module in self.model.modules(): - if isinstance(module, bayesian_modules): - kl_divergence = kl_divergence.to( - device=module.lvposterior.device - ) - kl_divergence += module.lvposterior - module.lprior - count += 1 - return kl_divergence / count - - -class ELBOLoss(nn.Module): - def __init__( - self, - model: nn.Module | None, - inner_loss: nn.Module, - kl_weight: float, - num_samples: int, - ) -> None: - """The Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks. - - ELBO loss for Bayesian Neural Networks. Use this loss function with the - objective that you seek to minimize as :attr:`inner_loss`. - - Args: - model (nn.Module): The Bayesian Neural Network to compute the loss for - inner_loss (nn.Module): The loss function to use during training - kl_weight (float): The weight of the KL divergence term - num_samples (int): The number of samples to use for the ELBO loss - - Note: - Set the model to None if you use the ELBOLoss within - the ClassificationRoutine. It will get filled automatically. - """ - super().__init__() - _elbo_loss_checks(inner_loss, kl_weight, num_samples) - self.set_model(model) - - self.inner_loss = inner_loss - self.kl_weight = kl_weight - self.num_samples = num_samples - - def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: - """Gather the KL divergence from the bayesian modules and aggregate - the ELBO loss for a given network. - - Args: - inputs (Tensor): The inputs of the Bayesian Neural Network - targets (Tensor): The target values - - Returns: - Tensor: The aggregated ELBO loss - """ - aggregated_elbo = torch.zeros(1, device=inputs.device) - for _ in range(self.num_samples): - logits = self.model(inputs) - aggregated_elbo += self.inner_loss(logits, targets) - # TODO: This shouldn't be necessary - aggregated_elbo += self.kl_weight * self._kl_div().to(inputs.device) - return aggregated_elbo / self.num_samples - - def set_model(self, model: nn.Module | None) -> None: - self.model = model - if model is not None: - self._kl_div = KLDiv(model) - - -def _elbo_loss_checks( - inner_loss: nn.Module, kl_weight: float, num_samples: int -) -> None: - if isinstance(inner_loss, type): - raise TypeError( - "The inner_loss should be an instance of a class." - f"Got {inner_loss}." - ) - - if kl_weight < 0: - raise ValueError( - f"The KL weight should be non-negative. Got {kl_weight}." - ) - - if num_samples < 1: - raise ValueError( - "The number of samples should not be lower than 1." - f"Got {num_samples}." - ) - if not isinstance(num_samples, int): - raise TypeError( - "The number of samples should be an integer. " - f"Got {type(num_samples)}." - ) - - -class DERLoss(DistributionNLLLoss): - def __init__( - self, reg_weight: float, reduction: str | None = "mean" - ) -> None: - """The Deep Evidential loss. - - This loss combines the negative log-likelihood loss of the normal - inverse gamma distribution and a weighted regularization term. - - Args: - reg_weight (float): The weight of the regularization term. - reduction (str, optional): specifies the reduction to apply to the - output:``'none'`` | ``'mean'`` | ``'sum'``. - - Reference: - Amini, A., Schwarting, W., Soleimany, A., & Rus, D. (2019). Deep - evidential regression. https://arxiv.org/abs/1910.02600. - """ - super().__init__(reduction=None) - - if reduction not in (None, "none", "mean", "sum"): - raise ValueError(f"{reduction} is not a valid value for reduction.") - self.final_reduction = reduction - - if reg_weight < 0: - raise ValueError( - "The regularization weight should be non-negative, but got " - f"{reg_weight}." - ) - self.reg_weight = reg_weight - - def _reg(self, dist: NormalInverseGamma, targets: Tensor) -> Tensor: - return torch.norm(targets - dist.loc, 1, dim=1, keepdim=True) * ( - 2 * dist.lmbda + dist.alpha - ) - - def forward( - self, - dist: NormalInverseGamma, - targets: Tensor, - ) -> Tensor: - loss_nll = super().forward(dist, targets) - loss_reg = self._reg(dist, targets) - loss = loss_nll + self.reg_weight * loss_reg - - if self.final_reduction == "mean": - return loss.mean() - if self.final_reduction == "sum": - return loss.sum() - return loss - - -class BetaNLL(nn.Module): - def __init__( - self, beta: float = 0.5, reduction: str | None = "mean" - ) -> None: - """The Beta Negative Log-likelihood loss. - - Args: - beta (float): TParameter from range [0, 1] controlling relative - weighting between data points, where `0` corresponds to - high weight on low error points and `1` to an equal weighting. - reduction (str, optional): specifies the reduction to apply to the - output:``'none'`` | ``'mean'`` | ``'sum'``. - - Reference: - Seitzer, M., Tavakoli, A., Antic, D., & Martius, G. (2022). On the - pitfalls of heteroscedastic uncertainty estimation with probabilistic - neural networks. https://arxiv.org/abs/2203.09168. - """ - super().__init__() - - if beta < 0 or beta > 1: - raise ValueError( - "The beta parameter should be in range [0, 1], but got " - f"{beta}." - ) - self.beta = beta - self.nll_loss = nn.GaussianNLLLoss(reduction="none") - if reduction not in ("none", "mean", "sum"): - raise ValueError(f"{reduction} is not a valid value for reduction.") - self.reduction = reduction - - def forward( - self, mean: Tensor, targets: Tensor, variance: Tensor - ) -> Tensor: - loss = self.nll_loss(mean, targets, variance) * ( - variance.detach() ** self.beta - ) - - if self.reduction == "mean": - return loss.mean() - if self.reduction == "sum": - return loss.sum() - return loss - class DECLoss(nn.Module): def __init__( @@ -294,7 +43,7 @@ def __init__( ) self.annealing_step = annealing_step - if reduction not in ("none", "mean", "sum"): + if reduction not in ("none", "mean", "sum") and reduction is not None: raise ValueError(f"{reduction} is not a valid value for reduction.") self.reduction = reduction @@ -417,3 +166,116 @@ def forward( elif self.reduction == "sum": loss = loss.sum() return loss + + +class ConfidencePenaltyLoss(nn.Module): + def __init__( + self, + reg_weight: float = 1, + reduction: str | None = "mean", + eps: float = 1e-6, + ) -> None: + """The Confidence Penalty Loss. + + Args: + reg_weight (float, optional): The weight of the regularization term. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. Defaults to "mean". + eps (float, optional): A small value to avoid numerical instability. + Defaults to 1e-6. + + Reference: + Gabriel Pereyra: Regularizing neural networks by penalizing + confident output distributions. https://arxiv.org/pdf/1701.06548. + + """ + super().__init__() + if reduction is None: + reduction = "none" + if reduction not in ("none", "mean", "sum"): + raise ValueError(f"{reduction} is not a valid value for reduction.") + self.reduction = reduction + if eps < 0: + raise ValueError( + "The epsilon value should be non-negative, but got " f"{eps}." + ) + self.eps = eps + if reg_weight < 0: + raise ValueError( + "The regularization weight should be non-negative, but got " + f"{reg_weight}." + ) + self.reg_weight = reg_weight + + def forward(self, logits: Tensor, targets: Tensor) -> Tensor: + """Compute the Confidence Penalty loss. + + Args: + logits (Tensor): The inputs of the Bayesian Neural Network + targets (Tensor): The target values + + Returns: + Tensor: The Confidence Penalty loss + """ + probs = F.softmax(logits, dim=1) + ce_loss = F.cross_entropy(logits, targets, reduction=self.reduction) + reg_loss = torch.log( + torch.tensor(logits.shape[-1], device=probs.device) + ) + (probs * torch.log(probs + self.eps)).sum(dim=-1) + if self.reduction == "sum": + return ce_loss + self.reg_weight * reg_loss.sum() + if self.reduction == "mean": + return ce_loss + self.reg_weight * reg_loss.mean() + return ce_loss + self.reg_weight * reg_loss + + +class ConflictualLoss(nn.Module): + def __init__( + self, + reg_weight: float = 1, + reduction: str | None = "mean", + ) -> None: + r"""The Conflictual Loss. + + Args: + reg_weight (float, optional): The weight of the regularization term. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. + + Reference: + `Mohammed Fellaji et al. On the Calibration of Epistemic Uncertainty: + Principles, Paradoxes and Conflictual Loss `_. + """ + super().__init__() + if reduction is None: + reduction = "none" + if reduction not in ("none", "mean", "sum"): + raise ValueError(f"{reduction} is not a valid value for reduction.") + self.reduction = reduction + if reg_weight < 0: + raise ValueError( + "The regularization weight should be non-negative, but got " + f"{reg_weight}." + ) + self.reg_weight = reg_weight + + def forward(self, logits: Tensor, targets: Tensor) -> Tensor: + """Compute the conflictual loss. + + Args: + logits (Tensor): The outputs of the model. + targets (Tensor): The target values. + + Returns: + Tensor: The conflictual loss. + """ + class_index = torch.randint( + 0, logits.shape[-1], (1,), dtype=torch.long, device=logits.device + ) + ce_loss = F.cross_entropy(logits, targets, reduction=self.reduction) + reg_loss = -F.log_softmax(logits, dim=1)[:, class_index] + if self.reduction == "sum": + return ce_loss + self.reg_weight * reg_loss.sum() + if self.reduction == "mean": + return ce_loss + self.reg_weight * reg_loss.mean() + return ce_loss + self.reg_weight * reg_loss diff --git a/torch_uncertainty/losses/regression.py b/torch_uncertainty/losses/regression.py new file mode 100644 index 00000000..99b9b9fd --- /dev/null +++ b/torch_uncertainty/losses/regression.py @@ -0,0 +1,142 @@ +from typing import Literal + +import torch +from torch import Tensor, nn +from torch.distributions import Distribution + +from torch_uncertainty.utils.distributions import NormalInverseGamma + + +class DistributionNLLLoss(nn.Module): + def __init__( + self, reduction: Literal["mean", "sum"] | None = "mean" + ) -> None: + """Negative Log-Likelihood loss using given distributions as inputs. + + Args: + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. Defaults to "mean". + """ + super().__init__() + self.reduction = reduction + + def forward( + self, + dist: Distribution, + targets: Tensor, + padding_mask: Tensor | None = None, + ) -> Tensor: + """Compute the NLL of the targets given predicted distributions. + + Args: + dist (Distribution): The predicted distributions + targets (Tensor): The target values + padding_mask (Tensor, optional): The padding mask. Defaults to None. + Sets the loss to 0 for padded values. + """ + loss = -dist.log_prob(targets) + if padding_mask is not None: + loss = loss.masked_fill(padding_mask, 0.0) + + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + return loss + + +class DERLoss(DistributionNLLLoss): + def __init__( + self, reg_weight: float, reduction: str | None = "mean" + ) -> None: + """The Deep Evidential Regression loss. + + This loss combines the negative log-likelihood loss of the normal + inverse gamma distribution and a weighted regularization term. + + Args: + reg_weight (float): The weight of the regularization term. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. + + Reference: + Amini, A., Schwarting, W., Soleimany, A., & Rus, D. (2019). Deep + evidential regression. https://arxiv.org/abs/1910.02600. + """ + super().__init__(reduction=None) + + if reduction not in ("none", "mean", "sum") and reduction is not None: + raise ValueError(f"{reduction} is not a valid value for reduction.") + self.der_reduction = reduction + + if reg_weight < 0: + raise ValueError( + "The regularization weight should be non-negative, but got " + f"{reg_weight}." + ) + self.reg_weight = reg_weight + + def _reg(self, dist: NormalInverseGamma, targets: Tensor) -> Tensor: + return torch.norm(targets - dist.loc, 1, dim=1, keepdim=True) * ( + 2 * dist.lmbda + dist.alpha + ) + + def forward( + self, + dist: NormalInverseGamma, + targets: Tensor, + ) -> Tensor: + loss_nll = super().forward(dist, targets) + loss_reg = self._reg(dist, targets) + loss = loss_nll + self.reg_weight * loss_reg + + if self.der_reduction == "mean": + return loss.mean() + if self.der_reduction == "sum": + return loss.sum() + return loss + + +class BetaNLL(nn.Module): + def __init__( + self, beta: float = 0.5, reduction: str | None = "mean" + ) -> None: + """The Beta Negative Log-likelihood loss. + + Args: + beta (float): TParameter from range [0, 1] controlling relative + weighting between data points, where `0` corresponds to + high weight on low error points and `1` to an equal weighting. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. + + Reference: + Seitzer, M., Tavakoli, A., Antic, D., & Martius, G. (2022). On the + pitfalls of heteroscedastic uncertainty estimation with probabilistic + neural networks. https://arxiv.org/abs/2203.09168. + """ + super().__init__() + + if beta < 0 or beta > 1: + raise ValueError( + "The beta parameter should be in range [0, 1], but got " + f"{beta}." + ) + self.beta = beta + self.nll_loss = nn.GaussianNLLLoss(reduction="none") + if reduction not in ("none", "mean", "sum"): + raise ValueError(f"{reduction} is not a valid value for reduction.") + self.reduction = reduction + + def forward( + self, mean: Tensor, targets: Tensor, variance: Tensor + ) -> Tensor: + loss = self.nll_loss(mean, targets, variance) * ( + variance.detach() ** self.beta + ) + + if self.reduction == "mean": + return loss.mean() + if self.reduction == "sum": + return loss.sum() + return loss diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index ee1a63b9..52e55366 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .classification import ( + AUGRC, AURC, AUSE, FPR95, @@ -8,6 +9,7 @@ CalibrationError, CategoricalNLL, CovAt5Risk, + CovAtxRisk, Disagreement, Entropy, GroupingLoss, diff --git a/torch_uncertainty/metrics/classification/__init__.py b/torch_uncertainty/metrics/classification/__init__.py index de375588..0e454888 100644 --- a/torch_uncertainty/metrics/classification/__init__.py +++ b/torch_uncertainty/metrics/classification/__init__.py @@ -5,10 +5,17 @@ from .categorical_nll import CategoricalNLL from .disagreement import Disagreement from .entropy import Entropy -from .fpr95 import FPR95, FPRx +from .fpr import FPR95, FPRx from .grouping_loss import GroupingLoss from .mean_iou import MeanIntersectionOverUnion from .mutual_information import MutualInformation -from .risk_coverage import AURC, CovAt5Risk, CovAtxRisk, RiskAt80Cov, RiskAtxCov +from .risk_coverage import ( + AUGRC, + AURC, + CovAt5Risk, + CovAtxRisk, + RiskAt80Cov, + RiskAtxCov, +) from .sparsification import AUSE from .variation_ratio import VariationRatio diff --git a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py index 4f4c4850..c1e066d9 100644 --- a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py +++ b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py @@ -64,7 +64,7 @@ def _ace_compute( norm: Norm function to use when computing calibration error. Defaults to "l1". debias: Apply debiasing to L2 norm computation as in - `Verified Uncertainty Calibration`_. Defaults to False. + `Verified Uncertainty Calibration`. Defaults to False. Returns: Tensor: Adaptive Calibration error scalar. diff --git a/torch_uncertainty/metrics/classification/fpr95.py b/torch_uncertainty/metrics/classification/fpr.py similarity index 50% rename from torch_uncertainty/metrics/classification/fpr95.py rename to torch_uncertainty/metrics/classification/fpr.py index f7e4a660..214daded 100644 --- a/torch_uncertainty/metrics/classification/fpr95.py +++ b/torch_uncertainty/metrics/classification/fpr.py @@ -1,4 +1,3 @@ -import numpy as np import torch from torch import Tensor from torchmetrics import Metric @@ -23,7 +22,8 @@ def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: kwargs: Additional arguments to pass to the metric class. Reference: - Inpired by https://github.com/hendrycks/anomaly-seg. + Improved from https://github.com/hendrycks/anomaly-seg and + translated to torch. """ super().__init__(**kwargs) @@ -47,68 +47,69 @@ def update(self, conf: Tensor, target: Tensor) -> None: Args: conf (Tensor): The confidence scores. - target (Tensor): The target labels. + target (Tensor): The target labels, 0 if ID, 1 if OOD. """ self.conf.append(conf) self.targets.append(target) def compute(self) -> Tensor: - """Compute the actual False Positive Rate at x% Recall. + """Compute the False Positive Rate at x% Recall. Returns: Tensor: The value of the FPRx. """ - conf = dim_zero_cat(self.conf).cpu().numpy() - targets = dim_zero_cat(self.targets).cpu().numpy() - - # out_labels is an array of 0s and 1s - 0 if IOD 1 if OOD - out_labels = targets == self.pos_label - - in_scores = conf[np.logical_not(out_labels)] - out_scores = conf[out_labels] - - neg = np.array(in_scores[:]).reshape((-1, 1)) - pos = np.array(out_scores[:]).reshape((-1, 1)) - examples = np.squeeze(np.vstack((pos, neg))) - labels = np.zeros(len(examples), dtype=np.int32) - labels[: len(pos)] += 1 - - # make labels a boolean vector, True if OOD - labels = labels == self.pos_label - - # sort scores and corresponding truth values - desc_score_indices = np.argsort(examples, kind="mergesort")[::-1] - examples = examples[desc_score_indices] - labels = labels[desc_score_indices] - - # examples typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = np.where(np.diff(examples))[0] - threshold_idxs = np.r_[distinct_value_indices, labels.shape[0] - 1] + conf = dim_zero_cat(self.conf) + targets = dim_zero_cat(self.targets) + + # map examples and labels to OOD first + indx = torch.argsort(targets, descending=True) + examples = conf[indx] + labels = torch.zeros_like(targets, dtype=torch.bool, device=self.device) + labels[: torch.count_nonzero(targets)] = True + + # sort examples and labels by decreasing confidence + desc_scores_indx = torch.argsort(examples, descending=True) + examples = examples[desc_scores_indx] + labels = labels[desc_scores_indx] + + # Get the indices of the distinct values + distinct_value_indices = torch.where(torch.diff(examples))[0] + threshold_idxs = torch.cat( + [ + distinct_value_indices, + torch.tensor( + [labels.shape[0] - 1], dtype=torch.long, device=self.device + ), + ] + ) # accumulate the true positives with decreasing threshold - tps = np.cumsum(labels)[threshold_idxs] - fps = 1 + threshold_idxs - tps # add one because of zero-based indexing - - thresholds = examples[threshold_idxs] - - recall = tps / tps[-1] - - last_ind = tps.searchsorted(tps[-1]) - sl = slice(last_ind, None, -1) # [last_ind::-1] - recall, fps, tps, thresholds = ( - np.r_[recall[sl], 1], - np.r_[fps[sl], 0], - np.r_[tps[sl], 0], - thresholds[sl], + true_pos = torch.cumsum(labels, dim=0)[threshold_idxs] + false_pos = ( + 1 + threshold_idxs - true_pos + ) # add one because of zero-based indexing + + # check that there is at least one OOD example + if true_pos[-1] == 0: + return torch.tensor([torch.nan], device=self.device) + + recall = true_pos / true_pos[-1] + + last_ind = torch.searchsorted(true_pos, true_pos[-1]) + recall = torch.cat( + [ + recall[: last_ind + 1].flip(0), + torch.tensor([1.0], device=self.device), + ] ) - - cutoff = np.argmin(np.abs(recall - self.recall_level)) - - return torch.tensor( - fps[cutoff] / (np.sum(np.logical_not(labels))), dtype=torch.float32 + false_pos = torch.cat( + [ + false_pos[: last_ind + 1].flip(0), + torch.tensor([0.0], device=self.device), + ] ) + cutoff = torch.argmin(torch.abs(recall - self.recall_level)) + return false_pos[cutoff] / (~labels).sum() class FPR95(FPRx): @@ -116,7 +117,6 @@ def __init__(self, pos_label: int, **kwargs) -> None: """The False Positive Rate at 95% Recall metric. Args: - recall_level (float): The recall level at which to compute the FPR. pos_label (int): The positive label. kwargs: Additional arguments to pass to the metric class. """ diff --git a/torch_uncertainty/metrics/classification/grouping_loss.py b/torch_uncertainty/metrics/classification/grouping_loss.py index da9eab41..155e111f 100644 --- a/torch_uncertainty/metrics/classification/grouping_loss.py +++ b/torch_uncertainty/metrics/classification/grouping_loss.py @@ -117,7 +117,7 @@ def update(self, probs: Tensor, target: Tensor, features: Tensor) -> None: f"{features.shape}." ) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """Compute the final Brier score based on inputs passed to ``update``. Returns: diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py index 264d363f..dced8409 100644 --- a/torch_uncertainty/metrics/classification/risk_coverage.py +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -4,7 +4,7 @@ import numpy as np import torch from torch import Tensor -from torchmetrics.metric import Metric +from torchmetrics import Metric from torchmetrics.utilities.compute import _auc_compute from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.plot import _AX_TYPE @@ -86,12 +86,9 @@ def compute(self) -> Tensor: error_rates = self.partial_compute() num_samples = error_rates.size(0) if num_samples < 2: - return torch.tensor([float("nan")], device=error_rates.device) - x = ( - torch.arange(1, num_samples + 1, device=error_rates.device) - / num_samples - ) - return _auc_compute(x, error_rates) / (1 - 1 / num_samples) + return torch.tensor([float("nan")], device=self.device) + cov = torch.arange(1, num_samples + 1, device=self.device) / num_samples + return _auc_compute(cov, error_rates) / (1 - 1 / num_samples) def plot( self, @@ -100,7 +97,7 @@ def plot( name: str | None = None, ) -> tuple[plt.Figure | None, plt.Axes]: """Plot the risk-cov. curve corresponding to the inputs passed to - ``update``, and the oracle risk-cov. curve. + ``update``. Args: ax (Axes | None, optional): An matplotlib axis object. If provided @@ -114,7 +111,7 @@ def plot( """ fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (None, ax) - # Computation of AUSEC + # Computation of AURC error_rates = self.partial_compute().cpu().flip(0) num_samples = error_rates.size(0) @@ -139,7 +136,7 @@ def plot( ax.text( 0.02, 0.95, - f"AUSEC={aurc:.2%}", + f"AURC={aurc:.2%}", color="black", ha="left", va="bottom", @@ -166,13 +163,119 @@ def _aurc_rejection_rate_compute( scores (Tensor): uncertainty scores of shape :math:`(B,)` errors (Tensor): binary errors of shape :math:`(B,)` """ - num_samples = scores.size(0) errors = errors[scores.argsort(descending=True)] return errors.cumsum(dim=-1) / torch.arange( - 1, num_samples + 1, dtype=scores.dtype, device=scores.device + 1, scores.size(0) + 1, dtype=scores.dtype, device=scores.device ) +class AUGRC(AURC): + """Area Under the Generalized Risk-Coverage curve. + + The Area Under the Generalized Risk-Coverage curve (AUGRC) for + Selective Classification (SC) performance assessment. It avoids putting too much + weight on the most confident samples. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape + ``(N, ...)`` containing probabilities for each observation. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape + ``(N, ...)`` containing ground-truth labels. + + As output to ``forward`` and ``compute`` the metric returns the + following output: + + - ``augrc`` (:class:`~torch.Tensor`): A scalar tensor containing the + area under the risk-coverage curve + + Args: + kwargs: Additional keyword arguments. + + Reference: + Traub et al. Overcoming Common Flaws in the Evaluation of Selective + Classification Systems. ArXiv. + """ + + def compute(self) -> Tensor: + """Compute the Area Under the Generalized Risk-Coverage curve (AUGRC). + + Normalize the AUGRC as if its support was between 0 and 1. This has an + impact on the AUGRC when the number of samples is small. + + Returns: + Tensor: The AUGRC. + """ + error_rates = self.partial_compute() + num_samples = error_rates.size(0) + if num_samples < 2: + return torch.tensor([float("nan")], device=self.device) + cov = torch.arange(1, num_samples + 1, device=self.device) / num_samples + return _auc_compute(cov, error_rates * cov) / (1 - 1 / num_samples) + + def plot( + self, + ax: _AX_TYPE | None = None, + plot_value: bool = True, + name: str | None = None, + ) -> tuple[plt.Figure | None, plt.Axes]: + """Plot the generalized risk-cov. curve corresponding to the inputs passed to + ``update``. + + Args: + ax (Axes | None, optional): An matplotlib axis object. If provided + will add plot to this axis. Defaults to None. + plot_value (bool, optional): Whether to print the AURC value on the + plot. Defaults to True. + name (str | None, optional): Name of the model. Defaults to None. + + Returns: + tuple[[Figure | None], Axes]: Figure object and Axes object + """ + fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (None, ax) + + # Computation of AUGRC + error_rates = self.partial_compute().cpu().flip(0) + num_samples = error_rates.size(0) + cov = torch.arange(num_samples) / num_samples + + augrc = _auc_compute(cov, error_rates * cov).cpu().item() + + # reduce plot size + plot_covs = np.arange(0.01, 100 + 0.01, 0.01) + covs = np.arange(start=1, stop=num_samples + 1) / num_samples + + rejection_rates = np.interp(plot_covs, covs, cov * 100) + error_rates = np.interp(plot_covs, covs, error_rates * covs[::-1] * 100) + + # plot + ax.plot( + 100 - rejection_rates, + error_rates, + label="Model" if name is None else name, + ) + + if plot_value: + ax.text( + 0.02, + 0.95, + f"AUGRC={augrc:.2%}", + color="black", + ha="left", + va="bottom", + transform=ax.transAxes, + ) + plt.grid(True, linestyle="--", alpha=0.7, zorder=0) + ax.set_xlabel("Coverage (%)", fontsize=16) + ax.set_ylabel("Generalized Risk (%)", fontsize=16) + ax.set_xlim(0, 100) + ax.set_ylim(0, 100) + ax.set_aspect("equal", "box") + ax.legend(loc="upper right") + fig.tight_layout() + return fig, ax + + class CovAtxRisk(Metric): is_differentiable: bool = False higher_is_better: bool = False @@ -182,7 +285,7 @@ class CovAtxRisk(Metric): errors: list[Tensor] def __init__(self, risk_threshold: float, **kwargs) -> None: - r"""`Coverage at x Risk`_. + r"""Coverage at x Risk. If there are multiple coverage values corresponding to the given risk, i.e., the risk(coverage) is not monotonic, the coverage at x risk is @@ -223,7 +326,7 @@ def compute(self) -> Tensor: errors = dim_zero_cat(self.errors) num_samples = scores.size(0) if num_samples < 1: - return torch.tensor([float("nan")], device=scores.device) + return torch.tensor([float("nan")], device=self.device) error_rates = _aurc_rejection_rate_compute(scores, errors) admissible_risks = (error_rates > self.risk_threshold) * 1 max_cov_at_risk = admissible_risks.flip(0).argmin() @@ -231,7 +334,7 @@ def compute(self) -> Tensor: # check if max_cov_at_risk is really admissible, if not return nan risk = admissible_risks[max_cov_at_risk] if risk > self.risk_threshold: - return torch.tensor([float("nan")], device=scores.device) + return torch.tensor([float("nan")], device=self.device) return 1 - max_cov_at_risk / num_samples diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 36b175c0..55a4c772 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -25,7 +25,6 @@ def __init__( norm: type[nn.Module], groups: int, dropout_rate: float, - last_layer_dropout: bool, ) -> None: super().__init__() self.activation = activation @@ -44,7 +43,6 @@ def __init__( ) self.dropout_rate = dropout_rate - self.last_layer_dropout = last_layer_dropout self.conv1 = conv2d_layer( in_channels, 6, (5, 5), groups=groups, **layer_args @@ -88,7 +86,6 @@ def _lenet( norm: type[nn.Module] = nn.Identity, groups: int = 1, dropout_rate: float = 0.0, - last_layer_dropout: bool = False, ) -> _LeNet | StochasticModel: model = _LeNet( in_channels=in_channels, @@ -100,7 +97,6 @@ def _lenet( groups=groups, layer_args=layer_args, dropout_rate=dropout_rate, - last_layer_dropout=last_layer_dropout, ) if stochastic: return StochasticModel(model, num_samples) @@ -114,7 +110,6 @@ def lenet( norm: type[nn.Module] = nn.Identity, groups: int = 1, dropout_rate: float = 0.0, - last_layer_dropout: bool = False, ) -> _LeNet: return _lenet( stochastic=False, @@ -127,7 +122,6 @@ def lenet( norm=norm, groups=groups, dropout_rate=dropout_rate, - last_layer_dropout=last_layer_dropout, ) diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index d0fdee07..720ce7f0 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -39,7 +39,6 @@ def __init__( super().__init__() self.activation = activation self.dropout_rate = dropout_rate - layers = nn.ModuleList() if len(hidden_dims) == 0: diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index 795cd13e..cd89e63c 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -332,7 +332,10 @@ def batched_resnet( Returns: _BatchedResNet: A BatchEnsemble-style ResNet. """ - block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck + block = ( + _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck + ) + in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _BatchedResNet( block=block, num_blocks=get_resnet_num_blocks(arch), @@ -343,6 +346,6 @@ def batched_resnet( dropout_rate=dropout_rate, groups=groups, style=style, - in_planes=int(64 * width_multiplier), + in_planes=int(in_planes * width_multiplier), normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/resnet/lpbnn.py b/torch_uncertainty/models/resnet/lpbnn.py index 36c4d103..b79de57c 100644 --- a/torch_uncertainty/models/resnet/lpbnn.py +++ b/torch_uncertainty/models/resnet/lpbnn.py @@ -326,7 +326,10 @@ def lpbnn_resnet( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", ) -> _LPBNNResNet: - block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck + block = ( + _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck + ) + in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _LPBNNResNet( block=block, num_blocks=get_resnet_num_blocks(arch), @@ -337,5 +340,5 @@ def lpbnn_resnet( conv_bias=conv_bias, groups=groups, style=style, - in_planes=int(64 * width_multiplier), + in_planes=int(in_planes * width_multiplier), ) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 45398891..04117e67 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -351,7 +351,10 @@ def masked_resnet( Returns: _MaskedResNet: A Masksembles-style ResNet. """ - block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck + block = ( + _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck + ) + in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _MaskedResNet( block=block, num_blocks=get_resnet_num_blocks(arch), @@ -363,6 +366,6 @@ def masked_resnet( conv_bias=conv_bias, dropout_rate=dropout_rate, style=style, - in_planes=int(64 * width_multiplier), + in_planes=int(in_planes * width_multiplier), normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/resnet/mimo.py b/torch_uncertainty/models/resnet/mimo.py index bf16a933..11ee2228 100644 --- a/torch_uncertainty/models/resnet/mimo.py +++ b/torch_uncertainty/models/resnet/mimo.py @@ -62,7 +62,10 @@ def mimo_resnet( style: Literal["imagenet", "cifar"] = "imagenet", normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _MIMOResNet: - block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck + block = ( + _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck + ) + in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _MIMOResNet( block=block, num_blocks=get_resnet_num_blocks(arch), @@ -73,6 +76,6 @@ def mimo_resnet( dropout_rate=dropout_rate, groups=groups, style=style, - in_planes=int(64 * width_multiplier), + in_planes=int(in_planes * width_multiplier), normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index 1cdd2f98..4bf170d8 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -46,7 +46,7 @@ def __init__( in_planes: int, planes: int, stride: int, - alpha: float, + alpha: int, num_estimators: int, gamma: int, conv_bias: bool, @@ -116,7 +116,7 @@ def __init__( in_planes: int, planes: int, stride: int, - alpha: float, + alpha: int, num_estimators: int, gamma: int, conv_bias: bool, @@ -333,7 +333,7 @@ def _make_layer( planes: int, num_blocks: int, stride: int, - alpha: float, + alpha: int, num_estimators: int, conv_bias: bool, dropout_rate: float, @@ -425,7 +425,10 @@ def packed_resnet( Returns: _PackedResNet: A Packed-Ensembles ResNet. """ - block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck + block = ( + _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck + ) + in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 net = _PackedResNet( block=block, num_blocks=get_resnet_num_blocks(arch), @@ -438,7 +441,7 @@ def packed_resnet( groups=groups, num_classes=num_classes, style=style, - in_planes=int(64 * width_multiplier), + in_planes=int(in_planes * width_multiplier), normalization_layer=normalization_layer, ) if pretrained: # coverage: ignore diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index cdf1303d..b07e7fc6 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -19,9 +19,9 @@ def __init__( stride: int, dropout_rate: float, groups: int, + conv_bias: bool, activation_fn: Callable, normalization_layer: type[nn.Module], - conv_bias: bool, ) -> None: super().__init__() self.activation_fn = activation_fn @@ -358,7 +358,7 @@ def resnet( activation_fn: Callable = relu, normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _ResNet: - """ResNet-18 model. + """ResNet model. Args: in_channels (int): Number of input channels. @@ -366,20 +366,22 @@ def resnet( arch (int): The architecture of the ResNet. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - width_multiplier (float): Width multiplier. Defaults to 1. + dropout_rate (float): Dropout rate. Defaults to 0.0. + width_multiplier (float): Width multiplier. Defaults to 1.0. groups (int): Number of groups in convolutions. Defaults to 1. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. - activation_fn (Callable, optional): Activation function. + activation_fn (Callable, optional): Activation function. Defaults to + ``torch.nn.functional.relu``. normalization_layer (nn.Module, optional): Normalization layer. Returns: _ResNet: The ResNet model. """ - block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck + block = ( + _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck + ) + in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _ResNet( block=block, num_blocks=get_resnet_num_blocks(arch), @@ -389,7 +391,7 @@ def resnet( dropout_rate=dropout_rate, groups=groups, style=style, - in_planes=int(64 * width_multiplier), + in_planes=int(in_planes * width_multiplier), activation_fn=activation_fn, normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/resnet/utils.py b/torch_uncertainty/models/resnet/utils.py index 0e082509..caf2cde9 100644 --- a/torch_uncertainty/models/resnet/utils.py +++ b/torch_uncertainty/models/resnet/utils.py @@ -5,10 +5,18 @@ def get_resnet_num_blocks(arch: int) -> list[int]: num_blocks = [3, 3, 3] elif arch == 34 or arch == 50: num_blocks = [3, 4, 6, 3] + elif arch == 44: + num_blocks = [7, 7, 7] + elif arch == 56: + num_blocks = [9, 9, 9] elif arch == 101: num_blocks = [3, 4, 23, 3] + elif arch == 110: + num_blocks = [18, 18, 18] elif arch == 152: num_blocks = [3, 8, 36, 3] + elif arch == 1202: + num_blocks = [200, 200, 200] else: raise ValueError(f"Unknown ResNet architecture. Got {arch}.") return num_blocks diff --git a/torch_uncertainty/models/segmentation/segformer.py b/torch_uncertainty/models/segmentation/segformer.py index 6c34dfcb..e13258e8 100644 --- a/torch_uncertainty/models/segmentation/segformer.py +++ b/torch_uncertainty/models/segmentation/segformer.py @@ -1,3 +1,4 @@ +import logging import math from functools import partial @@ -522,11 +523,14 @@ def resize( and (output_h - 1) % (input_h - 1) and (output_w - 1) % (input_w - 1) ): - print( - f"When align_corners={align_corners}, " + logging.info( + "When align_corners=%s, " "the output would more aligned if " - f"input size {(input_h, input_w)} is `x+1` and " - f"out size {(output_h, output_w)} is `nx+1`", + "input size %s is `x+1` and " + "out size %s is `nx+1`", + align_corners, + (input_h, input_w), + (output_h, output_w), ) if isinstance(size, torch.Size): size = tuple(int(x) for x in size) diff --git a/torch_uncertainty/models/utils.py b/torch_uncertainty/models/utils.py index 87fd65ee..ecf06466 100644 --- a/torch_uncertainty/models/utils.py +++ b/torch_uncertainty/models/utils.py @@ -1,4 +1,5 @@ from torch import Tensor, nn +from torch.nn.modules.batchnorm import _BatchNorm class Backbone(nn.Module): @@ -27,14 +28,20 @@ def forward(self, x: Tensor) -> list[Tensor]: """ feature = x features = [] - for k, v in self.model._modules.items(): - feature = v(feature) - if k in self.feat_names: + for key, layer in self.model._modules.items(): + feature = layer(feature) + if key in self.feat_names: features.append(feature) return features def set_bn_momentum(model: nn.Module, momentum: float) -> None: + """Set the momentum of all batch normalization layers in the model. + + Args: + model (nn.Module): Model. + momentum (float): Momentum of the batch normalization layers. + """ for m in model.modules(): - if isinstance(m, nn.BatchNorm2d): + if isinstance(m, _BatchNorm): m.momentum = momentum diff --git a/torch_uncertainty/models/wideresnet/batched.py b/torch_uncertainty/models/wideresnet/batched.py index b27b8f99..792c0e46 100644 --- a/torch_uncertainty/models/wideresnet/batched.py +++ b/torch_uncertainty/models/wideresnet/batched.py @@ -1,7 +1,8 @@ +from collections.abc import Callable from typing import Literal -import torch.nn.functional as F from torch import Tensor, nn +from torch.nn.functional import relu from torch_uncertainty.layers import BatchConv2d, BatchLinear @@ -15,20 +16,22 @@ def __init__( self, in_planes: int, planes: int, - conv_bias: bool, dropout_rate: float, stride: int, num_estimators: int, groups: int, + conv_bias: bool, + activation_fn: Callable, ) -> None: super().__init__() + self.activation_fn = activation_fn self.conv1 = BatchConv2d( in_planes, planes, kernel_size=3, num_estimators=num_estimators, - groups=groups, padding=1, + groups=groups, bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) @@ -38,11 +41,13 @@ def __init__( planes, kernel_size=3, num_estimators=num_estimators, - groups=groups, stride=stride, padding=1, + groups=groups, bias=conv_bias, ) + self.bn2 = nn.BatchNorm2d(planes) + self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: self.shortcut = nn.Sequential( @@ -51,19 +56,17 @@ def __init__( planes, kernel_size=1, num_estimators=num_estimators, - groups=groups, stride=stride, + groups=groups, bias=conv_bias, ), ) - self.bn2 = nn.BatchNorm2d(planes) - def forward(self, x: Tensor) -> Tensor: - out = F.relu(self.bn1(self.dropout(self.conv1(x)))) + out = self.activation_fn(self.bn1(self.dropout(self.conv1(x)))) out = self.conv2(out) out += self.shortcut(x) - return F.relu(self.bn2(out)) + return self.activation_fn(self.bn2(out)) class _BatchWideResNet(nn.Module): @@ -78,17 +81,22 @@ def __init__( dropout_rate: float, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, ) -> None: super().__init__() self.num_estimators = num_estimators + self.activation_fn = activation_fn self.in_planes = 16 if (depth - 4) % 6 != 0: - raise ValueError("Wide-resnet depth should be 6n+4.") + raise ValueError(f"Wide-resnet depth should be 6n+4. Got {depth}.") num_blocks = (depth - 4) // 6 - k = widen_factor - - num_stages = [16, 16 * k, 32 * k, 64 * k] + num_stages = [ + 16, + 16 * widen_factor, + 32 * widen_factor, + 64 * widen_factor, + ] if style == "imagenet": self.conv1 = BatchConv2d( @@ -99,7 +107,7 @@ def __init__( kernel_size=7, stride=2, padding=3, - bias=True, + bias=conv_bias, ) elif style == "cifar": self.conv1 = BatchConv2d( @@ -110,7 +118,7 @@ def __init__( kernel_size=3, stride=1, padding=1, - bias=True, + bias=conv_bias, ) else: raise ValueError(f"Unknown WideResNet style: {style}. ") @@ -128,37 +136,39 @@ def __init__( _WideBasicBlock, num_stages[1], num_blocks=num_blocks, - conv_bias=conv_bias, dropout_rate=dropout_rate, stride=1, num_estimators=self.num_estimators, groups=groups, + conv_bias=conv_bias, + activation_fn=activation_fn, ) self.layer2 = self._wide_layer( _WideBasicBlock, num_stages[2], num_blocks=num_blocks, - conv_bias=conv_bias, dropout_rate=dropout_rate, stride=2, num_estimators=self.num_estimators, groups=groups, + conv_bias=conv_bias, + activation_fn=activation_fn, ) self.layer3 = self._wide_layer( _WideBasicBlock, num_stages[3], num_blocks=num_blocks, - conv_bias=conv_bias, dropout_rate=dropout_rate, stride=2, num_estimators=self.num_estimators, groups=groups, + conv_bias=conv_bias, + activation_fn=activation_fn, ) self.dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) - self.linear = BatchLinear( num_stages[3], num_classes, @@ -170,11 +180,12 @@ def _wide_layer( block: type[nn.Module], planes: int, num_blocks: int, - conv_bias: bool, dropout_rate: float, stride: int, num_estimators: int, groups: int, + conv_bias: bool, + activation_fn: Callable, ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -189,22 +200,24 @@ def _wide_layer( stride=stride, num_estimators=num_estimators, groups=groups, + activation_fn=activation_fn, ) ) self.in_planes = planes - return nn.Sequential(*layers) - def forward(self, x: Tensor) -> Tensor: + def feats_forward(self, x: Tensor) -> Tensor: out = x.repeat(self.num_estimators, 1, 1, 1) - out = F.relu(self.bn1(self.conv1(out))) + out = self.activation_fn(self.bn1(self.conv1(out))) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.pool(out) - out = self.dropout(self.flatten(out)) - return self.linear(out) + return self.dropout(self.flatten(out)) + + def forward(self, x: Tensor) -> Tensor: + return self.linear(self.feats_forward(x)) def batched_wideresnet28x10( @@ -221,11 +234,11 @@ def batched_wideresnet28x10( Args: in_channels (int): Number of input channels. num_estimators (int): Number of estimators in the ensemble. - groups (int): Number of groups in the convolutions. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.3``. num_classes (int): Number of classes to predict. + groups (int): Number of groups in the convolutions. Defaults to ``1``. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. diff --git a/torch_uncertainty/models/wideresnet/masked.py b/torch_uncertainty/models/wideresnet/masked.py index b0d00581..3a90be81 100644 --- a/torch_uncertainty/models/wideresnet/masked.py +++ b/torch_uncertainty/models/wideresnet/masked.py @@ -1,7 +1,8 @@ +from collections.abc import Callable from typing import Literal -import torch.nn.functional as F from torch import Tensor, nn +from torch.nn.functional import relu from torch_uncertainty.layers import MaskedConv2d, MaskedLinear @@ -21,17 +22,19 @@ def __init__( num_estimators: int, scale: float, groups: int, + activation_fn: Callable, ) -> None: super().__init__() + self.activation_fn = activation_fn self.conv1 = MaskedConv2d( in_planes, planes, kernel_size=3, num_estimators=num_estimators, padding=1, - bias=conv_bias, scale=scale, groups=groups, + bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) self.bn1 = nn.BatchNorm2d(planes) @@ -42,10 +45,12 @@ def __init__( num_estimators=num_estimators, stride=stride, padding=1, - bias=conv_bias, scale=scale, groups=groups, + bias=conv_bias, ) + self.bn2 = nn.BatchNorm2d(planes) + self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: self.shortcut = nn.Sequential( @@ -55,18 +60,17 @@ def __init__( kernel_size=1, num_estimators=num_estimators, stride=stride, - bias=conv_bias, scale=scale, groups=groups, + bias=conv_bias, ), ) - self.bn2 = nn.BatchNorm2d(planes) def forward(self, x: Tensor) -> Tensor: - out = F.relu(self.bn1(self.dropout(self.conv1(x)))) + out = self.activation_fn(self.bn1(self.dropout(self.conv1(x)))) out = self.conv2(out) out += self.shortcut(x) - return F.relu(self.bn2(out)) + return self.activation_fn(self.bn2(out)) class _MaskedWideResNet(nn.Module): @@ -82,17 +86,22 @@ def __init__( scale: float = 2.0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, ) -> None: super().__init__() self.num_estimators = num_estimators + self.activation_fn = activation_fn self.in_planes = 16 if (depth - 4) % 6 != 0: - raise ValueError("Wide-resnet depth should be 6n+4.") + raise ValueError(f"Wide-resnet depth should be 6n+4. Got {depth}.") num_blocks = (depth - 4) // 6 - k = widen_factor - - num_stages = [16, 16 * k, 32 * k, 64 * k] + num_stages = [ + 16, + 16 * widen_factor, + 32 * widen_factor, + 64 * widen_factor, + ] if style == "imagenet": self.conv1 = nn.Conv2d( @@ -136,6 +145,7 @@ def __init__( num_estimators=self.num_estimators, scale=scale, groups=groups, + activation_fn=activation_fn, ) self.layer2 = self._wide_layer( _WideBasicBlock, @@ -147,6 +157,7 @@ def __init__( num_estimators=self.num_estimators, scale=scale, groups=groups, + activation_fn=activation_fn, ) self.layer3 = self._wide_layer( _WideBasicBlock, @@ -158,6 +169,7 @@ def __init__( num_estimators=self.num_estimators, scale=scale, groups=groups, + activation_fn=activation_fn, ) self.dropout = nn.Dropout(p=dropout_rate) @@ -179,6 +191,7 @@ def _wide_layer( num_estimators: int, scale: float = 2.0, groups: int = 1, + activation_fn: Callable = relu, ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -194,22 +207,24 @@ def _wide_layer( dropout_rate=dropout_rate, scale=scale, groups=groups, + activation_fn=activation_fn, ) ) self.in_planes = planes - return nn.Sequential(*layers) - def forward(self, x: Tensor) -> Tensor: + def feats_forward(self, x: Tensor) -> Tensor: out = x.repeat(self.num_estimators, 1, 1, 1) - out = F.relu(self.bn1(self.conv1(out))) + out = self.activation_fn(self.bn1(self.conv1(out))) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.pool(out) - out = self.dropout(self.flatten(out)) - return self.linear(out) + return self.dropout(self.flatten(out)) + + def forward(self, x: Tensor) -> Tensor: + return self.linear(self.feats_forward(x)) def masked_wideresnet28x10( @@ -217,9 +232,9 @@ def masked_wideresnet28x10( num_classes: int, num_estimators: int, scale: float, - groups: int, conv_bias: bool = True, dropout_rate: float = 0.3, + groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", ) -> _MaskedWideResNet: """Masksembles of Wide-ResNet-28x10. @@ -229,10 +244,11 @@ def masked_wideresnet28x10( num_classes (int): Number of classes to predict. num_estimators (int): Number of estimators in the ensemble. scale (float): Expansion factor affecting the width of the estimators. - groups (int): Number of groups within each estimator. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.3``. + groups (int): Number of groups within each estimator. Defaults to + ``1``. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. diff --git a/torch_uncertainty/models/wideresnet/mimo.py b/torch_uncertainty/models/wideresnet/mimo.py index c3a25e0a..edb9a588 100644 --- a/torch_uncertainty/models/wideresnet/mimo.py +++ b/torch_uncertainty/models/wideresnet/mimo.py @@ -1,7 +1,9 @@ +from collections.abc import Callable from typing import Literal import torch from einops import rearrange +from torch.nn.functional import relu from .std import _WideResNet @@ -22,6 +24,7 @@ def __init__( dropout_rate: float, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, ) -> None: super().__init__( depth, @@ -32,25 +35,26 @@ def __init__( dropout_rate=dropout_rate, groups=groups, style=style, + activation_fn=activation_fn, ) - self.num_estimators = num_estimators def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.training: x = x.repeat(self.num_estimators, 1, 1, 1) out = rearrange(x, "(m b) c h w -> b (m c) h w", m=self.num_estimators) - out = super().forward(out) - return rearrange(out, "b (m d) -> (m b) d", m=self.num_estimators) + return rearrange( + super().forward(out), "b (m d) -> (m b) d", m=self.num_estimators + ) def mimo_wideresnet28x10( in_channels: int, num_classes: int, num_estimators: int, - groups: int = 1, conv_bias: bool = True, dropout_rate: float = 0.3, + groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", ) -> _MIMOWideResNet: return _MIMOWideResNet( diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/wideresnet/packed.py index 8d16cecc..60fcc7cf 100644 --- a/torch_uncertainty/models/wideresnet/packed.py +++ b/torch_uncertainty/models/wideresnet/packed.py @@ -1,8 +1,10 @@ +from collections.abc import Callable from typing import Literal import torch.nn.functional as F from einops import rearrange from torch import Tensor, nn +from torch.nn.functional import relu from torch_uncertainty.layers import PackedConv2d, PackedLinear @@ -88,17 +90,22 @@ def __init__( gamma: int = 1, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, ) -> None: super().__init__() self.num_estimators = num_estimators + self.activation_fn = activation_fn self.in_planes = 16 if (depth - 4) % 6 != 0: - raise ValueError("Wide-resnet depth should be 6n+4.") + raise ValueError(f"Wide-resnet depth should be 6n+4. Got {depth}.") num_blocks = int((depth - 4) / 6) - k = widen_factor - - num_stages = [16, 16 * k, 32 * k, 64 * k] + num_stages = [ + 16, + 16 * widen_factor, + 32 * widen_factor, + 64 * widen_factor, + ] if style == "imagenet": self.conv1 = PackedConv2d( @@ -220,11 +227,10 @@ def _wide_layer( ) ) self.in_planes = planes - return nn.Sequential(*layers) - def forward(self, x: Tensor) -> Tensor: - out = F.relu(self.bn1(self.conv1(x))) + def feats_forward(self, x: Tensor) -> Tensor: + out = self.activation_fn(self.bn1(self.conv1(x))) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) @@ -233,8 +239,10 @@ def forward(self, x: Tensor) -> Tensor: out, "e (m c) h w -> (m e) c h w", m=self.num_estimators ) out = self.pool(out) - out = self.dropout(self.flatten(out)) - return self.linear(out) + return self.dropout(self.flatten(out)) + + def forward(self, x: Tensor) -> Tensor: + return self.linear(self.feats_forward(x)) def packed_wideresnet28x10( @@ -243,9 +251,9 @@ def packed_wideresnet28x10( num_estimators: int, alpha: int, gamma: int, - groups: int = 1, conv_bias: bool = True, dropout_rate: float = 0.3, + groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", ) -> _PackedWideResNet: """Packed-Ensembles of Wide-ResNet-28x10. diff --git a/torch_uncertainty/models/wideresnet/std.py b/torch_uncertainty/models/wideresnet/std.py index bd3d6a76..963b4d60 100644 --- a/torch_uncertainty/models/wideresnet/std.py +++ b/torch_uncertainty/models/wideresnet/std.py @@ -1,7 +1,8 @@ +from collections.abc import Callable from typing import Literal -import torch.nn.functional as F from torch import Tensor, nn +from torch.nn.functional import relu __all__ = [ "wideresnet28x10", @@ -14,17 +15,20 @@ def __init__( in_planes: int, planes: int, dropout_rate: float, - stride: int = 1, - groups: int = 1, + stride: int, + groups: int, + conv_bias: bool, + activation_fn: Callable, ) -> None: super().__init__() + self.activation_fn = activation_fn self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, padding=1, groups=groups, - bias=False, + bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) self.bn1 = nn.BatchNorm2d(planes) @@ -35,8 +39,10 @@ def __init__( stride=stride, padding=1, groups=groups, - bias=False, + bias=conv_bias, ) + self.bn2 = nn.BatchNorm2d(planes) + self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: self.shortcut = nn.Sequential( @@ -46,28 +52,18 @@ def __init__( kernel_size=1, stride=stride, groups=groups, - bias=True, + bias=conv_bias, ), ) - self.bn2 = nn.BatchNorm2d(planes) def forward(self, x: Tensor) -> Tensor: - out = F.relu(self.bn1(self.dropout(self.conv1(x)))) + out = self.activation_fn(self.bn1(self.dropout(self.conv1(x)))) out = self.conv2(out) out += self.shortcut(x) - return F.relu(self.bn2(out)) + return self.activation_fn(self.bn2(out)) class _WideResNet(nn.Module): - """WideResNet from `Wide Residual Networks`. - - Note: - if `dropout_rate` and `num_estimators` are set, the model will sample - from the dropout distribution during inference. If `last_layer_dropout` - is set, only the last layer will be sampled from the dropout - distribution during inference. - """ - def __init__( self, depth: int, @@ -78,17 +74,22 @@ def __init__( dropout_rate: float, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, ) -> None: super().__init__() - self.in_planes = 16 self.dropout_rate = dropout_rate + self.activation_fn = activation_fn + self.in_planes = 16 if (depth - 4) % 6 != 0: raise ValueError(f"Wide-resnet depth should be 6n+4. Got {depth}.") num_blocks = int((depth - 4) / 6) - k = widen_factor - - num_stages = [16, 16 * k, 32 * k, 64 * k] + num_stages = [ + 16, + 16 * widen_factor, + 32 * widen_factor, + 64 * widen_factor, + ] if style == "imagenet": self.conv1 = nn.Conv2d( @@ -129,6 +130,8 @@ def __init__( dropout_rate=dropout_rate, stride=1, groups=groups, + activation_fn=activation_fn, + conv_bias=conv_bias, ) self.layer2 = self._wide_layer( WideBasicBlock, @@ -137,6 +140,8 @@ def __init__( dropout_rate=dropout_rate, stride=2, groups=groups, + activation_fn=activation_fn, + conv_bias=conv_bias, ) self.layer3 = self._wide_layer( WideBasicBlock, @@ -145,12 +150,12 @@ def __init__( dropout_rate=dropout_rate, stride=2, groups=groups, + activation_fn=activation_fn, + conv_bias=conv_bias, ) - self.dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) - self.linear = nn.Linear( num_stages[3], num_classes, @@ -164,6 +169,8 @@ def _wide_layer( dropout_rate: float, stride: int, groups: int, + conv_bias: bool, + activation_fn: Callable, ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -171,25 +178,26 @@ def _wide_layer( for stride in strides: layers.append( block( - self.in_planes, - planes, - dropout_rate, - stride, - groups, + in_planes=self.in_planes, + planes=planes, + stride=stride, + dropout_rate=dropout_rate, + groups=groups, + conv_bias=conv_bias, + activation_fn=activation_fn, ) ) self.in_planes = planes - return nn.Sequential(*layers) def feats_forward(self, x: Tensor) -> Tensor: - out = F.relu(self.bn1(self.conv1(x))) + out = self.activation_fn(self.bn1(self.conv1(x))) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.pool(out) - return self.flatten(out) + return self.dropout(self.flatten(out)) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) @@ -198,10 +206,11 @@ def forward(self, x: Tensor) -> Tensor: def wideresnet28x10( in_channels: int, num_classes: int, - groups: int = 1, conv_bias: bool = True, dropout_rate: float = 0.3, + groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, ) -> _WideResNet: """Wide-ResNet-28x10 from `Wide Residual Networks `_. @@ -216,6 +225,8 @@ def wideresnet28x10( dropout_rate (float, optional): Dropout rate. Defaults to ``0.3``. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. + activation_fn (Callable, optional): Activation function. Defaults to + ``torch.nn.functional.relu``. Returns: _Wide: A Wide-ResNet-28x10. @@ -229,4 +240,5 @@ def wideresnet28x10( num_classes=num_classes, groups=groups, style=style, + activation_fn=activation_fn, ) diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index 8d648400..82b147c8 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from functools import partial from typing import Literal @@ -6,7 +7,13 @@ from timm.optim import Lamb from torch import nn, optim from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + LinearLR, + LRScheduler, + MultiStepLR, + SequentialLR, +) def optim_abnn( @@ -24,7 +31,7 @@ def optim_abnn( weight_decay=weight_decay, nesterov=nesterov, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[1, 4], gamma=0.1, @@ -43,7 +50,7 @@ def optim_cifar10_resnet18( weight_decay=5e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[25, 50], gamma=0.1, @@ -64,7 +71,7 @@ def optim_cifar10_resnet50( weight_decay=5e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[60, 120, 160], gamma=0.2, @@ -83,7 +90,7 @@ def optim_cifar10_wideresnet( weight_decay=5e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[60, 120, 160], gamma=0.2, @@ -100,7 +107,7 @@ def optim_cifar10_vgg16( lr=0.005, weight_decay=1e-6, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[25, 50], gamma=0.1, @@ -118,7 +125,7 @@ def optim_cifar100_resnet18( weight_decay=5e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[25, 50], gamma=0.1, @@ -139,7 +146,7 @@ def optim_cifar100_resnet50( weight_decay=5e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[60, 120, 160], gamma=0.2, @@ -158,7 +165,7 @@ def optim_cifar100_vgg16( weight_decay=1e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[60, 120, 160], gamma=0.2, @@ -182,9 +189,7 @@ def optim_imagenet_resnet50( weight_decay=3.0517578125e-05, nesterov=False, ) - scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, num_epochs, eta_min=end_lr - ) + scheduler = CosineAnnealingLR(optimizer, num_epochs, eta_min=end_lr) return { "optimizer": optimizer, "lr_scheduler": scheduler, @@ -207,23 +212,25 @@ def optim_imagenet_resnet50_a3( dict: The optimizer and the scheduler for the training. """ if effective_batch_size is None: - print("Setting effective batch size to 2048 for steps computations !") + logging.warning( + "Setting effective batch size to 2048 for steps computations !" + ) effective_batch_size = 2048 optimizer = Lamb(model.parameters(), lr=0.008, weight_decay=0.02) - warmup = optim.lr_scheduler.LinearLR( + warmup = LinearLR( optimizer, start_factor=1e-4, end_factor=1, total_iters=5 * (1281167 // effective_batch_size + 1), ) - cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( + cosine_scheduler = CosineAnnealingLR( optimizer, eta_min=1e-6, T_max=105 * (1281167 // effective_batch_size + 1), ) - scheduler = optim.lr_scheduler.SequentialLR( + scheduler = SequentialLR( optimizer, schedulers=[warmup, cosine_scheduler], milestones=[5 * (1281167 // effective_batch_size + 1)], @@ -249,7 +256,7 @@ def optim_cifar10_resnet34( weight_decay=1e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[100, 150], gamma=0.1, @@ -267,7 +274,7 @@ def optim_cifar100_resnet34( weight_decay=1e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[100, 150], gamma=0.1, @@ -292,7 +299,7 @@ def optim_tinyimagenet_resnet34( weight_decay=1e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[40, 60], gamma=0.1, @@ -317,7 +324,7 @@ def optim_tinyimagenet_resnet50( weight_decay=1e-4, nesterov=True, ) - scheduler = optim.lr_scheduler.MultiStepLR( + scheduler = MultiStepLR( optimizer, milestones=[40, 60], gamma=0.1, @@ -433,7 +440,75 @@ def get_procedure( return procedure -class FullSWALR(torch.optim.lr_scheduler.SequentialLR): +class WarmupScheduler(SequentialLR): + def __init__( + self, + optimizer: Optimizer, + base_scheduler: type[LRScheduler], + warmup_start_factor: float, + warmup_epochs: int, + scheduler_args: dict[str, float], + ) -> None: + """Scheduler with linear warmup. + + Args: + optimizer (Optimizer): The optimizer to be used.* + base_scheduler (type[LRScheduler]): The base scheduler class to use after + the warmup. + warmup_start_factor (float): The multiplicative factor to apply to + the learning rate at the start of the warmup. + warmup_epochs (int): The number of epochs to warmup the learning + rate. + scheduler_args (dict[str, float]): The arguments to pass to the base + scheduler. + """ + warmup_scheduler = LinearLR( + optimizer, + start_factor=warmup_start_factor, + end_factor=1, + total_iters=warmup_epochs, + ) + base_scheduler = base_scheduler(optimizer, **scheduler_args) + super().__init__( + optimizer=optimizer, + schedulers=[warmup_scheduler, base_scheduler], + milestones=[warmup_epochs], + ) + + +class CosineAnnealingWarmup(WarmupScheduler): + def __init__( + self, + optimizer: Optimizer, + warmup_start_factor: float, + warmup_epochs: int, + max_epochs: int, + eta_min: float = 0, + ) -> None: + """Cosine annealing scheduler with linear warmup. + + Args: + optimizer (Optimizer): The optimizer to be used. + warmup_start_factor (float): The multiplicative factor to apply to + the learning rate at the start of the warmup. + warmup_epochs (int): The number of epochs to warmup the learning + rate. + max_epochs (int): The total number of epochs including warmup. + eta_min (float): The minimum learning rate. + """ + super().__init__( + optimizer=optimizer, + base_scheduler=CosineAnnealingLR, + warmup_start_factor=warmup_start_factor, + warmup_epochs=warmup_epochs, + scheduler_args={ + "T_max": max_epochs - warmup_epochs, + "eta_min": eta_min, + }, + ) + + +class CosineSWALR(SequentialLR): def __init__( self, optimizer: Optimizer, @@ -457,7 +532,7 @@ def __init__( optim_eta_min (float): The minimum learning rate for the first optimizer. anneal_strategy (Literal["cos", "linear"]): The strategy to anneal the learning rate. """ - optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optim_scheduler = CosineAnnealingLR( optimizer=optimizer, T_max=milestone, eta_min=optim_eta_min ) swa_scheduler = torch.optim.swa_utils.SWALR( diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index d3400dfe..3dcf08e6 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -1,3 +1,4 @@ +import logging from typing import Literal import torch @@ -91,7 +92,7 @@ def calib_eval() -> float: @torch.no_grad() def forward(self, inputs: Tensor) -> Tensor: if not self.trained: - print( + logging.error( "TemperatureScaler has not been trained yet. Returning " "manually tempered inputs." ) diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index fc1d2894..dac1b0ba 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -4,6 +4,8 @@ from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset +from .abstract import PostProcessing + if util.find_spec("laplace"): from laplace import Laplace @@ -12,7 +14,7 @@ laplace_installed = False -class LaplaceApprox(nn.Module): +class LaplaceApprox(PostProcessing): def __init__( self, task: Literal["classification", "regression"], @@ -61,9 +63,10 @@ def __init__( self.batch_size = batch_size if model is not None: - self._setup_model(model) + self.set_model(model) - def _setup_model(self, model) -> None: + def set_model(self, model: nn.Module) -> None: + super().set_model(model) self.la = Laplace( model=model, likelihood=self.task, @@ -71,9 +74,6 @@ def _setup_model(self, model) -> None: hessian_structure=self.hessian_struct, ) - def set_model(self, model: nn.Module) -> None: - self._setup_model(model) - def fit(self, dataset: Dataset) -> None: dl = DataLoader(dataset, batch_size=self.batch_size) self.la.fit(train_loader=dl) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 2d45976a..b7e0262b 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -20,6 +20,7 @@ from torch_uncertainty.layers import Identity from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.metrics import ( + AUGRC, AURC, FPR95, BrierScore, @@ -194,6 +195,7 @@ def _init_metrics(self) -> None: num_classes=self.num_classes, ), "sc/AURC": AURC(), + "sc/AUGRC": AUGRC(), "sc/CovAt5Risk": CovAt5Risk(), "sc/RiskAt80Cov": RiskAt80Cov(), }, @@ -202,7 +204,7 @@ def _init_metrics(self) -> None: ["cls/Brier"], ["cls/NLL"], ["cal/ECE", "cal/aECE"], - ["sc/AURC", "sc/CovAt5Risk", "sc/RiskAt80Cov"], + ["sc/AURC", "sc/AUGRC", "sc/CovAt5Risk", "sc/RiskAt80Cov"], ], ) @@ -552,6 +554,10 @@ def on_test_epoch_end(self) -> None: "Risk-Coverage curve", self.test_cls_metrics["sc/AURC"].plot()[0], ) + self.logger.experiment.add_figure( + "Generalized Risk-Coverage curve", + self.test_cls_metrics["sc/AUGRC"].plot()[0], + ) if self.post_processing is not None: self.logger.experiment.add_figure( diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 2beeb435..b118590a 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -11,6 +11,7 @@ from torch.optim import Optimizer from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection +from torch_uncertainty.losses import ELBOLoss from torch_uncertainty.metrics import ( DistributionNLL, ) @@ -154,12 +155,16 @@ def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: inputs, targets = self.format_batch_fn(batch) - dists = self.model(inputs) if self.one_dim_regression: targets = targets.unsqueeze(-1) - loss = self.loss(dists, targets) + if isinstance(self.loss, ELBOLoss): + loss = self.loss(inputs, targets) + else: + dists = self.model(inputs) + loss = self.loss(dists, targets) + if self.needs_step_update: self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) @@ -182,7 +187,6 @@ def validation_step( dist_size(preds)[0] // batch_size, device=self.device ) ) - print(ens_dist, type(ens_dist)) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index f3ece492..966553d1 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -9,6 +9,7 @@ from torchvision.transforms.v2 import functional as F from torch_uncertainty.metrics import ( + AUGRC, AURC, BrierScore, CalibrationError, @@ -108,8 +109,16 @@ def __init__( num_classes=num_classes, ), "sc/AURC": AURC(), + "sc/AUGRC": AUGRC(), }, - compute_groups=False, + compute_groups=[ + ["seg/mAcc"], + ["seg/Brier"], + ["seg/NLL"], + ["seg/pixAcc"], + ["cal/ECE", "cal/aECE"], + ["sc/AURC", "sc/AUGRC"], + ], ) self.val_seg_metrics = seg_metrics.clone(prefix="val/") @@ -222,6 +231,10 @@ def on_test_epoch_end(self) -> None: "Risk-Coverage curve", self.test_sbsmpl_seg_metrics["sc/AURC"].plot()[0], ) + self.logger.experiment.add_figure( + "Generalized Risk-Coverage curve", + self.test_sbsmpl_seg_metrics["sc/AUGRC"].plot()[0], + ) def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: total_size = target.size(0) diff --git a/torch_uncertainty/transforms/corruptions.py b/torch_uncertainty/transforms/corruption.py similarity index 100% rename from torch_uncertainty/transforms/corruptions.py rename to torch_uncertainty/transforms/corruption.py