Skip to content

Commit

Permalink
Merge pull request #116 from ENSTA-U2IS-AI/dev
Browse files Browse the repository at this point in the history
👕 Complete metric overhaul, improve PP handling & fix Laplace
  • Loading branch information
o-laurent authored Oct 2, 2024
2 parents 7df97be + 60fbe40 commit 74e641a
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 132 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
)
author = "Adrien Lafage and Olivier Laurent"
release = "0.2.2.post1"
release = "0.2.2.post2"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "torch_uncertainty"
version = "0.2.2.post1"
version = "0.2.2.post2"
authors = [
{ name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" },
{ name = "Adrien Lafage", email = "adrienlafage@outlook.com" },
Expand All @@ -18,7 +18,6 @@ keywords = [
"ensembles",
"neural-networks",
"predictive-uncertainty",
"pytorch",
"reliable-ai",
"trustworthy-machine-learning",
"uncertainty",
Expand All @@ -44,6 +43,7 @@ dependencies = [
"numpy<2",
"opencv-python",
"glest==0.0.1a0",
"rich>=10.2.2",
]

[project.optional-dependencies]
Expand Down
4 changes: 3 additions & 1 deletion torch_uncertainty/post_processing/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self.weight_subset = weight_subset
self.hessian_struct = hessian_struct
self.batch_size = batch_size
self.optimize_prior_precision = optimize_prior_precision

if model is not None:
self.set_model(model)
Expand All @@ -80,7 +81,8 @@ def set_model(self, model: nn.Module) -> None:
def fit(self, dataset: Dataset) -> None:
dl = DataLoader(dataset, batch_size=self.batch_size)
self.la.fit(train_loader=dl)
self.la.optimize_prior_precision(method="marglik")
if self.optimize_prior_precision:
self.la.optimize_prior_precision(method="marglik")

def forward(
self,
Expand Down
16 changes: 8 additions & 8 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,23 +196,23 @@ def _init_metrics(self) -> None:
),
"sc/AURC": AURC(),
"sc/AUGRC": AUGRC(),
"sc/CovAt5Risk": CovAt5Risk(),
"sc/RiskAt80Cov": RiskAt80Cov(),
"sc/Cov@5Risk": CovAt5Risk(),
"sc/Risk@80Cov": RiskAt80Cov(),
},
compute_groups=[
["cls/Acc"],
["cls/Brier"],
["cls/NLL"],
["cal/ECE", "cal/aECE"],
["sc/AURC", "sc/AUGRC", "sc/CovAt5Risk", "sc/RiskAt80Cov"],
["sc/AURC", "sc/AUGRC", "sc/Cov@5Risk", "sc/Risk@80Cov"],
],
)

self.val_cls_metrics = cls_metrics.clone(prefix="val/")
self.test_cls_metrics = cls_metrics.clone(prefix="test/")

if self.post_processing is not None:
self.ts_cls_metrics = cls_metrics.clone(prefix="test/ts_")
self.post_cls_metrics = cls_metrics.clone(prefix="test/post/")

self.test_id_entropy = Entropy()

Expand Down Expand Up @@ -463,7 +463,7 @@ def test_step(
)
self.test_id_entropy(probs)
self.log(
"test/cls/entropy",
"test/cls/Entropy",
self.test_id_entropy,
on_epoch=True,
add_dataloader_idx=False,
Expand All @@ -486,7 +486,7 @@ def test_step(
pp_probs = F.softmax(pp_logits, dim=-1)
else:
pp_probs = pp_logits
self.ts_cls_metrics.update(pp_probs, targets)
self.post_cls_metrics.update(pp_probs, targets)

elif self.eval_ood and dataloader_idx == 1:
self.test_ood_metrics.update(ood_scores, torch.ones_like(targets))
Expand Down Expand Up @@ -529,7 +529,7 @@ def on_test_epoch_end(self) -> None:
)

if self.post_processing is not None:
tmp_metrics = self.ts_cls_metrics.compute()
tmp_metrics = self.post_cls_metrics.compute()
self.log_dict(tmp_metrics, sync_dist=True)
result_dict.update(tmp_metrics)

Expand Down Expand Up @@ -573,7 +573,7 @@ def on_test_epoch_end(self) -> None:
if self.post_processing is not None:
self.logger.experiment.add_figure(
"Reliabity diagram after calibration",
self.ts_cls_metrics["cal/ECE"].plot()[0],
self.post_cls_metrics["cal/ECE"].plot()[0],
)

# plot histograms of logits and likelihoods
Expand Down
24 changes: 12 additions & 12 deletions torch_uncertainty/routines/pixel_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ def __init__(

depth_metrics = MetricCollection(
{
"SILog": SILog(),
"log10": Log10(),
"ARE": MeanGTRelativeAbsoluteError(),
"RSRE": MeanGTRelativeSquaredError(squared=False),
"RMSE": MeanSquaredError(squared=False),
"RMSELog": MeanSquaredLogError(squared=False),
"iMAE": MeanAbsoluteErrorInverse(),
"iRMSE": MeanSquaredErrorInverse(squared=False),
"d1": ThresholdAccuracy(power=1),
"d2": ThresholdAccuracy(power=2),
"d3": ThresholdAccuracy(power=3),
"reg/SILog": SILog(),
"reg/log10": Log10(),
"reg/ARE": MeanGTRelativeAbsoluteError(),
"reg/RSRE": MeanGTRelativeSquaredError(squared=False),
"reg/RMSE": MeanSquaredError(squared=False),
"reg/RMSELog": MeanSquaredLogError(squared=False),
"reg/iMAE": MeanAbsoluteErrorInverse(),
"reg/iRMSE": MeanSquaredErrorInverse(squared=False),
"reg/d1": ThresholdAccuracy(power=1),
"reg/d2": ThresholdAccuracy(power=2),
"reg/d3": ThresholdAccuracy(power=3),
},
compute_groups=False,
)
Expand All @@ -119,7 +119,7 @@ def __init__(

if self.probabilistic:
depth_prob_metrics = MetricCollection(
{"NLL": DistributionNLL(reduction="mean")}
{"reg/NLL": DistributionNLL(reduction="mean")}
)
self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/")
self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/")
Expand Down
8 changes: 4 additions & 4 deletions torch_uncertainty/routines/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def __init__(

reg_metrics = MetricCollection(
{
"MAE": MeanAbsoluteError(),
"MSE": MeanSquaredError(squared=True),
"RMSE": MeanSquaredError(squared=False),
"reg/MAE": MeanAbsoluteError(),
"reg/MSE": MeanSquaredError(squared=True),
"reg/RMSE": MeanSquaredError(squared=False),
},
compute_groups=True,
)
Expand All @@ -96,7 +96,7 @@ def __init__(

if self.probabilistic:
reg_prob_metrics = MetricCollection(
{"NLL": DistributionNLL(reduction="mean")}
{"reg/NLL": DistributionNLL(reduction="mean")}
)
self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/")
self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/")
Expand Down
Loading

0 comments on commit 74e641a

Please sign in to comment.