Skip to content

Commit

Permalink
Merge pull request #126 from ENSTA-U2IS-AI/dev
Browse files Browse the repository at this point in the history
🎨 Rework Probabilistic Regression and add Packed-Transformer layers
  • Loading branch information
o-laurent authored Jan 6, 2025
2 parents c7d870d + c16d53e commit 3a021d2
Show file tree
Hide file tree
Showing 74 changed files with 4,771 additions and 690 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
- name: Test with pytest and compute coverage
if: steps.changed-files-specific.outputs.only_changed != 'true'
run: |
python3 -m pytest --cov --cov-report xml --durations 10
python3 -m pytest --cov --cov-report xml --durations 10 --junitxml=junit.xml
- name: Upload coverage to Codecov
if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev')
Expand All @@ -89,6 +89,15 @@ jobs:
name: CPU-coverage
env_vars: PYTHON_VERSION

- name: Upload test results to Codecov
if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev')
uses: codecov/test-results-action@v1
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: cpu,pytest
env_vars: PYTHON_VERSION

- name: Test sphinx build without tutorials
if: steps.changed-files-specific.outputs.only_changed != 'true'
run: |
Expand Down
18 changes: 9 additions & 9 deletions auto_tutorials_source/tutorial_der_cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
from torch_uncertainty.datasets.regression.toy import Cubic
from torch_uncertainty.losses import DERLoss
from torch_uncertainty.routines import RegressionRoutine
from torch_uncertainty.layers.distributions import NormalInverseGammaLayer
from torch_uncertainty.layers.distributions import NormalInverseGammaLinear
from torch_uncertainty.utils.distributions import get_dist_class

