diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 6c6ccc62..1c74dd5d 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -76,7 +76,7 @@ jobs: - name: Test with pytest and compute coverage if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m pytest --cov --cov-report xml --durations 10 + python3 -m pytest --cov --cov-report xml --durations 10 --junitxml=junit.xml - name: Upload coverage to Codecov if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') @@ -89,6 +89,15 @@ jobs: name: CPU-coverage env_vars: PYTHON_VERSION + - name: Upload test results to Codecov + if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + uses: codecov/test-results-action@v1 + continue-on-error: true + with: + token: ${{ secrets.CODECOV_TOKEN }} + flags: cpu,pytest + env_vars: PYTHON_VERSION + - name: Test sphinx build without tutorials if: steps.changed-files-specific.outputs.only_changed != 'true' run: | diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 24ee20a3..0d9de4f9 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -39,7 +39,8 @@ from torch_uncertainty.datasets.regression.toy import Cubic from torch_uncertainty.losses import DERLoss from torch_uncertainty.routines import RegressionRoutine -from torch_uncertainty.layers.distributions import NormalInverseGammaLayer +from torch_uncertainty.layers.distributions import NormalInverseGammaLinear +from torch_uncertainty.utils.distributions import get_dist_class # %% # 2. The Optimization Recipe @@ -50,12 +51,11 @@ def optim_regression( model: nn.Module, learning_rate: float = 5e-4, ): - optimizer = optim.Adam( + return optim.Adam( model.parameters(), lr=learning_rate, weight_decay=0, ) - return optimizer # %% @@ -64,7 +64,7 @@ def optim_regression( # # In the following, we create a trainer to train the model, the same synthetic regression # datasets as in the original DER paper and the model, a simple MLP with 2 hidden layers of 64 neurons each. -# Please note that this MLP finishes with a NormalInverseGammaLayer that interpret the outputs of the model +# Please note that this MLP finishes with a NormalInverseGammaLinear that interpret the outputs of the model # as the parameters of a Normal Inverse Gamma distribution. trainer = TUTrainer(accelerator="cpu", max_epochs=50) #, enable_progress_bar=False) @@ -82,10 +82,9 @@ def optim_regression( # model model = mlp( in_features=1, - num_outputs=4, + num_outputs=1, hidden_dims=[64, 64], - final_layer=NormalInverseGammaLayer, - final_layer_args={"dim": 1}, + dist_family="nig", # Normal Inverse Gamma ) # %% @@ -100,11 +99,11 @@ def optim_regression( loss = DERLoss(reg_weight=1e-2) routine = RegressionRoutine( - probabilistic=True, output_dim=1, model=model, loss=loss, optim_recipe=optim_regression(model), + dist_family="nig", ) # %% @@ -127,7 +126,8 @@ def optim_regression( with torch.no_grad(): x = torch.linspace(-7, 7, 1000) - dists = model(x.unsqueeze(-1)) + dist_params = model(x.unsqueeze(-1)) + dists = get_dist_class("nig")(**dist_params) means = dists.loc.squeeze(1) variances = torch.sqrt(dists.variance_loc).squeeze(1) diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py index fd60e3f1..ea2180a2 100644 --- a/auto_tutorials_source/tutorial_from_de_to_pe.py +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -29,7 +29,6 @@ The dataset is automatically downloaded using torchvision. We then visualize a few images to see a bit what we are working with. """ -# Create the transforms for the images # %% import torch import torchvision.transforms as T @@ -37,6 +36,7 @@ # We set the number of epochs to some low value for the sake of time max_epochs = 2 +# Create the transforms for the images train_transform = T.Compose( [ T.ToTensor(), diff --git a/auto_tutorials_source/tutorial_probabilistic_regression.py b/auto_tutorials_source/tutorial_probabilistic_regression.py new file mode 100644 index 00000000..c45cd020 --- /dev/null +++ b/auto_tutorials_source/tutorial_probabilistic_regression.py @@ -0,0 +1,170 @@ +""" +Deep Probabilistic Regression +============================= + +This tutorial aims to provide an overview of some utilities in TorchUncertainty for probabilistic regression. + +Building a MLP for Probabilistic Regression using TorchUncertainty distribution layers +-------------------------------------------------------------------------------------- + +In this section we cover the building of a very simple MLP outputting Normal distribution parameters. + +1. Loading the utilities +~~~~~~~~~~~~~~~~~~~~~~~~ + +We disable some logging and warnings to keep the output clean. +""" +# %% +import torch +from torch import nn + +import logging +logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) + +import warnings +warnings.filterwarnings("ignore") + +# %% +# 2. Building the MLP model +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To create a MLP model estimating a Normal distribution, we use the NormalLinear layer. +# This layer is a wrapper around the nn.Linear layer, which outputs the location and scale of a Normal distribution. +# Note that any other distribution layer from TU can be used in the same way. +from torch_uncertainty.layers.distributions import NormalLinear + + +class MLP(nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.fc1 = nn.Linear(in_features, 50) + self.fc2 = NormalLinear( + base_layer=nn.Linear, + event_dim=out_features, + in_features=50, + ) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + return self.fc2(x) + +# %% +# 3. Setting up the data +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# We use the UCI Kin8nm dataset, which is a regression dataset with 8 features and 8192 samples. +from torch_uncertainty.datamodules import UCIRegressionDataModule + +# datamodule +datamodule = UCIRegressionDataModule( + root="data", + batch_size=32, + dataset_name="kin8nm", +) + +# %% +# 4. Setting up the model and trainer +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +from torch_uncertainty import TUTrainer + +trainer = TUTrainer( + accelerator="cpu", + max_epochs=5, + enable_progress_bar=False, +) + +model = MLP(in_features=8, out_features=1) + + +# %% +# 5. The Loss, the Optimizer and the Training Routine +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We use the DistributionNLLLoss to compute the negative log-likelihood of the Normal distribution. +# Note that this loss can be used with any Distribution from torch.distributions. +# For the optimizer, we use the Adam optimizer with a learning rate of 5e-3. +# Finally, we create a RegressionRoutine to train the model. We indicate that the output dimension is 1 and the distribution family is "normal". + +from torch_uncertainty.losses import DistributionNLLLoss +from torch_uncertainty.routines import RegressionRoutine + +loss = DistributionNLLLoss() + +def optim_regression( + model: nn.Module, + learning_rate: float = 5e-3, +): + return torch.optim.Adam( + model.parameters(), + lr=learning_rate, + weight_decay=0, + ) + +routine = RegressionRoutine( + output_dim=1, + model=model, + loss=loss, + optim_recipe=optim_regression(model), + dist_family="normal", +) + + +# %% +# 6. Training the model +# ~~~~~~~~~~~~~~~~~~~~~~ + +trainer.fit(model=routine, datamodule=datamodule) +results = trainer.test(model=routine, datamodule=datamodule) + +# %% +# 7. Benchmarking different distributions +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Our MLP model assumes a Normal distribution as the output. However, we could be interested in comparing the performance of different distributions. +# TorchUncertainty provides a simple way to do this using the get_dist_linear_layer() function. +# Let us rewrite the MLP model to use it. + +from torch_uncertainty.layers.distributions import get_dist_linear_layer + +class MLP(nn.Module): + def __init__(self, in_features: int, out_features: int, dist_family: str): + super().__init__() + self.fc1 = nn.Linear(in_features, 50) + dist_layer = get_dist_linear_layer(dist_family) + self.fc2 = dist_layer( + base_layer=nn.Linear, + event_dim=out_features, + in_features=50, + ) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + return self.fc2(x) + +# %% +# We can now train the model with different distributions. +# Let us train the model with a Normal, Laplace, Student's t, and Cauchy distribution. +# Note that we use the mode as the point-wise estimate of the distribution as the mean +# is not defined for the Cauchy distribution. +for dist_family in ["normal", "laplace", "student", "cauchy"]: + print("#" * 50) + print(f">>> Training with {dist_family} distribution") + print("#" * 50) + trainer = TUTrainer( + accelerator="cpu", + max_epochs=10, + enable_model_summary=False, + enable_progress_bar=False, + ) + model = MLP(in_features=8, out_features=1, dist_family=dist_family) + routine = RegressionRoutine( + output_dim=1, + model=model, + loss=loss, + optim_recipe=optim_regression(model), + dist_family=dist_family, + dist_estimate="mode", + ) + trainer.fit(model=routine, datamodule=datamodule) + trainer.test(model=routine, datamodule=datamodule) diff --git a/docs/source/api.rst b/docs/source/api.rst index 4f505698..16f1f1c2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -125,6 +125,10 @@ Ensemble layers PackedLinear PackedConv2d + PackedMultiheadAttention + PackedLayerNorm + PackedTransformerEncoderLayer + PackedTransformerDecoderLayer BatchLinear BatchConv2d MaskedLinear @@ -148,6 +152,40 @@ Bayesian layers LPBNNLinear LPBNNConv2d + +Density layers +^^^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.layers.distributions + +Linear Layers +""""""""""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + NormalLinear + LaplaceLinear + CauchyLinear + StudentTLinear + NormalInverseGammaLinear + +Convolution Layers +"""""""""""""""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + NormalConvNd + LaplaceConvNd + CauchyConvNd + StudentTConvNd + NormalInverseGammaConvNd + Models ------ diff --git a/docs/source/conf.py b/docs/source/conf.py index aa6f332a..062ed7b5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.3.1" +release = "0.4.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml new file mode 100644 index 00000000..f8adbf90 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/boston/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 13 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: boston +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml new file mode 100644 index 00000000..95eaface --- /dev/null +++ b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/boston/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 13 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: boston +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml new file mode 100644 index 00000000..90daf59b --- /dev/null +++ b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/boston/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 13 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: boston +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml new file mode 100644 index 00000000..b6ff80c6 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml new file mode 100644 index 00000000..683333d6 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml new file mode 100644 index 00000000..cff6fd10 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml new file mode 100644 index 00000000..1837894f --- /dev/null +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/energy-efficiency/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: energy-efficiency +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml new file mode 100644 index 00000000..da02570b --- /dev/null +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/energy-efficiency/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: energy-efficiency +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml new file mode 100644 index 00000000..cff6fd10 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml similarity index 85% rename from experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml rename to experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml index c906150c..be42d710 100644 --- a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml @@ -5,17 +5,17 @@ trainer: accelerator: gpu devices: 1 precision: 16-mixed - max_epochs: 10 + max_epochs: 40 logger: class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: - save_dir: logs/gaussian_mlp_kin8nm + save_dir: logs/kin8nm/mlp/laplace name: standard default_hp_metric: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/NLL + monitor: val/reg/NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,17 +23,17 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: output_dim: 1 in_features: 8 hidden_dims: - - 100 + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std - distribution: laplace + dist_family: laplace data: root: ./data batch_size: 128 diff --git a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml similarity index 85% rename from experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml rename to experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml index 9d6e17ae..b5553356 100644 --- a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml @@ -5,17 +5,17 @@ trainer: accelerator: gpu devices: 1 precision: 16-mixed - max_epochs: 10 + max_epochs: 40 logger: class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: - save_dir: logs/gaussian_mlp_kin8nm + save_dir: logs/kin8nm/mlp/normal name: standard default_hp_metric: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/NLL + monitor: val/reg/NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,17 +23,17 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/NLL + monitor: val/reg/NLL patience: 1000 check_finite: true model: output_dim: 1 in_features: 8 hidden_dims: - - 100 + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std - distribution: normal + dist_family: normal data: root: ./data batch_size: 128 diff --git a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml similarity index 87% rename from experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml rename to experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml index ca09ac4a..fdc8fc44 100644 --- a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml @@ -5,17 +5,17 @@ trainer: accelerator: gpu devices: 1 precision: 16-mixed - max_epochs: 10 + max_epochs: 40 logger: class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: - save_dir: logs/pw_mlp_kin8nm + save_dir: logs/kin8nm/mlp/point_wise name: standard default_hp_metric: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: val/MSE + monitor: val/reg/MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,14 +23,14 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: val/MSE + monitor: val/reg/MSE patience: 1000 check_finite: true model: output_dim: 1 in_features: 8 hidden_dims: - - 100 + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml new file mode 100644 index 00000000..b5b66dfc --- /dev/null +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/naval-propulsion-plant/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 16 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: naval-propulsion-plant +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml new file mode 100644 index 00000000..92169a3c --- /dev/null +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/naval-propulsion-plant/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 16 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: naval-propulsion-plant +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml new file mode 100644 index 00000000..d8e6eb8f --- /dev/null +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/naval-propulsion-plant/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 16 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: naval-propulsion-plant +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml new file mode 100644 index 00000000..4c2ffd85 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/power-plant/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 4 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: power-plant +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml new file mode 100644 index 00000000..6173c120 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/power-plant/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 4 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: power-plant +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml new file mode 100644 index 00000000..d0ec4670 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/power-plant/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 4 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: power-plant +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml new file mode 100644 index 00000000..8b794d78 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 9 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml new file mode 100644 index 00000000..82dc62ea --- /dev/null +++ b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 9 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml new file mode 100644 index 00000000..b984e681 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 9 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml new file mode 100644 index 00000000..32275d18 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 11 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml new file mode 100644 index 00000000..188b8fb2 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 11 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml new file mode 100644 index 00000000..522e280e --- /dev/null +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 11 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml new file mode 100644 index 00000000..2d77e0cd --- /dev/null +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/laplace + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 6 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml new file mode 100644 index 00000000..c3641593 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/normal + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/NLL + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 6 + hidden_dims: + - 50 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml new file mode 100644 index 00000000..e4002049 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 40 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/concrete/mlp/point_wise + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/MSE + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 6 + hidden_dims: + - 50 + loss: MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: concrete +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/readme.md b/experiments/regression/uci_datasets/readme.md index 3e0ec7b0..b016a9fe 100644 --- a/experiments/regression/uci_datasets/readme.md +++ b/experiments/regression/uci_datasets/readme.md @@ -1,17 +1,33 @@ # UCI Regression - Benchmark + +| Dataset | Number of Instances | Number of Features | +| --- | --- | --- | +| Boston Housing | 506 | 13 | +| Concrete Compression Strength | 1030 | 8 | +| Energy Efficiency | 768 | 8 | +| Kin8nm | 8192 | 8 | +| Naval Propulsion | 11,934 | 16 | +| Combined Cycle Power Plant | 9568 | 4 | +| Protein Structure | 45730 | 9 | +| Wine Quality (Red) | 1599 | 11 | +| Yacht Hydrodynamics | 308 | 6 | + + +> [!WARNING] +> Some datasets require installing additional packages. + + This folder contains the code to train models on the UCI regression datasets. The task is to predict (a) continuous target variable(s). -Three experiments are provided: +**General command to train a model:** ```bash -python mlp.py fit --config configs/pw_mlp_kin8nm.yaml +python mlp.py fit --config configs/{dataset}/{network}/{dist_family}.yaml ``` -```bash -python mlp.py fit --config configs/gaussian_mlp_kin8nm.yaml -``` +*Example:* ```bash -python mlp.py fit --config configs/laplace_mlp_kin8nm.yaml +python mlp.py fit --config configs/kinn8nm/mlp/laplace.yaml ``` diff --git a/pyproject.toml b/pyproject.toml index bb52c317..27a4188e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.3.1" +version = "0.4.0" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index c76efc82..59291f14 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -2,11 +2,6 @@ from torch import nn -from torch_uncertainty.layers.distributions import ( - LaplaceLayer, - NormalInverseGammaLayer, - NormalLayer, -) from torch_uncertainty.models import EMA, SWA, deep_ensembles from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.post_processing import TemperatureScaler @@ -110,34 +105,19 @@ def __new__( class DummyRegressionBaseline: def __new__( cls, - probabilistic: bool, in_features: int, output_dim: int, loss: nn.Module, baseline_type: str = "single", optim_recipe=None, - dist_type: str = "normal", + dist_family: str | None = "normal", ema: bool = False, swa: bool = False, ) -> RegressionRoutine: - if probabilistic: - if dist_type == "normal": - last_layer = NormalLayer(output_dim) - num_classes = output_dim * 2 - elif dist_type == "laplace": - last_layer = LaplaceLayer(output_dim) - num_classes = output_dim * 2 - else: # dist_type == "nig" - last_layer = NormalInverseGammaLayer(output_dim) - num_classes = output_dim * 4 - else: - last_layer = nn.Identity() - num_classes = output_dim - model = dummy_model( in_channels=in_features, - num_classes=num_classes, - last_layer=last_layer, + num_classes=output_dim, + dist_family=dist_family, ) if ema: model = EMA(model, momentum=0.99) @@ -146,26 +126,26 @@ def __new__( if baseline_type == "single": return RegressionRoutine( - probabilistic=probabilistic, output_dim=output_dim, model=model, loss=loss, optim_recipe=optim_recipe(model), + dist_family=dist_family, ) # baseline_type == "ensemble": model = deep_ensembles( [model, copy.deepcopy(model)], task="regression", - probabilistic=probabilistic, + probabilistic=dist_family is not None, ) return RegressionRoutine( - probabilistic=probabilistic, output_dim=output_dim, model=model, loss=loss, is_ensemble=True, optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), + dist_family=dist_family, ) @@ -223,36 +203,21 @@ def __new__( class DummyPixelRegressionBaseline: def __new__( cls, - probabilistic: bool, in_channels: int, output_dim: int, image_size: int, loss: nn.Module, - dist_type: str = "normal", + dist_family: str = "normal", baseline_type: str = "single", optim_recipe=None, ema: bool = False, swa: bool = False, ) -> PixelRegressionRoutine: - if probabilistic: - if dist_type == "normal": - last_layer = NormalLayer(output_dim) - num_classes = output_dim * 2 - elif dist_type == "laplace": - last_layer = LaplaceLayer(output_dim) - num_classes = output_dim * 2 - else: # dist_type == "nig" - last_layer = NormalInverseGammaLayer(output_dim) - num_classes = output_dim * 4 - else: - last_layer = nn.Identity() - num_classes = output_dim - model = dummy_segmentation_model( - num_classes=num_classes, + num_classes=output_dim, in_channels=in_channels, image_size=image_size, - last_layer=last_layer, + dist_family=dist_family, ) if ema: model = EMA(model, momentum=0.99) @@ -261,25 +226,25 @@ def __new__( if baseline_type == "single": return PixelRegressionRoutine( - probabilistic=probabilistic, output_dim=output_dim, model=model, loss=loss, optim_recipe=optim_recipe(model), + dist_family=dist_family, ) # baseline_type == "ensemble": model = deep_ensembles( [model, copy.deepcopy(model)], task="pixel_regression", - probabilistic=probabilistic, + probabilistic=dist_family is not None, ) return PixelRegressionRoutine( - probabilistic=probabilistic, output_dim=output_dim, model=model, loss=loss, format_batch_fn=RepeatTarget(2), is_ensemble=True, optim_recipe=optim_recipe(model), + dist_family=dist_family, ) diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 7ec2581b..02517617 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -1,6 +1,8 @@ import torch from torch import Tensor, nn +from torch_uncertainty.layers.distributions import get_dist_conv_layer, get_dist_linear_layer + __all__ = [ "dummy_model", ] @@ -12,17 +14,22 @@ def __init__( in_channels: int, num_classes: int, dropout_rate: float, - last_layer: nn.Module, + dist_family: str | None = None, ) -> None: super().__init__() self.in_channels = in_channels self.dropout_rate = dropout_rate - self.linear = nn.Linear( - 1, - num_classes, - ) - self.last_layer = last_layer + self.linear = nn.Linear(1, num_classes) + + if dist_family is None: + self.last_layer = nn.Linear(num_classes, num_classes) + else: + self.last_layer = get_dist_linear_layer(dist_family)( + base_layer=nn.Linear, + event_dim=num_classes, + in_features=num_classes, + ) self.dropout = nn.Dropout(p=dropout_rate) def forward(self, x: Tensor) -> Tensor: @@ -53,7 +60,7 @@ def __init__( num_classes: int, dropout_rate: float, image_size: int, - last_layer: nn.Module, + dist_family: str | None = None, ) -> None: super().__init__() self.dropout_rate = dropout_rate @@ -62,7 +69,16 @@ def __init__( self.image_size = image_size self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=3, padding=1) self.dropout = nn.Dropout(p=dropout_rate) - self.last_layer = last_layer + if dist_family is None: + self.last_layer = nn.Identity() + else: + self.last_layer = get_dist_conv_layer(dist_family)( + base_layer=nn.Conv2d, + event_dim=num_classes, + in_channels=num_classes, + kernel_size=3, + padding=1, + ) def forward(self, x: Tensor) -> Tensor: return self.last_layer( @@ -87,7 +103,7 @@ def dummy_model( num_classes: int, dropout_rate: float = 0.0, with_feats: bool = True, - last_layer: nn.Module | None = None, + dist_family: str | None = None, ) -> _Dummy: """Dummy model for testing purposes. @@ -97,25 +113,23 @@ def dummy_model( num_estimators (int): Number of estimators in the ensemble. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. with_feats (bool, optional): Whether to include features. Defaults to True. - last_layer (nn.Module, optional): Last layer of the model. Defaults to None. + dist_family (str, optional): Distribution family. Defaults to None. Returns: _Dummy: Dummy model. """ - if last_layer is None: - last_layer = nn.Identity() if with_feats: return _DummyWithFeats( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, - last_layer=last_layer, + dist_family=dist_family, ) return _Dummy( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, - last_layer=last_layer, + dist_family=dist_family, ) @@ -124,7 +138,7 @@ def dummy_segmentation_model( num_classes: int, image_size: int, dropout_rate: float = 0.0, - last_layer: nn.Module | None = None, + dist_family: str | None = None, ) -> nn.Module: """Dummy segmentation model for testing purposes. @@ -133,17 +147,15 @@ def dummy_segmentation_model( num_classes (int): Number of output classes. image_size (int): Size of the input image. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. - last_layer (nn.Module, optional): Last layer of the model. Defaults to None. + dist_family (str, optional): Distribution family. Defaults to None. Returns: nn.Module: Dummy segmentation model. """ - if last_layer is None: - last_layer = nn.Identity() return _DummySegmentation( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, image_size=image_size, - last_layer=last_layer, + dist_family=dist_family, ) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index f97a527c..e59e0a5c 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -106,14 +106,14 @@ def test_standard(self): hidden_dims=[1], ) _ = net(torch.rand(1, 3)) - for distribution in ["normal", "laplace", "nig"]: + for dist_family in ["normal", "laplace", "nig"]: MLPBaseline( in_features=3, output_dim=10, loss=nn.MSELoss(), version="std", hidden_dims=[1], - distribution=distribution, + dist_family=dist_family, ) def test_errors(self): diff --git a/tests/datamodules/segmentation/test_muad.py b/tests/datamodules/segmentation/test_muad.py index 862206f0..97d4f6d0 100644 --- a/tests/datamodules/segmentation/test_muad.py +++ b/tests/datamodules/segmentation/test_muad.py @@ -35,3 +35,7 @@ def test_camvid_main(self): dm.setup() dm.train_dataloader() dm.val_dataloader() + + def test_small_muad_accessibility(self): + dataset = MUAD(root="./data/", split="test", version="small", download=True) + assert len(dataset.samples) > 0, "Dataset is not found" diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py index 63a52f27..cee3ebe5 100644 --- a/tests/layers/test_distributions.py +++ b/tests/layers/test_distributions.py @@ -1,22 +1,206 @@ import pytest +import torch from torch_uncertainty.layers.distributions import ( - LaplaceLayer, - NormalLayer, - TUDist, + get_dist_conv_layer, + get_dist_linear_layer, ) -class TestDistributions: - def test(self): - TUDist.__abstractmethods__ = set() - dist = TUDist(dim=1) - dist.forward(None) +@pytest.fixture() +def feat_input() -> torch.Tensor: + return torch.rand((3, 8)) # (B, Hin) + + +def img_input() -> torch.Tensor: + return torch.rand((3, 2, 32, 32)) # (B, C, H, W) + + +class TestDistributionLinear: + """Testing the distribution linear layer classes.""" + + def test_normal_linear(self, feat_input: torch.Tensor): + dist_layer = get_dist_linear_layer("normal") + layer = dist_layer( + base_layer=torch.nn.Linear, + event_dim=2, + in_features=8, + ) + out = layer(feat_input) + assert out.keys() == {"loc", "scale"} + assert out["loc"].shape == torch.Size([3, 2]) + assert out["scale"].shape == torch.Size([3, 2]) + + def test_laplace_linear(self, feat_input: torch.Tensor): + dist_layer = get_dist_linear_layer("laplace") + layer = dist_layer( + base_layer=torch.nn.Linear, + event_dim=2, + in_features=8, + ) + out = layer(feat_input) + assert out.keys() == {"loc", "scale"} + assert out["loc"].shape == torch.Size([3, 2]) + assert out["scale"].shape == torch.Size([3, 2]) + + def test_cauchy_linear(self, feat_input: torch.Tensor): + dist_layer = get_dist_linear_layer("cauchy") + layer = dist_layer( + base_layer=torch.nn.Linear, + event_dim=2, + in_features=8, + ) + out = layer(feat_input) + assert out.keys() == {"loc", "scale"} + assert out["loc"].shape == torch.Size([3, 2]) + assert out["scale"].shape == torch.Size([3, 2]) + + def test_student_linear(self, feat_input: torch.Tensor): + dist_layer = get_dist_linear_layer("student") + layer = dist_layer( + base_layer=torch.nn.Linear, + event_dim=2, + in_features=8, + ) + out = layer(feat_input) + assert out.keys() == {"loc", "scale", "df"} + assert out["loc"].shape == torch.Size([3, 2]) + assert out["scale"].shape == torch.Size([3, 2]) + assert out["df"].shape == torch.Size([3, 2]) + + layer = dist_layer( + base_layer=torch.nn.Linear, + event_dim=2, + in_features=8, + fixed_df=3.0, + ) + out = layer(feat_input) + assert out.keys() == {"loc", "scale", "df"} + assert out["loc"].shape == torch.Size([3, 2]) + assert out["scale"].shape == torch.Size([3, 2]) + assert out["df"].shape == torch.Size([3, 2]) + assert torch.allclose(out["df"], torch.tensor(3.0)) + + def test_nig_linear(self, feat_input: torch.Tensor): + dist_layer = get_dist_linear_layer("nig") + layer = dist_layer( + base_layer=torch.nn.Linear, + event_dim=2, + in_features=8, + ) + out = layer(feat_input) + assert out.keys() == {"loc", "lmbda", "alpha", "beta"} + assert out["loc"].shape == torch.Size([3, 2]) + assert out["lmbda"].shape == torch.Size([3, 2]) + assert out["alpha"].shape == torch.Size([3, 2]) + assert out["beta"].shape == torch.Size([3, 2]) + + def test_failures(self): + with pytest.raises(NotImplementedError): + get_dist_linear_layer("unknown") - def test_errors(self): - with pytest.raises(ValueError): - NormalLayer(-1, 1) with pytest.raises(ValueError): - NormalLayer(1, -1) + layer_class = get_dist_linear_layer("normal") + layer_class( + base_layer=torch.nn.Conv2d, + event_dim=2, + in_channels=5, + ) + + +class TestDistributionConv: + """Testing the distribution convolutional layer classes.""" + + def test_normal_conv(self): + dist_layer = get_dist_conv_layer("normal") + layer = dist_layer( + base_layer=torch.nn.Conv2d, + event_dim=2, + in_channels=2, + kernel_size=3, + ) + out = layer(img_input()) + assert out.keys() == {"loc", "scale"} + assert out["loc"].shape == torch.Size([3, 2, 30, 30]) + assert out["scale"].shape == torch.Size([3, 2, 30, 30]) + + def test_laplace_conv(self): + dist_layer = get_dist_conv_layer("laplace") + layer = dist_layer( + base_layer=torch.nn.Conv2d, + event_dim=2, + in_channels=2, + kernel_size=3, + ) + out = layer(img_input()) + assert out.keys() == {"loc", "scale"} + assert out["loc"].shape == torch.Size([3, 2, 30, 30]) + assert out["scale"].shape == torch.Size([3, 2, 30, 30]) + + def test_cauchy_conv(self): + dist_layer = get_dist_conv_layer("cauchy") + layer = dist_layer( + base_layer=torch.nn.Conv2d, + event_dim=2, + in_channels=2, + kernel_size=3, + ) + out = layer(img_input()) + assert out.keys() == {"loc", "scale"} + assert out["loc"].shape == torch.Size([3, 2, 30, 30]) + assert out["scale"].shape == torch.Size([3, 2, 30, 30]) + + def test_student_conv(self): + dist_layer = get_dist_conv_layer("student") + layer = dist_layer( + base_layer=torch.nn.Conv2d, + event_dim=2, + in_channels=2, + kernel_size=3, + ) + out = layer(img_input()) + assert out.keys() == {"loc", "scale", "df"} + assert out["loc"].shape == torch.Size([3, 2, 30, 30]) + assert out["scale"].shape == torch.Size([3, 2, 30, 30]) + assert out["df"].shape == torch.Size([3, 2, 30, 30]) + + layer = dist_layer( + base_layer=torch.nn.Conv2d, + event_dim=2, + in_channels=2, + kernel_size=3, + fixed_df=3.0, + ) + out = layer(img_input()) + assert out.keys() == {"loc", "scale", "df"} + assert out["loc"].shape == torch.Size([3, 2, 30, 30]) + assert out["scale"].shape == torch.Size([3, 2, 30, 30]) + assert out["df"].shape == torch.Size([3, 2, 30, 30]) + assert torch.allclose(out["df"], torch.tensor(3.0)) + + def test_nig_conv(self): + dist_layer = get_dist_conv_layer("nig") + layer = dist_layer( + base_layer=torch.nn.Conv2d, + event_dim=2, + in_channels=2, + kernel_size=3, + ) + out = layer(img_input()) + assert out.keys() == {"loc", "lmbda", "alpha", "beta"} + assert out["loc"].shape == torch.Size([3, 2, 30, 30]) + assert out["lmbda"].shape == torch.Size([3, 2, 30, 30]) + assert out["alpha"].shape == torch.Size([3, 2, 30, 30]) + assert out["beta"].shape == torch.Size([3, 2, 30, 30]) + + def test_failures(self): + with pytest.raises(NotImplementedError): + get_dist_conv_layer("unknown") + with pytest.raises(ValueError): - LaplaceLayer(1, -1) + layer_class = get_dist_conv_layer("normal") + layer_class( + base_layer=torch.nn.Linear, + event_dim=2, + in_features=5, + ) diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index 7cc7fd1d..cfcee746 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -1,11 +1,20 @@ import pytest import torch +from einops import repeat +from torch_uncertainty.layers.functional.packed import ( + packed_in_projection_packed, + packed_multi_head_attention_forward, +) from torch_uncertainty.layers.packed import ( PackedConv1d, PackedConv2d, PackedConv3d, + PackedLayerNorm, PackedLinear, + PackedMultiheadAttention, + PackedTransformerDecoderLayer, + PackedTransformerEncoderLayer, ) @@ -19,9 +28,14 @@ def feat_input_one_rearrange() -> torch.Tensor: return torch.rand((1 * 3, 5)) +@pytest.fixture() +def feat_multi_dim() -> torch.Tensor: + return torch.rand((1, 2, 3, 4, 6)) + + @pytest.fixture() def feat_input_16_features() -> torch.Tensor: - return torch.rand((2, 16)) + return torch.rand((3, 16)) @pytest.fixture() @@ -39,6 +53,81 @@ def voxels_input() -> torch.Tensor: return torch.rand((5, 6, 3, 3, 3)) +@pytest.fixture() +def unbatched_qkv() -> torch.Tensor: + return torch.rand((3, 6)) + + +@pytest.fixture() +def unbatched_q_kv() -> tuple[torch.Tensor, torch.Tensor]: + return torch.rand((3, 6)), torch.rand((4, 2)) + + +@pytest.fixture() +def unbatched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return torch.rand((3, 6)), torch.rand((4, 2)), torch.rand((4, 4)) + + +@pytest.fixture() +def batched_qkv() -> torch.Tensor: + return torch.rand((2, 3, 6)) + + +@pytest.fixture() +def extended_batched_qkv() -> torch.Tensor: + expansion = 2 + return torch.rand((2, 3, 6 * expansion)) + + +@pytest.fixture() +def batched_q_kv() -> tuple[torch.Tensor, torch.Tensor]: + return ( + torch.rand((2, 3, 6)), + torch.rand((2, 4, 2)), + ) + + +@pytest.fixture() +def batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return ( + torch.rand((2, 3, 6)), + torch.rand((2, 4, 2)), + torch.rand((2, 4, 4)), + ) + + +@pytest.fixture() +def extended_batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + expansion = 2 + return ( + torch.rand((2, 3, 6 * expansion)), + torch.rand((2, 4, 2 * expansion)), + torch.rand((2, 4, 4 * expansion)), + ) + + +@pytest.fixture() +def unbatched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: + return torch.rand((3, 6)), torch.rand((4, 6)) + + +@pytest.fixture() +def batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: + return ( + torch.rand((2, 3, 6)), + torch.rand((2, 4, 6)), + ) + + +@pytest.fixture() +def extended_batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: + expansion = 2 + return ( + torch.rand((2, 3, 6 * expansion)), + torch.rand((2, 4, 6 * expansion)), + ) + + class TestPackedLinear: """Testing the PackedLinear layer class.""" @@ -64,29 +153,47 @@ def test_linear_two_estimator_rearrange_not_divisible(self): out = layer(feat) assert out.shape == torch.Size([6, 1]) - def test_linear_full_implementation(self, feat_input_16_features: torch.Tensor): - layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full") + # Full implementation tests + def test_linear_full_implementation( + self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor + ): + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full", bias=False) out = layer(feat_input_16_features) - assert out.shape == torch.Size([2, 4]) + assert out.shape == torch.Size([3, 4]) layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="full") out = layer(feat_input_16_features) - assert out.shape == torch.Size([2, 4]) - - def test_linear_sparse_implementation(self, feat_input_16_features: torch.Tensor): + assert out.shape == torch.Size([3, 4]) + layer = PackedLinear(6, 2, alpha=1, num_estimators=1, implementation="full") + out = layer(feat_multi_dim) + assert out.shape == torch.Size([1, 2, 3, 4, 2]) + + # Sparse implementation tests + def test_linear_sparse_implementation( + self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor + ): layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="sparse") out = layer(feat_input_16_features) - assert out.shape == torch.Size([2, 4]) + assert out.shape == torch.Size([3, 4]) layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="sparse") out = layer(feat_input_16_features) - assert out.shape == torch.Size([2, 4]) - - def test_linear_einsum_implementation(self, feat_input_16_features: torch.Tensor): + assert out.shape == torch.Size([3, 4]) + layer = PackedLinear(6, 2, alpha=1, num_estimators=1, implementation="sparse") + out = layer(feat_multi_dim) + assert out.shape == torch.Size([1, 2, 3, 4, 2]) + + # Einsum implementation tests + def test_linear_einsum_implementation( + self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor + ): layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="einsum") out = layer(feat_input_16_features) - assert out.shape == torch.Size([2, 4]) + assert out.shape == torch.Size([3, 4]) layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="einsum") out = layer(feat_input_16_features) - assert out.shape == torch.Size([2, 4]) + assert out.shape == torch.Size([3, 4]) + layer = PackedLinear(6, 2, alpha=1, num_estimators=1, implementation="einsum") + out = layer(feat_multi_dim) + assert out.shape == torch.Size([1, 2, 3, 4, 2]) def test_linear_extend(self): _ = PackedConv2d(5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1) @@ -248,3 +355,624 @@ def test_conv3_failures(self): with pytest.raises(ValueError): _ = PackedConv3d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1) + + +class TestPackedLayerNorm: + """Testing the PackedGroupNorm layer class.""" + + def test_one_estimator_forward(self, batched_qkv: torch.Tensor): + packed_layer_norm = PackedLayerNorm( + embed_dim=6, + num_estimators=1, + alpha=1, + ) + out = packed_layer_norm(batched_qkv) + assert out.shape == torch.Size([2, 3, 6]) + + +class TestPackedMultiheadAttention: + """Testing the PackedMultiheadAttention layer class.""" + + def test_one_estimator_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch.Tensor): + attn_mask = torch.zeros(1, 3, 3, dtype=torch.bool) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=1, + alpha=1, + num_estimators=1, + ) + out, _ = layer( + query=unbatched_qkv, + key=unbatched_qkv, + value=unbatched_qkv, + attn_mask=attn_mask, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + attn_mask = torch.zeros(2, 3, 3, dtype=torch.bool) + out, _ = layer( + query=unbatched_qkv, + key=unbatched_qkv, + value=unbatched_qkv, + attn_mask=attn_mask, + is_causal=True, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=2, + alpha=1, + num_estimators=1, + batch_first=True, + bias=False, + ) + out, _ = layer( + query=batched_qkv, + key=batched_qkv, + value=batched_qkv, + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_one_estimator_q_kv( + self, + unbatched_q_kv: tuple[torch.Tensor, torch.Tensor], + batched_q_kv: tuple[torch.Tensor, torch.Tensor], + ): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=2, + alpha=1, + num_estimators=1, + kdim=2, + vdim=2, + add_zero_attn=True, + ) + out, _ = layer( + query=unbatched_q_kv[0], + key=unbatched_q_kv[1], + value=unbatched_q_kv[1], + ) + assert out.shape == torch.Size([3, 6]) + unbatched_q_kv = tuple(repeat(seq, "l h -> l b h", b=2) for seq in unbatched_q_kv) + out, _ = layer( + query=unbatched_q_kv[0], + key=unbatched_q_kv[1], + value=unbatched_q_kv[1], + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=2, + alpha=1, + num_estimators=1, + kdim=2, + vdim=2, + batch_first=True, + bias=False, + ) + out, _ = layer( + query=batched_q_kv[0], + key=batched_q_kv[1], + value=batched_q_kv[1], + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_one_estimator_q_k_v( + self, + unbatched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + batched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=1, + alpha=1, + num_estimators=1, + kdim=2, + vdim=4, + add_bias_kv=True, + ) + + key_padding_mask = torch.zeros(4, dtype=torch.bool) + + out, _ = layer( + query=unbatched_q_k_v[0], + key=unbatched_q_k_v[1], + value=unbatched_q_k_v[2], + key_padding_mask=key_padding_mask, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_q_k_v = tuple(repeat(seq, "l h -> l b h", b=2) for seq in unbatched_q_k_v) + + out, _ = layer( + query=unbatched_q_k_v[0], + key=unbatched_q_k_v[1], + value=unbatched_q_k_v[2], + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=2, + alpha=1, + num_estimators=1, + kdim=2, + vdim=4, + batch_first=True, + ) + + layer.eval() + + attn_mask = torch.zeros(3, 4, dtype=torch.bool) + key_padding_mask = torch.zeros(2, 4, dtype=torch.bool) + + out, _ = layer( + query=batched_q_k_v[0], + key=batched_q_k_v[1], + value=batched_q_k_v[2], + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + assert out.shape == torch.Size([2, 3, 6]) + assert out.isfinite().all() + + def test_two_estimators_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch.Tensor): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + ) + out, _ = layer( + query=unbatched_qkv, + key=unbatched_qkv, + value=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + out, _ = layer( + query=unbatched_qkv, + key=unbatched_qkv, + value=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + batch_first=True, + ) + out, _ = layer( + query=batched_qkv, + key=batched_qkv, + value=batched_qkv, + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators_q_kv( + self, + unbatched_q_kv: tuple[torch.Tensor, torch.Tensor], + batched_q_kv: tuple[torch.Tensor, torch.Tensor], + ): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + kdim=2, + vdim=2, + add_zero_attn=True, + ) + out, _ = layer( + query=unbatched_q_kv[0], + key=unbatched_q_kv[1], + value=unbatched_q_kv[1], + ) + assert out.shape == torch.Size([3, 6]) + unbatched_q_kv = tuple(repeat(seq, "l h -> l b h", b=2) for seq in unbatched_q_kv) + + attn_mask = torch.zeros(12, 3, 4, dtype=torch.bool) + key_padding_mask = torch.zeros(2, 4, dtype=torch.bool) + + out, _ = layer( + query=unbatched_q_kv[0], + key=unbatched_q_kv[1], + value=unbatched_q_kv[1], + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + kdim=2, + vdim=2, + batch_first=True, + ) + out, _ = layer( + query=batched_q_kv[0], + key=batched_q_kv[1], + value=batched_q_kv[1], + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators_q_k_v( + self, + unbatched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + extended_batched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + kdim=2, + vdim=4, + add_bias_kv=True, + ) + out, _ = layer( + query=unbatched_q_k_v[0], + key=unbatched_q_k_v[1], + value=unbatched_q_k_v[2], + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_q_k_v = tuple(repeat(seq, "l h -> l b h", b=2) for seq in unbatched_q_k_v) + + attn_mask = torch.zeros(3, 4, dtype=torch.bool) + key_padding_mask = torch.zeros(2, 4, dtype=torch.bool) + + out, _ = layer( + query=unbatched_q_k_v[0], + key=unbatched_q_k_v[1], + value=unbatched_q_k_v[2], + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=2, + num_estimators=2, + kdim=2, + vdim=4, + batch_first=True, + ) + out, _ = layer( + query=extended_batched_q_k_v[0], + key=extended_batched_q_k_v[1], + value=extended_batched_q_k_v[2], + ) + assert out.shape == torch.Size([2, 3, 12]) + + +class TestPackedTransformerEncoderLayer: + """Testing the PackedTransformerEncoderLayer class.""" + + def test_one_estimator(self, unbatched_qkv: torch.Tensor, batched_qkv: torch.Tensor): + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + norm_first=True, + first=True, + ) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + batch_first=True, + last=True, + activation=torch.nn.GELU(), + ) + out = layer( + src=batched_qkv, + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators(self, unbatched_qkv: torch.Tensor, extended_batched_qkv: torch.Tensor): + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=1, + num_estimators=2, + activation=torch.nn.ELU(), + ) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=2, + num_estimators=2, + batch_first=True, + ) + out = layer( + src=extended_batched_qkv, + ) + assert out.shape == torch.Size([2, 3, 12]) + + +class TestPackedTransformerDecoderLayer: + """Testing the PackedTransformerDecoderLayer class.""" + + def test_one_estimator( + self, + unbatched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + batched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + ): + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + norm_first=True, + first=True, + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_tgt_memory = tuple( + repeat(seq, "l h -> l b h", b=2) for seq in unbatched_tgt_memory + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + batch_first=True, + last=True, + activation=torch.nn.GELU(), + bias=False, + ) + out = layer( + tgt=batched_tgt_memory[0], + memory=batched_tgt_memory[1], + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators( + self, + unbatched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + extended_batched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + ): + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=1, + num_estimators=2, + activation=torch.nn.ELU(), + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_tgt_memory = tuple( + repeat(seq, "l h -> l b h", b=2) for seq in unbatched_tgt_memory + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=2, + num_estimators=2, + batch_first=True, + ) + out = layer( + tgt=extended_batched_tgt_memory[0], + memory=extended_batched_tgt_memory[1], + ) + assert out.shape == torch.Size([2, 3, 12]) + + +class TestPackedFunctional: + def test_packed_in_projection_packed( + self, + batched_qkv: torch.Tensor, + ): + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=batched_qkv, + k=batched_qkv, + v=batched_qkv, + w=torch.rand((1, 18, 6)), + num_groups=1, + ) + assert proj_q.shape == torch.Size([2, 3, 6]) + assert proj_k.shape == torch.Size([2, 3, 6]) + assert proj_v.shape == torch.Size([2, 3, 6]) + + q_kv = torch.rand((2, 3, 6)), torch.rand((2, 4, 6)) + + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_kv[0], + k=q_kv[1], + v=q_kv[1], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=None, + ) + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_kv[0], + k=q_kv[1], + v=q_kv[1], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=torch.rand(18), + ) + + assert proj_q.shape == torch.Size([2, 3, 6]) + assert proj_k.shape == torch.Size([2, 4, 6]) + assert proj_v.shape == torch.Size([2, 4, 6]) + + q_k_v = torch.rand((2, 3, 6)), torch.rand((2, 4, 6)), torch.rand((2, 4, 6)) + + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_k_v[0], + k=q_k_v[1], + v=q_k_v[2], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=None, + ) + + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_k_v[0], + k=q_k_v[1], + v=q_k_v[2], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=torch.rand(18), + ) + + assert proj_q.shape == torch.Size([2, 3, 6]) + assert proj_k.shape == torch.Size([2, 4, 6]) + assert proj_v.shape == torch.Size([2, 4, 6]) + + def test_packed_multi_head_attention_forward_failures(self, unbatched_q_k_v: torch.Tensor): + q, k, v = unbatched_q_k_v + with pytest.raises(RuntimeError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + is_causal=True, + attn_mask=None, + ) + + with pytest.raises(RuntimeError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + attn_mask=torch.rand(2, 2), + use_separate_proj_weight=True, + q_proj_weight=torch.rand(1, 6, 6), + k_proj_weight=torch.rand(1, 6, 2), + v_proj_weight=torch.rand(1, 6, 4), + ) + + with pytest.raises(AssertionError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + attn_mask=torch.rand(1, 1, 3, 4), + use_separate_proj_weight=True, + q_proj_weight=torch.rand(1, 6, 6), + k_proj_weight=torch.rand(1, 6, 2), + v_proj_weight=torch.rand(1, 6, 4), + ) + + with pytest.raises(AssertionError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + attn_mask=torch.rand(1, 2, 2), + use_separate_proj_weight=True, + q_proj_weight=torch.rand(1, 6, 6), + k_proj_weight=torch.rand(1, 6, 2), + v_proj_weight=torch.rand(1, 6, 4), + ) diff --git a/tests/losses/test_bayesian.py b/tests/losses/test_bayesian.py index 6d68e16d..afca2e53 100644 --- a/tests/losses/test_bayesian.py +++ b/tests/losses/test_bayesian.py @@ -3,7 +3,8 @@ from torch import nn, optim from torch_uncertainty.layers.bayesian import BayesLinear -from torch_uncertainty.losses import ELBOLoss +from torch_uncertainty.layers.distributions import NormalLinear +from torch_uncertainty.losses import DistributionNLLLoss, ELBOLoss from torch_uncertainty.routines import RegressionRoutine @@ -23,13 +24,32 @@ def test_main(self): loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) loss(model(torch.randn(1, 1)), torch.randn(1, 1)) + def test_prob_regression_training_step(self): + model = NormalLinear(BayesLinear, event_dim=4, in_features=10) + criterion = DistributionNLLLoss() + loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=3, dist_family="normal") + + routine = RegressionRoutine( + output_dim=1, + model=model, + loss=loss, + dist_family="normal", + optim_recipe=optim.Adam( + model.parameters(), + lr=5e-4, + weight_decay=0, + ), + ) + inputs = torch.randn(1, 10) + targets = torch.randn(1, 4) + routine.training_step((inputs, targets), 0) + def test_training_step(self): model = BayesLinear(10, 4) criterion = nn.MSELoss() loss = ELBOLoss(model, criterion, kl_weight=1 / 50000, num_samples=3) routine = RegressionRoutine( - probabilistic=False, output_dim=4, model=model, loss=loss, diff --git a/tests/losses/test_regression.py b/tests/losses/test_regression.py index 893845f9..c20d0bff 100644 --- a/tests/losses/test_regression.py +++ b/tests/losses/test_regression.py @@ -2,14 +2,14 @@ import pytest import torch -from torch.distributions import Normal +from torch.distributions import Independent, Normal -from torch_uncertainty.layers.distributions import NormalInverseGamma from torch_uncertainty.losses import ( BetaNLL, DERLoss, DistributionNLLLoss, ) +from torch_uncertainty.utils.distributions import NormalInverseGamma class TestDistributionNLL: @@ -43,6 +43,8 @@ def test_main(self): torch.ones((2, 1)), ) + inputs = Independent(inputs, 0) + assert loss( inputs, targets, diff --git a/tests/models/test_mlps.py b/tests/models/test_mlps.py index 2e7a72e8..e811bea2 100644 --- a/tests/models/test_mlps.py +++ b/tests/models/test_mlps.py @@ -1,4 +1,3 @@ -from torch_uncertainty.layers.distributions import NormalLayer from torch_uncertainty.models.mlp import bayesian_mlp, mlp, packed_mlp @@ -10,8 +9,7 @@ def test_mlps(self): 1, 1, hidden_dims=[1, 1, 1], - final_layer=NormalLayer, - final_layer_args={"dim": 1}, + dist_family="normal", ) mlp(1, 1, hidden_dims=[]) packed_mlp(1, 1, hidden_dims=[]) diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index a9fbe0f2..a519ba82 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -12,8 +12,4 @@ class TestSegformer: def test_main(self): model = seg_former(10, 0) seg_former(10, 1) - seg_former(10, 2) - seg_former(10, 3) - seg_former(10, 4) - seg_former(10, 5) model(torch.randn(1, 3, 32, 32)) diff --git a/tests/models/wrappers/test_deep_ensembles.py b/tests/models/wrappers/test_deep_ensembles.py index 98e45c08..5c49f029 100644 --- a/tests/models/wrappers/test_deep_ensembles.py +++ b/tests/models/wrappers/test_deep_ensembles.py @@ -31,6 +31,16 @@ def test_list_singleton(self): with pytest.raises(ValueError): deep_ensembles([model_1], num_estimators=1) + def test_error_prob_regression(self): + # The output dicts will have different keys + model_1 = dummy_model(1, 2, dist_family="normal") + model_2 = dummy_model(1, 2, dist_family="nig") + + de = deep_ensembles([model_1, model_2], task="regression", probabilistic=True) + + with pytest.raises(ValueError): + de(torch.randn(5, 1)) + def test_errors(self): model_1 = dummy_model(1, 10) with pytest.raises(ValueError): diff --git a/tests/post_processing/test_laplace.py b/tests/post_processing/test_laplace.py index 6b798d6b..8b6249ea 100644 --- a/tests/post_processing/test_laplace.py +++ b/tests/post_processing/test_laplace.py @@ -23,7 +23,7 @@ def test_training(self): ds = TensorDataset(torch.randn(16, 1), torch.randn(16, 10)) la = LaplaceApprox( task="classification", - model=dummy_model(1, 10, last_layer=nn.Linear(10, 10)), + model=dummy_model(1, 10), ) la.fit(ds) la(torch.randn(1, 1)) diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py index 6b2bfdcf..7cddac7c 100644 --- a/tests/routines/test_pixel_regression.py +++ b/tests/routines/test_pixel_regression.py @@ -30,7 +30,7 @@ def test_one_estimator_two_classes(self): dm = DummyPixelRegressionDataModule(root=root, batch_size=5, output_dim=3) model = DummyPixelRegressionBaseline( - probabilistic=False, + dist_family=None, in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, @@ -52,7 +52,7 @@ def test_one_estimator_two_classes(self): enable_checkpointing=False, ) model = DummyPixelRegressionBaseline( - probabilistic=True, + dist_family="normal", in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, @@ -74,7 +74,7 @@ def test_two_estimators_one_class(self): dm = DummyPixelRegressionDataModule(root=root, batch_size=4, output_dim=1) model = DummyPixelRegressionBaseline( - probabilistic=False, + dist_family=None, in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, @@ -90,7 +90,7 @@ def test_two_estimators_one_class(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=None) model = DummyPixelRegressionBaseline( - probabilistic=True, + dist_family="normal", in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, @@ -110,18 +110,18 @@ def test_two_estimators_one_class(self): def test_depth_errors(self): with pytest.raises(ValueError, match="output_dim must be positive"): PixelRegressionRoutine( - probabilistic=False, model=nn.Identity(), output_dim=0, loss=nn.MSELoss(), + dist_family=None, ) with pytest.raises(ValueError, match="num_image_plot must be positive"): PixelRegressionRoutine( - probabilistic=False, model=nn.Identity(), output_dim=1, loss=nn.MSELoss(), + dist_family=None, num_image_plot=0, log_plots=True, ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 8f1fa6ed..e491209e 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest +import torch from torch import nn from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule @@ -20,13 +21,13 @@ def test_one_estimator_one_output(self): dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) model = DummyRegressionBaseline( - probabilistic=True, in_features=dm.in_features, output_dim=1, loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ema=True, + dist_family="normal", ) trainer.fit(model, dm) @@ -36,13 +37,13 @@ def test_one_estimator_one_output(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( - probabilistic=False, in_features=dm.in_features, output_dim=1, loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", swa=True, + dist_family=None, ) trainer.fit(model, dm) @@ -57,13 +58,12 @@ def test_one_estimator_two_outputs(self): dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) model = DummyRegressionBaseline( - probabilistic=True, in_features=dm.in_features, output_dim=2, loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", - dist_type="laplace", + dist_family="laplace", ) trainer.fit(model, dm) trainer.validate(model, dm) @@ -72,12 +72,12 @@ def test_one_estimator_two_outputs(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( - probabilistic=False, in_features=dm.in_features, output_dim=2, loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", + dist_family=None, ) trainer.fit(model, dm) trainer.validate(model, dm) @@ -91,13 +91,12 @@ def test_two_estimators_one_output(self): dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) model = DummyRegressionBaseline( - probabilistic=True, in_features=dm.in_features, output_dim=1, loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", - dist_type="nig", + dist_family="nig", ) trainer.fit(model, dm) trainer.validate(model, dm) @@ -106,12 +105,12 @@ def test_two_estimators_one_output(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( - probabilistic=False, in_features=dm.in_features, output_dim=1, loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", + dist_family=None, ) trainer.fit(model, dm) trainer.validate(model, dm) @@ -125,12 +124,12 @@ def test_two_estimators_two_outputs(self): dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) model = DummyRegressionBaseline( - probabilistic=True, in_features=dm.in_features, output_dim=2, loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", + dist_family="normal", ) trainer.fit(model, dm) trainer.validate(model, dm) @@ -139,12 +138,12 @@ def test_two_estimators_two_outputs(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( - probabilistic=False, in_features=dm.in_features, output_dim=2, loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", + dist_family=None, ) trainer.fit(model, dm) trainer.validate(model, dm) @@ -154,8 +153,17 @@ def test_two_estimators_two_outputs(self): def test_regression_failures(self): with pytest.raises(ValueError, match="output_dim must be positive"): RegressionRoutine( - probabilistic=True, + dist_family="normal", output_dim=0, model=nn.Identity(), loss=nn.MSELoss(), ) + + with pytest.raises(TypeError): + routine = RegressionRoutine( + dist_family="normal", + output_dim=1, + model=nn.Identity(), + loss=nn.MSELoss(), + ) + routine(torch.randn(1, 1)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 60d9e9e9..2d2ddc44 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,6 @@ HfHubHTTPError, RepositoryNotFoundError, ) -from torch.distributions import Laplace, Normal from torch_uncertainty.utils import ( csv_writer, @@ -78,12 +77,20 @@ def test_nig(self): ) _ = dist.mean, dist.mean_loc, dist.mean_variance, dist.variance_loc - def test_errors(self): - with pytest.raises(ValueError): - distributions.cat_dist( - [ - Normal(torch.tensor([0.0]), torch.tensor([1.0])), - Laplace(torch.tensor([0.0]), torch.tensor([1.0])), - ], - dim=0, - ) + def test_get_dist_class(self): + dist = distributions.get_dist_class("normal") + assert dist == torch.distributions.Normal + dist = distributions.get_dist_class("laplace") + assert dist == torch.distributions.Laplace + dist = distributions.get_dist_class("nig") + assert dist == distributions.NormalInverseGamma + dist = distributions.get_dist_class("cauchy") + assert dist == torch.distributions.Cauchy + dist = distributions.get_dist_class("student") + assert dist == torch.distributions.StudentT + + def test_get_dist_estimate(self): + dist = torch.distributions.Normal(0.0, 1.0) + mean = distributions.get_dist_estimate(dist, "mean") + mode = distributions.get_dist_estimate(dist, "mode") + assert mean == mode diff --git a/torch_uncertainty/baselines/depth/bts.py b/torch_uncertainty/baselines/depth/bts.py index 2f05e18b..31b54051 100644 --- a/torch_uncertainty/baselines/depth/bts.py +++ b/torch_uncertainty/baselines/depth/bts.py @@ -22,11 +22,12 @@ def __init__( version: Literal["std"], arch: int, max_depth: float, + dist_family: str | None = None, num_estimators: int = 1, pretrained_backbone: bool = True, ) -> None: params = { - "dist_layer": nn.Identity, + "dist_family": dist_family, "max_depth": max_depth, "pretrained_backbone": pretrained_backbone, } @@ -39,10 +40,10 @@ def __init__( model = self.versions[version][self.archs.index(arch)](**params) super().__init__( output_dim=1, - probabilistic=False, model=model, loss=loss, num_estimators=num_estimators, format_batch_fn=format_batch_fn, + dist_family=dist_family, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 34cfdc21..c677466a 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -2,11 +2,6 @@ from torch import nn -from torch_uncertainty.layers.distributions import ( - LaplaceLayer, - NormalInverseGammaLayer, - NormalLayer, -) from torch_uncertainty.models.mlp import mlp, packed_mlp from torch_uncertainty.routines.regression import ( RegressionRoutine, @@ -30,37 +25,19 @@ def __init__( dropout_rate: float = 0.0, alpha: float | None = None, gamma: int = 1, - distribution: Literal["normal", "laplace", "nig"] | None = None, + dist_family: str | None = None, + dist_args: dict | None = None, ) -> None: r"""MLP baseline for regression providing support for various versions.""" - probabilistic = True params = { "dropout_rate": dropout_rate, "in_features": in_features, "num_outputs": output_dim, "hidden_dims": hidden_dims, + "dist_family": dist_family, + "dist_args": dist_args, } - if distribution == "normal": - final_layer = NormalLayer - final_layer_args = {"dim": output_dim} - params["num_outputs"] *= 2 - elif distribution == "laplace": - final_layer = LaplaceLayer - final_layer_args = {"dim": output_dim} - params["num_outputs"] *= 2 - elif distribution == "nig": - final_layer = NormalInverseGammaLayer - final_layer_args = {"dim": output_dim} - params["num_outputs"] *= 4 - else: # distribution is None: - probabilistic = False - final_layer = nn.Identity - final_layer_args = {} - - params["final_layer"] = final_layer - params["final_layer_args"] = final_layer_args - format_batch_fn = nn.Identity() if version not in self.versions: @@ -76,12 +53,11 @@ def __init__( model = self.versions[version](**params) - # version in self.versions: super().__init__( - probabilistic=probabilistic, output_dim=output_dim, model=model, loss=loss, + dist_family=dist_family, is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, ) diff --git a/torch_uncertainty/datasets/classification/__init__.py b/torch_uncertainty/datasets/classification/__init__.py index a0b496a5..7a3bb716 100644 --- a/torch_uncertainty/datasets/classification/__init__.py +++ b/torch_uncertainty/datasets/classification/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .cifar import CIFAR10C, CIFAR10H, CIFAR10N, CIFAR100C, CIFAR100N +from .cub import CUB from .imagenet import ( ImageNetA, ImageNetC, diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py new file mode 100644 index 00000000..1dfa1f0b --- /dev/null +++ b/torch_uncertainty/datasets/classification/cub.py @@ -0,0 +1,82 @@ +import logging +from collections.abc import Callable +from pathlib import Path + +import torch +from torch import Tensor +from torchvision.datasets import ImageFolder +from torchvision.datasets.utils import check_integrity, download_and_extract_archive + + +class CUB(ImageFolder): + base_folder = "CUB_200_2011/images" + url = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" + filename = "CUB_200_2011.tgz" + tgz_md5 = "97eceeb196236b17998738112f37df78" + + def __init__( + self, + root: str | Path, + train: bool = True, + transform: Callable | None = None, + target_transform: Callable | None = None, + download: bool = False, + ): + """The Caltech-UCSD Birds-200-2011 dataset. + + Args: + root (str): Root directory of the dataset. + train (bool, optional): If True, creates dataset from training set, otherwise creates + from test set. Defaults to True. + transform (callable, optional): A function/transform that takes in an PIL image and + returns a transformed version. E.g, transforms.RandomCrop. Defaults to None. + target_transform (callable, optional): A function/transform that takes in the target + and transforms it. Defaults to None. + download (bool, optional): If True, downloads the dataset from the internet and puts it + in root directory. If dataset is already downloaded, it is not downloaded again. + Defaults to + Reference: + Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S. Caltech-UCSD + Birds 200. + """ + self.folder_root = Path(root) + self.train = train + if download: + self._download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. You can use download=True to " "download it." + ) + + super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) + + training_idx = self._load_train_idx() + self.samples = [sample for i, sample in enumerate(self.samples) if training_idx[i] == train] + self._labels = [label for i, label in enumerate(self.targets) if training_idx[i] == train] + with Path(self.folder_root / "CUB_200_2011" / "classes.txt").open("r") as f: + self.class_names = [ + line.split(" ")[1].split(".")[1].replace("\n", "").replace("_", " ") for line in f + ] + + def _load_train_idx(self) -> Tensor: + is_training_img = [] + with (self.folder_root / "CUB_200_2011" / "train_test_split.txt").open("r") as f: + is_training_img = [int(line.split(" ")[1]) for line in f] + return torch.as_tensor(is_training_img) + + def _check_integrity(self) -> bool: + fpath = self.folder_root / self.filename + return check_integrity( + fpath, + self.tgz_md5, + ) + + def _download(self): + if self._check_integrity(): + logging.info("Files already downloaded and verified") + return + + download_and_extract_archive( + url=self.url, download_root=self.folder_root, filename=self.filename, md5=self.tgz_md5 + ) diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index 1590d6be..e0b28ae0 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -88,11 +88,3 @@ def download(self) -> None: md5=self.tgz_md5, ) logging.info("Downloaded %s to %s.", self.filename, self.root) - - def __getitem__(self, index: int) -> tuple[Any, Any]: - """Get the samples and targets of the dataset. - - Args: - index (int): The index of the sample to get. - """ - return super().__getitem__(index) diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index c21f9f67..73ea2ba7 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -1,11 +1,13 @@ -import json import logging import os import shutil from collections.abc import Callable from importlib import util from pathlib import Path -from typing import Literal +from typing import Literal, NamedTuple + +from huggingface_hub import hf_hub_download +from PIL import Image if util.find_spec("cv2"): import cv2 @@ -14,17 +16,17 @@ else: # coverage: ignore cv2_installed = False import numpy as np -import torch -from einops import rearrange -from PIL import Image from torchvision import tv_tensors from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( - check_integrity, download_and_extract_archive, - download_url, ) -from torchvision.transforms.v2 import functional as F + + +class MUADClass(NamedTuple): + name: str + id: int + color: tuple[int, int, int] class MUAD(VisionDataset): @@ -38,17 +40,55 @@ class MUAD(VisionDataset): "val": "957af9c1c36f0a85c33279e06b6cf8d8", "val_depth": "0282030d281aeffee3335f713ba12373", } + + small_muad_url = "ENSTA-U2IS/miniMUAD" + _num_samples = { - "train": 3420, - "val": 492, - "test": ..., + "full": { + "train": 3420, + "val": 492, + "test": ..., + }, + "small": { + "train": 400, + "val": 54, + "test": 112, + "ood": 20, + }, } + + classes = [ + MUADClass("road", 0, (128, 64, 128)), + MUADClass("sidewalk", 1, (244, 35, 232)), + MUADClass("building", 2, (70, 70, 70)), + MUADClass("wall", 3, (102, 102, 156)), + MUADClass("fence", 4, (190, 153, 153)), + MUADClass("pole", 5, (153, 153, 153)), + MUADClass("traffic_light", 6, (250, 170, 30)), + MUADClass("traffic_sign", 7, (220, 220, 0)), + MUADClass("vegetation", 8, (107, 142, 35)), + MUADClass("terrain", 9, (152, 251, 152)), + MUADClass("sky", 10, (70, 130, 180)), + MUADClass("person", 11, (220, 20, 60)), + MUADClass("rider", 12, (255, 0, 0)), + MUADClass("car", 13, (0, 0, 142)), + MUADClass("truck", 14, (0, 0, 70)), + MUADClass("bus", 15, (0, 60, 100)), + MUADClass("train", 16, (0, 80, 100)), + MUADClass("motorcycle", 17, (0, 0, 230)), + MUADClass("bicycle", 18, (119, 11, 32)), + MUADClass("bear deer cow", 19, (255, 228, 196)), + MUADClass("garbage_bag stand_food trash_can", 20, (128, 128, 0)), + MUADClass("unlabeled", 21, (0, 0, 0)), # id 255 or 21 + ] + targets: list[Path] = [] def __init__( self, root: str | Path, - split: Literal["train", "val"], + split: Literal["train", "val", "test", "ood"], + version: Literal["small", "full"] = "full", min_depth: float | None = None, max_depth: float | None = None, target_type: Literal["semantic", "depth"] = "semantic", @@ -61,6 +101,8 @@ def __init__( root (str): Root directory of dataset where directory 'leftImg8bit' and 'leftLabel' or 'leftDepth' are located. split (str, optional): The image split to use, 'train' or 'val'. + version (str, optional): The version of the dataset to use, 'small' + or 'full'. Defaults to 'full'. min_depth (float, optional): The maximum depth value to use if target_type is 'depth'. Defaults to None. max_depth (float, optional): The maximum depth value to use if @@ -86,20 +128,25 @@ def __init__( "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) + + if version == "small" and target_type == "depth": + raise ValueError("Depth target is not available for the small version of MUAD.") + logging.info( "MUAD is restricted to non-commercial use. By using MUAD, you " "agree to the terms and conditions." ) - super().__init__( - root=Path(root) / "MUAD", - transforms=transforms, - ) + + dataset_root = Path(root) / "MUAD" if version == "full" else Path(root) / "MUAD_small" + + super().__init__(dataset_root, transforms=transforms) self.min_depth = min_depth self.max_depth = max_depth - if split not in ["train", "val"]: + if split not in ["train", "val", "test", "ood"]: raise ValueError(f"split must be one of ['train', 'val']. Got {split}.") self.split = split + self.version = version self.target_type = target_type if not self.check_split_integrity("leftImg8bit"): @@ -133,49 +180,8 @@ def __init__( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) - # Load classes metadata - cls_path = self.root / "classes.json" - if (not check_integrity(cls_path, self.classes_md5)) and download: - download_url( - self.classes_url, - self.root, - "classes.json", - self.classes_md5, - ) - - with (self.root / "classes.json").open() as file: - self.classes = json.load(file) - - train_id_to_color = [c["object_id"] for c in self.classes if c["train_id"] not in [-1, 255]] - train_id_to_color.append([0, 0, 0]) - self.train_id_to_color = np.array(train_id_to_color) - self._make_dataset(self.root / split) - def encode_target(self, target: Image.Image) -> Image.Image: - """Encode target image to tensor. - - Args: - target (Image.Image): Target PIL image. - - Returns: - torch.Tensor: Encoded target. - """ - target = F.pil_to_tensor(target) - target = rearrange(target, "c h w -> h w c") - out = torch.zeros_like(target[..., :1]) - # convert target color to index - for muad_class in self.classes: - out[(target == torch.tensor(muad_class["id"], dtype=target.dtype)).all(dim=-1)] = ( - muad_class["train_id"] - ) - - return F.to_pil_image(rearrange(out, "h w c -> c h w")) - - def decode_target(self, target: Image.Image) -> np.ndarray: - target[target == 255] = 19 - return self.train_id_to_color[target] - def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the sample at the given index. @@ -188,7 +194,7 @@ def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """ image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) if self.target_type == "semantic": - target = tv_tensors.Mask(self.encode_target(Image.open(self.targets[index]))) + target = tv_tensors.Mask(Image.open(self.targets[index])) else: os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" target = Image.fromarray( @@ -211,13 +217,12 @@ def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: def check_split_integrity(self, folder: str) -> bool: split_path = self.root / self.split return ( - split_path.is_dir() - and len(list((split_path / folder).glob("**/*"))) == self._num_samples[self.split] + split_path.is_dir() and len(list((split_path / folder).glob("**/*"))) == self.__len__() ) def __len__(self) -> int: """The number of samples in the dataset.""" - return self._num_samples[self.split] + return self._num_samples[self.version][self.split] def _make_dataset(self, path: Path) -> None: """Create a list of samples and targets. @@ -241,9 +246,16 @@ def _make_dataset(self, path: Path) -> None: def _download(self, split: str) -> None: """Download and extract the chosen split of the dataset.""" - split_url = self.base_url + split + ".zip" - download_and_extract_archive(split_url, self.root, md5=self.zip_md5[split]) + if self.version == "small": + filename = f"{split}.zip" + downloaded_file = hf_hub_download( + repo_id=self.small_muad_url, filename=filename, repo_type="dataset" + ) + shutil.unpack_archive(downloaded_file, extract_dir=self.root) + else: + split_url = self.base_url + split + ".zip" + download_and_extract_archive(split_url, self.root, md5=self.zip_md5[split]) @property def color_palette(self) -> np.ndarray: - return self.train_id_to_color.tolist() + return [c.color for c in self.classes] diff --git a/torch_uncertainty/layers/__init__.py b/torch_uncertainty/layers/__init__.py index 210e0bea..689943ff 100644 --- a/torch_uncertainty/layers/__init__.py +++ b/torch_uncertainty/layers/__init__.py @@ -4,4 +4,13 @@ from .channel_layer_norm import ChannelLayerNorm from .masksembles import MaskedConv2d, MaskedLinear from .modules import Identity -from .packed import PackedConv1d, PackedConv2d, PackedConv3d, PackedLinear +from .packed import ( + PackedConv1d, + PackedConv2d, + PackedConv3d, + PackedLayerNorm, + PackedLinear, + PackedMultiheadAttention, + PackedTransformerDecoderLayer, + PackedTransformerEncoderLayer, +) diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index 4c7829c7..0236f49d 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -1,110 +1,552 @@ -from abc import ABC, abstractmethod +import inspect +import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.distributions import Distribution, Laplace, Normal -from torch_uncertainty.utils.distributions import NormalInverseGamma +def get_dist_linear_layer(dist_family: str) -> type[nn.Module]: + if dist_family == "normal": + return NormalLinear + if dist_family == "laplace": + return LaplaceLinear + if dist_family == "cauchy": + return CauchyLinear + if dist_family == "student": + return StudentTLinear + if dist_family == "nig": + return NormalInverseGammaLinear + raise NotImplementedError( + f"{dist_family} distribution is not supported. Raise an issue if needed." + ) + + +def get_dist_conv_layer(dist_family: str) -> type[nn.Module]: + if dist_family == "normal": + return NormalConvNd + if dist_family == "laplace": + return LaplaceConvNd + if dist_family == "cauchy": + return CauchyConvNd + if dist_family == "student": + return StudentTConvNd + if dist_family == "nig": + return NormalInverseGammaConvNd + raise NotImplementedError( + f"{dist_family} distribution is not supported. Raise an issue if needed." + ) + + +class _ExpandOutputLinear(nn.Module): + """Abstract class for expanding the output of any nn.Module using an `out_features` argument. + + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + num_params (int): The number of parameters to output. For instance, the normal distribution + has 2 parameters (loc and scale). + **layer_args: Additional arguments for the base layer. + """ + + def __init__(self, base_layer: type[nn.Module], event_dim: int, num_params: int, **layer_args): + if "out_features" not in inspect.getfullargspec(base_layer.__init__).args: + raise ValueError(f"{base_layer.__name__} does not have an `out_features` argument.") -class TUDist(ABC, nn.Module): - def __init__(self, dim: int) -> None: super().__init__() - if dim < 1: - raise ValueError(f"dim must be positive, got {dim}.") - self.dim = dim + self.base_layer = base_layer(out_features=num_params * event_dim, **layer_args) + self.event_dim = event_dim + + def forward(self, x: Tensor) -> Tensor: + return self.base_layer(x) + + +class _ExpandOutputConvNd(nn.Module): + """Abstract class for expanding the output of any nn.Module using an `out_channels` argument. - @abstractmethod - def forward(self, x: Tensor) -> Distribution: - pass + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + num_params (int): The number of parameters to output. For instance, the normal distribution + has 2 parameters (loc and scale). + **layer_args: Additional arguments for the base layer. + """ + + def __init__(self, base_layer: type[nn.Module], event_dim: int, num_params: int, **layer_args): + if "out_channels" not in inspect.getfullargspec(base_layer.__init__).args: + raise ValueError(f"{base_layer.__name__} does not have an `out_channels` argument.") + super().__init__() + self.base_layer = base_layer(out_channels=num_params * event_dim, **layer_args) + self.event_dim = event_dim + + def forward(self, x: Tensor) -> Tensor: + return self.base_layer(x) -class NormalLayer(TUDist): - """Normal distribution layer. - Converts model outputs to Independent Normal distributions. +class _LocScaleLinear(_ExpandOutputLinear): + """Base Linear layer for any distribution with loc and scale parameters. Args: - dim (int): The number of independent dimensions for each prediction. - eps (float): The minimal value of the :attr:`scale` parameter. + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_scale (float): The minimal value of the scale parameter. + **layer_args: Additional arguments for the base layer. """ - def __init__(self, dim: int, eps: float = 1e-6) -> None: - super().__init__(dim) - if eps <= 0: - raise ValueError(f"eps must be positive, got {eps}.") - self.eps = eps + def __init__( + self, + base_layer: type[nn.Module], + event_dim: int, + min_scale: float = 1e-6, + **layer_args, + ) -> None: + super().__init__( + base_layer=base_layer, + event_dim=event_dim, + num_params=2, + **layer_args, + ) + self.min_scale = min_scale - def forward(self, x: Tensor) -> Normal: - r"""Forward pass of the Normal distribution layer. + def forward(self, x: Tensor) -> dict[str, Tensor]: + x = super().forward(x) + loc = x[..., : self.event_dim] + scale = torch.clamp( + F.softplus(x[..., self.event_dim : 2 * self.event_dim]), min=self.min_scale + ) + return {"loc": loc, "scale": scale} - Args: - x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`2). - Returns: - Normal: The output normal distribution. - """ - loc = x[:, : self.dim] - scale = F.softplus(x[:, self.dim :]) + self.eps - return Normal(loc, scale) +class _LocScaleConvNd(_ExpandOutputConvNd): + """Base Convolutional layer for any distribution with loc and scale parameters. + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_scale (float): The minimal value of the scale parameter. + **layer_args: Additional arguments for the base layer. + """ + + def __init__( + self, + base_layer: type[nn.Module], + event_dim: int, + min_scale: float = 1e-6, + **layer_args, + ) -> None: + super().__init__( + base_layer=base_layer, + event_dim=event_dim, + num_params=2, + **layer_args, + ) + self.min_scale = min_scale + + def forward(self, x: Tensor) -> dict[str, Tensor]: + x = super().forward(x) + loc = x[:, : self.event_dim] + scale = torch.clamp( + F.softplus(x[:, self.event_dim : 2 * self.event_dim]), min=self.min_scale + ) + return {"loc": loc, "scale": scale} -class LaplaceLayer(TUDist): - """Laplace distribution layer. - Converts model outputs to Independent Laplace distributions. +class NormalLinear(_LocScaleLinear): + r"""Normal Distribution Linear Density Layer. Args: - dim (int): The number of independent dimensions for each prediction. - eps (float): The minimal value of the :attr:`scale` parameter. + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_scale (float): The minimal value of the scale parameter. + **layer_args: Additional arguments for the base layer. + + Shape: + - Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including + none and :math:`H_{in} = \text{in_features}`. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Normal distribution of shape :math:`(\ast, H_{out})` where + all but the last dimension are the same as the input and + :math:`H_{out} = \text{out_features}`. + - ``"scale"``: The standard deviation of the Normal distribution of shape + :math:`(\ast, H_{out})`. """ - def __init__(self, dim: int, eps: float = 1e-6) -> None: - super().__init__(dim) - if eps <= 0: - raise ValueError(f"eps must be positive, got {eps}.") - self.eps = eps - def forward(self, x: Tensor) -> Laplace: - r"""Forward pass of the Laplace distribution layer. +class NormalConvNd(_LocScaleConvNd): + r"""Normal Distribution Convolutional Density Layer. - Args: - x (Tensor): A tensor of shape (..., :attr:`dim` :math:`\times`2). + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of event channels. + kernel_size (int | tuple[int]): The size of the convolutional kernel. + stride (int | tuple[int]): The stride of the convolution. + padding (int | tuple[int]): The padding of the convolution. + dilation (int | tuple[int]): The dilation of the convolution. + groups (int): The number of groups in the convolution. + min_scale (float): The minimal value of the scale parameter. + device (torch.device): The device where the layer is stored. + dtype (torch.dtype): The datatype of the layer. - Returns: - Laplace: The output Laplace distribution. - """ - loc = x[..., : self.dim] - scale = F.softplus(x[..., self.dim :]) + self.eps - return Laplace(loc, scale) + Shape: + - Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and + :math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size. + - Output: A dict with the following keys + - ``"loc"``: The mean of the Normal distribution of shape :math:`(N, C_{out}, \ast)` where + :math:`C_{out} = \text{out_channels}`. + - ``"scale"``: The standard deviation of the Normal distribution of shape + :math:`(\ast, C_{out}, \ast)`. + """ -class NormalInverseGammaLayer(TUDist): - """Normal-Inverse-Gamma distribution layer. - Converts model outputs to Independent Normal-Inverse-Gamma distributions. +class LaplaceLinear(_LocScaleLinear): + r"""Laplace Distribution Linear Density Layer. Args: - dim (int): The number of independent dimensions for each prediction. - eps (float): The minimal values of the :attr:`lmbda`, :attr:`alpha`-1 - and :attr:`beta` parameters. + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_scale (float): The minimal value of the scale parameter. + **layer_args: Additional arguments for the base layer. + + Shape: + - Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including + none and :math:`H_{in} = \text{in_features}`. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Laplace distribution of shape :math:`(\ast, H_{out})` where + all but the last dimension are the same as the input and + :math:`H_{out} = \text{out_features}`. + - ``"scale"``: The standard deviation of the Laplace distribution of shape + :math:`(\ast, H_{out})`. + """ + + +class LaplaceConvNd(_LocScaleConvNd): + r"""Laplace Distribution Convolutional Density Layer. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of event channels. + kernel_size (int | tuple[int]): The size of the convolutional kernel. + stride (int | tuple[int]): The stride of the convolution. + padding (int | tuple[int]): The padding of the convolution. + dilation (int | tuple[int]): The dilation of the convolution. + groups (int): The number of groups in the convolution. + min_scale (float): The minimal value of the scale parameter. + device (torch.device): The device where the layer is stored. + dtype (torch.dtype): The datatype of the layer. + + Shape: + - Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and + :math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Laplace distribution of shape :math:`(N, C_{out}, \ast)` where + :math:`C_{out} = \text{out_channels}`. + - ``"scale"``: The standard deviation of the Laplace distribution of shape + :math:`(\ast, C_{out}, \ast)`. """ - def __init__(self, dim: int, eps: float = 1e-6) -> None: - super().__init__(dim) - self.eps = eps - def forward(self, x: Tensor) -> NormalInverseGamma: - r"""Forward pass of the NormalInverseGamma distribution layer. +class CauchyLinear(_LocScaleLinear): + r"""Cauchy Distribution Linear Density Layer. + + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_scale (float): The minimal value of the scale parameter. + **layer_args: Additional arguments for the base layer. + + Shape: + - Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including + none and :math:`H_{in} = \text{in_features}`. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Cauchy distribution of shape :math:`(\ast, H_{out})` where + all but the last dimension are the same as the input and + :math:`H_{out} = \text{out_features}`. + - ``"scale"``: The standard deviation of the Cauchy distribution of shape + :math:`(\ast, H_{out})`. + """ + + +class CauchyConvNd(_LocScaleConvNd): + r"""Cauchy Distribution Convolutional Density Layer. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of event channels. + kernel_size (int | tuple[int]): The size of the convolutional kernel. + stride (int | tuple[int]): The stride of the convolution. + padding (int | tuple[int]): The padding of the convolution. + dilation (int | tuple[int]): The dilation of the convolution. + groups (int): The number of groups in the convolution. + min_scale (float): The minimal value of the scale parameter. + device (torch.device): The device where the layer is stored. + dtype (torch.dtype): The datatype of the layer. + + Shape: + - Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and + :math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Cauchy distribution of shape :math:`(N, C_{out}, \ast)` where + :math:`C_{out} = \text{out_channels}`. + - ``"scale"``: The standard deviation of the Cauchy distribution of shape + :math:`(\ast, C_{out}, \ast)`. + """ + + +class StudentTLinear(_ExpandOutputLinear): + r"""Student's T-Distribution Linear Density Layer. + + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_scale (float): The minimal value of the scale parameter. + min_df (float): The minimal value of the degrees of freedom parameter. + fixed_df (float): If not None, the degrees of freedom parameter is fixed to this value. + Otherwise, it is learned. + **layer_args: Additional arguments for the base layer. + + Shape: + - Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including + none and :math:`H_{in} = \text{in_features}`. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Student's t-distribution of shape :math:`(\ast, H_{out})` where + all but the last dimension are the same as the input and + :math:`H_{out} = \text{out_features}`. + - ``"scale"``: The standard deviation of the Student's t-distribution of shape + :math:`(\ast, H_{out})`. + - ``"df"``: The degrees of freedom of the Student's t distribution of shape + :math:`(\ast, H_{out})` or Number. + """ + + def __init__( + self, + base_layer: type[nn.Module], + event_dim: int, + min_scale: float = 1e-6, + min_df: float = 2.0, + fixed_df: float | None = None, + **layer_args, + ) -> None: + super().__init__( + base_layer=base_layer, + event_dim=event_dim, + num_params=3 if fixed_df is None else 2, + **layer_args, + ) + + self.min_scale = min_scale + self.min_df = min_df + self.fixed_df = fixed_df + + def forward(self, x: Tensor) -> dict[str, Tensor]: + x = super().forward(x) + loc = x[..., : self.event_dim] + scale = torch.clamp( + F.softplus(x[..., self.event_dim : 2 * self.event_dim]), min=self.min_scale + ) + df = ( + torch.clamp(F.softplus(x[..., 2 * self.event_dim :]), min=self.min_df) + if self.fixed_df is None + else torch.full_like(loc, self.fixed_df) + ) + return {"loc": loc, "scale": scale, "df": df} + + +class StudentTConvNd(_ExpandOutputConvNd): + r"""Student's T-Distribution Convolutional Density Layer. + + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_scale (float): The minimal value of the scale parameter. + min_df (float): The minimal value of the degrees of freedom parameter. + fixed_df (float): If not None, the degrees of freedom parameter is fixed to this value. + Otherwise, it is learned. + **layer_args: Additional arguments for the base layer. + + Shape: + - Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and + :math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Student's t-distribution of shape :math:`(N, C_{out}, \ast)` where + :math:`C_{out} = \text{out_channels}`. + - ``"scale"``: The standard deviation of the Student's t-distribution of shape + :math:`(\ast, C_{out}, \ast)`. + - ``"df"``: The degrees of freedom of the Student's t distribution of shape + :math:`(\ast, C_{out}, \ast)`. + """ + + def __init__( + self, + base_layer: type[nn.Module], + event_dim: int, + min_scale: float = 1e-6, + min_df: float = 2.0, + fixed_df: float | None = None, + **layer_args, + ) -> None: + super().__init__( + base_layer=base_layer, + event_dim=event_dim, + num_params=3 if fixed_df is None else 2, + **layer_args, + ) + + self.min_scale = min_scale + self.min_df = min_df + self.fixed_df = fixed_df + + def forward(self, x: Tensor) -> dict[str, Tensor]: + x = super().forward(x) + loc = x[:, : self.event_dim] + scale = torch.clamp( + F.softplus(x[:, self.event_dim : 2 * self.event_dim]), min=self.min_scale + ) + df = ( + torch.clamp(F.softplus(x[:, 2 * self.event_dim :]), min=self.min_df) + if self.fixed_df is None + else torch.full_like(loc, self.fixed_df) + ) + return {"loc": loc, "scale": scale, "df": df} + + +class NormalInverseGammaLinear(_ExpandOutputLinear): + r"""Normal-Inverse-Gamma Distribution Linear Density Layer. + + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_lmbda (float): The minimal value of the :math:`\lambda` parameter. + min_alpha (float): The minimal value of the :math:`\alpha` parameter. + min_beta (float): The minimal value of the :math:`\beta` parameter. + **layer_args: Additional arguments for the base layer. + + Shape: + - Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including + none and :math:`H_{in} = \text{in_features}`. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Normal-Inverse-Gamma distribution of shape :math:`(\ast, H_{out})` where + all but the last dimension are the same as the input and + :math:`H_{out} = \text{out_features}`. + - ``"lmbda"``: The lambda parameter of the Normal-Inverse-Gamma distribution of shape + :math:`(\ast, H_{out})`. + - ``"alpha"``: The alpha parameter of the Normal-Inverse-Gamma distribution of shape + :math:`(\ast, H_{out})`. + - ``"beta"``: The beta parameter of the Normal-Inverse-Gamma distribution of shape + :math:`(\ast, H_{out})`. + + Source: + - `Normal-Inverse-Gamma Distribution `_ + """ + + def __init__( + self, + base_layer: type[nn.Module], + event_dim: int, + min_lmbda: float = 1e-6, + min_alpha: float = 1e-6, + min_beta: float = 1e-6, + **layer_args, + ) -> None: + super().__init__( + base_layer=base_layer, + event_dim=event_dim, + num_params=4, + **layer_args, + ) + + self.min_lmbda = min_lmbda + self.min_alpha = min_alpha + self.min_beta = min_beta + + def forward(self, x: Tensor) -> dict[str, Tensor]: + x = super().forward(x) + loc = x[..., : self.event_dim] + lmbda = torch.clamp( + F.softplus(x[..., self.event_dim : 2 * self.event_dim]), min=self.min_lmbda + ) + alpha = 1 + torch.clamp( + F.softplus(x[..., 2 * self.event_dim : 3 * self.event_dim]), min=self.min_alpha + ) + beta = torch.clamp(F.softplus(x[..., 3 * self.event_dim :]), min=self.min_beta) + return { + "loc": loc, + "lmbda": lmbda, + "alpha": alpha, + "beta": beta, + } + + +class NormalInverseGammaConvNd(_ExpandOutputConvNd): + r"""Normal-Inverse-Gamma Distribution Convolutional Density Layer. + + Args: + base_layer (type[nn.Module]): The base layer class. + event_dim (int): The number of event dimensions. + min_lmbda (float): The minimal value of the :math:`\lambda` parameter. + min_alpha (float): The minimal value of the :math:`\alpha` parameter. + min_beta (float): The minimal value of the :math:`\beta` parameter. + **layer_args: Additional arguments for the base layer. + + Shape: + - Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and + :math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size. + - Output: A dict with the following keys + + - ``"loc"``: The mean of the Normal-Inverse-Gamma distribution of shape :math:`(N, C_{out}, \ast)` where + :math:`C_{out} = \text{out_channels}`. + - ``"lmbda"``: The lambda parameter of the Normal-Inverse-Gamma distribution of shape + :math:`(N, C_{out}, \ast)`. + - ``"alpha"``: The alpha parameter of the Normal-Inverse-Gamma distribution of shape + :math:`(N, C_{out}, \ast)`. + - ``"beta"``: The beta parameter of the Normal-Inverse-Gamma distribution of shape + :math:`(N, C_{out}, \ast)`. + + Source: + - `Normal-Inverse-Gamma Distribution `_ + """ + + def __init__( + self, + base_layer: type[nn.Module], + event_dim: int, + min_lmbda: float = 1e-6, + min_alpha: float = 1e-6, + min_beta: float = 1e-6, + **layer_args, + ) -> None: + super().__init__( + base_layer=base_layer, + event_dim=event_dim, + num_params=4, + **layer_args, + ) - Args: - x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`4). + self.min_lmbda = min_lmbda + self.min_alpha = min_alpha + self.min_beta = min_beta - Returns: - NormalInverseGamma: The output NormalInverseGamma distribution. - """ - loc = x[..., : self.dim] - lmbda = F.softplus(x[..., self.dim : 2 * self.dim]) + self.eps - alpha = 1 + F.softplus(x[..., 2 * self.dim : 3 * self.dim]) + self.eps - beta = F.softplus(x[..., 3 * self.dim :]) + self.eps - return NormalInverseGamma(loc, lmbda, alpha, beta) + def forward(self, x: Tensor) -> dict[str, Tensor]: + x = super().forward(x) + loc = x[:, : self.event_dim] + lmbda = torch.clamp( + F.softplus(x[:, self.event_dim : 2 * self.event_dim]), min=self.min_lmbda + ) + alpha = 1 + torch.clamp( + F.softplus(x[:, 2 * self.event_dim : 3 * self.event_dim]), min=self.min_alpha + ) + beta = torch.clamp(F.softplus(x[:, 3 * self.event_dim :]), min=self.min_beta) + return { + "loc": loc, + "lmbda": lmbda, + "alpha": alpha, + "beta": beta, + } diff --git a/torch_uncertainty/layers/functional/__init__.py b/torch_uncertainty/layers/functional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py new file mode 100644 index 00000000..c962531e --- /dev/null +++ b/torch_uncertainty/layers/functional/packed.py @@ -0,0 +1,480 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor + + +def packed_linear( + inputs: Tensor, + weight: Tensor, + num_groups: int, + implementation: str, + bias: Tensor | None = None, +) -> Tensor: + r"""Applies a packed linear transformation to the incoming data using the given implementation. + + Args: + inputs (Tensor): :math:`(\star, \text{in\_features})` where :math:`\star` is any number of + additional dimensions including none. + weight (Tensor): :math:(\text{num\_groups}, \frac{\text{out\_features}}{\text{num\_groups}}, \frac{\text{in\_features}}{\text{num\_groups}})`. + num_groups (int): number of groups to split the input. + implementation (str): the implementation of the packed linear operation. Three + implementations are currently supported: + - "full": creates a block diagonal matrix from the weight tensor and applies the linear + transformation using `torch.nn.functional.linear`. + - "sparse": uses a sparse weight tensor directly to apply the linear transformation. + - "einsum": uses `torch.einsum` to apply the packed linear transformation. + rearrange (bool, optional): _description_. Defaults to True. + bias (Tensor | None, optional): _description_. Defaults to None. + + Returns: + Tensor: + """ + if implementation == "full": + block_diag = torch.block_diag(*weight) + return F.linear(inputs, block_diag, bias) + if implementation == "sparse": + out = inputs @ weight.transpose(0, 1) + if bias is not None: + out += bias + return out + if implementation == "einsum": + out = torch.einsum( + "...ki,kij->...kj", + rearrange(inputs, "... (m d) -> ... m d", m=num_groups), + weight.transpose(1, 2), + ).flatten(start_dim=-2) + if bias is not None: + out += bias + return out + raise ValueError(f"Unknown implementation: {implementation}") + + +def packed_in_projection( + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + num_groups: int, + implementation: str = "full", + b_q: Tensor | None = None, + b_k: Tensor | None = None, + b_v: Tensor | None = None, +) -> tuple[Tensor, Tensor, Tensor]: + emb_q, emb_k, emb_v = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == ( + num_groups, + emb_q // num_groups, + emb_q // num_groups, + ), f"expecting query weights shape of {(emb_q, emb_q)}, but got {w_q.shape}" + assert w_k.shape == ( + num_groups, + emb_q // num_groups, + emb_k // num_groups, + ), f"expecting key weights shape of {(emb_q, emb_k)}, but got {w_k.shape}" + assert w_v.shape == ( + num_groups, + emb_q // num_groups, + emb_v // num_groups, + ), f"expecting value weights shape of {(emb_q, emb_v)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == ( + emb_q, + ), f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + emb_q, + ), f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + emb_q, + ), f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" + + return ( + packed_linear(q, w_q, num_groups, implementation, b_q), + packed_linear(k, w_k, num_groups, implementation, b_k), + packed_linear(v, w_v, num_groups, implementation, b_v), + ) + + +def packed_in_projection_packed( + q: Tensor, + k: Tensor, + v: Tensor, + w: Tensor, + num_groups: int, + implementation: str = "full", + b: Tensor | None = None, +) -> tuple[Tensor, Tensor, Tensor]: + emb = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = packed_linear( + inputs=q, weight=w, num_groups=num_groups, implementation=implementation, bias=b + ) + # reshape to 3, emb and not emb, 3 is deliberate for better memory + # coalescing and keeping same order as chunk() + proj = ( + proj.unflatten(-1, (3, emb)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + ) + return proj[0], proj[1], proj[2] + + # encoder-decoder attention + _tmp_dim = w.size(-1) + w_q, w_kv = w.split([_tmp_dim, 2 * _tmp_dim], dim=1) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([emb, 2 * emb]) + q_proj = packed_linear( + inputs=q, weight=w_q, num_groups=num_groups, implementation=implementation, bias=b_q + ) + kv_proj = packed_linear( + inputs=k, weight=w_kv, num_groups=num_groups, implementation=implementation, bias=b_kv + ) + # reshape to 2, emb and not emb, 2 is deliberate for better memory + # coalescing and keeping same order as chunk() + kv_proj = ( + kv_proj.unflatten(-1, (2, emb)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + ) + return q_proj, kv_proj[0], kv_proj[1] + + w_q, w_k, w_v = w.chunk(3, dim=1) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return ( + packed_linear( + inputs=q, weight=w_q, num_groups=num_groups, implementation=implementation, bias=b_q + ), + packed_linear( + inputs=k, weight=w_k, num_groups=num_groups, implementation=implementation, bias=b_k + ), + packed_linear( + inputs=v, weight=w_v, num_groups=num_groups, implementation=implementation, bias=b_v + ), + ) + + +def packed_multi_head_attention_forward( # noqa: D417 + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + num_groups: int, + in_proj_weight: Tensor | None, + in_proj_bias: Tensor | None, + bias_k: Tensor | None, + bias_v: Tensor | None, + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor | None, + implementation: str = "einsum", + training: bool = True, + key_padding_mask: Tensor | None = None, + need_weights: bool = False, # TODO: add support + attn_mask: Tensor | None = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Tensor | None = None, + k_proj_weight: Tensor | None = None, + v_proj_weight: Tensor | None = None, + static_k: Tensor | None = None, + static_v: Tensor | None = None, + average_attn_weights: bool = True, # TODO: add support # noqa: ARG001 + is_causal: bool = False, +) -> tuple[Tensor, Tensor | None]: + r"""Parallel Multihead Attention (pMHA) with packed inputs. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + implementation (str, optional): the implementation of the packed linear operation. Three + implementations are currently supported: + - ``"full"``: creates a block diagonal matrix from the weight tensor and applies the + linear transformation using `torch.nn.functional.linear`. + - ``"sparse"``: uses a sparse weight tensor directly to apply the linear + transformation. + - ``"einsum"``: uses `torch.einsum` to apply the packed linear transformation. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + is_causal: If specified, applies a causal mask as attention mask, and ignores + attn_mask for computing scaled dot product attention. + Default: ``False``. + .. warning:: + is_causal is provides a hint that the attn_mask is the + causal mask.Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + average_attn_weights: If ``True``, indicates that the returned ``attn_weights`` should be averaged across heads. + Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect + when ``need_weights=True.``. Default to ``True``. + + Shape: + Inputs: + - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a FloatTensor is provided, it will be directly added to the value. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns + attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. + + References: + Implementation of the packed multi-head attention is based on the PyTorch implementation of the + `torch.nn.MultiheadAttention` module. The implementation is adapted to support packed inputs. + """ + is_batched = F._mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = packed_in_projection_packed( + q=query, k=key, v=value, w=in_proj_weight, num_groups=num_groups, b=in_proj_bias + ) + else: + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + + q, k, v = packed_in_projection( + q=query, + k=key, + v=value, + w_q=q_proj_weight, + w_k=k_proj_weight, + w_v=v_proj_weight, + num_groups=num_groups, + implementation=implementation, + b_q=b_q, + b_k=b_k, + b_v=b_v, + ) + + # prep attention mask + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + # unreachable code due to the check above (F._mha_shape_check, l.274) + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + # unreachable code due to the check above (F._mha_shape_check, l.274) + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)], dim=0) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)], dim=0) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None, "bias_k is not None" + assert bias_v is None, "bias_v is not None" + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = rearrange(q, "l b (h d) -> b h l d", h=num_heads) + k = rearrange(k, "s b (h d) -> b h s d", h=num_heads) + v = rearrange(v, "s b (h d) -> b h s d", h=num_heads) + + if add_zero_attn: + zero_attn_shape = (bsz, num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=2) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=2) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + src_len = k.size(2) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, src_len) + ) + attn_mask = key_padding_mask if attn_mask is None else attn_mask + key_padding_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + if need_weights: + raise NotImplementedError("need_weights is not supported yet") + + # attn_mask can be either (L,S) or (N*num_key_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_key_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + + attn_output = rearrange(attn_output, "b h l d -> (l b) (h d)") + + attn_output = packed_linear( + attn_output, out_proj_weight, num_groups, implementation, out_proj_bias + ) + + attn_output = rearrange(attn_output, "(l b) d -> l b d", l=tgt_len) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 490f65ee..71537b82 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -1,4 +1,5 @@ import math +from collections.abc import Callable from typing import Any import torch @@ -7,6 +8,8 @@ from torch.nn import functional as F from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t +from .functional.packed import packed_linear, packed_multi_head_attention_forward + def check_packed_parameters_consistency(alpha: float, gamma: int, num_estimators: int) -> None: """Check the consistency of the parameters of the Packed-Ensembles layers. @@ -190,18 +193,13 @@ def forward(self, inputs: Tensor) -> Tensor: if self.rearrange: return self._rearrange_forward(inputs) return F.conv1d(inputs, self.weight, self.bias, 1, 0, 1, self.groups) - if self.implementation == "full": - block_diag = torch.block_diag(*self.weight) - return F.linear(inputs, block_diag, self.bias) - if self.implementation == "sparse": - return (inputs @ self.weight.transpose(0, 1)) + self.bias - if self.implementation == "einsum": - return torch.einsum( - "bki,kij->bkj", - inputs.view(-1, self.groups, self.in_features), - self.weight.transpose(1, 2), - ).flatten(start_dim=-2, end_dim=-1) - raise ValueError(f"Unknown implementation: {self.implementation}") + return packed_linear( + inputs=inputs, + weight=self.weight, + num_groups=self.groups, + implementation=self.implementation, + bias=self.bias, + ) class PackedConv1d(nn.Module): @@ -565,3 +563,876 @@ def weight(self) -> Tensor: def bias(self) -> Tensor | None: r"""The bias of the underlying convolutional layer.""" return self.conv.bias + + +class PackedLayerNorm(nn.GroupNorm): + def __init__( + self, + embed_dim: int, + num_estimators: int, + alpha: float, + eps: float = 1e-5, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + r"""Packed-Ensembles-style LayerNorm layer. + + Args: + embed_dim (int): the number of features in the input tensor. + num_estimators (int): the number of estimators in the ensemble. + alpha (float): the width multiplier of the layer. + eps (float, optional): a value added to the denominator for numerical stability. Defaults + to 1e-5. + affine (bool, optional): a boolean value that when set to ``True``, this module has + learnable per_channel affine parameters initialized to ones (for weights) and zeros + (for biases). Defaults to ``True``. + device (torch.device, optional): the device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): the dtype to use for the layer's parameters. Defaults to + ``None``. + + Shape: + - Input: :math:`(B, *)` where :math:`*` means any number of additional dimensions. + - Output: :math:`(B, *)` (same shape as input) + """ + super().__init__( + num_groups=num_estimators, + num_channels=int(embed_dim * alpha), + eps=eps, + affine=affine, + device=device, + dtype=dtype, + ) + + def forward(self, inputs: Tensor) -> Tensor: + x = rearrange(inputs, "b ... h -> b h ...") + x = F.group_norm( + x, + self.num_groups, + self.weight, + self.bias, + self.eps, + ) + return rearrange(x, "b h ... -> b ... h") + + +class PackedMultiheadAttention(nn.Module): + __constants__ = ["batch_first"] + bias_k: Tensor | None + bias_v: Tensor | None + + def __init__( + self, + embed_dim: int, + num_heads: int, + alpha: float, + num_estimators: int, + gamma: int = 1, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim: int | None = None, + vdim: int | None = None, + batch_first=False, + first=False, + last=False, + device=None, + dtype=None, + ) -> None: + r"""Packed-Ensembles-style MultiheadAttention layer. + + Args: + embed_dim (int): Size of the embedding dimension. + num_heads (int): Number of parallel attention heads. + alpha (float): The width multiplier of the embedding dimension. + num_estimators (int): The number of estimators packed in the layer. + gamma (int, optional): Defaults to ``1``. + dropout (float, optional): Dropout probability on ``attn_output_weights``. Defaults to ``0.0`` + (no dropout). + bias (bool, optional): Ì specified, adds bias to input / output projection layers. + Defaults to ``True``. + add_bias_kv (bool, optional): If specified, adds bias to the key and value sequences at + ``dim=0``. Defaults to ``False``. + add_zero_attn (bool, optional): If specified, adds a new batch of zeros to the key and + value sequences at ``dim=1``. Defaults to ``False``. + kdim (int | None, optional): Total number of features for keys. Defaults to ``None`` + (uses ``kdim=embed_dim``). + vdim (int | None, optional): Total number of features for values. Defaults to ``None`` + (uses ``vdim=embed_dim``). + batch_first (bool, optional): If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Defaults to ``False`` (seq, batch, feature). + first (bool, optional): Whether this is the first layer of the network. Defaults to + ``False``. + last (bool, optional): Whether this is the last layer of the network. Defaults to + ``False``. + device (torch.device, optional): The device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to + ``None``. + + Reference: + - `Attention Is All You Need `_: Original Multihead Attention formulation. + - `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting `_ + : Packed-Ensembles-style Multihead Attention formulation. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.embed_dim = int(embed_dim * alpha) + + augmentation = 1 if first else alpha + in_embed_dim = int(embed_dim * augmentation) + self.kdim = int(self.kdim * augmentation) + self.vdim = int(self.vdim * augmentation) + + self.num_groups = 1 if first else num_estimators * gamma + + self.num_heads = num_heads * self.num_groups + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = self.embed_dim // self.num_heads + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.num_estimators = num_estimators + self.alpha = alpha + self.gamma = gamma + + if not self._qkv_same_embed_dim: + self.q_proj_weight = nn.Parameter( + torch.empty( + ( + self.num_groups, + self.embed_dim // self.num_groups, + in_embed_dim // self.num_groups, + ), + **factory_kwargs, + ) + ) + self.k_proj_weight = nn.Parameter( + torch.empty( + ( + self.num_groups, + self.embed_dim // self.num_groups, + self.kdim // self.num_groups, + ), + **factory_kwargs, + ) + ) + self.v_proj_weight = nn.Parameter( + torch.empty( + ( + self.num_groups, + self.embed_dim // self.num_groups, + self.vdim // self.num_groups, + ), + **factory_kwargs, + ) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = nn.Parameter( + torch.empty( + ( + self.num_groups, + 3 * self.embed_dim // self.num_groups, + in_embed_dim // self.num_groups, + ), + **factory_kwargs, + ) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = nn.Parameter(torch.empty(3 * self.embed_dim, **factory_kwargs)) + else: + self.register_parameter("in_proj_bias", None) + + if add_bias_kv: + self.bias_k = nn.Parameter(torch.empty((1, 1, self.embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, self.embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + self.out_proj = PackedLinear( + in_features=embed_dim, + out_features=embed_dim, + alpha=alpha, + num_estimators=num_estimators, + gamma=gamma, + implementation="einsum", + bias=bias, + first=False, + last=last, + **factory_kwargs, + ) + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + for i in range(self.in_proj_weight.size(0)): + nn.init.xavier_uniform_(self.in_proj_weight[i]) + else: + for i in range(self.q_proj_weight.size(0)): + nn.init.xavier_uniform_(self.q_proj_weight[i]) + nn.init.xavier_uniform_(self.k_proj_weight[i]) + nn.init.xavier_uniform_(self.v_proj_weight[i]) + + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Tensor | None = None, + need_weights: bool = False, + attn_mask: Tensor | None = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, None]: + r"""Computes attention outputs given query, key, and value tensors. + + Args: + query (Tensor): Query embeddings of shape :math:`(L, E_q)` for unbatched input, + :math:`(L, B, E_q)` when ``batch_first=False`` or :math:`(B, L, E_q)` when + ``batch_first=True``, where :math:`L` is the target sequence length, :math:`B` is + the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + key (Tensor): Key embeddingd of shape :math:`(S, E_k)` for unbatched input, + :math:`(S, B, E_k)` when ``batch_first=False`` or :math:`(B, S, E_k)` when + ``batch_first=True``, where :math:`S` is the source sequence length, :math:`B` is + the batch size and :math:`E_k` is the key embedding dimension ``kdim``. + value (Tensor): Value embeddings of shape :math:`(S, E_v)` for unbatched input, + :math:`(S, B, E_v)` when ``batch_first=False`` or :math:`(B, S, E_v)` when + ``batch_first=True``, where :math:`S` is the source sequence length, :math:`B` is + the batch size and :math:`E_v` is the value embedding dimension ``vdim``. + key_padding_mask (Tensor | None, optional): If specified, a mask of shape + :math:`(B, S)` indicating which elements within ``key`` to ignore for the purpose + of attention (i.e. treat as "padding"). For unbatched `query`, shape should be + :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` + value indicates that the corresponding ``key`` value will be ignored for the + purpose of attention. For a float mask, it will be directly added to the + corresponding ``key`` value. Defaults to ``None``. + need_weights (bool, optional): If specified, returns ``attn_output_weights`` in + addition to ``attn_outputs``. Set ``need_weights=False`` to use the optimized + ``scale_dot_product_attention`` and achieve the best performance for MHA. + Defaults to ``False``. + attn_mask (Tensor | None, optional): If specified, a 2D or 3D mask preventing attention + to certain positions. Must be of shape :math:`(L,S)` or + :math:`(B \times \text{num_heads}, L, S)`, where :math:`B` is the batch size, :math:`L` + is the target sequence length, and :math:`S` is the source sequence length. A 2D mask + will be broadcasted across the batch while a 3D mask allows for a different mask for + each entry in the batch. Binary and float masks are supported. For a binary mask, a + ``True`` value indicates that the corresponding position is not allowed to attend to. + For a float mask, the mask values will be added to the attention weight. If both + ``attn_mask`` and ``key_padding_mask`` are provided, their types should match. + Defaults to ``None``. + average_attn_weights (bool, optional): If ``True``, indicates that the returned + ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are + provided separately per head. Note that this flag only has an effect when + ``need_weights=True``. Defaults to ``True``. + is_causal (bool, optional): _description_. Defaults to ``False``. + + Warning: + ``need_weights=True`` and therefore ``average_attn_weights`` are not supported yet thus + have no effect. + + Returns: + tuple[Tensor, None]: + - *attn_output* (Tensor): The output tensor of shape :math:`(L, E_q)`, :math:`(L, B, E_q)` + or :math:`(B, L, E_q)` where :math:`L` is the target sequence length, :math:`B` is + the batch size, and :math:`E_q` is the embedding dimension ``embed_dim``. + - *attn_output_weights* (None): Always ``None`` has we do not support + ``need_weights=True`` yet. + """ + is_batched = query.dim() == 3 + + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = (x.transpose(1, 0) for x in (query, key)) + value = key + else: + query, key, value = (x.transpose(1, 0) for x in (query, key, value)) + + if not self._qkv_same_embed_dim: + ( + attn_output, + _, + ) = packed_multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.num_groups, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + else: + ( + attn_output, + _, + ) = packed_multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.num_groups, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), None + return attn_output, None + + +class PackedTransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + alpha: float, + num_estimators: int, + gamma: int = 1, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Callable[[Tensor], Tensor] = F.relu, + layer_norm_eps: float = 1e-5, + bias: bool = True, + batch_first: bool = False, + norm_first: bool = False, + first: bool = False, + last: bool = False, + device=None, + dtype=None, + ) -> None: + r"""Packed-Ensembles-style TransformerEncoderLayer (made up of self-attention followed by a + feedforward network). + + Args: + d_model (int): the number of expected features in the input. + nhead (int): the number of heads in the multiheadattention models. + alpha (float): the width multiplier of the layer. + num_estimators (int): the number of estimators packed in the layer. + gamma (int, optional): Defaults to ``1``. + dim_feedforward (int, optional): the dimension of the feedforward network model. Defaults + to ``2048``. + dropout (float, optional): the dropout value. Defaults to ``0.1``. + activation (Callable[[Tensor], Tensor], optional): the activation function of the + intermediate layer, that is a unary callable. Defaults to ``F.relu``. + layer_norm_eps (float, optional): the eps value in layer normalization components. Defaults + to ``1e-5``. + bias (bool, optional): If ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an + additive bias. Defaults to ``True``. + batch_first (bool, optional): If ``True``, then the input and output tensors are provided + as :math:`(\text{batch}, \text{seq}, \text{d_model})`. Defaults to ``False`` + :math:`(\text{seq}, \text{batch}, \text{d_model})`. + norm_first (bool, optional): If ``True``, the layer norm is done prior to attention and + feedforward operations, respectively. Otherwise, it is done after. Defaults to + ``False``. + first (bool, optional): Whether this is the first layer of the network. Defaults to + ``False``. + last (bool, optional): Whether this is the last layer of the network. Defaults to + ``False``. + device (torch.device, optional): The device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to + ``None``. + + Reference: + - `Attention Is All You Need `_: Original Multihead Attention formulation. + - `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting `_ + : Packed-Ensembles-style Multihead Attention formulation. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.self_attn = PackedMultiheadAttention( + embed_dim=d_model, + num_heads=nhead, + alpha=alpha, + num_estimators=num_estimators, + bias=bias, + gamma=gamma, + dropout=dropout, + batch_first=batch_first, + first=first, + **factory_kwargs, + ) + + self.linear1 = PackedLinear( + in_features=d_model, + out_features=dim_feedforward, + alpha=alpha, + num_estimators=num_estimators, + gamma=gamma, + implementation="einsum", + bias=bias, + **factory_kwargs, + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = PackedLinear( + in_features=dim_feedforward, + out_features=d_model, + alpha=alpha, + num_estimators=num_estimators, + gamma=gamma, + implementation="einsum", + last=last, + bias=bias, + **factory_kwargs, + ) + + self.norm_first = norm_first + if self.norm_first and first: + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + else: + self.norm1 = PackedLayerNorm( + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, + eps=layer_norm_eps, + **factory_kwargs, + ) + + if not self.norm_first and last: + self.norm2 = PackedLayerNorm( + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, + eps=layer_norm_eps, + **factory_kwargs, + ) + else: + self.norm2 = PackedLayerNorm( + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, + eps=layer_norm_eps, + **factory_kwargs, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + if activation is F.relu or isinstance(activation, torch.nn.ReLU): + self.activation_relu_or_gelu = 1 + elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def forward( + self, + src: Tensor, + src_mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + is_causal: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src (Tensor): The sequence to the encoder layer. Shape: :math:`(B, L, E)` or + :math:`(L, B, E)`. + src_mask (Tensor | None, optional): The mask for the ``src`` sequence. Defaults to ``None``. + src_key_padding_mask (Tensor | None, optional): The mask for the ``src`` keys per + batch. Defaults to ``None``. + is_causal (bool, optional): If specified, applies a causal mask as ``src_mask``. + Defaults to ``False``. Warning: ``is_causal`` provides a hint the ``src_mask`` is + a causal mask. Providing incorrect hints can result in incorrect execution, + including forward and backward compatibility. + + Returns: + Tensor: The output of the encoder layer. Shape: :math:`(B, L, E)` or :math:`(L, B, E)`. + """ + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(src_mask), + other_name="src_mask", + target_type=src.dtype, + ) + + src_mask = F._canonical_mask( + mask=src_mask, + mask_name="src_mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + x = src + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), + src_mask, + src_key_padding_mask, + is_causal=is_causal, + ) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal) + ) + x = self.norm2(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + )[0] + return self.dropout1(x) + + # feed-forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class PackedTransformerDecoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + alpha: int, + num_estimators: int, + gamma: int = 1, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Callable[[Tensor], Tensor] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + first: bool = False, + last: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + r"""Packed-Ensembles-style TransformerDecoderLayer (made up of self-attention, multi-head + attention, and feedforward network). + + Args: + d_model (int): the number of expected features in the input. + nhead (int): the number of heads in the multiheadattention models. + alpha (float): the width multiplier of the layer. + num_estimators (int): the number of estimators packed in the layer. + gamma (int, optional): Defaults to ``1``. + dim_feedforward (int, optional): the dimension of the feedforward network model. Defaults + to ``2048``. + dropout (float, optional): the dropout value. Defaults to ``0.1``. + activation (Callable[[Tensor], Tensor], optional): the activation function of the + intermediate layer, that is a unary callable. Defaults to ``F.relu``. + layer_norm_eps (float, optional): the eps value in layer normalization components. Defaults + to ``1e-5``. + bias (bool, optional): If ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an + additive bias. Defaults to ``True``. + batch_first (bool, optional): If ``True``, then the input and output tensors are provided + as :math:`(\text{batch}, \text{seq}, \text{d_model})`. Defaults to ``False`` + :math:`(\text{seq}, \text{batch}, \text{d_model})`. + norm_first (bool, optional): If ``True``, the layer norm is done prior to attention and + feedforward operations, respectively. Otherwise, it is done after. Defaults to + ``False``. + first (bool, optional): Whether this is the first layer of the network. Defaults to + ``False``. + last (bool, optional): Whether this is the last layer of the network. Defaults to + ``False``. + device (torch.device, optional): The device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to + ``None``. + + Reference: + - `Attention Is All You Need `_: Original Multihead Attention formulation. + - `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting `_ + : Packed-Ensembles-style Multihead Attention formulation. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.self_attn = PackedMultiheadAttention( + embed_dim=d_model, + num_heads=nhead, + alpha=alpha, + num_estimators=num_estimators, + gamma=gamma, + dropout=dropout, + bias=bias, + batch_first=batch_first, + first=first, + **factory_kwargs, + ) + + self.multihead_attn = PackedMultiheadAttention( + embed_dim=d_model, + num_heads=nhead, + alpha=alpha, + num_estimators=num_estimators, + gamma=gamma, + dropout=dropout, + bias=bias, + batch_first=batch_first, + **factory_kwargs, + ) + + self.linear1 = PackedLinear( + in_features=d_model, + out_features=dim_feedforward, + alpha=alpha, + num_estimators=num_estimators, + gamma=gamma, + implementation="einsum", + bias=bias, + **factory_kwargs, + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = PackedLinear( + in_features=dim_feedforward, + out_features=d_model, + alpha=alpha, + num_estimators=num_estimators, + gamma=gamma, + implementation="einsum", + bias=bias, + last=last, + **factory_kwargs, + ) + + self.norm_first = norm_first + if self.norm_first and first: + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + else: + self.norm1 = PackedLayerNorm( + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, + eps=layer_norm_eps, + **factory_kwargs, + ) + + self.norm2 = PackedLayerNorm( + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, + eps=layer_norm_eps, + **factory_kwargs, + ) + + if not self.norm_first and last: + self.norm3 = PackedLayerNorm( + embed_dim=d_model, + num_estimators=num_estimators, + alpha=num_estimators, + eps=layer_norm_eps, + **factory_kwargs, + ) + else: + self.norm3 = PackedLayerNorm( + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, + eps=layer_norm_eps, + **factory_kwargs, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + if activation is F.relu or isinstance(activation, torch.nn.ReLU): + self.activation_relu_or_gelu = 1 + elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + ) -> Tensor: + r"""Pass the input (and mask) through the decoder layer. + + Args: + tgt (Tensor): The sequence to the decoder layer. Shape: :math:`(B, L, E)` or + :math:`(L, B, E)`. + memory (Tensor): The sequence from the last layer of the encoder. Shape: + :math:`(B, S, E)` or :math:`(S, B, E)`. + tgt_mask (Tensor | None, optional): The mask for the ``tgt`` sequence. Defaults to + ``None``. + memory_mask (Tensor | None, optional): The mask for the ``memory`` sequence. Defaults + to ``None``. + tgt_key_padding_mask (Tensor | None, optional): The mask for the ``tgt`` keys per + batch. Defaults to ``None``. + memory_key_padding_mask (Tensor | None, optional): The mask for the ``memory`` keys per + batch. Defaults to ``None``. + tgt_is_causal (bool, optional): If specified, applies a causal mask as ``tgt_mask``. + Defaults to ``False``. Warning: ``tgt_is_causal`` provides a hint the ``tgt_mask`` + is a causal mask. Providing incorrect hints can result in incorrect execution, + including forward and backward compatibility. + memory_is_causal (bool, optional): If specified, applies a causal mask as ``memory_mask``. + Defaults to ``False``. Warning: ``memory_is_causal`` provides a hint the ``memory_mask`` + is a causal mask. Providing incorrect hints can result in incorrect execution, + including forward and backward compatibility. + + Returns: + Tensor: The output of the encoder layer. Shape: :math:`(B, L, E)` or :math:`(L, B, E)`. + """ + x = tgt + if self.norm_first: + x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal) + x = x + self._mha_block( + self.norm2(x), + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + ) + x = x + self._ff_block(self.norm3(x)) + else: + x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)) + x = self.norm2( + x + + self._mha_block( + x, + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + ) + ) + x = self.norm3(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + )[0] + return self.dropout1(x) + + # multi-head attention block + def _mha_block( + self, + x: Tensor, + memory: Tensor, + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, + is_causal: bool = False, + ) -> Tensor: + x = self.multihead_attn( + x, + memory, + memory, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + )[0] + return self.dropout2(x) + + # feed-forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) diff --git a/torch_uncertainty/losses/bayesian.py b/torch_uncertainty/losses/bayesian.py index 6c17913e..3e0e632e 100644 --- a/torch_uncertainty/losses/bayesian.py +++ b/torch_uncertainty/losses/bayesian.py @@ -1,7 +1,9 @@ import torch from torch import Tensor, nn +from torch.distributions import Independent from torch_uncertainty.layers.bayesian import bayesian_modules +from torch_uncertainty.utils.distributions import get_dist_class class KLDiv(nn.Module): @@ -37,6 +39,7 @@ def __init__( inner_loss: nn.Module, kl_weight: float, num_samples: int, + dist_family: str | None = None, ) -> None: """The Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks. @@ -48,6 +51,8 @@ def __init__( inner_loss (nn.Module): The loss function to use during training kl_weight (float): The weight of the KL divergence term num_samples (int): The number of samples to use for the ELBO loss + dist_family (str, optional): The distribution family to use for the + output of the model. ``None`` means point-wise prediction. Defaults to ``None``. Note: Set the model to None if you use the ELBOLoss within @@ -60,6 +65,7 @@ def __init__( self.inner_loss = inner_loss self.kl_weight = kl_weight self.num_samples = num_samples + self.dist_family = dist_family def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: """Gather the KL divergence from the Bayesian modules and aggregate @@ -73,9 +79,13 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: Tensor: The aggregated ELBO loss """ aggregated_elbo = torch.zeros(1, device=inputs.device) + dist_class = get_dist_class(self.dist_family) if self.dist_family is not None else None for _ in range(self.num_samples): - logits = self.model(inputs) - aggregated_elbo += self.inner_loss(logits, targets) + out = self.model(inputs) + if dist_class is not None: + # Wrap the distribution in an Independent distribution for log_prob computation. + out = Independent(dist_class(**out), 1) + aggregated_elbo += self.inner_loss(out, targets) # TODO: This shouldn't be necessary aggregated_elbo += self.kl_weight * self._kl_div().to(inputs.device) return aggregated_elbo / self.num_samples diff --git a/torch_uncertainty/losses/regression.py b/torch_uncertainty/losses/regression.py index 32cbe02f..888de286 100644 --- a/torch_uncertainty/losses/regression.py +++ b/torch_uncertainty/losses/regression.py @@ -2,7 +2,7 @@ import torch from torch import Tensor, nn -from torch.distributions import Distribution +from torch.distributions import Distribution, Independent from torch_uncertainty.utils.distributions import NormalInverseGamma @@ -34,12 +34,12 @@ def forward( """ loss = -dist.log_prob(targets) if padding_mask is not None: - loss = loss.masked_fill(padding_mask, 0.0) + loss = loss.masked_fill(padding_mask, float("nan")) if self.reduction == "mean": - loss = loss.mean() + loss = loss.nanmean() elif self.reduction == "sum": - loss = loss.sum() + loss = loss.nansum() return loss @@ -71,7 +71,10 @@ def __init__(self, reg_weight: float, reduction: str | None = "mean") -> None: ) self.reg_weight = reg_weight - def _reg(self, dist: NormalInverseGamma, targets: Tensor) -> Tensor: + def _reg(self, dist: NormalInverseGamma | Independent, targets: Tensor) -> Tensor: + if isinstance(dist, Independent): + dist = dist.base_dist + return torch.norm(targets - dist.loc, 1, dim=1, keepdim=True) * ( 2 * dist.lmbda + dist.alpha ) diff --git a/torch_uncertainty/metrics/classification/mean_iou.py b/torch_uncertainty/metrics/classification/mean_iou.py index 54dd5a0b..5d9ef56d 100644 --- a/torch_uncertainty/metrics/classification/mean_iou.py +++ b/torch_uncertainty/metrics/classification/mean_iou.py @@ -1,13 +1,9 @@ -from typing import Literal - from torch import Tensor from torchmetrics.classification.stat_scores import MulticlassStatScores from torchmetrics.utilities.compute import _safe_divide class MeanIntersectionOverUnion(MulticlassStatScores): - """Compute the MeanIntersection over Union (IoU) score.""" - is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False @@ -16,16 +12,44 @@ def __init__( self, num_classes: int, top_k: int = 1, - multidim_average: Literal["global", "samplewise"] = "global", ignore_index: int | None = None, validate_args: bool = True, **kwargs, ) -> None: + r"""Computes Mean Intersection over Union (IoU) score. + + Args: + num_classes (int): Integer specifying the number of classes. + top_k (int, optional): Number of highest probability or logit score predictions + considered to find the correct label. Only works when ``preds`` contain + probabilities/logits. Defaults to ``1``. + ignore_index (int | None, optional): Specifies a target value that is ignored and does + not contribute to the metric calculation. Defaults to ``None``. + validate_args (bool, optional): Bool indicating if input arguments and tensors should + be validated for correctness. Set to ``False`` for faster computations. Defaults to + ``True``. + **kwargs: kwargs: Additional keyword arguments, see + `Advanced metric settings `_ + for more info. + + Shape: + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, ...)`` or float tensor of shape ``(B, C, ..)``. + If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert + probabilities/logits into an int tensor. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, ...)``. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``mean_iou`` (:class:`~torch.Tensor`): The computed Mean Intersection over Union (IoU) score. + A tensor containing a single float value. + """ super().__init__( num_classes, top_k, "macro", - multidim_average, + "global", ignore_index, validate_args, **kwargs, @@ -34,4 +58,5 @@ def __init__( def compute(self) -> Tensor: """Compute the Means Intersection over Union (MIoU) based on saved inputs.""" tp, fp, _, fn = self._final_state() - return _safe_divide(tp, tp + fp + fn).mean() + + return _safe_divide(tp, tp + fp + fn, zero_division=float("nan")).nanmean() diff --git a/torch_uncertainty/metrics/regression/nll.py b/torch_uncertainty/metrics/regression/nll.py index 9db4dd31..94af660e 100644 --- a/torch_uncertainty/metrics/regression/nll.py +++ b/torch_uncertainty/metrics/regression/nll.py @@ -21,20 +21,20 @@ def update( """ nlog_prob = -dist.log_prob(target) if padding_mask is not None: - nlog_prob = nlog_prob.masked_fill(padding_mask, 0.0) + nlog_prob = nlog_prob.masked_fill(padding_mask, float("nan")) if self.reduction is None or self.reduction == "none": self.values.append(nlog_prob) else: - self.values += nlog_prob.sum() - self.total += target.size(0) + self.values += nlog_prob.nansum() + self.total += padding_mask.sum() if padding_mask is not None else target.numel() def compute(self) -> Tensor: """Computes NLL based on inputs passed in to ``update`` previously.""" values = dim_zero_cat(self.values) if self.reduction == "sum": - return values.sum(dim=-1) + return values.nansum() if self.reduction == "mean": - return values.sum(dim=-1) / self.total + return values.nansum() / self.total # reduction is None or "none" return values diff --git a/torch_uncertainty/models/depth/bts.py b/torch_uncertainty/models/depth/bts.py index 64bae545..10e105f5 100644 --- a/torch_uncertainty/models/depth/bts.py +++ b/torch_uncertainty/models/depth/bts.py @@ -4,7 +4,6 @@ import torch import torchvision.models as tv_models from torch import Tensor, nn -from torch.distributions import Distribution from torch.nn import functional as F from torchvision.models.densenet import DenseNet121_Weights, DenseNet161_Weights from torchvision.models.resnet import ( @@ -14,7 +13,7 @@ ResNeXt101_32X8D_Weights, ) -from torch_uncertainty.layers.distributions import LaplaceLayer, NormalLayer +from torch_uncertainty.layers.distributions import get_dist_conv_layer from torch_uncertainty.models.utils import Backbone resnet_feat_out_channels = [64, 256, 512, 1024, 2048] @@ -286,12 +285,22 @@ def __init__(self, backbone_name: str, pretrained: bool) -> None: class BTSDecoder(nn.Module): + """BTS decoder. + + Args: + max_depth (float): The maximum predicted depth. + feat_out_channels (list[int]): The number of output channels from the backbone. + num_features (int): The number of features to use in the decoder. + dist_family (str | None, optional): The distribution family name. ``None`` means point-wise + prediction. Defaults to ``None``. + """ + def __init__( self, max_depth: float, feat_out_channels: list[int], num_features: int, - dist_layer: type[nn.Module], + dist_family: str | None = None, ): super().__init__() self.max_depth = max_depth @@ -400,12 +409,20 @@ def __init__( ) self.conv1 = nn.Conv2d(num_features // 16 + 4, num_features // 16, 3, 1, 1, bias=False) self.output_channels = 1 - if dist_layer in (NormalLayer, LaplaceLayer): - self.output_channels = 2 - elif dist_layer != nn.Identity: - raise ValueError(f"Unsupported distribution layer. Got {dist_layer}.") - self.depth = nn.Conv2d(num_features // 16, self.output_channels, 3, 1, 1, bias=False) - self.dist_layer = dist_layer(dim=1) + + if dist_family is not None: + dist_layer_class = get_dist_conv_layer(dist_family) + self.depth = dist_layer_class( + base_layer=nn.Conv2d, + event_dim=self.output_channels, + in_channels=num_features // 16, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + else: + self.depth = nn.Conv2d(num_features // 16, self.output_channels, 3, 1, 1, bias=False) def feat_forward(self, features: list[Tensor]) -> Tensor: dense_features = F.relu(features[4]) @@ -474,22 +491,27 @@ def feat_forward(self, features: list[Tensor]) -> Tensor: ) return F.elu(self.conv1(concat1)) - def forward(self, features: list[Tensor]) -> Tensor | Distribution: + def forward(self, features: list[Tensor]) -> Tensor | dict[str, Tensor]: """Forward pass. Args: features (list[Tensor]): list of the features from the backbone. Note: - Depending of the :attr:`dist_layer` of the backbone, the output can - be a distribution or a single tensor. + Depending of the :attr:`dist_family` of the backbone, the output can + be a dictionnary of distribution parameters or a single tensor. """ # TODO: handle focal out = self.depth(self.feat_forward(features)) - if self.output_channels != 1: - loc = self.max_depth * F.sigmoid(out[:, 0, :, :]) - scale = self.max_depth * out[:, 1, :, :] - out = self.dist_layer(torch.stack([loc, scale], -1)) + if isinstance(out, dict): + if "loc" not in out or "scale" not in out: + raise ValueError( + "Expected 'loc' and 'scale' in the output dictionary.", + "Consider raising an issue on the repository if you need support ", + "for distributions that do not use location and scale.", + ) + out["loc"] = self.max_depth * F.sigmoid(out["loc"]) + out["scale"] = self.max_depth * out["scale"] else: out = self.max_depth * F.sigmoid(out) return out @@ -508,7 +530,7 @@ def __init__( ], max_depth: float, bts_size: int = 512, - dist_layer: type[nn.Module] = nn.Identity, + dist_family: str | None = None, pretrained_backbone: bool = True, ) -> None: """BTS model. @@ -517,7 +539,7 @@ def __init__( backbone_name (str): Name of the encoding backbone. max_depth (float): Maximum predicted depth. bts_size (int): BTS feature size. Defaults to 512. - dist_layer (nn.Module): Distribution layer for probabilistic depth + dist_family (str): Distribution family name. Defaults to None. estimation. Defaults to nn.Identity. pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. @@ -529,7 +551,7 @@ def __init__( self.max_depth = max_depth self.backbone = BTSBackbone(backbone_name, pretrained_backbone) - self.decoder = BTSDecoder(max_depth, self.backbone.feat_out_channels, bts_size, dist_layer) + self.decoder = BTSDecoder(max_depth, self.backbone.feat_out_channels, bts_size, dist_family) # TODO: Handle focal def forward(self, x: Tensor, focal: float | None = None) -> Tensor: @@ -546,18 +568,18 @@ def _bts( backbone_name: str, max_depth: float, bts_size: int = 512, - dist_layer: type[nn.Module] = nn.Identity, + dist_family: str | None = None, pretrained_backbone: bool = True, ) -> _BTS: if backbone_name not in bts_backbones: raise ValueError(f"Unsupported backbone. Got {backbone_name}.") - return _BTS(backbone_name, max_depth, bts_size, dist_layer, pretrained_backbone) + return _BTS(backbone_name, max_depth, bts_size, dist_family, pretrained_backbone) def bts_resnet50( max_depth: float, bts_size: int = 512, - dist_layer: type[nn.Module] = nn.Identity, + dist_family: str | None = None, pretrained_backbone: bool = True, ) -> _BTS: """BTS model with ResNet-50 backbone. @@ -565,15 +587,14 @@ def bts_resnet50( Args: max_depth (float): Maximum predicted depth. bts_size (int): BTS feature size. Defaults to 512. - dist_layer (nn.Module): Distribution layer for probabilistic depth - estimation. Defaults to nn.Identity. + dist_family (str): Distribution family name. Defaults to None. pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. """ return _bts( "resnet50", max_depth, bts_size=bts_size, - dist_layer=dist_layer, + dist_family=dist_family, pretrained_backbone=pretrained_backbone, ) @@ -581,7 +602,7 @@ def bts_resnet50( def bts_resnet101( max_depth: float, bts_size: int = 512, - dist_layer: type[nn.Module] = nn.Identity, + dist_family: str | None = None, pretrained_backbone: bool = True, ) -> _BTS: """BTS model with ResNet-101 backbone. @@ -589,14 +610,13 @@ def bts_resnet101( Args: max_depth (float): Maximum predicted depth. bts_size (int): BTS feature size. Defaults to 512. - dist_layer (nn.Module): Distribution layer for probabilistic depth - estimation. Defaults to nn.Identity. + dist_family (str): Distribution family name. Defaults to None. pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. """ return _bts( "resnet101", max_depth, bts_size=bts_size, - dist_layer=dist_layer, + dist_family=dist_family, pretrained_backbone=pretrained_backbone, ) diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 0f519a26..cda24d09 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -4,6 +4,7 @@ from torch import Tensor, nn from torch_uncertainty.layers.bayesian import BayesLinear +from torch_uncertainty.layers.distributions import get_dist_linear_layer from torch_uncertainty.layers.packed import PackedLinear from torch_uncertainty.models import StochasticModel @@ -19,9 +20,9 @@ def __init__( layer: type[nn.Module], activation: Callable, layer_args: dict, - final_layer: type[nn.Module], - final_layer_args: dict, dropout_rate: float, + dist_family: str | None, + dist_args: dict, ) -> None: """Multi-layer perceptron class. @@ -32,9 +33,10 @@ def __init__( layer (nn.Module): Layer class. activation (Callable): Activation function. layer_args (Dict): Arguments for the layer class. - final_layer (nn.Module): Final layer class for distribution regression. - final_layer_args (Dict): Arguments for the final layer class. dropout_rate (float): Dropout probability. + dist_family (str, optional): Distribution family name. ``None`` means point-wise + prediction. Defaults to ``None``. + dist_args (Dict, optional): Arguments for the distribution layer class. """ super().__init__() self.activation = activation @@ -43,38 +45,49 @@ def __init__( if len(hidden_dims) == 0: if layer == PackedLinear: - layers.append( - layer( - in_features, - num_outputs, - first=True, - last=True, - **layer_args, - ) - ) - else: - layers.append(layer(in_features, num_outputs, **layer_args)) + layer_args |= {"first": True, "last": True} + + self.final_layer = layer( + in_features=in_features, out_features=num_outputs, **layer_args + ) else: if layer == PackedLinear: - layers.append(layer(in_features, hidden_dims[0], first=True, **layer_args)) - else: - layers.append(layer(in_features, hidden_dims[0], **layer_args)) + layer_args |= {"first": True, "last": False} + + layers.append(layer(in_features=in_features, out_features=hidden_dims[0], **layer_args)) + + if layer == PackedLinear: + layer_args |= {"first": False} for i in range(1, len(hidden_dims)): - layers.append(layer(hidden_dims[i - 1], hidden_dims[i], **layer_args)) + layers.append( + layer(in_features=hidden_dims[i - 1], out_features=hidden_dims[i], **layer_args) + ) if layer == PackedLinear: - layers.append(layer(hidden_dims[-1], num_outputs, last=True, **layer_args)) + layer_args |= {"last": True} + + if dist_family is not None: + dist_layer_class = get_dist_linear_layer(dist_family) + self.final_layer = dist_layer_class( + base_layer=layer, + event_dim=num_outputs, + in_features=hidden_dims[-1], + **layer_args, + **dist_args, + ) else: - layers.append(layer(hidden_dims[-1], num_outputs, **layer_args)) + self.final_layer = layer( + in_features=hidden_dims[-1], out_features=num_outputs, **layer_args + ) + self.layers = layers - self.final_layer = final_layer(**final_layer_args) - def forward(self, x: Tensor) -> Tensor: - for layer in self.layers[:-1]: + def forward(self, x: Tensor) -> Tensor | dict[str, Tensor]: + for layer in self.layers: x = F.dropout(layer(x), p=self.dropout_rate, training=self.training) x = self.activation(x) - return self.final_layer(self.layers[-1](x)) + return self.final_layer(x) def _mlp( @@ -83,27 +96,23 @@ def _mlp( num_outputs: int, hidden_dims: list[int], num_samples: int = 16, - layer_args: dict | None = None, layer: type[nn.Module] = nn.Linear, + layer_args: dict | None = None, activation: Callable = F.relu, - final_layer: type[nn.Module] = nn.Identity, - final_layer_args: dict | None = None, dropout_rate: float = 0.0, + dist_family: str | None = None, + dist_args: dict | None = None, ) -> _MLP | StochasticModel: - if layer_args is None: - layer_args = {} - if final_layer_args is None: - final_layer_args = {} model = _MLP( in_features=in_features, num_outputs=num_outputs, hidden_dims=hidden_dims, - layer_args=layer_args, + layer_args=layer_args or {}, layer=layer, activation=activation, - final_layer=final_layer, - final_layer_args=final_layer_args, dropout_rate=dropout_rate, + dist_family=dist_family, + dist_args=dist_args or {}, ) if stochastic: return StochasticModel(model, num_samples) @@ -114,11 +123,10 @@ def mlp( in_features: int, num_outputs: int, hidden_dims: list[int], - layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, - final_layer: type[nn.Module] = nn.Identity, - final_layer_args: dict | None = None, dropout_rate: float = 0.0, + dist_family: str | None = None, + dist_args: dict | None = None, ) -> _MLP: """Multi-layer perceptron. @@ -126,13 +134,12 @@ def mlp( in_features (int): Number of input features. num_outputs (int): Number of output features. hidden_dims (list[int]): Number of features in each hidden layer. - layer (nn.Module, optional): Layer type. Defaults to nn.Linear. activation (Callable, optional): Activation function. Defaults to F.relu. - final_layer (nn.Module, optional): Final layer class for distribution - regression. Defaults to nn.Identity. - final_layer_args (Dict, optional): Arguments for the final layer class. dropout_rate (float, optional): Dropout probability. Defaults to 0.0. + dist_family (str, optional): Distribution family. Defaults to None. + dist_args (Dict, optional): Arguments for the distribution layer class. Defaults + to None. Returns: _MLP: A Multi-Layer-Perceptron model. @@ -142,11 +149,10 @@ def mlp( in_features=in_features, num_outputs=num_outputs, hidden_dims=hidden_dims, - layer=layer, activation=activation, - final_layer=final_layer, - final_layer_args=final_layer_args, dropout_rate=dropout_rate, + dist_family=dist_family, + dist_args=dist_args, ) @@ -158,9 +164,9 @@ def packed_mlp( alpha: float = 2, gamma: float = 1, activation: Callable = F.relu, - final_layer: type[nn.Module] = nn.Identity, - final_layer_args: dict | None = None, dropout_rate: float = 0.0, + dist_family: str | None = None, + dist_args: dict | None = None, ) -> _MLP: layer_args = { "num_estimators": num_estimators, @@ -175,9 +181,9 @@ def packed_mlp( layer=PackedLinear, activation=activation, layer_args=layer_args, - final_layer=final_layer, - final_layer_args=final_layer_args, dropout_rate=dropout_rate, + dist_family=dist_family, + dist_args=dist_args, ) @@ -187,9 +193,9 @@ def bayesian_mlp( hidden_dims: list[int], num_samples: int = 16, activation: Callable = F.relu, - final_layer: type[nn.Module] = nn.Identity, - final_layer_args: dict | None = None, dropout_rate: float = 0.0, + dist_family: str | None = None, + dist_args: dict | None = None, ) -> StochasticModel: return _mlp( stochastic=True, @@ -199,7 +205,7 @@ def bayesian_mlp( hidden_dims=hidden_dims, layer=BayesLinear, activation=activation, - final_layer=final_layer, - final_layer_args=final_layer_args, dropout_rate=dropout_rate, + dist_family=dist_family, + dist_args=dist_args, ) diff --git a/torch_uncertainty/models/segmentation/segformer.py b/torch_uncertainty/models/segmentation/segformer.py index cf8f35d6..f303e0fb 100644 --- a/torch_uncertainty/models/segmentation/segformer.py +++ b/torch_uncertainty/models/segmentation/segformer.py @@ -435,7 +435,7 @@ def _get_embed_dims(arch: int) -> list[int]: return [64, 128, 320, 512] -def _get_depths(arch: int) -> list[int]: +def _get_depths(arch: int) -> list[int]: # coverage: ignore if arch == 0 or arch == 1: return [2, 2, 2, 2] if arch == 2: diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 2a333e07..4c49a3d5 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -3,9 +3,6 @@ import torch from torch import nn -from torch.distributions import Distribution - -from torch_uncertainty.utils.distributions import cat_dist class _DeepEnsembles(nn.Module): @@ -42,17 +39,23 @@ def __init__( super().__init__(models) self.probabilistic = probabilistic - def forward(self, x: torch.Tensor) -> Distribution: + def forward(self, x: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: r"""Return the logits of the ensemble. Args: x (Tensor): The input of the model. Returns: - Distribution: + Tensor | dict[str, Tensor]: The output of the model with shape :math:`(N \times B, *)` + where :math:`B` is the batch size, :math:`N` is the number of estimators, and + :math:`*` is any other dimension. """ if self.probabilistic: - return cat_dist([model.forward(x) for model in self.core_models], dim=0) + out = [model.forward(x) for model in self.core_models] + key_set = {tuple(o.keys()) for o in out} + if len(key_set) != 1: + raise ValueError("The output of the models must have the same keys.") + return {k: torch.cat([o[k] for o in out], dim=0) for k in key_set.pop()} return super().forward(x) @@ -107,7 +110,7 @@ def deep_ensembles( if reset_model_parameters: for model in models: - for layer in model.children(): + for layer in model.modules(): if hasattr(layer, "reset_parameters"): layer.reset_parameters() diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index 3132439d..af81eb5d 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -39,7 +39,7 @@ def __init__( model (nn.Module): PyTorch model to be trained. cycle_start (int): Begininning of the first SWAG averaging cycle. cycle_length (int): Number of epochs between SWAG updates. The - first update occurs at :attr:`cycle_start`+:attr:`cycle_length`. + first update occurs at :attr:`cycle_start` + :attr:`cycle_length`. scale (float, optional): Scale of the Gaussian. Defaults to 1.0. diag_covariance (bool, optional): Whether to use a diagonal covariance. Defaults to False. diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index f81ff762..dc6f8a85 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -34,9 +34,8 @@ STEP_UPDATE_MODEL, ) from torch_uncertainty.utils.distributions import ( - dist_rearrange, - dist_size, - dist_squeeze, + get_dist_class, + get_dist_estimate, ) @@ -50,8 +49,9 @@ def __init__( self, model: nn.Module, output_dim: int, - probabilistic: bool, loss: nn.Module, + dist_family: str | None = None, + dist_estimate: str = "mean", is_ensemble: bool = False, format_batch_fn: nn.Module | None = None, optim_recipe: dict | Optimizer | None = None, @@ -59,14 +59,17 @@ def __init__( num_image_plot: int = 4, log_plots: bool = False, ) -> None: - """Routine for training & testing on **pixel regression** tasks. + r"""Routine for training & testing on **pixel regression** tasks. Args: model (nn.Module): Model to train. output_dim (int): Number of outputs of the model. - probabilistic (bool): Whether the model is probabilistic, i.e., - outputs a PyTorch distribution. loss (nn.Module): Loss function to optimize the :attr:`model`. + dist_family (str, optional): The distribution family to use for + probabilistic pixel regression. If ``None`` then point-wise regression. + Defaults to ``None``. + dist_estimate (str, optional): The estimate to use when computing the + point-wise metrics. Defaults to ``"mean"``. is_ensemble (bool, optional): Whether the model is an ensemble. Defaults to ``False``. optim_recipe (dict or Optimizer, optional): The optimizer and @@ -89,7 +92,9 @@ def __init__( self.model = model self.output_dim = output_dim self.one_dim_depth = output_dim == 1 - self.probabilistic = probabilistic + self.dist_family = dist_family + self.dist_estimate = dist_estimate + self.probabilistic = dist_family is not None self.loss = loss self.num_image_plot = num_image_plot self.is_ensemble = is_ensemble @@ -163,7 +168,7 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = self.model(inputs) if self.probabilistic: if not self.is_ensemble: - pred = dist_squeeze(pred, -1) + pred = {k: v.squeeze(-1) for k, v in pred.items()} else: if not self.is_ensemble: pred = pred.squeeze(-1) @@ -174,39 +179,52 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OU if self.one_dim_depth: target = target.unsqueeze(1) - dists = self.model(inputs) - out_shape = dist_size(dists)[-2:] if self.probabilistic else dists.shape[-2:] + out = self.model(inputs) + out_shape = out[next(iter(out))].shape[-2:] if self.probabilistic else out.shape[-2:] target = F.resize(target, out_shape, interpolation=F.InterpolationMode.NEAREST) - padding_mask = torch.isnan(target) + target = rearrange(target, "b c h w -> b h w c") + padding_mask = torch.isnan(target).any(dim=-1) if self.probabilistic: + dist_params = {k: rearrange(v, "b c h w -> b h w c") for k, v in out.items()} + # Adding the Independent wrapper to the distribution to compute correctly the + # log-likelihood given a target. Here the last dimension is the event dimension. + # When computing the log-likelihood, the values are summed over the event dimension. + dists = Independent(get_dist_class(self.dist_family)(**dist_params), 1) loss = self.loss(dists, target, padding_mask) else: - loss = self.loss(dists[padding_mask], target[padding_mask]) + out = rearrange(out, "b c h w -> b h w c") + loss = self.loss(out[padding_mask], target[padding_mask]) if self.needs_step_update: self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss, prog_bar=True, logger=True) return loss + def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | None]: + batch_size = inputs.size(0) + preds = self.model(inputs) + + if self.probabilistic: + dist_params = { + k: rearrange(v, "(m b) c h w -> b h w m c", b=batch_size) for k, v in preds.items() + } + # Adding the Independent wrapper to the distribution to create a MixtureSameFamily. + # As required by the torch.distributions API, the last dimension is the event dimension. + comp = Independent(get_dist_class(self.dist_family)(**dist_params), 1) + mix = Categorical(torch.ones(comp.batch_shape, device=self.device)) + mixture = MixtureSameFamily(mix, comp) + preds = get_dist_estimate(comp, self.dist_estimate).mean(-2) + return preds, mixture + + preds = rearrange(preds, "(m b) c h w -> b m h w c", b=batch_size) + return preds.mean(dim=1), None + def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) - batch_size = targets.size(0) - targets = rearrange(targets, "b c h w -> (b c h w)") - preds = self.model(inputs) - - if self.probabilistic: - ens_dist = Independent( - dist_rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size), - 0, - ) - mix = Categorical(torch.ones((dist_size(preds)[0] // batch_size), device=self.device)) - mixture = MixtureSameFamily(mix, ens_dist) - preds = mixture.mean - else: - preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) - preds = preds.mean(dim=1) + targets = rearrange(targets, "b c h w -> b h w c") + preds, dist = self.evaluation_forward(inputs) if batch_idx == 0 and self.log_plots: self._plot_depth( @@ -216,10 +234,10 @@ def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: stage="val", ) - padding_mask = torch.isnan(targets) + padding_mask = torch.isnan(targets).any(dim=-1) self.val_metrics.update(preds[padding_mask], targets[padding_mask]) - if self.probabilistic: - self.val_prob_metrics.update(mixture, targets, padding_mask) + if isinstance(dist, Distribution): + self.val_prob_metrics.update(dist, targets, padding_mask) def test_step( self, @@ -234,18 +252,8 @@ def test_step( inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) - batch_size = targets.size(0) - targets = rearrange(targets, "b c h w -> (b c h w)") - preds = self.model(inputs) - - if self.probabilistic: - ens_dist = dist_rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) - mix = Categorical(torch.ones((dist_size(preds)[0] // batch_size), device=self.device)) - mixture = MixtureSameFamily(mix, ens_dist) - preds = mixture.mean - else: - preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) - preds = preds.mean(dim=1) + targets = rearrange(targets, "b c h w -> b h w c") + preds, dist = self.evaluation_forward(inputs) if batch_idx == 0 and self.log_plots: num_images = ( @@ -258,10 +266,10 @@ def test_step( stage="test", ) - padding_mask = torch.isnan(targets) + padding_mask = torch.isnan(targets).any(dim=-1) self.test_metrics.update(preds[padding_mask], targets[padding_mask]) - if self.probabilistic: - self.test_prob_metrics.update(mixture, targets, padding_mask) + if isinstance(dist, Distribution): + self.test_prob_metrics.update(dist, targets, padding_mask) def on_validation_epoch_end(self) -> None: res_dict = self.val_metrics.compute() diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index e51a126b..29bdae7c 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -6,6 +6,7 @@ from torch.distributions import ( Categorical, Distribution, + Independent, MixtureSameFamily, ) from torch.optim import Optimizer @@ -20,9 +21,8 @@ STEP_UPDATE_MODEL, ) from torch_uncertainty.utils.distributions import ( - dist_rearrange, - dist_size, - dist_squeeze, + get_dist_class, + get_dist_estimate, ) @@ -31,8 +31,9 @@ def __init__( self, model: nn.Module, output_dim: int, - probabilistic: bool, loss: nn.Module, + dist_family: str | None = None, + dist_estimate: str = "mean", is_ensemble: bool = False, optim_recipe: dict | Optimizer | None = None, eval_shift: bool = False, @@ -43,9 +44,12 @@ def __init__( Args: model (torch.nn.Module): Model to train. output_dim (int): Number of outputs of the model. - probabilistic (bool): Whether the model is probabilistic, i.e., - outputs a PyTorch distribution. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. + dist_family (str, optional): The distribution family to use for + probabilistic regression. If ``None`` then point-wise regression. + Defaults to ``None``. + dist_estimate (str, optional): The estimate to use when computing the + point-wise metrics. Defaults to ``"mean"``. is_ensemble (bool, optional): Whether the model is an ensemble. Defaults to ``False``. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and @@ -76,7 +80,9 @@ def __init__( ) self.model = model - self.probabilistic = probabilistic + self.dist_family = dist_family + self.dist_estimate = dist_estimate + self.probabilistic = dist_family is not None self.output_dim = output_dim self.loss = loss self.is_ensemble = is_ensemble @@ -141,10 +147,16 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: """ pred = self.model(inputs) if self.probabilistic: - if self.one_dim_regression: - pred = dist_squeeze(pred, -1) - if not self.is_ensemble: - pred = dist_squeeze(pred, -1) + if isinstance(pred, dict): + if self.one_dim_regression: + pred = {k: v.squeeze(-1) for k, v in pred.items()} + if not self.is_ensemble: + pred = {k: v.squeeze(-1) for k, v in pred.items()} + else: + raise TypeError( + "If the model is probabilistic, the output must be a dictionary ", + "of PyTorch distributions.", + ) else: if self.one_dim_regression: pred = pred.squeeze(-1) @@ -161,34 +173,50 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OU if isinstance(self.loss, ELBOLoss): loss = self.loss(inputs, targets) else: - dists = self.model(inputs) - loss = self.loss(dists, targets) + out = self.model(inputs) + if self.probabilistic: + # Adding the Independent wrapper to the distribution to compute correctly the + # log-likelihood given a target. Here the last dimension is the event dimension. + # When computing the log-likelihood, the values are summed over the event + # dimension. + dists = Independent(get_dist_class(self.dist_family)(**out), 1) + loss = self.loss(dists, targets) + else: + loss = self.loss(out, targets) if self.needs_step_update: self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss, prog_bar=True, logger=True) return loss + def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | None]: + batch_size = inputs.size(0) + preds = self.model(inputs) + + if self.probabilistic: + dist_params = { + k: rearrange(v, "(m b) c -> b m c", b=batch_size) for k, v in preds.items() + } + # Adding the Independent wrapper to the distribution to create a MixtureSameFamily. + # As required by the torch.distributions API, the last dimension is the event dimension. + comp = Independent(get_dist_class(self.dist_family)(**dist_params), 1) + mix = Categorical(torch.ones(comp.batch_shape, device=self.device)) + dist = MixtureSameFamily(mix, comp) + preds = get_dist_estimate(comp, self.dist_estimate).mean(1) + return preds, dist + + preds = rearrange(preds, "(m b) c -> b m c", b=batch_size) + return preds.mean(dim=1), None + def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) - batch_size = targets.size(0) - targets = rearrange(targets, "b c -> (b c)") - preds = self.model(inputs) - - if self.probabilistic: - ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) - mix = Categorical(torch.ones(dist_size(preds)[0] // batch_size, device=self.device)) - mixture = MixtureSameFamily(mix, ens_dist) - preds = mixture.mean - else: - preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size) - preds = preds.mean(dim=1) + preds, dist = self.evaluation_forward(inputs) self.val_metrics.update(preds, targets) - if self.probabilistic: - self.val_prob_metrics.update(mixture, targets) + if isinstance(dist, Distribution): + self.val_prob_metrics.update(dist, targets) def test_step( self, @@ -204,22 +232,11 @@ def test_step( inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) - batch_size = targets.size(0) - targets = rearrange(targets, "b c -> (b c)") - preds = self.model(inputs) - - if self.probabilistic: - ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) - mix = Categorical(torch.ones(dist_size(preds)[0] // batch_size, device=self.device)) - mixture = MixtureSameFamily(mix, ens_dist) - preds = mixture.mean - else: - preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size) - preds = preds.mean(dim=1) + preds, dist = self.evaluation_forward(inputs) self.test_metrics.update(preds, targets) - if self.probabilistic: - self.test_prob_metrics.update(mixture, targets) + if isinstance(dist, Distribution): + self.test_prob_metrics.update(dist, targets) def on_validation_epoch_end(self) -> None: res_dict = self.val_metrics.compute() diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index d3d161a8..3e115e93 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -1,102 +1,58 @@ from numbers import Number import torch -from einops import rearrange from torch import Tensor from torch.distributions import ( + Cauchy, Distribution, Laplace, Normal, + StudentT, constraints, ) from torch.distributions.utils import broadcast_all -def dist_size(distribution: Distribution) -> torch.Size: - """Get the size of the distribution. +def get_dist_class(dist_family: str) -> type[Distribution]: + """Get the distribution class from a string. Args: - distribution (Distribution): The distribution. + dist_family (str): The distribution family. Returns: - torch.Size: The size of the distribution. + type[Distribution]: The distribution class. """ - if isinstance(distribution, Normal | Laplace | NormalInverseGamma): - return distribution.loc.size() + if dist_family == "normal": + return Normal + if dist_family == "laplace": + return Laplace + if dist_family == "nig": + return NormalInverseGamma + if dist_family == "cauchy": + return Cauchy + if dist_family == "student": + return StudentT raise NotImplementedError( - f"Size of {type(distribution)} distributions is not supported." "Raise an issue if needed." + f"{dist_family} distribution is not supported." "Raise an issue if needed." ) -def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: - """Concatenate a list of distributions into a single distribution. +def get_dist_estimate(dist: Distribution, dist_estimate: str) -> Tensor: + """Get a point-wise prediction from a distribution. Args: - distributions (list[Distribution]): The list of distributions. - dim (int): The dimension to concatenate. + dist (Distribution): The distribution. + dist_estimate (str): The estimate to use. Returns: - Distribution: The concatenated distributions. + Tensor: The estimated value. """ - dist_type = type(distributions[0]) - if not all(isinstance(distribution, dist_type) for distribution in distributions): - raise ValueError("All distributions must have the same type.") - - if isinstance(distributions[0], Normal | Laplace): - locs = torch.cat([distribution.loc for distribution in distributions], dim=dim) - scales = torch.cat([distribution.scale for distribution in distributions], dim=dim) - return dist_type(loc=locs, scale=scales) - if isinstance(distributions[0], NormalInverseGamma): - locs = torch.cat([distribution.loc for distribution in distributions], dim=dim) - lmbdas = torch.cat([distribution.lmbda for distribution in distributions], dim=dim) - alphas = torch.cat([distribution.alpha for distribution in distributions], dim=dim) - betas = torch.cat([distribution.beta for distribution in distributions], dim=dim) - return NormalInverseGamma(loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas) + if dist_estimate == "mean": + return dist.mean + if dist_estimate == "mode": + return dist.mode raise NotImplementedError( - f"Concatenation of {dist_type} distributions is not supported." "Raise an issue if needed." - ) - - -def dist_squeeze(distribution: Distribution, dim: int) -> Distribution: - """Squeeze the distribution along a given dimension. - - Args: - distribution (Distribution): The distribution to squeeze. - dim (int): The dimension to squeeze. - - Returns: - Distribution: The squeezed distribution. - """ - dist_type = type(distribution) - if isinstance(distribution, Normal | Laplace): - loc = distribution.loc.squeeze(dim) - scale = distribution.scale.squeeze(dim) - return dist_type(loc=loc, scale=scale) - if isinstance(distribution, NormalInverseGamma): - loc = distribution.loc.squeeze(dim) - lmbda = distribution.lmbda.squeeze(dim) - alpha = distribution.alpha.squeeze(dim) - beta = distribution.beta.squeeze(dim) - return NormalInverseGamma(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) - raise NotImplementedError( - f"Squeezing of {dist_type} distributions is not supported." "Raise an issue if needed." - ) - - -def dist_rearrange(distribution: Distribution, pattern: str, **axes_lengths: int) -> Distribution: - dist_type = type(distribution) - if isinstance(distribution, Normal | Laplace): - loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) - scale = rearrange(distribution.scale, pattern=pattern, **axes_lengths) - return dist_type(loc=loc, scale=scale) - if isinstance(distribution, NormalInverseGamma): - loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) - lmbda = rearrange(distribution.lmbda, pattern=pattern, **axes_lengths) - alpha = rearrange(distribution.alpha, pattern=pattern, **axes_lengths) - beta = rearrange(distribution.beta, pattern=pattern, **axes_lengths) - return NormalInverseGamma(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) - raise NotImplementedError( - f"Rearrange of {dist_type} is not supported. Raise an issue if needed." + f"{dist_estimate} estimate is not supported." "Raise an issue if needed." ) diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index dac1138c..d079cae4 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -7,6 +7,29 @@ from rich import get_console from rich.console import Group from rich.table import Table +from torch import Tensor + +PERCENTAGE_METRICS = [ + "Acc", + "AUPR", + "AUROC", + "FPR95", + "Cov@5Risk", + "Risk@80Cov", + "pixAcc", + "mIoU", + "AURC", + "AUGRC", + "mAcc", +] + + +def _add_row(table: Table, metric_name: str, value: Tensor) -> None: + if metric_name in PERCENTAGE_METRICS: + value = value * 100 + table.add_row(metric_name, f"{value.item():.2f}%") + else: + table.add_row(metric_name, f"{value.item():.5f}") class TUEvaluationLoop(_EvaluationLoop): @@ -20,21 +43,6 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: # test/post: Post-Processing Metrics # test/seg: Segmentation Metrics - # In percentage - percentage_metrics = [ - "Acc", - "AUPR", - "AUROC", - "FPR95", - "Cov@5Risk", - "Risk@80Cov", - "pixAcc", - "mIoU", - "AURC", - "AUGRC", - "mAcc", - ] - metrics = {} for result in results: for key, value in result.items(): @@ -88,12 +96,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: table.add_column(first_col_name, justify="center", style="cyan", width=12) table.add_column("Classification", justify="center", style="magenta", width=25) cls_metrics = OrderedDict(sorted(metrics["cls"].items())) - for metric, value in cls_metrics.items(): - if metric in percentage_metrics: - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + for metric_name, value in cls_metrics.items(): + _add_row(table, metric_name, value) tables.append(table) if "seg" in metrics: @@ -101,12 +105,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: table.add_column(first_col_name, justify="center", style="cyan", width=12) table.add_column("Segmentation", justify="center", style="magenta", width=25) seg_metrics = OrderedDict(sorted(metrics["seg"].items())) - for metric, value in seg_metrics.items(): - if metric in percentage_metrics: - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + for metric_name, value in seg_metrics.items(): + _add_row(table, metric_name, value) tables.append(table) if "reg" in metrics: @@ -114,12 +114,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: table.add_column(first_col_name, justify="center", style="cyan", width=12) table.add_column("Regression", justify="center", style="magenta", width=25) reg_metrics = OrderedDict(sorted(metrics["reg"].items())) - for metric, value in reg_metrics.items(): - if metric in percentage_metrics: # coverage: ignore - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + for metric_name, value in reg_metrics.items(): + _add_row(table, metric_name, value) tables.append(table) if "cal" in metrics: @@ -127,12 +123,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: table.add_column(first_col_name, justify="center", style="cyan", width=12) table.add_column("Calibration", justify="center", style="magenta", width=25) cal_metrics = OrderedDict(sorted(metrics["cal"].items())) - for metric, value in cal_metrics.items(): - if metric in percentage_metrics: - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + for metric_name, value in cal_metrics.items(): + _add_row(table, metric_name, value) tables.append(table) if "ood" in metrics: @@ -140,12 +132,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: table.add_column(first_col_name, justify="center", style="cyan", width=12) table.add_column("OOD Detection", justify="center", style="magenta", width=25) ood_metrics = OrderedDict(sorted(metrics["ood"].items())) - for metric, value in ood_metrics.items(): - if metric in percentage_metrics: - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + for metric_name, value in ood_metrics.items(): + _add_row(table, metric_name, value) tables.append(table) if "sc" in metrics: @@ -158,12 +146,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: width=25, ) sc_metrics = OrderedDict(sorted(metrics["sc"].items())) - for metric, value in sc_metrics.items(): - if metric in percentage_metrics: - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + for metric_name, value in sc_metrics.items(): + _add_row(table, metric_name, value) tables.append(table) if "post" in metrics: @@ -171,12 +155,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: table.add_column(first_col_name, justify="center", style="cyan", width=12) table.add_column("Post-Processing", justify="center", style="magenta", width=25) post_metrics = OrderedDict(sorted(metrics["post"].items())) - for metric, value in post_metrics.items(): - if metric in percentage_metrics: - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + for metric_name, value in post_metrics.items(): + _add_row(table, metric_name, value) tables.append(table) if "shift" in metrics: @@ -190,14 +170,10 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: width=25, ) shift_metrics = OrderedDict(sorted(metrics["shift"].items())) - for metric, value in shift_metrics.items(): - if metric == "shift_severity": + for metric_name, value in shift_metrics.items(): + if metric_name == "shift_severity": continue - if metric in percentage_metrics: - value = value * 100 - table.add_row(metric, f"{value.item():.2f}%") - else: - table.add_row(metric, f"{value.item():.5f}") + _add_row(table, metric_name, value) tables.append(table) console = get_console()