# %%
# 2. The Optimization Recipe
Expand All @@ -50,12 +51,11 @@ def optim_regression(
model: nn.Module,
learning_rate: float = 5e-4,
):
optimizer = optim.Adam(
return optim.Adam(
model.parameters(),
lr=learning_rate,
weight_decay=0,
)
return optimizer


# %%
Expand All @@ -64,7 +64,7 @@ def optim_regression(
#
# In the following, we create a trainer to train the model, the same synthetic regression
# datasets as in the original DER paper and the model, a simple MLP with 2 hidden layers of 64 neurons each.
# Please note that this MLP finishes with a NormalInverseGammaLayer that interpret the outputs of the model
# Please note that this MLP finishes with a NormalInverseGammaLinear that interpret the outputs of the model
# as the parameters of a Normal Inverse Gamma distribution.

trainer = TUTrainer(accelerator="cpu", max_epochs=50) #, enable_progress_bar=False)
Expand All @@ -82,10 +82,9 @@ def optim_regression(
# model
model = mlp(
in_features=1,
num_outputs=4,
num_outputs=1,
hidden_dims=[64, 64],
final_layer=NormalInverseGammaLayer,
final_layer_args={"dim": 1},
dist_family="nig", # Normal Inverse Gamma
)

# %%
Expand All @@ -100,11 +99,11 @@ def optim_regression(
loss = DERLoss(reg_weight=1e-2)

routine = RegressionRoutine(
probabilistic=True,
output_dim=1,
model=model,
loss=loss,
optim_recipe=optim_regression(model),
dist_family="nig",
)

# %%
Expand All @@ -127,7 +126,8 @@ def optim_regression(
with torch.no_grad():
x = torch.linspace(-7, 7, 1000)

dists = model(x.unsqueeze(-1))
dist_params = model(x.unsqueeze(-1))
dists = get_dist_class("nig")(**dist_params)
means = dists.loc.squeeze(1)
variances = torch.sqrt(dists.variance_loc).squeeze(1)

Expand Down
2 changes: 1 addition & 1 deletion auto_tutorials_source/tutorial_from_de_to_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
The dataset is automatically downloaded using torchvision. We then visualize a few images to see a bit what we are working with.
"""
# Create the transforms for the images
# %%
import torch
import torchvision.transforms as T

# We set the number of epochs to some low value for the sake of time
max_epochs = 2

# Create the transforms for the images
train_transform = T.Compose(
[
T.ToTensor(),
Expand Down
170 changes: 170 additions & 0 deletions auto_tutorials_source/tutorial_probabilistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
Deep Probabilistic Regression
=============================
This tutorial aims to provide an overview of some utilities in TorchUncertainty for probabilistic regression.
Building a MLP for Probabilistic Regression using TorchUncertainty distribution layers
--------------------------------------------------------------------------------------
In this section we cover the building of a very simple MLP outputting Normal distribution parameters.
1. Loading the utilities
~~~~~~~~~~~~~~~~~~~~~~~~
We disable some logging and warnings to keep the output clean.
"""
# %%
import torch
from torch import nn

import logging
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)

import warnings
warnings.filterwarnings("ignore")

# %%
# 2. Building the MLP model
# ~~~~~~~~~~~~~~~~~~~~~~~~~
#
# To create a MLP model estimating a Normal distribution, we use the NormalLinear layer.
# This layer is a wrapper around the nn.Linear layer, which outputs the location and scale of a Normal distribution.
# Note that any other distribution layer from TU can be used in the same way.
from torch_uncertainty.layers.distributions import NormalLinear


class MLP(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.fc1 = nn.Linear(in_features, 50)
self.fc2 = NormalLinear(
base_layer=nn.Linear,
event_dim=out_features,
in_features=50,
)

def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)

# %%
# 3. Setting up the data
# ~~~~~~~~~~~~~~~~~~~~~~
#
# We use the UCI Kin8nm dataset, which is a regression dataset with 8 features and 8192 samples.
from torch_uncertainty.datamodules import UCIRegressionDataModule

# datamodule
datamodule = UCIRegressionDataModule(
root="data",
batch_size=32,
dataset_name="kin8nm",
)

# %%
# 4. Setting up the model and trainer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

from torch_uncertainty import TUTrainer

trainer = TUTrainer(
accelerator="cpu",
max_epochs=5,
enable_progress_bar=False,
)

model = MLP(in_features=8, out_features=1)


# %%
# 5. The Loss, the Optimizer and the Training Routine
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We use the DistributionNLLLoss to compute the negative log-likelihood of the Normal distribution.
# Note that this loss can be used with any Distribution from torch.distributions.
# For the optimizer, we use the Adam optimizer with a learning rate of 5e-3.
# Finally, we create a RegressionRoutine to train the model. We indicate that the output dimension is 1 and the distribution family is "normal".

from torch_uncertainty.losses import DistributionNLLLoss
from torch_uncertainty.routines import RegressionRoutine

loss = DistributionNLLLoss()

def optim_regression(
model: nn.Module,
learning_rate: float = 5e-3,
):
return torch.optim.Adam(
model.parameters(),
lr=learning_rate,
weight_decay=0,
)

routine = RegressionRoutine(
output_dim=1,
model=model,
loss=loss,
optim_recipe=optim_regression(model),
dist_family="normal",
)


# %%
# 6. Training the model
# ~~~~~~~~~~~~~~~~~~~~~~

trainer.fit(model=routine, datamodule=datamodule)
results = trainer.test(model=routine, datamodule=datamodule)

# %%
# 7. Benchmarking different distributions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Our MLP model assumes a Normal distribution as the output. However, we could be interested in comparing the performance of different distributions.
# TorchUncertainty provides a simple way to do this using the get_dist_linear_layer() function.
# Let us rewrite the MLP model to use it.

from torch_uncertainty.layers.distributions import get_dist_linear_layer

class MLP(nn.Module):
def __init__(self, in_features: int, out_features: int, dist_family: str):
super().__init__()
self.fc1 = nn.Linear(in_features, 50)
dist_layer = get_dist_linear_layer(dist_family)
self.fc2 = dist_layer(
base_layer=nn.Linear,
event_dim=out_features,
in_features=50,
)

def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)

# %%
# We can now train the model with different distributions.
# Let us train the model with a Normal, Laplace, Student's t, and Cauchy distribution.
# Note that we use the mode as the point-wise estimate of the distribution as the mean
# is not defined for the Cauchy distribution.
for dist_family in ["normal", "laplace", "student", "cauchy"]:
print("#" * 50)
print(f">>> Training with {dist_family} distribution")
print("#" * 50)
trainer = TUTrainer(
accelerator="cpu",
max_epochs=10,
enable_model_summary=False,
enable_progress_bar=False,
)
model = MLP(in_features=8, out_features=1, dist_family=dist_family)
routine = RegressionRoutine(
output_dim=1,
model=model,
loss=loss,
optim_recipe=optim_regression(model),
dist_family=dist_family,
dist_estimate="mode",
)
trainer.fit(model=routine, datamodule=datamodule)
trainer.test(model=routine, datamodule=datamodule)
38 changes: 38 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ Ensemble layers

PackedLinear
PackedConv2d
PackedMultiheadAttention
PackedLayerNorm
PackedTransformerEncoderLayer
PackedTransformerDecoderLayer
BatchLinear
BatchConv2d
MaskedLinear
Expand All @@ -148,6 +152,40 @@ Bayesian layers
LPBNNLinear
LPBNNConv2d


Density layers
^^^^^^^^^^^^^^

.. currentmodule:: torch_uncertainty.layers.distributions

Linear Layers
"""""""""""""

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

NormalLinear
LaplaceLinear
CauchyLinear
StudentTLinear
NormalInverseGammaLinear

Convolution Layers
""""""""""""""""""

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

NormalConvNd
LaplaceConvNd
CauchyConvNd
StudentTConvNd
NormalInverseGammaConvNd

Models
------

Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
)
author = "Adrien Lafage and Olivier Laurent"
release = "0.3.1"
release = "0.4.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 40
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/boston/mlp/laplace
name: standard
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/reg/NLL
mode: min
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/reg/NLL
patience: 1000
check_finite: true
model:
output_dim: 1
in_features: 13
hidden_dims:
- 50
loss: torch_uncertainty.losses.DistributionNLLLoss
version: std
distribution: laplace
data:
root: ./data
batch_size: 128
dataset_name: boston
optimizer:
lr: 5e-3
weight_decay: 0
Loading

0 comments on commit 3a021d2

Please sign in to comment.