From b217e4ed983bf1e0406d93ad2904f352c3b722df Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 30 Nov 2023 16:24:00 +0100 Subject: [PATCH 001/148] :tada: Initial commit for segmentation support --- .../datasets/segmentation/__init__.py | 0 .../datasets/segmentation/camvid.py | 40 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 torch_uncertainty/datasets/segmentation/__init__.py create mode 100644 torch_uncertainty/datasets/segmentation/camvid.py diff --git a/torch_uncertainty/datasets/segmentation/__init__.py b/torch_uncertainty/datasets/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py new file mode 100644 index 00000000..808022b2 --- /dev/null +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -0,0 +1,40 @@ +from collections.abc import Callable +from typing import NamedTuple + +from torchvision.datasets import VisionDataset + + +class CamVidClass(NamedTuple): + name: str + index: int + color: tuple[int, int, int] + + +class CamVid(VisionDataset): + # Notes: some classes are not used here + classes = [ + CamVidClass("sky", 0, (128, 128, 128)), + CamVidClass("building", 1, (128, 0, 0)), + CamVidClass("pole", 2, (192, 192, 128)), + CamVidClass("road_marking", 3, (255, 69, 0)), + CamVidClass("road", 4, (128, 64, 128)), + CamVidClass("pavement", 5, (60, 40, 222)), + CamVidClass("tree", 6, (128, 128, 0)), + CamVidClass("sign_symbol", 7, (192, 128, 128)), + CamVidClass("fence", 8, (64, 64, 128)), + CamVidClass("car", 9, (64, 0, 128)), + CamVidClass("pedestrian", 10, (64, 64, 0)), + CamVidClass("bicyclist", 11, (0, 128, 192)), + CamVidClass("unlabelled", 12, (0, 0, 0)), + ] + + def __init__( + self, + root: str, + split: str = "train", + transform: Callable | None = None, + target_transform: Callable | None = None, + transforms: Callable | None = None, + ) -> None: + """`CamVid `_ Dataset.""" + super().__init__(root, transforms, transform, target_transform) From 016421e8ec70d579be5f34c29a486747e5d1d1b0 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 30 Nov 2023 17:39:36 +0100 Subject: [PATCH 002/148] :sparkles: CamVid segmentation dataset support --- tests/datasets/segmentation/__init__.py | 0 tests/datasets/segmentation/test_camvid.py | 11 ++ .../datasets/segmentation/__init__.py | 2 + .../datasets/segmentation/camvid.py | 114 +++++++++++++++++- 4 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 tests/datasets/segmentation/__init__.py create mode 100644 tests/datasets/segmentation/test_camvid.py diff --git a/tests/datasets/segmentation/__init__.py b/tests/datasets/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datasets/segmentation/test_camvid.py b/tests/datasets/segmentation/test_camvid.py new file mode 100644 index 00000000..2777ad4d --- /dev/null +++ b/tests/datasets/segmentation/test_camvid.py @@ -0,0 +1,11 @@ +import pytest + +from torch_uncertainty.datasets.segmentation import CamVid + + +class TestCamVid: + """Testing the CamVid dataset class.""" + + def test_nodataset(self): + with pytest.raises(RuntimeError): + _ = CamVid("./.data") diff --git a/torch_uncertainty/datasets/segmentation/__init__.py b/torch_uncertainty/datasets/segmentation/__init__.py index e69de29b..90f7bad4 100644 --- a/torch_uncertainty/datasets/segmentation/__init__.py +++ b/torch_uncertainty/datasets/segmentation/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .camvid import CamVid diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 808022b2..5c370353 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -1,7 +1,12 @@ +import shutil from collections.abc import Callable +from pathlib import Path from typing import NamedTuple +from PIL import Image +from torchvision import tv_tensors from torchvision.datasets import VisionDataset +from torchvision.datasets.utils import download_and_extract_archive class CamVidClass(NamedTuple): @@ -28,13 +33,112 @@ class CamVid(VisionDataset): CamVidClass("unlabelled", 12, (0, 0, 0)), ] + urls = { + "raw": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip", + "label": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip", + } + filenames = { + "raw": "701_StillsRaw_full.zip", + "label": "LabeledApproved_full.zip", + } + base_folder = "camvid" + num_samples = 701 + def __init__( self, root: str, - split: str = "train", - transform: Callable | None = None, - target_transform: Callable | None = None, transforms: Callable | None = None, + download: bool = False, ) -> None: - """`CamVid `_ Dataset.""" - super().__init__(root, transforms, transform, target_transform) + """`CamVid `_ Dataset. + + Args: + root (str): Root directory of dataset where ``camvid/`` exists or + will be saved to if download is set to ``True``. + transforms (callable, optional): A function/transform that takes + input sample and its target as entry and returns a transformed + version. + download (bool, optional): If true, downloads the dataset from the + internet and puts it in root directory. If dataset is already + downloaded, it is not downloaded again. + """ + super().__init__(root, transforms, None, None) + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. " + "You can use download=True to download it" + ) + + self.images = sorted((Path(self.root) / "camvid" / "raw").glob("*.png")) + self.targets = sorted( + (Path(self.root) / "camvid" / "label").glob("*.png") + ) + + def __getitem__(self, index: int) -> tuple: + """Get image and target at index. + + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is the segmentation mask. + """ + image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) + target = tv_tensors.Mask(Image.open(self.targets[index])) + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + """Return the number of samples.""" + return self.num_samples + + def _check_integrity(self) -> bool: + """Check if the dataset exists.""" + if ( + len(list((Path(self.root) / "camvid" / "raw").glob("*.png"))) + != self.num_samples + ): + return False + if ( + len(list((Path(self.root) / "camvid" / "label").glob("*.png"))) + != self.num_samples + ): + return False + return True + + def download(self) -> None: + """Download the CamVid data if it doesn't exist already.""" + if self._check_integrity(): + print("Files already downloaded and verified") + return + + if Path(self.root) / self.base_folder: + shutil.rmtree(Path(self.root) / self.base_folder) + + download_and_extract_archive( + self.urls["raw"], + self.root, + extract_root=Path(self.root) / "camvid", + filename=self.filenames["raw"], + ) + (Path(self.root) / "camvid" / "701_StillsRaw_full").replace( + Path(self.root) / "camvid" / "raw" + ) + download_and_extract_archive( + self.urls["label"], + self.root, + extract_root=Path(self.root) / "camvid" / "label", + filename=self.filenames["label"], + ) + + +if __name__ == "__main__": + dataset = CamVid("data", download=True) + print(dataset) From f26c9bc9772af186aa5052354326a0be83854352 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 8 Dec 2023 17:33:18 +0100 Subject: [PATCH 003/148] :hammer: Add split argument to CamVid dataset Following splits from https://github.com/alexgkendall/SegNet-Tutorial --- .../datasets/segmentation/camvid.py | 63 +- .../datasets/segmentation/camvid_splits.json | 709 ++++++++++++++++++ 2 files changed, 762 insertions(+), 10 deletions(-) create mode 100644 torch_uncertainty/datasets/segmentation/camvid_splits.json diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 5c370353..e95b4205 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -1,7 +1,8 @@ +import json import shutil from collections.abc import Callable from pathlib import Path -from typing import NamedTuple +from typing import Literal, NamedTuple from PIL import Image from torchvision import tv_tensors @@ -42,11 +43,17 @@ class CamVid(VisionDataset): "label": "LabeledApproved_full.zip", } base_folder = "camvid" - num_samples = 701 + num_samples = { + "train": 367, + "val": 101, + "test": 233, + "all": 701, + } def __init__( self, root: str, + split: Literal["train", "val", "test"] | None = None, transforms: Callable | None = None, download: bool = False, ) -> None: @@ -55,6 +62,8 @@ def __init__( Args: root (str): Root directory of dataset where ``camvid/`` exists or will be saved to if download is set to ``True``. + split (str, optional): The dataset split, supports ``train``, + ``val`` and ``test``. Default: ``None``. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. @@ -62,6 +71,12 @@ def __init__( internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ + if split not in ["train", "val", "test", None]: + raise ValueError( + f"Unknown split '{split}'. " + "Supported splits are ['train', 'val', 'test', None]" + ) + super().__init__(root, transforms, None, None) if download: @@ -73,10 +88,38 @@ def __init__( "You can use download=True to download it" ) - self.images = sorted((Path(self.root) / "camvid" / "raw").glob("*.png")) - self.targets = sorted( - (Path(self.root) / "camvid" / "label").glob("*.png") - ) + # get filenames for split + if split is None: + self.images = sorted( + (Path(self.root) / "camvid" / "raw").glob("*.png") + ) + self.targets = sorted( + (Path(self.root) / "camvid" / "label").glob("*.png") + ) + else: + with (Path(__file__).parent / "camvid_splits.json").open() as f: + filenames = json.load(f)[split] + + self.images = sorted( + [ + path + for path in (Path(self.root) / "camvid" / "raw").glob( + "*.png" + ) + if path.stem in filenames + ] + ) + self.targets = sorted( + [ + path + for path in (Path(self.root) / "camvid" / "label").glob( + "*.png" + ) + if path.stem in filenames + ] + ) + + self.split = split if split is not None else "all" def __getitem__(self, index: int) -> tuple: """Get image and target at index. @@ -97,18 +140,18 @@ def __getitem__(self, index: int) -> tuple: def __len__(self) -> int: """Return the number of samples.""" - return self.num_samples + return self.num_samples[self.split] def _check_integrity(self) -> bool: """Check if the dataset exists.""" if ( len(list((Path(self.root) / "camvid" / "raw").glob("*.png"))) - != self.num_samples + != self.num_samples["all"] ): return False if ( len(list((Path(self.root) / "camvid" / "label").glob("*.png"))) - != self.num_samples + != self.num_samples["all"] ): return False return True @@ -140,5 +183,5 @@ def download(self) -> None: if __name__ == "__main__": - dataset = CamVid("data", download=True) + dataset = CamVid("data", split=None, download=True) print(dataset) diff --git a/torch_uncertainty/datasets/segmentation/camvid_splits.json b/torch_uncertainty/datasets/segmentation/camvid_splits.json new file mode 100644 index 00000000..6ea0bcd8 --- /dev/null +++ b/torch_uncertainty/datasets/segmentation/camvid_splits.json @@ -0,0 +1,709 @@ +{ + "train": [ + "0001TP_006690", + "0001TP_006720", + "0001TP_006750", + "0001TP_006780", + "0001TP_006810", + "0001TP_006840", + "0001TP_006870", + "0001TP_006900", + "0001TP_006930", + "0001TP_006960", + "0001TP_006990", + "0001TP_007020", + "0001TP_007050", + "0001TP_007080", + "0001TP_007110", + "0001TP_007140", + "0001TP_007170", + "0001TP_007200", + "0001TP_007230", + "0001TP_007260", + "0001TP_007290", + "0001TP_007320", + "0001TP_007350", + "0001TP_007380", + "0001TP_007410", + "0001TP_007440", + "0001TP_007470", + "0001TP_007500", + "0001TP_007530", + "0001TP_007560", + "0001TP_007590", + "0001TP_007620", + "0001TP_007650", + "0001TP_007680", + "0001TP_007710", + "0001TP_007740", + "0001TP_007770", + "0001TP_007800", + "0001TP_007830", + "0001TP_007860", + "0001TP_007890", + "0001TP_007920", + "0001TP_007950", + "0001TP_007980", + "0001TP_008010", + "0001TP_008040", + "0001TP_008070", + "0001TP_008100", + "0001TP_008130", + "0001TP_008160", + "0001TP_008190", + "0001TP_008220", + "0001TP_008250", + "0001TP_008280", + "0001TP_008310", + "0001TP_008340", + "0001TP_008370", + "0001TP_008400", + "0001TP_008430", + "0001TP_008460", + "0001TP_008490", + "0001TP_008520", + "0006R0_f00930", + "0006R0_f00960", + "0006R0_f00990", + "0006R0_f01020", + "0006R0_f01050", + "0006R0_f01080", + "0006R0_f01110", + "0006R0_f01140", + "0006R0_f01170", + "0006R0_f01200", + "0006R0_f01230", + "0006R0_f01260", + "0006R0_f01290", + "0006R0_f01320", + "0006R0_f01350", + "0006R0_f01380", + "0006R0_f01410", + "0006R0_f01440", + "0006R0_f01470", + "0006R0_f01500", + "0006R0_f01530", + "0006R0_f01560", + "0006R0_f01590", + "0006R0_f01620", + "0006R0_f01650", + "0006R0_f01680", + "0006R0_f01710", + "0006R0_f01740", + "0006R0_f01770", + "0006R0_f01800", + "0006R0_f01830", + "0006R0_f01860", + "0006R0_f01890", + "0006R0_f01920", + "0006R0_f01950", + "0006R0_f01980", + "0006R0_f02010", + "0006R0_f02040", + "0006R0_f02070", + "0006R0_f02100", + "0006R0_f02130", + "0006R0_f02160", + "0006R0_f02190", + "0006R0_f02220", + "0006R0_f02250", + "0006R0_f02280", + "0006R0_f02310", + "0006R0_f02340", + "0006R0_f02370", + "0006R0_f02400", + "0006R0_f02430", + "0006R0_f02460", + "0006R0_f02490", + "0006R0_f02520", + "0006R0_f02550", + "0006R0_f02580", + "0006R0_f02610", + "0006R0_f02640", + "0006R0_f02670", + "0006R0_f02700", + "0006R0_f02730", + "0006R0_f02760", + "0006R0_f02790", + "0006R0_f02820", + "0006R0_f02850", + "0006R0_f02880", + "0006R0_f02910", + "0006R0_f02940", + "0006R0_f02970", + "0006R0_f03000", + "0006R0_f03030", + "0006R0_f03060", + "0006R0_f03090", + "0006R0_f03120", + "0006R0_f03150", + "0006R0_f03180", + "0006R0_f03210", + "0006R0_f03240", + "0006R0_f03270", + "0006R0_f03300", + "0006R0_f03330", + "0006R0_f03360", + "0006R0_f03390", + "0006R0_f03420", + "0006R0_f03450", + "0006R0_f03480", + "0006R0_f03510", + "0006R0_f03540", + "0006R0_f03570", + "0006R0_f03600", + "0006R0_f03630", + "0006R0_f03660", + "0006R0_f03690", + "0006R0_f03720", + "0006R0_f03750", + "0006R0_f03780", + "0006R0_f03810", + "0006R0_f03840", + "0006R0_f03870", + "0006R0_f03900", + "0006R0_f03930", + "0016E5_00390", + "0016E5_00420", + "0016E5_00450", + "0016E5_00480", + "0016E5_00510", + "0016E5_00540", + "0016E5_00570", + "0016E5_00600", + "0016E5_00630", + "0016E5_00660", + "0016E5_00690", + "0016E5_00720", + "0016E5_00750", + "0016E5_00780", + "0016E5_00810", + "0016E5_00840", + "0016E5_00870", + "0016E5_00901", + "0016E5_00930", + "0016E5_00960", + "0016E5_00990", + "0016E5_01020", + "0016E5_01050", + "0016E5_01080", + "0016E5_01110", + "0016E5_01140", + "0016E5_01170", + "0016E5_01200", + "0016E5_01230", + "0016E5_01260", + "0016E5_01290", + "0016E5_01320", + "0016E5_01350", + "0016E5_01380", + "0016E5_01410", + "0016E5_01440", + "0016E5_01470", + "0016E5_01500", + "0016E5_01530", + "0016E5_01560", + "0016E5_01590", + "0016E5_01620", + "0016E5_01650", + "0016E5_01680", + "0016E5_01710", + "0016E5_01740", + "0016E5_01770", + "0016E5_01800", + "0016E5_01830", + "0016E5_01860", + "0016E5_01890", + "0016E5_01920", + "0016E5_01950", + "0016E5_01980", + "0016E5_02010", + "0016E5_02040", + "0016E5_02070", + "0016E5_02100", + "0016E5_02130", + "0016E5_02160", + "0016E5_02190", + "0016E5_02220", + "0016E5_02250", + "0016E5_02280", + "0016E5_02310", + "0016E5_02340", + "0016E5_02370", + "0016E5_02400", + "0016E5_04350", + "0016E5_04380", + "0016E5_04410", + "0016E5_04440", + "0016E5_04470", + "0016E5_04500", + "0016E5_04530", + "0016E5_04560", + "0016E5_04590", + "0016E5_04620", + "0016E5_04650", + "0016E5_04680", + "0016E5_04710", + "0016E5_04740", + "0016E5_04770", + "0016E5_04800", + "0016E5_04830", + "0016E5_04860", + "0016E5_04890", + "0016E5_04920", + "0016E5_04950", + "0016E5_04980", + "0016E5_05010", + "0016E5_05040", + "0016E5_05070", + "0016E5_05100", + "0016E5_05130", + "0016E5_05160", + "0016E5_05190", + "0016E5_05220", + "0016E5_05250", + "0016E5_05280", + "0016E5_05310", + "0016E5_05340", + "0016E5_05370", + "0016E5_05400", + "0016E5_05430", + "0016E5_05460", + "0016E5_05490", + "0016E5_05520", + "0016E5_05550", + "0016E5_05580", + "0016E5_05610", + "0016E5_05640", + "0016E5_05670", + "0016E5_05700", + "0016E5_05730", + "0016E5_05760", + "0016E5_05790", + "0016E5_05820", + "0016E5_05850", + "0016E5_05880", + "0016E5_05910", + "0016E5_05940", + "0016E5_05970", + "0016E5_06000", + "0016E5_06030", + "0016E5_06060", + "0016E5_06090", + "0016E5_06120", + "0016E5_06150", + "0016E5_06180", + "0016E5_06210", + "0016E5_06240", + "0016E5_06270", + "0016E5_06300", + "0016E5_06330", + "0016E5_06360", + "0016E5_06390", + "0016E5_06420", + "0016E5_06450", + "0016E5_06480", + "0016E5_06510", + "0016E5_06540", + "0016E5_06570", + "0016E5_06600", + "0016E5_06630", + "0016E5_06660", + "0016E5_06690", + "0016E5_06720", + "0016E5_06750", + "0016E5_06780", + "0016E5_06810", + "0016E5_06840", + "0016E5_06870", + "0016E5_06900", + "0016E5_06930", + "0016E5_06960", + "0016E5_06990", + "0016E5_07020", + "0016E5_07050", + "0016E5_07080", + "0016E5_07110", + "0016E5_07140", + "0016E5_07170", + "0016E5_07200", + "0016E5_07230", + "0016E5_07260", + "0016E5_07290", + "0016E5_07320", + "0016E5_07350", + "0016E5_07380", + "0016E5_07410", + "0016E5_07440", + "0016E5_07470", + "0016E5_07500", + "0016E5_07530", + "0016E5_07560", + "0016E5_07590", + "0016E5_07620", + "0016E5_07650", + "0016E5_07680", + "0016E5_07710", + "0016E5_07740", + "0016E5_07770", + "0016E5_07800", + "0016E5_07830", + "0016E5_07860", + "0016E5_07890", + "0016E5_07920", + "0016E5_08190", + "0016E5_08220", + "0016E5_08250", + "0016E5_08280", + "0016E5_08310", + "0016E5_08340", + "0016E5_08370", + "0016E5_08400", + "0016E5_08430", + "0016E5_08460", + "0016E5_08490", + "0016E5_08520", + "0016E5_08550", + "0016E5_08580", + "0016E5_08610", + "0016E5_08640" + ], + "val": [ + "0016E5_07959", + "0016E5_07961", + "0016E5_07963", + "0016E5_07965", + "0016E5_07967", + "0016E5_07969", + "0016E5_07971", + "0016E5_07973", + "0016E5_07975", + "0016E5_07977", + "0016E5_07979", + "0016E5_07981", + "0016E5_07983", + "0016E5_07985", + "0016E5_07987", + "0016E5_07989", + "0016E5_07991", + "0016E5_07993", + "0016E5_07995", + "0016E5_07997", + "0016E5_07999", + "0016E5_08001", + "0016E5_08003", + "0016E5_08005", + "0016E5_08007", + "0016E5_08009", + "0016E5_08011", + "0016E5_08013", + "0016E5_08015", + "0016E5_08017", + "0016E5_08019", + "0016E5_08021", + "0016E5_08023", + "0016E5_08025", + "0016E5_08027", + "0016E5_08029", + "0016E5_08031", + "0016E5_08033", + "0016E5_08035", + "0016E5_08037", + "0016E5_08039", + "0016E5_08041", + "0016E5_08043", + "0016E5_08045", + "0016E5_08047", + "0016E5_08049", + "0016E5_08051", + "0016E5_08053", + "0016E5_08055", + "0016E5_08057", + "0016E5_08059", + "0016E5_08061", + "0016E5_08063", + "0016E5_08065", + "0016E5_08067", + "0016E5_08069", + "0016E5_08071", + "0016E5_08073", + "0016E5_08075", + "0016E5_08077", + "0016E5_08079", + "0016E5_08081", + "0016E5_08083", + "0016E5_08085", + "0016E5_08087", + "0016E5_08089", + "0016E5_08091", + "0016E5_08093", + "0016E5_08095", + "0016E5_08097", + "0016E5_08099", + "0016E5_08101", + "0016E5_08103", + "0016E5_08105", + "0016E5_08107", + "0016E5_08109", + "0016E5_08111", + "0016E5_08113", + "0016E5_08115", + "0016E5_08117", + "0016E5_08119", + "0016E5_08121", + "0016E5_08123", + "0016E5_08125", + "0016E5_08127", + "0016E5_08129", + "0016E5_08131", + "0016E5_08133", + "0016E5_08135", + "0016E5_08137", + "0016E5_08139", + "0016E5_08141", + "0016E5_08143", + "0016E5_08145", + "0016E5_08147", + "0016E5_08149", + "0016E5_08151", + "0016E5_08153", + "0016E5_08155", + "0016E5_08157", + "0016E5_08159" + ], + "test": [ + "0001TP_008550", + "0001TP_008580", + "0001TP_008610", + "0001TP_008640", + "0001TP_008670", + "0001TP_008700", + "0001TP_008730", + "0001TP_008760", + "0001TP_008790", + "0001TP_008820", + "0001TP_008850", + "0001TP_008880", + "0001TP_008910", + "0001TP_008940", + "0001TP_008970", + "0001TP_009000", + "0001TP_009030", + "0001TP_009060", + "0001TP_009090", + "0001TP_009120", + "0001TP_009150", + "0001TP_009180", + "0001TP_009210", + "0001TP_009240", + "0001TP_009270", + "0001TP_009300", + "0001TP_009330", + "0001TP_009360", + "0001TP_009390", + "0001TP_009420", + "0001TP_009450", + "0001TP_009480", + "0001TP_009510", + "0001TP_009540", + "0001TP_009570", + "0001TP_009600", + "0001TP_009630", + "0001TP_009660", + "0001TP_009690", + "0001TP_009720", + "0001TP_009750", + "0001TP_009780", + "0001TP_009810", + "0001TP_009840", + "0001TP_009870", + "0001TP_009900", + "0001TP_009930", + "0001TP_009960", + "0001TP_009990", + "0001TP_010020", + "0001TP_010050", + "0001TP_010080", + "0001TP_010110", + "0001TP_010140", + "0001TP_010170", + "0001TP_010200", + "0001TP_010230", + "0001TP_010260", + "0001TP_010290", + "0001TP_010320", + "0001TP_010350", + "0001TP_010380", + "Seq05VD_f00000", + "Seq05VD_f00030", + "Seq05VD_f00060", + "Seq05VD_f00090", + "Seq05VD_f00120", + "Seq05VD_f00150", + "Seq05VD_f00180", + "Seq05VD_f00210", + "Seq05VD_f00240", + "Seq05VD_f00270", + "Seq05VD_f00300", + "Seq05VD_f00330", + "Seq05VD_f00360", + "Seq05VD_f00390", + "Seq05VD_f00420", + "Seq05VD_f00450", + "Seq05VD_f00480", + "Seq05VD_f00510", + "Seq05VD_f00540", + "Seq05VD_f00570", + "Seq05VD_f00600", + "Seq05VD_f00630", + "Seq05VD_f00660", + "Seq05VD_f00690", + "Seq05VD_f00720", + "Seq05VD_f00750", + "Seq05VD_f00780", + "Seq05VD_f00810", + "Seq05VD_f00840", + "Seq05VD_f00870", + "Seq05VD_f00900", + "Seq05VD_f00930", + "Seq05VD_f00960", + "Seq05VD_f00990", + "Seq05VD_f01020", + "Seq05VD_f01050", + "Seq05VD_f01080", + "Seq05VD_f01110", + "Seq05VD_f01140", + "Seq05VD_f01170", + "Seq05VD_f01200", + "Seq05VD_f01230", + "Seq05VD_f01260", + "Seq05VD_f01290", + "Seq05VD_f01320", + "Seq05VD_f01350", + "Seq05VD_f01380", + "Seq05VD_f01410", + "Seq05VD_f01440", + "Seq05VD_f01470", + "Seq05VD_f01500", + "Seq05VD_f01530", + "Seq05VD_f01560", + "Seq05VD_f01590", + "Seq05VD_f01620", + "Seq05VD_f01650", + "Seq05VD_f01680", + "Seq05VD_f01710", + "Seq05VD_f01740", + "Seq05VD_f01770", + "Seq05VD_f01800", + "Seq05VD_f01830", + "Seq05VD_f01860", + "Seq05VD_f01890", + "Seq05VD_f01920", + "Seq05VD_f01950", + "Seq05VD_f01980", + "Seq05VD_f02010", + "Seq05VD_f02040", + "Seq05VD_f02070", + "Seq05VD_f02100", + "Seq05VD_f02130", + "Seq05VD_f02160", + "Seq05VD_f02190", + "Seq05VD_f02220", + "Seq05VD_f02250", + "Seq05VD_f02280", + "Seq05VD_f02310", + "Seq05VD_f02340", + "Seq05VD_f02370", + "Seq05VD_f02400", + "Seq05VD_f02430", + "Seq05VD_f02460", + "Seq05VD_f02490", + "Seq05VD_f02520", + "Seq05VD_f02550", + "Seq05VD_f02580", + "Seq05VD_f02610", + "Seq05VD_f02640", + "Seq05VD_f02670", + "Seq05VD_f02700", + "Seq05VD_f02730", + "Seq05VD_f02760", + "Seq05VD_f02790", + "Seq05VD_f02820", + "Seq05VD_f02850", + "Seq05VD_f02880", + "Seq05VD_f02910", + "Seq05VD_f02940", + "Seq05VD_f02970", + "Seq05VD_f03000", + "Seq05VD_f03030", + "Seq05VD_f03060", + "Seq05VD_f03090", + "Seq05VD_f03120", + "Seq05VD_f03150", + "Seq05VD_f03180", + "Seq05VD_f03210", + "Seq05VD_f03240", + "Seq05VD_f03270", + "Seq05VD_f03300", + "Seq05VD_f03330", + "Seq05VD_f03360", + "Seq05VD_f03390", + "Seq05VD_f03420", + "Seq05VD_f03450", + "Seq05VD_f03480", + "Seq05VD_f03510", + "Seq05VD_f03540", + "Seq05VD_f03570", + "Seq05VD_f03600", + "Seq05VD_f03630", + "Seq05VD_f03660", + "Seq05VD_f03690", + "Seq05VD_f03720", + "Seq05VD_f03750", + "Seq05VD_f03780", + "Seq05VD_f03810", + "Seq05VD_f03840", + "Seq05VD_f03870", + "Seq05VD_f03900", + "Seq05VD_f03930", + "Seq05VD_f03960", + "Seq05VD_f03990", + "Seq05VD_f04020", + "Seq05VD_f04050", + "Seq05VD_f04080", + "Seq05VD_f04110", + "Seq05VD_f04140", + "Seq05VD_f04170", + "Seq05VD_f04200", + "Seq05VD_f04230", + "Seq05VD_f04260", + "Seq05VD_f04290", + "Seq05VD_f04320", + "Seq05VD_f04350", + "Seq05VD_f04380", + "Seq05VD_f04410", + "Seq05VD_f04440", + "Seq05VD_f04470", + "Seq05VD_f04500", + "Seq05VD_f04530", + "Seq05VD_f04560", + "Seq05VD_f04590", + "Seq05VD_f04620", + "Seq05VD_f04650", + "Seq05VD_f04680", + "Seq05VD_f04710", + "Seq05VD_f04740", + "Seq05VD_f04770", + "Seq05VD_f04800", + "Seq05VD_f04830", + "Seq05VD_f04860", + "Seq05VD_f04890", + "Seq05VD_f04920", + "Seq05VD_f04950", + "Seq05VD_f04980", + "Seq05VD_f05010", + "Seq05VD_f05040", + "Seq05VD_f05070", + "Seq05VD_f05100" + ] +} From a4b82498941b7e9d568652149d7a51f93eb776a6 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 13 Dec 2023 15:30:57 +0100 Subject: [PATCH 004/148] :fire: Remove camvid_splits.json to download it instead --- .../datasets/segmentation/camvid.py | 21 +- .../datasets/segmentation/camvid_splits.json | 709 ------------------ 2 files changed, 19 insertions(+), 711 deletions(-) delete mode 100644 torch_uncertainty/datasets/segmentation/camvid_splits.json diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index e95b4205..a6a1ddbd 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -7,7 +7,10 @@ from PIL import Image from torchvision import tv_tensors from torchvision.datasets import VisionDataset -from torchvision.datasets.utils import download_and_extract_archive +from torchvision.datasets.utils import ( + download_and_extract_archive, + download_url, +) class CamVidClass(NamedTuple): @@ -37,7 +40,11 @@ class CamVid(VisionDataset): urls = { "raw": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip", "label": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip", + "splits": "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/segmentation/camvid/splits.json", } + + splits_md5 = "db45289aaa83c60201391b11e78c6382" + filenames = { "raw": "701_StillsRaw_full.zip", "label": "LabeledApproved_full.zip", @@ -97,7 +104,7 @@ def __init__( (Path(self.root) / "camvid" / "label").glob("*.png") ) else: - with (Path(__file__).parent / "camvid_splits.json").open() as f: + with (Path(self.root) / "camvid" / "splits.json").open() as f: filenames = json.load(f)[split] self.images = sorted( @@ -154,6 +161,10 @@ def _check_integrity(self) -> bool: != self.num_samples["all"] ): return False + + if not (Path(self.root) / "camvid" / "splits.json").exists(): + return False + return True def download(self) -> None: @@ -180,6 +191,12 @@ def download(self) -> None: extract_root=Path(self.root) / "camvid" / "label", filename=self.filenames["label"], ) + download_url( + self.urls["splits"], + Path(self.root) / "camvid", + filename="splits.json", + md5=self.splits_md5, + ) if __name__ == "__main__": diff --git a/torch_uncertainty/datasets/segmentation/camvid_splits.json b/torch_uncertainty/datasets/segmentation/camvid_splits.json deleted file mode 100644 index 6ea0bcd8..00000000 --- a/torch_uncertainty/datasets/segmentation/camvid_splits.json +++ /dev/null @@ -1,709 +0,0 @@ -{ - "train": [ - "0001TP_006690", - "0001TP_006720", - "0001TP_006750", - "0001TP_006780", - "0001TP_006810", - "0001TP_006840", - "0001TP_006870", - "0001TP_006900", - "0001TP_006930", - "0001TP_006960", - "0001TP_006990", - "0001TP_007020", - "0001TP_007050", - "0001TP_007080", - "0001TP_007110", - "0001TP_007140", - "0001TP_007170", - "0001TP_007200", - "0001TP_007230", - "0001TP_007260", - "0001TP_007290", - "0001TP_007320", - "0001TP_007350", - "0001TP_007380", - "0001TP_007410", - "0001TP_007440", - "0001TP_007470", - "0001TP_007500", - "0001TP_007530", - "0001TP_007560", - "0001TP_007590", - "0001TP_007620", - "0001TP_007650", - "0001TP_007680", - "0001TP_007710", - "0001TP_007740", - "0001TP_007770", - "0001TP_007800", - "0001TP_007830", - "0001TP_007860", - "0001TP_007890", - "0001TP_007920", - "0001TP_007950", - "0001TP_007980", - "0001TP_008010", - "0001TP_008040", - "0001TP_008070", - "0001TP_008100", - "0001TP_008130", - "0001TP_008160", - "0001TP_008190", - "0001TP_008220", - "0001TP_008250", - "0001TP_008280", - "0001TP_008310", - "0001TP_008340", - "0001TP_008370", - "0001TP_008400", - "0001TP_008430", - "0001TP_008460", - "0001TP_008490", - "0001TP_008520", - "0006R0_f00930", - "0006R0_f00960", - "0006R0_f00990", - "0006R0_f01020", - "0006R0_f01050", - "0006R0_f01080", - "0006R0_f01110", - "0006R0_f01140", - "0006R0_f01170", - "0006R0_f01200", - "0006R0_f01230", - "0006R0_f01260", - "0006R0_f01290", - "0006R0_f01320", - "0006R0_f01350", - "0006R0_f01380", - "0006R0_f01410", - "0006R0_f01440", - "0006R0_f01470", - "0006R0_f01500", - "0006R0_f01530", - "0006R0_f01560", - "0006R0_f01590", - "0006R0_f01620", - "0006R0_f01650", - "0006R0_f01680", - "0006R0_f01710", - "0006R0_f01740", - "0006R0_f01770", - "0006R0_f01800", - "0006R0_f01830", - "0006R0_f01860", - "0006R0_f01890", - "0006R0_f01920", - "0006R0_f01950", - "0006R0_f01980", - "0006R0_f02010", - "0006R0_f02040", - "0006R0_f02070", - "0006R0_f02100", - "0006R0_f02130", - "0006R0_f02160", - "0006R0_f02190", - "0006R0_f02220", - "0006R0_f02250", - "0006R0_f02280", - "0006R0_f02310", - "0006R0_f02340", - "0006R0_f02370", - "0006R0_f02400", - "0006R0_f02430", - "0006R0_f02460", - "0006R0_f02490", - "0006R0_f02520", - "0006R0_f02550", - "0006R0_f02580", - "0006R0_f02610", - "0006R0_f02640", - "0006R0_f02670", - "0006R0_f02700", - "0006R0_f02730", - "0006R0_f02760", - "0006R0_f02790", - "0006R0_f02820", - "0006R0_f02850", - "0006R0_f02880", - "0006R0_f02910", - "0006R0_f02940", - "0006R0_f02970", - "0006R0_f03000", - "0006R0_f03030", - "0006R0_f03060", - "0006R0_f03090", - "0006R0_f03120", - "0006R0_f03150", - "0006R0_f03180", - "0006R0_f03210", - "0006R0_f03240", - "0006R0_f03270", - "0006R0_f03300", - "0006R0_f03330", - "0006R0_f03360", - "0006R0_f03390", - "0006R0_f03420", - "0006R0_f03450", - "0006R0_f03480", - "0006R0_f03510", - "0006R0_f03540", - "0006R0_f03570", - "0006R0_f03600", - "0006R0_f03630", - "0006R0_f03660", - "0006R0_f03690", - "0006R0_f03720", - "0006R0_f03750", - "0006R0_f03780", - "0006R0_f03810", - "0006R0_f03840", - "0006R0_f03870", - "0006R0_f03900", - "0006R0_f03930", - "0016E5_00390", - "0016E5_00420", - "0016E5_00450", - "0016E5_00480", - "0016E5_00510", - "0016E5_00540", - "0016E5_00570", - "0016E5_00600", - "0016E5_00630", - "0016E5_00660", - "0016E5_00690", - "0016E5_00720", - "0016E5_00750", - "0016E5_00780", - "0016E5_00810", - "0016E5_00840", - "0016E5_00870", - "0016E5_00901", - "0016E5_00930", - "0016E5_00960", - "0016E5_00990", - "0016E5_01020", - "0016E5_01050", - "0016E5_01080", - "0016E5_01110", - "0016E5_01140", - "0016E5_01170", - "0016E5_01200", - "0016E5_01230", - "0016E5_01260", - "0016E5_01290", - "0016E5_01320", - "0016E5_01350", - "0016E5_01380", - "0016E5_01410", - "0016E5_01440", - "0016E5_01470", - "0016E5_01500", - "0016E5_01530", - "0016E5_01560", - "0016E5_01590", - "0016E5_01620", - "0016E5_01650", - "0016E5_01680", - "0016E5_01710", - "0016E5_01740", - "0016E5_01770", - "0016E5_01800", - "0016E5_01830", - "0016E5_01860", - "0016E5_01890", - "0016E5_01920", - "0016E5_01950", - "0016E5_01980", - "0016E5_02010", - "0016E5_02040", - "0016E5_02070", - "0016E5_02100", - "0016E5_02130", - "0016E5_02160", - "0016E5_02190", - "0016E5_02220", - "0016E5_02250", - "0016E5_02280", - "0016E5_02310", - "0016E5_02340", - "0016E5_02370", - "0016E5_02400", - "0016E5_04350", - "0016E5_04380", - "0016E5_04410", - "0016E5_04440", - "0016E5_04470", - "0016E5_04500", - "0016E5_04530", - "0016E5_04560", - "0016E5_04590", - "0016E5_04620", - "0016E5_04650", - "0016E5_04680", - "0016E5_04710", - "0016E5_04740", - "0016E5_04770", - "0016E5_04800", - "0016E5_04830", - "0016E5_04860", - "0016E5_04890", - "0016E5_04920", - "0016E5_04950", - "0016E5_04980", - "0016E5_05010", - "0016E5_05040", - "0016E5_05070", - "0016E5_05100", - "0016E5_05130", - "0016E5_05160", - "0016E5_05190", - "0016E5_05220", - "0016E5_05250", - "0016E5_05280", - "0016E5_05310", - "0016E5_05340", - "0016E5_05370", - "0016E5_05400", - "0016E5_05430", - "0016E5_05460", - "0016E5_05490", - "0016E5_05520", - "0016E5_05550", - "0016E5_05580", - "0016E5_05610", - "0016E5_05640", - "0016E5_05670", - "0016E5_05700", - "0016E5_05730", - "0016E5_05760", - "0016E5_05790", - "0016E5_05820", - "0016E5_05850", - "0016E5_05880", - "0016E5_05910", - "0016E5_05940", - "0016E5_05970", - "0016E5_06000", - "0016E5_06030", - "0016E5_06060", - "0016E5_06090", - "0016E5_06120", - "0016E5_06150", - "0016E5_06180", - "0016E5_06210", - "0016E5_06240", - "0016E5_06270", - "0016E5_06300", - "0016E5_06330", - "0016E5_06360", - "0016E5_06390", - "0016E5_06420", - "0016E5_06450", - "0016E5_06480", - "0016E5_06510", - "0016E5_06540", - "0016E5_06570", - "0016E5_06600", - "0016E5_06630", - "0016E5_06660", - "0016E5_06690", - "0016E5_06720", - "0016E5_06750", - "0016E5_06780", - "0016E5_06810", - "0016E5_06840", - "0016E5_06870", - "0016E5_06900", - "0016E5_06930", - "0016E5_06960", - "0016E5_06990", - "0016E5_07020", - "0016E5_07050", - "0016E5_07080", - "0016E5_07110", - "0016E5_07140", - "0016E5_07170", - "0016E5_07200", - "0016E5_07230", - "0016E5_07260", - "0016E5_07290", - "0016E5_07320", - "0016E5_07350", - "0016E5_07380", - "0016E5_07410", - "0016E5_07440", - "0016E5_07470", - "0016E5_07500", - "0016E5_07530", - "0016E5_07560", - "0016E5_07590", - "0016E5_07620", - "0016E5_07650", - "0016E5_07680", - "0016E5_07710", - "0016E5_07740", - "0016E5_07770", - "0016E5_07800", - "0016E5_07830", - "0016E5_07860", - "0016E5_07890", - "0016E5_07920", - "0016E5_08190", - "0016E5_08220", - "0016E5_08250", - "0016E5_08280", - "0016E5_08310", - "0016E5_08340", - "0016E5_08370", - "0016E5_08400", - "0016E5_08430", - "0016E5_08460", - "0016E5_08490", - "0016E5_08520", - "0016E5_08550", - "0016E5_08580", - "0016E5_08610", - "0016E5_08640" - ], - "val": [ - "0016E5_07959", - "0016E5_07961", - "0016E5_07963", - "0016E5_07965", - "0016E5_07967", - "0016E5_07969", - "0016E5_07971", - "0016E5_07973", - "0016E5_07975", - "0016E5_07977", - "0016E5_07979", - "0016E5_07981", - "0016E5_07983", - "0016E5_07985", - "0016E5_07987", - "0016E5_07989", - "0016E5_07991", - "0016E5_07993", - "0016E5_07995", - "0016E5_07997", - "0016E5_07999", - "0016E5_08001", - "0016E5_08003", - "0016E5_08005", - "0016E5_08007", - "0016E5_08009", - "0016E5_08011", - "0016E5_08013", - "0016E5_08015", - "0016E5_08017", - "0016E5_08019", - "0016E5_08021", - "0016E5_08023", - "0016E5_08025", - "0016E5_08027", - "0016E5_08029", - "0016E5_08031", - "0016E5_08033", - "0016E5_08035", - "0016E5_08037", - "0016E5_08039", - "0016E5_08041", - "0016E5_08043", - "0016E5_08045", - "0016E5_08047", - "0016E5_08049", - "0016E5_08051", - "0016E5_08053", - "0016E5_08055", - "0016E5_08057", - "0016E5_08059", - "0016E5_08061", - "0016E5_08063", - "0016E5_08065", - "0016E5_08067", - "0016E5_08069", - "0016E5_08071", - "0016E5_08073", - "0016E5_08075", - "0016E5_08077", - "0016E5_08079", - "0016E5_08081", - "0016E5_08083", - "0016E5_08085", - "0016E5_08087", - "0016E5_08089", - "0016E5_08091", - "0016E5_08093", - "0016E5_08095", - "0016E5_08097", - "0016E5_08099", - "0016E5_08101", - "0016E5_08103", - "0016E5_08105", - "0016E5_08107", - "0016E5_08109", - "0016E5_08111", - "0016E5_08113", - "0016E5_08115", - "0016E5_08117", - "0016E5_08119", - "0016E5_08121", - "0016E5_08123", - "0016E5_08125", - "0016E5_08127", - "0016E5_08129", - "0016E5_08131", - "0016E5_08133", - "0016E5_08135", - "0016E5_08137", - "0016E5_08139", - "0016E5_08141", - "0016E5_08143", - "0016E5_08145", - "0016E5_08147", - "0016E5_08149", - "0016E5_08151", - "0016E5_08153", - "0016E5_08155", - "0016E5_08157", - "0016E5_08159" - ], - "test": [ - "0001TP_008550", - "0001TP_008580", - "0001TP_008610", - "0001TP_008640", - "0001TP_008670", - "0001TP_008700", - "0001TP_008730", - "0001TP_008760", - "0001TP_008790", - "0001TP_008820", - "0001TP_008850", - "0001TP_008880", - "0001TP_008910", - "0001TP_008940", - "0001TP_008970", - "0001TP_009000", - "0001TP_009030", - "0001TP_009060", - "0001TP_009090", - "0001TP_009120", - "0001TP_009150", - "0001TP_009180", - "0001TP_009210", - "0001TP_009240", - "0001TP_009270", - "0001TP_009300", - "0001TP_009330", - "0001TP_009360", - "0001TP_009390", - "0001TP_009420", - "0001TP_009450", - "0001TP_009480", - "0001TP_009510", - "0001TP_009540", - "0001TP_009570", - "0001TP_009600", - "0001TP_009630", - "0001TP_009660", - "0001TP_009690", - "0001TP_009720", - "0001TP_009750", - "0001TP_009780", - "0001TP_009810", - "0001TP_009840", - "0001TP_009870", - "0001TP_009900", - "0001TP_009930", - "0001TP_009960", - "0001TP_009990", - "0001TP_010020", - "0001TP_010050", - "0001TP_010080", - "0001TP_010110", - "0001TP_010140", - "0001TP_010170", - "0001TP_010200", - "0001TP_010230", - "0001TP_010260", - "0001TP_010290", - "0001TP_010320", - "0001TP_010350", - "0001TP_010380", - "Seq05VD_f00000", - "Seq05VD_f00030", - "Seq05VD_f00060", - "Seq05VD_f00090", - "Seq05VD_f00120", - "Seq05VD_f00150", - "Seq05VD_f00180", - "Seq05VD_f00210", - "Seq05VD_f00240", - "Seq05VD_f00270", - "Seq05VD_f00300", - "Seq05VD_f00330", - "Seq05VD_f00360", - "Seq05VD_f00390", - "Seq05VD_f00420", - "Seq05VD_f00450", - "Seq05VD_f00480", - "Seq05VD_f00510", - "Seq05VD_f00540", - "Seq05VD_f00570", - "Seq05VD_f00600", - "Seq05VD_f00630", - "Seq05VD_f00660", - "Seq05VD_f00690", - "Seq05VD_f00720", - "Seq05VD_f00750", - "Seq05VD_f00780", - "Seq05VD_f00810", - "Seq05VD_f00840", - "Seq05VD_f00870", - "Seq05VD_f00900", - "Seq05VD_f00930", - "Seq05VD_f00960", - "Seq05VD_f00990", - "Seq05VD_f01020", - "Seq05VD_f01050", - "Seq05VD_f01080", - "Seq05VD_f01110", - "Seq05VD_f01140", - "Seq05VD_f01170", - "Seq05VD_f01200", - "Seq05VD_f01230", - "Seq05VD_f01260", - "Seq05VD_f01290", - "Seq05VD_f01320", - "Seq05VD_f01350", - "Seq05VD_f01380", - "Seq05VD_f01410", - "Seq05VD_f01440", - "Seq05VD_f01470", - "Seq05VD_f01500", - "Seq05VD_f01530", - "Seq05VD_f01560", - "Seq05VD_f01590", - "Seq05VD_f01620", - "Seq05VD_f01650", - "Seq05VD_f01680", - "Seq05VD_f01710", - "Seq05VD_f01740", - "Seq05VD_f01770", - "Seq05VD_f01800", - "Seq05VD_f01830", - "Seq05VD_f01860", - "Seq05VD_f01890", - "Seq05VD_f01920", - "Seq05VD_f01950", - "Seq05VD_f01980", - "Seq05VD_f02010", - "Seq05VD_f02040", - "Seq05VD_f02070", - "Seq05VD_f02100", - "Seq05VD_f02130", - "Seq05VD_f02160", - "Seq05VD_f02190", - "Seq05VD_f02220", - "Seq05VD_f02250", - "Seq05VD_f02280", - "Seq05VD_f02310", - "Seq05VD_f02340", - "Seq05VD_f02370", - "Seq05VD_f02400", - "Seq05VD_f02430", - "Seq05VD_f02460", - "Seq05VD_f02490", - "Seq05VD_f02520", - "Seq05VD_f02550", - "Seq05VD_f02580", - "Seq05VD_f02610", - "Seq05VD_f02640", - "Seq05VD_f02670", - "Seq05VD_f02700", - "Seq05VD_f02730", - "Seq05VD_f02760", - "Seq05VD_f02790", - "Seq05VD_f02820", - "Seq05VD_f02850", - "Seq05VD_f02880", - "Seq05VD_f02910", - "Seq05VD_f02940", - "Seq05VD_f02970", - "Seq05VD_f03000", - "Seq05VD_f03030", - "Seq05VD_f03060", - "Seq05VD_f03090", - "Seq05VD_f03120", - "Seq05VD_f03150", - "Seq05VD_f03180", - "Seq05VD_f03210", - "Seq05VD_f03240", - "Seq05VD_f03270", - "Seq05VD_f03300", - "Seq05VD_f03330", - "Seq05VD_f03360", - "Seq05VD_f03390", - "Seq05VD_f03420", - "Seq05VD_f03450", - "Seq05VD_f03480", - "Seq05VD_f03510", - "Seq05VD_f03540", - "Seq05VD_f03570", - "Seq05VD_f03600", - "Seq05VD_f03630", - "Seq05VD_f03660", - "Seq05VD_f03690", - "Seq05VD_f03720", - "Seq05VD_f03750", - "Seq05VD_f03780", - "Seq05VD_f03810", - "Seq05VD_f03840", - "Seq05VD_f03870", - "Seq05VD_f03900", - "Seq05VD_f03930", - "Seq05VD_f03960", - "Seq05VD_f03990", - "Seq05VD_f04020", - "Seq05VD_f04050", - "Seq05VD_f04080", - "Seq05VD_f04110", - "Seq05VD_f04140", - "Seq05VD_f04170", - "Seq05VD_f04200", - "Seq05VD_f04230", - "Seq05VD_f04260", - "Seq05VD_f04290", - "Seq05VD_f04320", - "Seq05VD_f04350", - "Seq05VD_f04380", - "Seq05VD_f04410", - "Seq05VD_f04440", - "Seq05VD_f04470", - "Seq05VD_f04500", - "Seq05VD_f04530", - "Seq05VD_f04560", - "Seq05VD_f04590", - "Seq05VD_f04620", - "Seq05VD_f04650", - "Seq05VD_f04680", - "Seq05VD_f04710", - "Seq05VD_f04740", - "Seq05VD_f04770", - "Seq05VD_f04800", - "Seq05VD_f04830", - "Seq05VD_f04860", - "Seq05VD_f04890", - "Seq05VD_f04920", - "Seq05VD_f04950", - "Seq05VD_f04980", - "Seq05VD_f05010", - "Seq05VD_f05040", - "Seq05VD_f05070", - "Seq05VD_f05100" - ] -} From 0158cdd4351af402dc967f46e1db6e453141624e Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 13 Dec 2023 17:49:41 +0100 Subject: [PATCH 005/148] :construction: Init CamVid DataModule class --- .../datamodules/segmentation/__init__.py | 0 .../datamodules/segmentation/camvid.py | 69 +++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 torch_uncertainty/datamodules/segmentation/__init__.py create mode 100644 torch_uncertainty/datamodules/segmentation/camvid.py diff --git a/torch_uncertainty/datamodules/segmentation/__init__.py b/torch_uncertainty/datamodules/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py new file mode 100644 index 00000000..c7cdd9ef --- /dev/null +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -0,0 +1,69 @@ +from argparse import ArgumentParser +from pathlib import Path +from typing import Any + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets.segmentation import CamVid + + +class CamVidDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + **kwargs, + ) -> None: + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.dataset = CamVid + + self.transform_train = ... + self.transform_test = ... + + def prepare_data(self) -> None: # coverage: ignore + self.dataset(root=self.root, download=True) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + self.train = self.dataset( + root=self.root, + split="train", + download=False, + transform=self.transform_train, + ) + self.val = self.dataset( + root=self.root, + split="val", + download=False, + transform=self.transform_test, + ) + elif stage == "test": + self.test = self.dataset( + root=self.root, + split="test", + download=False, + transform=self.transform_test, + ) + else: + raise ValueError(f"Stage {stage} is not supported.") + + @classmethod + def add_argparse_args( + cls, + parent_parser: ArgumentParser, + **kwargs: Any, + ) -> ArgumentParser: + p = parent_parser.add_argument_group("datamodule") + p.add_argument("--root", type=str, default="./data/") + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--num_workers", type=int, default=4) + return parent_parser From a0a5bd479a6ca779b4b8f4931e723f938d167a7a Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 20 Dec 2023 10:26:58 +0100 Subject: [PATCH 006/148] :construction: Add Intersection over Union metric --- torch_uncertainty/metrics/iou.py | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 torch_uncertainty/metrics/iou.py diff --git a/torch_uncertainty/metrics/iou.py b/torch_uncertainty/metrics/iou.py new file mode 100644 index 00000000..c6259ed1 --- /dev/null +++ b/torch_uncertainty/metrics/iou.py @@ -0,0 +1,34 @@ +from einops import rearrange +from torch import Tensor +from torchmetrics.classification.stat_scores import MulticlassStatScores +from torchmetrics.utilities.compute import _safe_divide + + +class IntersectionOverUnion(MulticlassStatScores): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + + Args: + preds (Tensor): prediction images of shape :math:`(B, H, W)` or + :math:`(B, C, H, W)`. + target (Tensor): target images of shape :math:`(B, H, W)`. + """ + if preds.ndim == 3: + preds = preds.flatten() + if preds.ndim == 4: + preds = rearrange(preds, "b c h w -> (b h w) c") + + target = target.flatten() + + super().update(preds, target) + + def compute(self) -> Tensor: + """Compute the Intersection over Union (IoU) based on inputs passed to + ``update``. + """ + tp, fp, _, fn = self._final_state() + return _safe_divide(tp, tp + fp + fn) From aa5339501cd5efbedc8b25094c7bbfa49a0949eb Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 19 Jan 2024 10:33:25 +0100 Subject: [PATCH 007/148] :construction: Segmentation and LightningCLI support (wip) --- .../cifar10/configs/resnet18.yaml | 36 + experiments/classification/cifar10/readme.md | 17 + experiments/classification/cifar10/resnet.py | 78 +- pyproject.toml | 2 +- torch_uncertainty/baselines/__init__.py | 7 +- .../baselines/classification/__init__.py | 5 +- .../baselines/classification/resnet.py | 195 ++- .../baselines/classification/vgg.py | 446 ++--- .../baselines/classification/wideresnet.py | 498 +++--- .../baselines/segmentation/__init__.py | 0 torch_uncertainty/datamodules/abstract.py | 2 +- torch_uncertainty/datamodules/cifar10.py | 2 - .../datamodules/segmentation/camvid.py | 11 +- torch_uncertainty/lightning_cli.py | 47 + torch_uncertainty/metrics/__init__.py | 1 + torch_uncertainty/routines/classification.py | 1501 +++++++++++------ torch_uncertainty/routines/segmentation.py | 70 + 17 files changed, 1730 insertions(+), 1188 deletions(-) create mode 100644 experiments/classification/cifar10/configs/resnet18.yaml create mode 100644 experiments/classification/cifar10/readme.md create mode 100644 torch_uncertainty/baselines/segmentation/__init__.py create mode 100644 torch_uncertainty/lightning_cli.py create mode 100644 torch_uncertainty/routines/segmentation.py diff --git a/experiments/classification/cifar10/configs/resnet18.yaml b/experiments/classification/cifar10/configs/resnet18.yaml new file mode 100644 index 00000000..7caa52e6 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18.yaml @@ -0,0 +1,36 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +trainer: + precision: 16 + max_epochs: 75 +model: + num_classes: 10 + in_channels: 3 + loss: + - class_path: torch.nn.CrossEntropyLoss + version: "vanilla" + arch: 18 + style: cifar +data: + root: null + evaluate_ood: null + batch_size: null + val_split: 0.0 + num_workers: 1 + cutout: null + auto_augment: null + test_alt: null + corruption_severity: 1 + num_dataloaders: 1 + pin_memory: true + persistent_workers: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/readme.md b/experiments/classification/cifar10/readme.md new file mode 100644 index 00000000..bb5b06b4 --- /dev/null +++ b/experiments/classification/cifar10/readme.md @@ -0,0 +1,17 @@ +# CIFAR10 - Benchmark + +This folder contains the code to train models on the CIFAR10 dataset. The task is to classify images into $10$ classes. + +## ResNet-backbone models + +`torch-uncertainty` leverages [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html#lightning.pytorch.cli.LightningCLI) the configurable command line tool for pytorch-lightning. To ease the train of models, we provide a set of predefined configurations for the CIFAR10 dataset. The configurations are located in the `configs` folder. + +**Train** + +```bash +``` + +**Evaluate** + +```bash +``` diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index a4b1b76a..7fcf7e41 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -1,72 +1,24 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser, LightningCLI +from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 -from torch import nn - -from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import get_procedure -from torch_uncertainty.utils import csv_writer - -if __name__ == "__main__": - args = init_args(ResNet, CIFAR10DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) +from torch_uncertainty.lightning_cli import MySaveConfigCallback - if args.exp_name == "": - args.exp_name = f"{args.version}-resnet{args.arch}-cifar10" - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) +class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - if args.opt_temp_scaling: - calibration_set = dm.get_test_set - elif args.val_temp_scaling: - calibration_set = dm.get_val_set - else: - calibration_set = None - if args.use_cv: - list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) - list_model = [ - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - style="cifar", - calibration_set=calibration_set, - **vars(args), - ) - for i in range(len(list_dm)) - ] +def cli_main(): + _ = MyLightningCLI( + ResNet, CIFAR10DataModule, save_config_callback=MySaveConfigCallback + ) - results = cli_main( - list_model, list_dm, args.exp_dir, args.exp_name, args - ) - else: - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - style="cifar", - calibration_set=calibration_set, - **vars(args), - ) - results = cli_main(model, dm, args.exp_dir, args.exp_name, args) - - for dict_result in results: - csv_writer( - Path(args.exp_dir) / Path(args.exp_name) / "results.csv", - dict_result, - ) +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli_main() diff --git a/pyproject.toml b/pyproject.toml index 67c5dfb3..4f2acc47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ ] dependencies = [ "timm", - "pytorch-lightning<2", + "lightning[pytorch-extra]", "tensorboard", "einops", "torchinfo", diff --git a/torch_uncertainty/baselines/__init__.py b/torch_uncertainty/baselines/__init__.py index 71fe93e3..4bfa4e54 100644 --- a/torch_uncertainty/baselines/__init__.py +++ b/torch_uncertainty/baselines/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .classification.resnet import ResNet -from .classification.vgg import VGG -from .classification.wideresnet import WideResNet -from .deep_ensembles import DeepEnsembles + +# from .classification.vgg import VGG +# from .classification.wideresnet import WideResNet +# from .deep_ensembles import DeepEnsembles diff --git a/torch_uncertainty/baselines/classification/__init__.py b/torch_uncertainty/baselines/classification/__init__.py index 1326c2e3..65873c97 100644 --- a/torch_uncertainty/baselines/classification/__init__.py +++ b/torch_uncertainty/baselines/classification/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 from .resnet import ResNet -from .vgg import VGG -from .wideresnet import WideResNet + +# from .vgg import VGG +# from .wideresnet import WideResNet diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 2243be64..fb4e381d 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -1,21 +1,8 @@ -from argparse import ArgumentParser, BooleanOptionalAction -from pathlib import Path -from typing import Any, Literal +from collections.abc import Callable +from typing import Literal -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) from torch import nn -from torch_uncertainty.baselines.utils.parser_addons import ( - add_masked_specific_args, - add_mimo_specific_args, - add_packed_specific_args, - add_resnet_specific_args, -) from torch_uncertainty.models.resnet import ( batched_resnet18, batched_resnet34, @@ -44,13 +31,14 @@ resnet152, ) from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, + # ClassificationEnsemble, + # ClassificationSingle, + ClassificationRoutine, ) from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget -class ResNet: +class ResNet(ClassificationRoutine): single = ["vanilla"] ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] versions = { @@ -87,12 +75,12 @@ class ResNet: } archs = [18, 34, 50, 101, 152] - def __new__( - cls, + def __init__( + self, num_classes: int, in_channels: int, loss: type[nn.Module], - optimization_procedure: Any, + # optimization_procedure: Any, version: Literal[ "vanilla", "mc-dropout", @@ -103,8 +91,15 @@ def __new__( ], arch: int, style: str = "imagenet", - num_estimators: int | None = None, + num_estimators: int = 1, dropout_rate: float = 0.0, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, groups: int = 1, scale: float | None = None, alpha: float | None = None, @@ -115,9 +110,11 @@ def __new__( use_logits: bool = False, use_mi: bool = False, use_variation_ratio: bool = False, + log_plots: bool = False, + calibration_set: Callable | None = None, + evaluate_ood: bool = False, pretrained: bool = False, - **kwargs, - ) -> LightningModule: + ) -> None: r"""ResNet backbone baseline for classification providing support for various versions and architectures. @@ -154,6 +151,13 @@ def __new__( Only used if :attr:`version` is either ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. + mixtype (str, optional): _description_ + mixmode (str, optional): _description_ + dist_sim (str, optional): _description_ + kernel_tau_max (float, optional): _description_ + kernel_tau_std (float, optional): _description_ + mixup_alpha (float, optional): _description_ + cutmix_alpha (float, optional): _description_ groups (int, optional): Number of groups in convolutions. Defaults to ``1``. scale (float, optional): Expansion factor affecting the width of the @@ -178,10 +182,15 @@ def __new__( information as the OOD criterion or not. Defaults to ``False``. use_variation_ratio (bool, optional): Indicates whether to use the variation ratio as the OOD criterion or not. Defaults to ``False``. + log_plots (bool, optional): Indicates whether to log the plots or not. + Defaults to ``False``. + calibration_set (Callable, optional): Calibration set. Defaults to + ``None``. + evaluate_ood (bool, optional): Indicates whether to evaluate the OOD + detection or not. Defaults to ``False``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. Defaults to ``False``. - **kwargs: Additional arguments. Raises: ValueError: If :attr:`version` is not either ``"vanilla"``, @@ -199,7 +208,7 @@ def __new__( format_batch_fn = nn.Identity() - if version not in cls.versions: + if version not in self.versions: raise ValueError(f"Unknown version: {version}") if version == "vanilla": @@ -252,75 +261,89 @@ def __new__( batch_repeat=batch_repeat, ) - model = cls.versions[version][cls.archs.index(arch)](**params) - kwargs.update(params) - kwargs.update({"version": version, "arch": arch}) + model = self.versions[version][self.archs.index(arch)](**params) + # kwargs.update(params) + # kwargs.update({"version": version, "arch": arch}) # routine specific parameters - if version in cls.single: - return ClassificationSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - # version in cls.ensemble - return ClassificationEnsemble( + # if version in cls.single: + # return ClassificationSingle( + # model=model, + # loss=loss, + # # optimization_procedure=optimization_procedure, + # format_batch_fn=format_batch_fn, + # use_entropy=use_entropy, + # use_logits=use_logits, + # # **kwargs, + # ) + # # version in cls.ensemble + super().__init__( + num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, + num_estimators=num_estimators, format_batch_fn=format_batch_fn, + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + evaluate_ood=evaluate_ood, use_entropy=use_entropy, use_logits=use_logits, use_mi=use_mi, use_variation_ratio=use_variation_ratio, - **kwargs, + log_plots=log_plots, + calibration_set=calibration_set, ) - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) + self.save_hyperparameters( + ignore=[ + "log_plots", + ] + ) - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj + # @classmethod + # def load_from_checkpoint( + # cls, + # checkpoint_path: str | Path, + # hparams_file: str | Path, + # **kwargs, + # ) -> LightningModule: # coverage: ignore + # if hparams_file is not None: + # extension = str(hparams_file).split(".")[-1] + # if extension.lower() == "csv": + # hparams = load_hparams_from_tags_csv(hparams_file) + # elif extension.lower() in ("yml", "yaml"): + # hparams = load_hparams_from_yaml(hparams_file) + # else: + # raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser = add_resnet_specific_args(parser) - parser = add_packed_specific_args(parser) - parser = add_masked_specific_args(parser) - parser = add_mimo_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="vanilla", - help=f"Variation of ResNet. Choose among: {cls.versions.keys()}", - ) - parser.add_argument( - "--pretrained", - dest="pretrained", - action=BooleanOptionalAction, - default=False, - ) - return parser + # hparams.update(kwargs) + # checkpoint = torch.load(checkpoint_path) + # obj = cls(**hparams) + # obj.load_state_dict(checkpoint["state_dict"]) + # return obj + + # @classmethod + # def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: + # parser = ClassificationEnsemble.add_model_specific_args(parser) + # parser = add_resnet_specific_args(parser) + # parser = add_packed_specific_args(parser) + # parser = add_masked_specific_args(parser) + # parser = add_mimo_specific_args(parser) + # parser.add_argument( + # "--version", + # type=str, + # choices=cls.versions.keys(), + # default="vanilla", + # help=f"Variation of ResNet. Choose among: {cls.versions.keys()}", + # ) + # parser.add_argument( + # "--pretrained", + # dest="pretrained", + # action=BooleanOptionalAction, + # default=False, + # ) + # return parser diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 0db8e546..7f6adea7 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -1,223 +1,223 @@ -from argparse import ArgumentParser -from pathlib import Path -from typing import Any, Literal - -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) -from torch import nn - -from torch_uncertainty.baselines.utils.parser_addons import ( - add_packed_specific_args, - add_vgg_specific_args, -) -from torch_uncertainty.models.vgg import ( - packed_vgg11, - packed_vgg13, - packed_vgg16, - packed_vgg19, - vgg11, - vgg13, - vgg16, - vgg19, -) -from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, -) -from torch_uncertainty.transforms import RepeatTarget - - -class VGG: - single = ["vanilla"] - ensemble = ["mc-dropout", "packed"] - versions = { - "vanilla": [vgg11, vgg13, vgg16, vgg19], - "mc-dropout": [vgg11, vgg13, vgg16, vgg19], - "packed": [ - packed_vgg11, - packed_vgg13, - packed_vgg16, - packed_vgg19, - ], - } - archs = [11, 13, 16, 19] - - def __new__( - cls, - num_classes: int, - in_channels: int, - loss: type[nn.Module], - optimization_procedure: Any, - version: Literal["vanilla", "mc-dropout", "packed"], - arch: int, - num_estimators: int | None = None, - dropout_rate: float = 0.0, - style: str = "imagenet", - groups: int = 1, - alpha: float | None = None, - gamma: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - **kwargs, - ) -> LightningModule: - r"""VGG backbone baseline for classification providing support for - various versions and architectures. - - Args: - num_classes (int): Number of classes to predict. - in_channels (int): Number of input channels. - loss (nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - version (str): - Determines which VGG version to use: - - - ``"vanilla"``: original VGG - - ``"mc-dropout"``: Monte Carlo Dropout VGG - - ``"packed"``: Packed-Ensembles VGG - - arch (int): - Determines which VGG architecture to use: - - - ``11``: VGG-11 - - ``13``: VGG-13 - - ``16``: VGG-16 - - ``19``: VGG-19 - - style (str, optional): Which VGG style to use. Defaults to - ``imagenet``. - num_estimators (int, optional): Number of estimators in the ensemble. - Only used if :attr:`version` is either ``"packed"``, ``"batched"`` - or ``"masked"`` Defaults to ``None``. - dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - groups (int, optional): Number of groups in convolutions. Defaults to - ``1``. - alpha (float, optional): Expansion factor affecting the width of the - estimators. Only used if :attr:`version` is ``"packed"``. Defaults - to ``None``. - gamma (int, optional): Number of groups within each estimator. Only - used if :attr:`version` is ``"packed"`` and scales with - :attr:`groups`. Defaults to ``1s``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. - **kwargs: Additional arguments to be passed to the - Raises: - ValueError: If :attr:`version` is not either ``"vanilla"``, - ``"packed"``, ``"batched"`` or ``"masked"``. - - Returns: - LightningModule: VGG baseline ready for training and evaluation. - """ - params = { - "in_channels": in_channels, - "num_classes": num_classes, - "style": style, - "groups": groups, - } - - if version not in cls.versions: - raise ValueError(f"Unknown version: {version}") - - format_batch_fn = nn.Identity() - - if version == "vanilla": - params.update( - { - "dropout_rate": dropout_rate, - } - ) - elif version == "mc-dropout": - params.update( - { - "dropout_rate": dropout_rate, - "num_estimators": num_estimators, - } - ) - elif version == "packed": - params.update( - { - "num_estimators": num_estimators, - "alpha": alpha, - "style": style, - "gamma": gamma, - } - ) - format_batch_fn = RepeatTarget(num_repeats=num_estimators) - - model = cls.versions[version][cls.archs.index(arch)](**params) - kwargs.update(params) - # routine specific parameters - if version in cls.single: - return ClassificationSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - # version in cls.ensemble - return ClassificationEnsemble( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - **kwargs, - ) - - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) - - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser = add_vgg_specific_args(parser) - parser = add_packed_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="vanilla", - help=f"Variation of VGG. Choose among: {cls.versions.keys()}", - ) - return parser +# from argparse import ArgumentParser +# from pathlib import Path +# from typing import Any, Literal + +# import torch +# from pytorch_lightning import LightningModule +# from pytorch_lightning.core.saving import ( +# load_hparams_from_tags_csv, +# load_hparams_from_yaml, +# ) +# from torch import nn + +# from torch_uncertainty.baselines.utils.parser_addons import ( +# add_packed_specific_args, +# add_vgg_specific_args, +# ) +# from torch_uncertainty.models.vgg import ( +# packed_vgg11, +# packed_vgg13, +# packed_vgg16, +# packed_vgg19, +# vgg11, +# vgg13, +# vgg16, +# vgg19, +# ) +# from torch_uncertainty.routines.classification import ( +# ClassificationEnsemble, +# ClassificationSingle, +# ) +# from torch_uncertainty.transforms import RepeatTarget + + +# class VGG: +# single = ["vanilla"] +# ensemble = ["mc-dropout", "packed"] +# versions = { +# "vanilla": [vgg11, vgg13, vgg16, vgg19], +# "mc-dropout": [vgg11, vgg13, vgg16, vgg19], +# "packed": [ +# packed_vgg11, +# packed_vgg13, +# packed_vgg16, +# packed_vgg19, +# ], +# } +# archs = [11, 13, 16, 19] + +# def __new__( +# cls, +# num_classes: int, +# in_channels: int, +# loss: type[nn.Module], +# optimization_procedure: Any, +# version: Literal["vanilla", "mc-dropout", "packed"], +# arch: int, +# num_estimators: int | None = None, +# dropout_rate: float = 0.0, +# style: str = "imagenet", +# groups: int = 1, +# alpha: float | None = None, +# gamma: int = 1, +# use_entropy: bool = False, +# use_logits: bool = False, +# use_mi: bool = False, +# use_variation_ratio: bool = False, +# **kwargs, +# ) -> LightningModule: +# r"""VGG backbone baseline for classification providing support for +# various versions and architectures. + +# Args: +# num_classes (int): Number of classes to predict. +# in_channels (int): Number of input channels. +# loss (nn.Module): Training loss. +# optimization_procedure (Any): Optimization procedure, corresponds to +# what expect the `LightningModule.configure_optimizers() +# `_ +# method. +# version (str): +# Determines which VGG version to use: + +# - ``"vanilla"``: original VGG +# - ``"mc-dropout"``: Monte Carlo Dropout VGG +# - ``"packed"``: Packed-Ensembles VGG + +# arch (int): +# Determines which VGG architecture to use: + +# - ``11``: VGG-11 +# - ``13``: VGG-13 +# - ``16``: VGG-16 +# - ``19``: VGG-19 + +# style (str, optional): Which VGG style to use. Defaults to +# ``imagenet``. +# num_estimators (int, optional): Number of estimators in the ensemble. +# Only used if :attr:`version` is either ``"packed"``, ``"batched"`` +# or ``"masked"`` Defaults to ``None``. +# dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. +# groups (int, optional): Number of groups in convolutions. Defaults to +# ``1``. +# alpha (float, optional): Expansion factor affecting the width of the +# estimators. Only used if :attr:`version` is ``"packed"``. Defaults +# to ``None``. +# gamma (int, optional): Number of groups within each estimator. Only +# used if :attr:`version` is ``"packed"`` and scales with +# :attr:`groups`. Defaults to ``1s``. +# use_entropy (bool, optional): Indicates whether to use the entropy +# values as the OOD criterion or not. Defaults to ``False``. +# use_logits (bool, optional): Indicates whether to use the logits as the +# OOD criterion or not. Defaults to ``False``. +# use_mi (bool, optional): Indicates whether to use the mutual +# information as the OOD criterion or not. Defaults to ``False``. +# use_variation_ratio (bool, optional): Indicates whether to use the +# variation ratio as the OOD criterion or not. Defaults to ``False``. +# **kwargs: Additional arguments to be passed to the +# Raises: +# ValueError: If :attr:`version` is not either ``"vanilla"``, +# ``"packed"``, ``"batched"`` or ``"masked"``. + +# Returns: +# LightningModule: VGG baseline ready for training and evaluation. +# """ +# params = { +# "in_channels": in_channels, +# "num_classes": num_classes, +# "style": style, +# "groups": groups, +# } + +# if version not in cls.versions: +# raise ValueError(f"Unknown version: {version}") + +# format_batch_fn = nn.Identity() + +# if version == "vanilla": +# params.update( +# { +# "dropout_rate": dropout_rate, +# } +# ) +# elif version == "mc-dropout": +# params.update( +# { +# "dropout_rate": dropout_rate, +# "num_estimators": num_estimators, +# } +# ) +# elif version == "packed": +# params.update( +# { +# "num_estimators": num_estimators, +# "alpha": alpha, +# "style": style, +# "gamma": gamma, +# } +# ) +# format_batch_fn = RepeatTarget(num_repeats=num_estimators) + +# model = cls.versions[version][cls.archs.index(arch)](**params) +# kwargs.update(params) +# # routine specific parameters +# if version in cls.single: +# return ClassificationSingle( +# model=model, +# loss=loss, +# optimization_procedure=optimization_procedure, +# format_batch_fn=format_batch_fn, +# use_entropy=use_entropy, +# use_logits=use_logits, +# **kwargs, +# ) +# # version in cls.ensemble +# return ClassificationEnsemble( +# model=model, +# loss=loss, +# optimization_procedure=optimization_procedure, +# format_batch_fn=format_batch_fn, +# use_entropy=use_entropy, +# use_logits=use_logits, +# use_mi=use_mi, +# use_variation_ratio=use_variation_ratio, +# **kwargs, +# ) + +# @classmethod +# def load_from_checkpoint( +# cls, +# checkpoint_path: str | Path, +# hparams_file: str | Path, +# **kwargs, +# ) -> LightningModule: # coverage: ignore +# if hparams_file is not None: +# extension = str(hparams_file).split(".")[-1] +# if extension.lower() == "csv": +# hparams = load_hparams_from_tags_csv(hparams_file) +# elif extension.lower() in ("yml", "yaml"): +# hparams = load_hparams_from_yaml(hparams_file) +# else: +# raise ValueError( +# ".csv, .yml or .yaml is required for `hparams_file`" +# ) + +# hparams.update(kwargs) +# checkpoint = torch.load(checkpoint_path) +# obj = cls(**hparams) +# obj.load_state_dict(checkpoint["state_dict"]) +# return obj + +# @classmethod +# def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: +# parser = ClassificationEnsemble.add_model_specific_args(parser) +# parser = add_vgg_specific_args(parser) +# parser = add_packed_specific_args(parser) +# parser.add_argument( +# "--version", +# type=str, +# choices=cls.versions.keys(), +# default="vanilla", +# help=f"Variation of VGG. Choose among: {cls.versions.keys()}", +# ) +# return parser diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 1a1830ec..37ad16f4 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -1,267 +1,267 @@ -from argparse import ArgumentParser, BooleanOptionalAction -from pathlib import Path -from typing import Any, Literal +# from argparse import ArgumentParser, BooleanOptionalAction +# from pathlib import Path +# from typing import Any, Literal -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) -from torch import nn +# import torch +# from pytorch_lightning import LightningModule +# from pytorch_lightning.core.saving import ( +# load_hparams_from_tags_csv, +# load_hparams_from_yaml, +# ) +# from torch import nn -from torch_uncertainty.baselines.utils.parser_addons import ( - add_masked_specific_args, - add_mimo_specific_args, - add_packed_specific_args, - add_wideresnet_specific_args, -) -from torch_uncertainty.models.wideresnet import ( - batched_wideresnet28x10, - masked_wideresnet28x10, - mimo_wideresnet28x10, - packed_wideresnet28x10, - wideresnet28x10, -) -from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, -) -from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget +# from torch_uncertainty.baselines.utils.parser_addons import ( +# add_masked_specific_args, +# add_mimo_specific_args, +# add_packed_specific_args, +# add_wideresnet_specific_args, +# ) +# from torch_uncertainty.models.wideresnet import ( +# batched_wideresnet28x10, +# masked_wideresnet28x10, +# mimo_wideresnet28x10, +# packed_wideresnet28x10, +# wideresnet28x10, +# ) +# from torch_uncertainty.routines.classification import ( +# ClassificationEnsemble, +# ClassificationSingle, +# ) +# from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget -class WideResNet: - single = ["vanilla"] - ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] - versions = { - "vanilla": [wideresnet28x10], - "mc-dropout": [wideresnet28x10], - "packed": [packed_wideresnet28x10], - "batched": [batched_wideresnet28x10], - "masked": [masked_wideresnet28x10], - "mimo": [mimo_wideresnet28x10], - } +# class WideResNet: +# single = ["vanilla"] +# ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] +# versions = { +# "vanilla": [wideresnet28x10], +# "mc-dropout": [wideresnet28x10], +# "packed": [packed_wideresnet28x10], +# "batched": [batched_wideresnet28x10], +# "masked": [masked_wideresnet28x10], +# "mimo": [mimo_wideresnet28x10], +# } - def __new__( - cls, - num_classes: int, - in_channels: int, - loss: type[nn.Module], - optimization_procedure: Any, - version: Literal[ - "vanilla", "mc-dropout", "packed", "batched", "masked", "mimo" - ], - style: str = "imagenet", - num_estimators: int | None = None, - dropout_rate: float = 0.0, - groups: int | None = None, - scale: float | None = None, - alpha: int | None = None, - gamma: int | None = None, - rho: float = 1.0, - batch_repeat: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - # pretrained: bool = False, - **kwargs, - ) -> LightningModule: - r"""Wide-ResNet28x10 backbone baseline for classification providing support - for various versions. +# def __new__( +# cls, +# num_classes: int, +# in_channels: int, +# loss: type[nn.Module], +# optimization_procedure: Any, +# version: Literal[ +# "vanilla", "mc-dropout", "packed", "batched", "masked", "mimo" +# ], +# style: str = "imagenet", +# num_estimators: int | None = None, +# dropout_rate: float = 0.0, +# groups: int | None = None, +# scale: float | None = None, +# alpha: int | None = None, +# gamma: int | None = None, +# rho: float = 1.0, +# batch_repeat: int = 1, +# use_entropy: bool = False, +# use_logits: bool = False, +# use_mi: bool = False, +# use_variation_ratio: bool = False, +# # pretrained: bool = False, +# **kwargs, +# ) -> LightningModule: +# r"""Wide-ResNet28x10 backbone baseline for classification providing support +# for various versions. - Args: - num_classes (int): Number of classes to predict. - in_channels (int): Number of input channels. - loss (nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - version (str): - Determines which Wide-ResNet version to use: +# Args: +# num_classes (int): Number of classes to predict. +# in_channels (int): Number of input channels. +# loss (nn.Module): Training loss. +# optimization_procedure (Any): Optimization procedure, corresponds to +# what expect the `LightningModule.configure_optimizers() +# `_ +# method. +# version (str): +# Determines which Wide-ResNet version to use: - - ``"vanilla"``: original Wide-ResNet - - ``"mc-dropout"``: Monte Carlo Dropout Wide-ResNet - - ``"packed"``: Packed-Ensembles Wide-ResNet - - ``"batched"``: BatchEnsemble Wide-ResNet - - ``"masked"``: Masksemble Wide-ResNet - - ``"mimo"``: MIMO Wide-ResNet +# - ``"vanilla"``: original Wide-ResNet +# - ``"mc-dropout"``: Monte Carlo Dropout Wide-ResNet +# - ``"packed"``: Packed-Ensembles Wide-ResNet +# - ``"batched"``: BatchEnsemble Wide-ResNet +# - ``"masked"``: Masksemble Wide-ResNet +# - ``"mimo"``: MIMO Wide-ResNet - style (bool, optional): (str, optional): Which ResNet style to use. - Defaults to ``imagenet``. - num_estimators (int, optional): Number of estimators in the ensemble. - Only used if :attr:`version` is either ``"packed"``, ``"batched"`` - or ``"masked"`` Defaults to ``None``. - dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - groups (int, optional): Number of groups in convolutions. Defaults to - ``1``. - scale (float, optional): Expansion factor affecting the width of the - estimators. Only used if :attr:`version` is ``"masked"``. Defaults - to ``None``. - alpha (float, optional): Expansion factor affecting the width of the - estimators. Only used if :attr:`version` is ``"packed"``. Defaults - to ``None``. - gamma (int, optional): Number of groups within each estimator. Only - used if :attr:`version` is ``"packed"`` and scales with - :attr:`groups`. Defaults to ``1s``. - rho (float, optional): Probability that all estimators share the same - input. Only used if :attr:`version` is ``"mimo"``. Defaults to - ``1``. - batch_repeat (int, optional): Number of times to repeat the batch. Only - used if :attr:`version` is ``"mimo"``. Defaults to ``1``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. - pretrained (bool, optional): Indicates whether to use the pretrained - weights or not. Only used if :attr:`version` is ``"packed"``. - Defaults to ``False``. - **kwargs: Additional arguments. +# style (bool, optional): (str, optional): Which ResNet style to use. +# Defaults to ``imagenet``. +# num_estimators (int, optional): Number of estimators in the ensemble. +# Only used if :attr:`version` is either ``"packed"``, ``"batched"`` +# or ``"masked"`` Defaults to ``None``. +# dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. +# groups (int, optional): Number of groups in convolutions. Defaults to +# ``1``. +# scale (float, optional): Expansion factor affecting the width of the +# estimators. Only used if :attr:`version` is ``"masked"``. Defaults +# to ``None``. +# alpha (float, optional): Expansion factor affecting the width of the +# estimators. Only used if :attr:`version` is ``"packed"``. Defaults +# to ``None``. +# gamma (int, optional): Number of groups within each estimator. Only +# used if :attr:`version` is ``"packed"`` and scales with +# :attr:`groups`. Defaults to ``1s``. +# rho (float, optional): Probability that all estimators share the same +# input. Only used if :attr:`version` is ``"mimo"``. Defaults to +# ``1``. +# batch_repeat (int, optional): Number of times to repeat the batch. Only +# used if :attr:`version` is ``"mimo"``. Defaults to ``1``. +# use_entropy (bool, optional): Indicates whether to use the entropy +# values as the OOD criterion or not. Defaults to ``False``. +# use_logits (bool, optional): Indicates whether to use the logits as the +# OOD criterion or not. Defaults to ``False``. +# use_mi (bool, optional): Indicates whether to use the mutual +# information as the OOD criterion or not. Defaults to ``False``. +# use_variation_ratio (bool, optional): Indicates whether to use the +# variation ratio as the OOD criterion or not. Defaults to ``False``. +# pretrained (bool, optional): Indicates whether to use the pretrained +# weights or not. Only used if :attr:`version` is ``"packed"``. +# Defaults to ``False``. +# **kwargs: Additional arguments. - Raises: - ValueError: If :attr:`version` is not either ``"vanilla"``, - ``"packed"``, ``"batched"`` or ``"masked"``. +# Raises: +# ValueError: If :attr:`version` is not either ``"vanilla"``, +# ``"packed"``, ``"batched"`` or ``"masked"``. - Returns: - LightningModule: Wide-ResNet baseline ready for training and - evaluation. - """ - params = { - "in_channels": in_channels, - "num_classes": num_classes, - "style": style, - "groups": groups, - } +# Returns: +# LightningModule: Wide-ResNet baseline ready for training and +# evaluation. +# """ +# params = { +# "in_channels": in_channels, +# "num_classes": num_classes, +# "style": style, +# "groups": groups, +# } - format_batch_fn = nn.Identity() +# format_batch_fn = nn.Identity() - if version not in cls.versions: - raise ValueError(f"Unknown version: {version}") +# if version not in cls.versions: +# raise ValueError(f"Unknown version: {version}") - # version specific params - if version == "vanilla": - params.update( - { - "dropout_rate": dropout_rate, - } - ) - elif version == "mc-dropout": - params.update( - { - "dropout_rate": dropout_rate, - "num_estimators": num_estimators, - } - ) - elif version == "packed": - params.update( - { - "num_estimators": num_estimators, - "alpha": alpha, - "gamma": gamma, - } - ) - format_batch_fn = RepeatTarget(num_repeats=num_estimators) - elif version == "batched": - params.update( - { - "num_estimators": num_estimators, - } - ) - format_batch_fn = RepeatTarget(num_repeats=num_estimators) - elif version == "masked": - params.update( - { - "num_estimators": num_estimators, - "scale": scale, - } - ) - format_batch_fn = RepeatTarget(num_repeats=num_estimators) - elif version == "mimo": - params.update( - { - "num_estimators": num_estimators, - } - ) - format_batch_fn = MIMOBatchFormat( - num_estimators=num_estimators, - rho=rho, - batch_repeat=batch_repeat, - ) +# # version specific params +# if version == "vanilla": +# params.update( +# { +# "dropout_rate": dropout_rate, +# } +# ) +# elif version == "mc-dropout": +# params.update( +# { +# "dropout_rate": dropout_rate, +# "num_estimators": num_estimators, +# } +# ) +# elif version == "packed": +# params.update( +# { +# "num_estimators": num_estimators, +# "alpha": alpha, +# "gamma": gamma, +# } +# ) +# format_batch_fn = RepeatTarget(num_repeats=num_estimators) +# elif version == "batched": +# params.update( +# { +# "num_estimators": num_estimators, +# } +# ) +# format_batch_fn = RepeatTarget(num_repeats=num_estimators) +# elif version == "masked": +# params.update( +# { +# "num_estimators": num_estimators, +# "scale": scale, +# } +# ) +# format_batch_fn = RepeatTarget(num_repeats=num_estimators) +# elif version == "mimo": +# params.update( +# { +# "num_estimators": num_estimators, +# } +# ) +# format_batch_fn = MIMOBatchFormat( +# num_estimators=num_estimators, +# rho=rho, +# batch_repeat=batch_repeat, +# ) - model = cls.versions[version][0](**params) - kwargs.update(params) - # routine specific parameters - if version in cls.single: - return ClassificationSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - # version in cls.ensemble - return ClassificationEnsemble( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - **kwargs, - ) +# model = cls.versions[version][0](**params) +# kwargs.update(params) +# # routine specific parameters +# if version in cls.single: +# return ClassificationSingle( +# model=model, +# loss=loss, +# optimization_procedure=optimization_procedure, +# format_batch_fn=format_batch_fn, +# use_entropy=use_entropy, +# use_logits=use_logits, +# **kwargs, +# ) +# # version in cls.ensemble +# return ClassificationEnsemble( +# model=model, +# loss=loss, +# optimization_procedure=optimization_procedure, +# format_batch_fn=format_batch_fn, +# use_entropy=use_entropy, +# use_logits=use_logits, +# use_mi=use_mi, +# use_variation_ratio=use_variation_ratio, +# **kwargs, +# ) - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) +# @classmethod +# def load_from_checkpoint( +# cls, +# checkpoint_path: str | Path, +# hparams_file: str | Path, +# **kwargs, +# ) -> LightningModule: # coverage: ignore +# if hparams_file is not None: +# extension = str(hparams_file).split(".")[-1] +# if extension.lower() == "csv": +# hparams = load_hparams_from_tags_csv(hparams_file) +# elif extension.lower() in ("yml", "yaml"): +# hparams = load_hparams_from_yaml(hparams_file) +# else: +# raise ValueError( +# ".csv, .yml or .yaml is required for `hparams_file`" +# ) - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj +# hparams.update(kwargs) +# checkpoint = torch.load(checkpoint_path) +# obj = cls(**hparams) +# obj.load_state_dict(checkpoint["state_dict"]) +# return obj - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser = add_wideresnet_specific_args(parser) - parser = add_packed_specific_args(parser) - parser = add_masked_specific_args(parser) - parser = add_mimo_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="vanilla", - help=f"Variation of WideResNet. Choose among: {cls.versions.keys()}", - ) - parser.add_argument( - "--pretrained", - dest="pretrained", - action=BooleanOptionalAction, - default=False, - ) +# @classmethod +# def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: +# parser = ClassificationEnsemble.add_model_specific_args(parser) +# parser = add_wideresnet_specific_args(parser) +# parser = add_packed_specific_args(parser) +# parser = add_masked_specific_args(parser) +# parser = add_mimo_specific_args(parser) +# parser.add_argument( +# "--version", +# type=str, +# choices=cls.versions.keys(), +# default="vanilla", +# help=f"Variation of WideResNet. Choose among: {cls.versions.keys()}", +# ) +# parser.add_argument( +# "--pretrained", +# dest="pretrained", +# action=BooleanOptionalAction, +# default=False, +# ) - return parser +# return parser diff --git a/torch_uncertainty/baselines/segmentation/__init__.py b/torch_uncertainty/baselines/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 2854a95f..59c6b956 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -2,8 +2,8 @@ from pathlib import Path from typing import Any +from lightning.pytorch.core import LightningDataModule from numpy.typing import ArrayLike -from pytorch_lightning import LightningDataModule from sklearn.model_selection import StratifiedKFold from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index 932bc49d..2cf268c8 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -37,7 +37,6 @@ def __init__( num_dataloaders: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for CIFAR10. @@ -60,7 +59,6 @@ def __init__( pin_memory (bool): Whether to pin memory. Defaults to ``True``. persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. - kwargs: Additional arguments. """ super().__init__( root=root, diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index c7cdd9ef..fb483c75 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -2,6 +2,8 @@ from pathlib import Path from typing import Any +from torchvision.transforms import v2 + from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.segmentation import CamVid @@ -14,7 +16,6 @@ def __init__( num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: super().__init__( root=root, @@ -26,8 +27,12 @@ def __init__( self.dataset = CamVid - self.transform_train = ... - self.transform_test = ... + self.transform_train = v2.Compose( + [v2.Resize((360, 480), interpolation=v2.InterpolationMode.NEAREST)] + ) + self.transform_test = v2.Compose( + [v2.Resize((360, 480), interpolation=v2.InterpolationMode.NEAREST)] + ) def prepare_data(self) -> None: # coverage: ignore self.dataset(root=self.root, download=True) diff --git a/torch_uncertainty/lightning_cli.py b/torch_uncertainty/lightning_cli.py new file mode 100644 index 00000000..242d54f5 --- /dev/null +++ b/torch_uncertainty/lightning_cli.py @@ -0,0 +1,47 @@ +from pathlib import Path + +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.cli import SaveConfigCallback +from typing_extensions import override + + +class MySaveConfigCallback(SaveConfigCallback): + @override + def setup( + self, trainer: Trainer, pl_module: LightningModule, stage: str + ) -> None: + if self.already_saved: + return + + if self.save_to_log_dir and stage == "fit": + log_dir = trainer.log_dir # this broadcasts the directory + assert log_dir is not None + config_path = Path(log_dir) / self.config_filename + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = ( + fs.isfile(config_path) if trainer.is_global_zero else False + ) + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: + # TODO: complete error description + raise RuntimeError("TODO") + + if trainer.is_global_zero: + fs.makedirs(log_dir, exist_ok=True) + self.parser.save( + self.config, + config_path, + skip_none=False, + overwrite=self.overwrite, + multifile=self.multifile, + ) + if trainer.is_global_zero: + self.save_config(trainer, pl_module, stage) + self.already_saved = True + + self.already_saved = trainer.strategy.broadcast(self.already_saved) diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 938e7226..606e205c 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -4,6 +4,7 @@ from .disagreement import Disagreement from .entropy import Entropy from .fpr95 import FPR95 +from .iou import IntersectionOverUnion from .mutual_information import MutualInformation from .nll import GaussianNegativeLogLikelihood, NegativeLogLikelihood from .sparsification import AUSE diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 22e764d0..8f72e5d7 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,15 +1,20 @@ -from argparse import ArgumentParser, Namespace +# from argparse import ArgumentParser, Namespace from collections.abc import Callable from functools import partial -from typing import Any -import pytorch_lightning as pl import torch import torch.nn.functional as F -from einops import rearrange + +# import pytorch_lightning as pl +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT + +# from einops import rearrange from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities.memory import get_model_size_mb -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT + +# from pytorch_lightning.utilities.memory import get_model_size_mb +# from lightning.pytorch.utilities import get_model_size_mb +# from pytorch_lightning.utilities.types import STEP_OUTPUT from timm.data import Mixup as timm_Mixup from torch import Tensor, nn from torchmetrics import Accuracy, MetricCollection @@ -29,18 +34,789 @@ NegativeLogLikelihood, VariationRatio, ) -from torch_uncertainty.plotting_utils import plot_hist + +# from torch_uncertainty.plotting_utils import plot_hist from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup - -class ClassificationSingle(pl.LightningModule): +# class ClassificationSingle(pl.LightningModule): +# def __init__( +# self, +# num_classes: int, +# model: nn.Module, +# loss: type[nn.Module], +# optimization_procedure: Any, +# format_batch_fn: nn.Module | None = None, +# mixtype: str = "erm", +# mixmode: str = "elem", +# dist_sim: str = "emb", +# kernel_tau_max: float = 1.0, +# kernel_tau_std: float = 0.5, +# mixup_alpha: float = 0, +# cutmix_alpha: float = 0, +# evaluate_ood: bool = False, +# use_entropy: bool = False, +# use_logits: bool = False, +# log_plots: bool = False, +# calibration_set: Callable | None = None, +# **kwargs, +# ) -> None: +# """Classification routine for single models. + +# Args: +# num_classes (int): Number of classes. +# model (nn.Module): Model to train. +# loss (type[nn.Module]): Loss function. +# optimization_procedure (Any): Optimization procedure. +# format_batch_fn (nn.Module, optional): Function to format the batch. +# Defaults to :class:`torch.nn.Identity()`. +# mixtype (str, optional): Mixup type. Defaults to ``"erm"``. +# mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. +# dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. +# kernel_tau_max (float, optional): Maximum value for the kernel tau. +# Defaults to 1.0. +# kernel_tau_std (float, optional): Standard deviation for the kernel tau. +# Defaults to 0.5. +# mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. +# cutmix_alpha (float, optional): Alpha parameter for Cutmix. +# Defaults to 0. +# evaluate_ood (bool, optional): Indicates whether to evaluate the OOD +# detection performance or not. Defaults to ``False``. +# use_entropy (bool, optional): Indicates whether to use the entropy +# values as the OOD criterion or not. Defaults to ``False``. +# use_logits (bool, optional): Indicates whether to use the logits as the +# OOD criterion or not. Defaults to ``False``. +# log_plots (bool, optional): Indicates whether to log plots from +# metrics. Defaults to ``False``. +# calibration_set (Callable, optional): Function to get the calibration +# set. Defaults to ``None``. +# kwargs (Any): Additional arguments. + +# Note: +# The default OOD criterion is the softmax confidence score. + +# Warning: +# Make sure at most only one of :attr:`use_entropy` and :attr:`use_logits` +# attributes is set to ``True``. Otherwise a :class:`ValueError()` will +# be raised. +# """ +# super().__init__() + +# if format_batch_fn is None: +# format_batch_fn = nn.Identity() + +# self.save_hyperparameters( +# ignore=[ +# "model", +# "loss", +# "optimization_procedure", +# "format_batch_fn", +# "calibration_set", +# ] +# ) + +# if (use_logits + use_entropy) > 1: +# raise ValueError("You cannot choose more than one OOD criterion.") + +# self.num_classes = num_classes +# self.evaluate_ood = evaluate_ood +# self.use_logits = use_logits +# self.use_entropy = use_entropy +# self.log_plots = log_plots + +# self.calibration_set = calibration_set + +# self.binary_cls = num_classes == 1 + +# self.model = model +# self.loss = loss +# self.optimization_procedure = optimization_procedure +# # batch format +# self.format_batch_fn = format_batch_fn + +# # metrics +# if self.binary_cls: +# cls_metrics = MetricCollection( +# { +# "acc": Accuracy(task="binary"), +# "ece": CE(task="binary"), +# "brier": BrierScore(num_classes=1), +# }, +# compute_groups=False, +# ) +# else: +# cls_metrics = MetricCollection( +# { +# "nll": NegativeLogLikelihood(), +# "acc": Accuracy(task="multiclass", num_classes=self.num_classes), +# "ece": CE(task="multiclass", num_classes=self.num_classes), +# "brier": BrierScore(num_classes=self.num_classes), +# }, +# compute_groups=False, +# ) + +# self.val_cls_metrics = cls_metrics.clone(prefix="hp/val_") +# self.test_cls_metrics = cls_metrics.clone(prefix="hp/test_") + +# if self.calibration_set is not None: +# self.ts_cls_metrics = cls_metrics.clone(prefix="hp/ts_") + +# self.test_entropy_id = Entropy() + +# if self.evaluate_ood: +# ood_metrics = MetricCollection( +# { +# "fpr95": FPR95(pos_label=1), +# "auroc": BinaryAUROC(), +# "aupr": BinaryAveragePrecision(), +# }, +# compute_groups=[["auroc", "aupr"], ["fpr95"]], +# ) +# self.test_ood_metrics = ood_metrics.clone(prefix="hp/test_") +# self.test_entropy_ood = Entropy() + +# if mixup_alpha < 0 or cutmix_alpha < 0: +# raise ValueError( +# "Cutmix alpha and Mixup alpha must be positive." +# f"Got {mixup_alpha} and {cutmix_alpha}." +# ) + +# self.mixtype = mixtype +# self.mixmode = mixmode +# self.dist_sim = dist_sim + +# self.mixup = self.init_mixup( +# mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std +# ) + +# # Handle ELBO special cases +# self.is_elbo = isinstance(self.loss, partial) and self.loss.func == ELBOLoss + +# # DEC +# self.is_dec = self.loss == DECLoss or ( +# isinstance(self.loss, partial) and self.loss.func == DECLoss +# ) + +# def configure_optimizers(self) -> Any: +# return self.optimization_procedure(self) + +# @property +# def criterion(self) -> nn.Module: +# if self.is_elbo: +# self.loss = partial(self.loss, model=self.model) +# return self.loss() + +# def forward(self, inputs: Tensor) -> Tensor: +# return self.model.forward(inputs) + +# def on_train_start(self) -> None: +# # hyperparameters for performances +# param = {} +# param["storage"] = f"{get_model_size_mb(self)} MB" +# if self.logger is not None: # coverage: ignore +# self.logger.log_hyperparams( +# Namespace(**param), +# { +# "hp/val_nll": 0, +# "hp/val_acc": 0, +# "hp/test_acc": 0, +# "hp/test_nll": 0, +# "hp/test_ece": 0, +# "hp/test_brier": 0, +# "hp/test_entropy_id": 0, +# "hp/test_entropy_ood": 0, +# "hp/test_aupr": 0, +# "hp/test_auroc": 0, +# "hp/test_fpr95": 0, +# "hp/ts_test_nll": 0, +# "hp/ts_test_ece": 0, +# "hp/ts_test_brier": 0, +# }, +# ) + +# def training_step( +# self, batch: tuple[Tensor, Tensor], batch_idx: int +# ) -> STEP_OUTPUT: +# if self.mixtype == "kernel_warping": +# if self.dist_sim == "emb": +# with torch.no_grad(): +# feats = self.model.feats_forward(batch[0]).detach() + +# batch = self.mixup(*batch, feats) +# elif self.dist_sim == "inp": +# batch = self.mixup(*batch, batch[0]) +# else: +# batch = self.mixup(*batch) + +# inputs, targets = self.format_batch_fn(batch) + +# if self.is_elbo: +# loss = self.criterion(inputs, targets) +# else: +# logits = self.forward(inputs) +# # BCEWithLogitsLoss expects float targets +# if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: +# logits = logits.squeeze(-1) +# targets = targets.float() + +# if not self.is_dec: +# loss = self.criterion(logits, targets) +# else: +# loss = self.criterion(logits, targets, self.current_epoch) +# self.log("train_loss", loss) +# return loss + +# def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: +# inputs, targets = batch +# logits = self.forward(inputs) + +# if self.binary_cls: +# probs = torch.sigmoid(logits).squeeze(-1) +# else: +# probs = F.softmax(logits, dim=-1) + +# self.val_cls_metrics.update(probs, targets) + +# def validation_epoch_end(self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT]) -> None: +# self.log_dict(self.val_cls_metrics.compute()) +# self.val_cls_metrics.reset() + +# def on_test_start(self) -> None: +# if self.calibration_set is not None: +# self.scaler = TemperatureScaler(device=self.device).fit( +# model=self.model, calibration_set=self.calibration_set() +# ) +# self.cal_model = torch.nn.Sequential(self.model, self.scaler) +# else: +# self.scaler = None +# self.cal_model = None + +# def test_step( +# self, +# batch: tuple[Tensor, Tensor], +# batch_idx: int, +# dataloader_idx: int | None = 0, +# ) -> Tensor: +# inputs, targets = batch +# logits = self.forward(inputs) + +# if self.binary_cls: +# probs = torch.sigmoid(logits).squeeze(-1) +# else: +# probs = F.softmax(logits, dim=-1) + +# # self.cal_plot.update(probs, targets) +# confs = probs.max(dim=-1)[0] + +# if self.use_logits: +# ood_scores = -logits.max(dim=-1)[0] +# elif self.use_entropy: +# ood_scores = torch.special.entr(probs).sum(dim=-1) +# else: +# ood_scores = -confs + +# if ( +# self.calibration_set is not None +# and self.scaler is not None +# and self.cal_model is not None +# ): +# cal_logits = self.cal_model(inputs) +# cal_probs = F.softmax(cal_logits, dim=-1) +# self.ts_cls_metrics.update(cal_probs, targets) + +# if dataloader_idx == 0: +# self.test_cls_metrics.update(probs, targets) +# self.test_entropy_id(probs) +# self.log( +# "hp/test_entropy_id", +# self.test_entropy_id, +# on_epoch=True, +# add_dataloader_idx=False, +# ) +# if self.evaluate_ood: +# self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) +# elif self.evaluate_ood and dataloader_idx == 1: +# self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) +# self.test_entropy_ood(probs) +# self.log( +# "hp/test_entropy_ood", +# self.test_entropy_ood, +# on_epoch=True, +# add_dataloader_idx=False, +# ) +# return logits + +# def test_epoch_end(self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT]) -> None: +# self.log_dict( +# self.test_cls_metrics.compute(), +# ) + +# if ( +# self.calibration_set is not None +# and self.scaler is not None +# and self.cal_model is not None +# ): +# self.log_dict(self.ts_cls_metrics.compute()) +# self.ts_cls_metrics.reset() + +# if self.evaluate_ood: +# self.log_dict( +# self.test_ood_metrics.compute(), +# ) +# self.test_ood_metrics.reset() + +# if isinstance(self.logger, TensorBoardLogger) and self.log_plots: +# self.logger.experiment.add_figure( +# "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] +# ) + +# if self.evaluate_ood: +# id_logits = torch.cat(outputs[0], 0).float().cpu() +# ood_logits = torch.cat(outputs[1], 0).float().cpu() + +# id_probs = F.softmax(id_logits, dim=-1) +# ood_probs = F.softmax(ood_logits, dim=-1) + +# logits_fig = plot_hist( +# [id_logits.max(-1).values, ood_logits.max(-1).values], +# 20, +# "Histogram of the logits", +# )[0] +# probs_fig = plot_hist( +# [id_probs.max(-1).values, ood_probs.max(-1).values], +# 20, +# "Histogram of the likelihoods", +# )[0] +# self.logger.experiment.add_figure("Logit Histogram", logits_fig) +# self.logger.experiment.add_figure("Likelihood Histogram", probs_fig) + +# self.test_cls_metrics.reset() + +# def init_mixup( +# self, +# mixup_alpha: float, +# cutmix_alpha: float, +# kernel_tau_max: float, +# kernel_tau_std: float, +# ) -> Callable: +# if self.mixtype == "timm": +# return timm_Mixup( +# mixup_alpha=mixup_alpha, +# cutmix_alpha=cutmix_alpha, +# mode=self.mixmode, +# num_classes=self.num_classes, +# ) +# if self.mixtype == "mixup": +# return Mixup( +# alpha=mixup_alpha, +# mode=self.mixmode, +# num_classes=self.num_classes, +# ) +# if self.mixtype == "mixup_io": +# return MixupIO( +# alpha=mixup_alpha, +# mode=self.mixmode, +# num_classes=self.num_classes, +# ) +# if self.mixtype == "regmixup": +# return RegMixup( +# alpha=mixup_alpha, +# mode=self.mixmode, +# num_classes=self.num_classes, +# ) +# if self.mixtype == "kernel_warping": +# return WarpingMixup( +# alpha=mixup_alpha, +# mode=self.mixmode, +# num_classes=self.num_classes, +# apply_kernel=True, +# tau_max=kernel_tau_max, +# tau_std=kernel_tau_std, +# ) +# return lambda x, y: (x, y) + +# @staticmethod +# def add_model_specific_args( +# parent_parser: ArgumentParser, +# ) -> ArgumentParser: +# """Defines the routine's attributes via command-line options. + +# Args: +# parent_parser (ArgumentParser): Parent parser to be completed. + +# Adds: +# - ``--entropy``: sets :attr:`use_entropy` to ``True``. +# - ``--logits``: sets :attr:`use_logits` to ``True``. +# - ``--mixup_alpha``: sets :attr:`mixup_alpha` for Mixup +# - ``--cutmix_alpha``: sets :attr:`cutmix_alpha` for Cutmix +# - ``--mixtype``: sets :attr:`mixtype` for Mixup +# - ``--mixmode``: sets :attr:`mixmode` for Mixup +# - ``--dist_sim``: sets :attr:`dist_sim` for Mixup +# - ``--kernel_tau_max``: sets :attr:`kernel_tau_max` for Mixup +# - ``--kernel_tau_std``: sets :attr:`kernel_tau_std` for Mixup +# """ +# parent_parser.add_argument("--entropy", dest="use_entropy", action="store_true") +# parent_parser.add_argument("--logits", dest="use_logits", action="store_true") + +# # Mixup args +# parent_parser.add_argument( +# "--mixup_alpha", dest="mixup_alpha", type=float, default=0 +# ) +# parent_parser.add_argument( +# "--cutmix_alpha", dest="cutmix_alpha", type=float, default=0 +# ) +# parent_parser.add_argument("--mixtype", dest="mixtype", type=str, default="erm") +# parent_parser.add_argument( +# "--mixmode", dest="mixmode", type=str, default="elem" +# ) +# parent_parser.add_argument( +# "--dist_sim", dest="dist_sim", type=str, default="emb" +# ) +# parent_parser.add_argument( +# "--kernel_tau_max", dest="kernel_tau_max", type=float, default=1.0 +# ) +# parent_parser.add_argument( +# "--kernel_tau_std", dest="kernel_tau_std", type=float, default=0.5 +# ) +# return parent_parser + + +# class ClassificationEnsemble(ClassificationSingle): +# def __init__( +# self, +# num_classes: int, +# model: nn.Module, +# loss: type[nn.Module], +# optimization_procedure: Any, +# num_estimators: int, +# format_batch_fn: nn.Module | None = None, +# mixtype: str = "erm", +# mixmode: str = "elem", +# dist_sim: str = "emb", +# kernel_tau_max: float = 1.0, +# kernel_tau_std: float = 0.5, +# mixup_alpha: float = 0, +# cutmix_alpha: float = 0, +# evaluate_ood: bool = False, +# use_entropy: bool = False, +# use_logits: bool = False, +# use_mi: bool = False, +# use_variation_ratio: bool = False, +# log_plots: bool = False, +# **kwargs, +# ) -> None: +# """Classification routine for ensemble models. + +# Args: +# num_classes (int): Number of classes. +# model (nn.Module): Model to train. +# loss (type[nn.Module]): Loss function. +# optimization_procedure (Any): Optimization procedure. +# num_estimators (int): Number of estimators in the ensemble. +# format_batch_fn (nn.Module, optional): Function to format the batch. +# Defaults to :class:`torch.nn.Identity()`. +# mixtype (str, optional): Mixup type. Defaults to ``"erm"``. +# mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. +# dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. +# kernel_tau_max (float, optional): Maximum value for the kernel tau. +# Defaults to 1.0. +# kernel_tau_std (float, optional): Standard deviation for the kernel tau. +# Defaults to 0.5. +# mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. +# cutmix_alpha (float, optional): Alpha parameter for Cutmix. +# Defaults to 0. +# evaluate_ood (bool, optional): Indicates whether to evaluate the OOD +# detection performance or not. Defaults to ``False``. +# use_entropy (bool, optional): Indicates whether to use the entropy +# values as the OOD criterion or not. Defaults to ``False``. +# use_logits (bool, optional): Indicates whether to use the logits as the +# OOD criterion or not. Defaults to ``False``. +# use_mi (bool, optional): Indicates whether to use the mutual +# information as the OOD criterion or not. Defaults to ``False``. +# use_variation_ratio (bool, optional): Indicates whether to use the +# variation ratio as the OOD criterion or not. Defaults to ``False``. +# log_plots (bool, optional): Indicates whether to log plots from +# metrics. Defaults to ``False``. +# calibration_set (Callable, optional): Function to get the calibration +# set. Defaults to ``None``. +# kwargs (Any): Additional arguments. + +# Note: +# The default OOD criterion is the averaged softmax confidence score. + +# Warning: +# Make sure at most only one of :attr:`use_entropy`, :attr:`use_logits` +# , :attr:`use_mi`, and :attr:`use_variation_ratio` attributes is set to +# ``True``. Otherwise a :class:`ValueError()` will be raised. +# """ +# super().__init__( +# num_classes=num_classes, +# model=model, +# loss=loss, +# optimization_procedure=optimization_procedure, +# format_batch_fn=format_batch_fn, +# mixtype=mixtype, +# mixmode=mixmode, +# dist_sim=dist_sim, +# kernel_tau_max=kernel_tau_max, +# kernel_tau_std=kernel_tau_std, +# mixup_alpha=mixup_alpha, +# cutmix_alpha=cutmix_alpha, +# evaluate_ood=evaluate_ood, +# use_entropy=use_entropy, +# use_logits=use_logits, +# **kwargs, +# ) + +# self.num_estimators = num_estimators + +# self.use_mi = use_mi +# self.use_variation_ratio = use_variation_ratio +# self.log_plots = log_plots + +# if ( +# self.use_logits + self.use_entropy + self.use_mi + self.use_variation_ratio +# ) > 1: +# raise ValueError("You cannot choose more than one OOD criterion.") + +# # metrics for ensembles only +# ens_metrics = MetricCollection( +# { +# "disagreement": Disagreement(), +# "mi": MutualInformation(), +# "entropy": Entropy(), +# } +# ) +# self.test_id_ens_metrics = ens_metrics.clone(prefix="hp/test_id_ens_") + +# if self.evaluate_ood: +# self.test_ood_ens_metrics = ens_metrics.clone(prefix="hp/test_ood_ens_") + +# def on_train_start(self) -> None: +# param = {} +# param["storage"] = f"{get_model_size_mb(self)} MB" +# if self.logger is not None: # coverage: ignore +# self.logger.log_hyperparams( +# Namespace(**param), +# { +# "hp/val_nll": 0, +# "hp/val_acc": 0, +# "hp/test_acc": 0, +# "hp/test_nll": 0, +# "hp/test_ece": 0, +# "hp/test_brier": 0, +# "hp/test_entropy_id": 0, +# "hp/test_entropy_ood": 0, +# "hp/test_aupr": 0, +# "hp/test_auroc": 0, +# "hp/test_fpr95": 0, +# "hp/test_id_ens_disagreement": 0, +# "hp/test_id_ens_mi": 0, +# "hp/test_id_ens_entropy": 0, +# "hp/test_ood_ens_disagreement": 0, +# "hp/test_ood_ens_mi": 0, +# "hp/test_ood_ens_entropy": 0, +# }, +# ) + +# def training_step( +# self, batch: tuple[Tensor, Tensor], batch_idx: int +# ) -> STEP_OUTPUT: +# batch = self.mixup(*batch) +# # eventual input repeat is done in the model +# inputs, targets = self.format_batch_fn(batch) + +# if self.is_elbo: +# loss = self.criterion(inputs, targets) +# else: +# logits = self.forward(inputs) +# # BCEWithLogitsLoss expects float targets +# if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: +# logits = logits.squeeze(-1) +# targets = targets.float() + +# if not self.is_dec: +# loss = self.criterion(logits, targets) +# else: +# loss = self.criterion(logits, targets, self.current_epoch) + +# self.log("train_loss", loss) +# return loss + +# def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: +# inputs, targets = batch +# logits = self.forward(inputs) +# logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) +# if self.binary_cls: +# probs_per_est = torch.sigmoid(logits).squeeze(-1) +# else: +# probs_per_est = F.softmax(logits, dim=-1) + +# probs = probs_per_est.mean(dim=1) +# self.val_cls_metrics.update(probs, targets) + +# def test_step( +# self, +# batch: tuple[Tensor, Tensor], +# batch_idx: int, +# dataloader_idx: int | None = 0, +# ) -> Tensor: +# inputs, targets = batch +# logits = self.forward(inputs) +# logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators) + +# if self.binary_cls: +# probs_per_est = torch.sigmoid(logits) +# else: +# probs_per_est = F.softmax(logits, dim=-1) + +# probs = probs_per_est.mean(dim=1) +# # self.cal_plot.update(probs, targets) +# confs = probs.max(-1)[0] + +# if self.use_logits: +# ood_scores = -logits.mean(dim=1).max(dim=-1)[0] +# elif self.use_entropy: +# ood_scores = torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) +# elif self.use_mi: +# mi_metric = MutualInformation(reduction="none") +# ood_scores = mi_metric(probs_per_est) +# elif self.use_variation_ratio: +# vr_metric = VariationRatio(reduction="none", probabilistic=False) +# ood_scores = vr_metric(probs_per_est.transpose(0, 1)) +# else: +# ood_scores = -confs + +# if dataloader_idx == 0: +# # squeeze if binary classification only for binary metrics +# self.test_cls_metrics.update( +# probs.squeeze(-1) if self.binary_cls else probs, +# targets, +# ) +# self.test_entropy_id(probs) + +# self.test_id_ens_metrics.update(probs_per_est) +# self.log( +# "hp/test_entropy_id", +# self.test_entropy_id, +# on_epoch=True, +# add_dataloader_idx=False, +# ) + +# if self.evaluate_ood: +# self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) +# elif self.evaluate_ood and dataloader_idx == 1: +# self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) +# self.test_entropy_ood(probs) +# self.test_ood_ens_metrics.update(probs_per_est) +# self.log( +# "hp/test_entropy_ood", +# self.test_entropy_ood, +# on_epoch=True, +# add_dataloader_idx=False, +# ) +# return logits + +# def test_epoch_end(self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT]) -> None: +# self.log_dict( +# self.test_cls_metrics.compute(), +# ) + +# self.log_dict( +# self.test_id_ens_metrics.compute(), +# ) + +# if self.evaluate_ood: +# self.log_dict( +# self.test_ood_metrics.compute(), +# ) +# self.log_dict( +# self.test_ood_ens_metrics.compute(), +# ) + +# self.test_ood_metrics.reset() +# self.test_ood_ens_metrics.reset() + +# if isinstance(self.logger, TensorBoardLogger) and self.log_plots: +# self.logger.experiment.add_figure( +# "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] +# ) + +# if self.evaluate_ood: +# id_logits = torch.cat(outputs[0], 0).float().cpu() +# ood_logits = torch.cat(outputs[1], 0).float().cpu() + +# id_probs = F.softmax(id_logits, dim=-1) +# ood_probs = F.softmax(ood_logits, dim=-1) + +# logits_fig = plot_hist( +# [ +# id_logits.mean(1).max(-1).values, +# ood_logits.mean(1).max(-1).values, +# ], +# 20, +# "Histogram of the logits", +# )[0] +# probs_fig = plot_hist( +# [ +# id_probs.mean(1).max(-1).values, +# ood_probs.mean(1).max(-1).values, +# ], +# 20, +# "Histogram of the likelihoods", +# )[0] +# self.logger.experiment.add_figure("Logit Histogram", logits_fig) +# self.logger.experiment.add_figure("Likelihood Histogram", probs_fig) + +# self.test_cls_metrics.reset() +# self.test_id_ens_metrics.reset() + +# @staticmethod +# def add_model_specific_args( +# parent_parser: ArgumentParser, +# ) -> ArgumentParser: +# """Defines the routine's attributes via command-line options. + +# Adds: +# - ``--entropy``: sets :attr:`use_entropy` to ``True``. +# - ``--logits``: sets :attr:`use_logits` to ``True``. +# - ``--mutual_information``: sets :attr:`use_mi` to ``True``. +# - ``--variation_ratio``: sets :attr:`use_variation_ratio` to ``True``. +# - ``--num_estimators``: sets :attr:`num_estimators`. +# """ +# parent_parser = ClassificationSingle.add_model_specific_args(parent_parser) +# # FIXME: should be a str to choose among the available OOD criteria +# # rather than a boolean, but it is not possible since +# # ClassificationSingle and ClassificationEnsemble have different OOD +# # criteria. +# parent_parser.add_argument( +# "--mutual_information", +# dest="use_mi", +# action="store_true", +# default=False, +# ) +# parent_parser.add_argument( +# "--variation_ratio", +# dest="use_variation_ratio", +# action="store_true", +# default=False, +# ) +# parent_parser.add_argument( +# "--num_estimators", +# type=int, +# default=None, +# help="Number of estimators for ensemble", +# ) +# return parent_parser + + +class ClassificationRoutine(LightningModule): def __init__( self, num_classes: int, model: nn.Module, loss: type[nn.Module], - optimization_procedure: Any, + num_estimators: int, format_batch_fn: nn.Module | None = None, mixtype: str = "erm", mixmode: str = "elem", @@ -52,80 +828,73 @@ def __init__( evaluate_ood: bool = False, use_entropy: bool = False, use_logits: bool = False, + use_mi: bool = False, + use_variation_ratio: bool = False, log_plots: bool = False, calibration_set: Callable | None = None, - **kwargs, ) -> None: - """Classification routine for single models. + """Classification routine. Args: - num_classes (int): Number of classes. - model (nn.Module): Model to train. - loss (type[nn.Module]): Loss function. - optimization_procedure (Any): Optimization procedure. - format_batch_fn (nn.Module, optional): Function to format the batch. - Defaults to :class:`torch.nn.Identity()`. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to 1.0. - kernel_tau_std (float, optional): Standard deviation for the kernel tau. - Defaults to 0.5. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. - cutmix_alpha (float, optional): Alpha parameter for Cutmix. - Defaults to 0. - evaluate_ood (bool, optional): Indicates whether to evaluate the OOD - detection performance or not. Defaults to ``False``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - log_plots (bool, optional): Indicates whether to log plots from - metrics. Defaults to ``False``. - calibration_set (Callable, optional): Function to get the calibration - set. Defaults to ``None``. - kwargs (Any): Additional arguments. - - Note: - The default OOD criterion is the softmax confidence score. - - Warning: - Make sure at most only one of :attr:`use_entropy` and :attr:`use_logits` - attributes is set to ``True``. Otherwise a :class:`ValueError()` will - be raised. + num_classes (int): _description_ + model (nn.Module): _description_ + loss (type[nn.Module]): _description_ + num_estimators (int): _description_ + format_batch_fn (nn.Module | None, optional): _description_. Defaults to None. + mixtype (str, optional): _description_. Defaults to "erm". + mixmode (str, optional): _description_. Defaults to "elem". + dist_sim (str, optional): _description_. Defaults to "emb". + kernel_tau_max (float, optional): _description_. Defaults to 1.0. + kernel_tau_std (float, optional): _description_. Defaults to 0.5. + mixup_alpha (float, optional): _description_. Defaults to 0. + cutmix_alpha (float, optional): _description_. Defaults to 0. + evaluate_ood (bool, optional): _description_. Defaults to False. + use_entropy (bool, optional): _description_. Defaults to False. + use_logits (bool, optional): _description_. Defaults to False. + use_mi (bool, optional): _description_. Defaults to False. + use_variation_ratio (bool, optional): _description_. Defaults to False. + log_plots (bool, optional): _description_. Defaults to False. + calibration_set (Callable | None, optional): _description_. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ """ super().__init__() if format_batch_fn is None: format_batch_fn = nn.Identity() - self.save_hyperparameters( - ignore=[ - "model", - "loss", - "optimization_procedure", - "format_batch_fn", - "calibration_set", - ] - ) - - if (use_logits + use_entropy) > 1: + if (use_logits + use_entropy + use_mi + use_variation_ratio) > 1: raise ValueError("You cannot choose more than one OOD criterion.") + if not isinstance(num_estimators, int) and num_estimators < 1: + raise ValueError( + "The number of estimators must be a positive integer >= 1." + f"Got {num_estimators}." + ) + + if num_estimators == 1 and (use_mi or use_variation_ratio): + raise ValueError( + "You cannot use mutual information or variation ratio with a single" + " model." + ) + self.num_classes = num_classes + self.num_estimators = num_estimators self.evaluate_ood = evaluate_ood self.use_logits = use_logits self.use_entropy = use_entropy + self.use_mi = use_mi + self.use_variation_ratio = use_variation_ratio self.log_plots = log_plots - self.calibration_set = calibration_set - self.binary_cls = num_classes == 1 self.model = model self.loss = loss - self.optimization_procedure = optimization_procedure # batch format self.format_batch_fn = format_batch_fn @@ -172,19 +941,19 @@ def __init__( self.test_ood_metrics = ood_metrics.clone(prefix="hp/test_") self.test_entropy_ood = Entropy() - if mixup_alpha < 0 or cutmix_alpha < 0: - raise ValueError( - "Cutmix alpha and Mixup alpha must be positive." - f"Got {mixup_alpha} and {cutmix_alpha}." - ) - self.mixtype = mixtype self.mixmode = mixmode self.dist_sim = dist_sim + if num_estimators == 1: + if mixup_alpha < 0 or cutmix_alpha < 0: + raise ValueError( + "Cutmix alpha and Mixup alpha must be positive." + f"Got {mixup_alpha} and {cutmix_alpha}." + ) - self.mixup = self.init_mixup( - mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std - ) + self.mixup = self.init_mixup( + mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std + ) # Handle ELBO special cases self.is_elbo = ( @@ -196,210 +965,23 @@ def __init__( isinstance(self.loss, partial) and self.loss.func == DECLoss ) - def configure_optimizers(self) -> Any: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - if self.is_elbo: - self.loss = partial(self.loss, model=self.model) - return self.loss() - - def forward(self, inputs: Tensor) -> Tensor: - return self.model.forward(inputs) - - def on_train_start(self) -> None: - # hyperparameters for performances - param = {} - param["storage"] = f"{get_model_size_mb(self)} MB" - if self.logger is not None: # coverage: ignore - self.logger.log_hyperparams( - Namespace(**param), + # metrics for ensembles only + if self.num_estimators > 1: + ens_metrics = MetricCollection( { - "hp/val_nll": 0, - "hp/val_acc": 0, - "hp/test_acc": 0, - "hp/test_nll": 0, - "hp/test_ece": 0, - "hp/test_brier": 0, - "hp/test_entropy_id": 0, - "hp/test_entropy_ood": 0, - "hp/test_aupr": 0, - "hp/test_auroc": 0, - "hp/test_fpr95": 0, - "hp/ts_test_nll": 0, - "hp/ts_test_ece": 0, - "hp/ts_test_brier": 0, - }, + "disagreement": Disagreement(), + "mi": MutualInformation(), + "entropy": Entropy(), + } ) - - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> STEP_OUTPUT: - if self.mixtype == "kernel_warping": - if self.dist_sim == "emb": - with torch.no_grad(): - feats = self.model.feats_forward(batch[0]).detach() - - batch = self.mixup(*batch, feats) - elif self.dist_sim == "inp": - batch = self.mixup(*batch, batch[0]) - else: - batch = self.mixup(*batch) - - inputs, targets = self.format_batch_fn(batch) - - if self.is_elbo: - loss = self.criterion(inputs, targets) - else: - logits = self.forward(inputs) - # BCEWithLogitsLoss expects float targets - if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: - logits = logits.squeeze(-1) - targets = targets.float() - - if not self.is_dec: - loss = self.criterion(logits, targets) - else: - loss = self.criterion(logits, targets, self.current_epoch) - self.log("train_loss", loss) - return loss - - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> None: - inputs, targets = batch - logits = self.forward(inputs) - - if self.binary_cls: - probs = torch.sigmoid(logits).squeeze(-1) - else: - probs = F.softmax(logits, dim=-1) - - self.val_cls_metrics.update(probs, targets) - - def validation_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: - self.log_dict(self.val_cls_metrics.compute()) - self.val_cls_metrics.reset() - - def on_test_start(self) -> None: - if self.calibration_set is not None: - self.scaler = TemperatureScaler(device=self.device).fit( - model=self.model, calibration_set=self.calibration_set() + self.test_id_ens_metrics = ens_metrics.clone( + prefix="hp/test_id_ens_" ) - self.cal_model = torch.nn.Sequential(self.model, self.scaler) - else: - self.scaler = None - self.cal_model = None - - def test_step( - self, - batch: tuple[Tensor, Tensor], - batch_idx: int, - dataloader_idx: int | None = 0, - ) -> Tensor: - inputs, targets = batch - logits = self.forward(inputs) - - if self.binary_cls: - probs = torch.sigmoid(logits).squeeze(-1) - else: - probs = F.softmax(logits, dim=-1) - - # self.cal_plot.update(probs, targets) - confs = probs.max(dim=-1)[0] - - if self.use_logits: - ood_scores = -logits.max(dim=-1)[0] - elif self.use_entropy: - ood_scores = torch.special.entr(probs).sum(dim=-1) - else: - ood_scores = -confs - if ( - self.calibration_set is not None - and self.scaler is not None - and self.cal_model is not None - ): - cal_logits = self.cal_model(inputs) - cal_probs = F.softmax(cal_logits, dim=-1) - self.ts_cls_metrics.update(cal_probs, targets) - - if dataloader_idx == 0: - self.test_cls_metrics.update(probs, targets) - self.test_entropy_id(probs) - self.log( - "hp/test_entropy_id", - self.test_entropy_id, - on_epoch=True, - add_dataloader_idx=False, - ) if self.evaluate_ood: - self.test_ood_metrics.update( - ood_scores, torch.zeros_like(targets) + self.test_ood_ens_metrics = ens_metrics.clone( + prefix="hp/test_ood_ens_" ) - elif self.evaluate_ood and dataloader_idx == 1: - self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) - self.test_entropy_ood(probs) - self.log( - "hp/test_entropy_ood", - self.test_entropy_ood, - on_epoch=True, - add_dataloader_idx=False, - ) - return logits - - def test_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: - self.log_dict( - self.test_cls_metrics.compute(), - ) - - if ( - self.calibration_set is not None - and self.scaler is not None - and self.cal_model is not None - ): - self.log_dict(self.ts_cls_metrics.compute()) - self.ts_cls_metrics.reset() - - if self.evaluate_ood: - self.log_dict( - self.test_ood_metrics.compute(), - ) - self.test_ood_metrics.reset() - - if isinstance(self.logger, TensorBoardLogger) and self.log_plots: - self.logger.experiment.add_figure( - "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] - ) - - if self.evaluate_ood: - id_logits = torch.cat(outputs[0], 0).float().cpu() - ood_logits = torch.cat(outputs[1], 0).float().cpu() - - id_probs = F.softmax(id_logits, dim=-1) - ood_probs = F.softmax(ood_logits, dim=-1) - - logits_fig = plot_hist( - [id_logits.max(-1).values, ood_logits.max(-1).values], - 20, - "Histogram of the logits", - )[0] - probs_fig = plot_hist( - [id_probs.max(-1).values, ood_probs.max(-1).values], - 20, - "Histogram of the likelihoods", - )[0] - self.logger.experiment.add_figure("Logit Histogram", logits_fig) - self.logger.experiment.add_figure( - "Likelihood Histogram", probs_fig - ) - - self.test_cls_metrics.reset() def init_mixup( self, @@ -444,206 +1026,53 @@ def init_mixup( ) return lambda x, y: (x, y) - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the routine's attributes via command-line options. - - Args: - parent_parser (ArgumentParser): Parent parser to be completed. - - Adds: - - ``--entropy``: sets :attr:`use_entropy` to ``True``. - - ``--logits``: sets :attr:`use_logits` to ``True``. - - ``--mixup_alpha``: sets :attr:`mixup_alpha` for Mixup - - ``--cutmix_alpha``: sets :attr:`cutmix_alpha` for Cutmix - - ``--mixtype``: sets :attr:`mixtype` for Mixup - - ``--mixmode``: sets :attr:`mixmode` for Mixup - - ``--dist_sim``: sets :attr:`dist_sim` for Mixup - - ``--kernel_tau_max``: sets :attr:`kernel_tau_max` for Mixup - - ``--kernel_tau_std``: sets :attr:`kernel_tau_std` for Mixup - """ - parent_parser.add_argument( - "--entropy", dest="use_entropy", action="store_true" - ) - parent_parser.add_argument( - "--logits", dest="use_logits", action="store_true" - ) - - # Mixup args - parent_parser.add_argument( - "--mixup_alpha", dest="mixup_alpha", type=float, default=0 - ) - parent_parser.add_argument( - "--cutmix_alpha", dest="cutmix_alpha", type=float, default=0 - ) - parent_parser.add_argument( - "--mixtype", dest="mixtype", type=str, default="erm" - ) - parent_parser.add_argument( - "--mixmode", dest="mixmode", type=str, default="elem" - ) - parent_parser.add_argument( - "--dist_sim", dest="dist_sim", type=str, default="emb" - ) - parent_parser.add_argument( - "--kernel_tau_max", dest="kernel_tau_max", type=float, default=1.0 - ) - parent_parser.add_argument( - "--kernel_tau_std", dest="kernel_tau_std", type=float, default=0.5 - ) - return parent_parser - - -class ClassificationEnsemble(ClassificationSingle): - def __init__( - self, - num_classes: int, - model: nn.Module, - loss: type[nn.Module], - optimization_procedure: Any, - num_estimators: int, - format_batch_fn: nn.Module | None = None, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, - evaluate_ood: bool = False, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - log_plots: bool = False, - **kwargs, - ) -> None: - """Classification routine for ensemble models. - - Args: - num_classes (int): Number of classes. - model (nn.Module): Model to train. - loss (type[nn.Module]): Loss function. - optimization_procedure (Any): Optimization procedure. - num_estimators (int): Number of estimators in the ensemble. - format_batch_fn (nn.Module, optional): Function to format the batch. - Defaults to :class:`torch.nn.Identity()`. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to 1.0. - kernel_tau_std (float, optional): Standard deviation for the kernel tau. - Defaults to 0.5. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. - cutmix_alpha (float, optional): Alpha parameter for Cutmix. - Defaults to 0. - evaluate_ood (bool, optional): Indicates whether to evaluate the OOD - detection performance or not. Defaults to ``False``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. - log_plots (bool, optional): Indicates whether to log plots from - metrics. Defaults to ``False``. - calibration_set (Callable, optional): Function to get the calibration - set. Defaults to ``None``. - kwargs (Any): Additional arguments. - - Note: - The default OOD criterion is the averaged softmax confidence score. - - Warning: - Make sure at most only one of :attr:`use_entropy`, :attr:`use_logits` - , :attr:`use_mi`, and :attr:`use_variation_ratio` attributes is set to - ``True``. Otherwise a :class:`ValueError()` will be raised. - """ - super().__init__( - num_classes=num_classes, - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, - evaluate_ood=evaluate_ood, - use_entropy=use_entropy, - use_logits=use_logits, - **kwargs, - ) - - self.num_estimators = num_estimators - - self.use_mi = use_mi - self.use_variation_ratio = use_variation_ratio - self.log_plots = log_plots - - if ( - self.use_logits - + self.use_entropy - + self.use_mi - + self.use_variation_ratio - ) > 1: - raise ValueError("You cannot choose more than one OOD criterion.") + def on_train_start(self) -> None: + init_metrics = {k: 0 for k in self.val_cls_metrics} + init_metrics.update({k: 0 for k in self.test_cls_metrics}) - # metrics for ensembles only - ens_metrics = MetricCollection( - { - "disagreement": Disagreement(), - "mi": MutualInformation(), - "entropy": Entropy(), - } - ) - self.test_id_ens_metrics = ens_metrics.clone(prefix="hp/test_id_ens_") + # self.hparams.storage = f"{get_model_size_mb(self)} MB" - if self.evaluate_ood: - self.test_ood_ens_metrics = ens_metrics.clone( - prefix="hp/test_ood_ens_" - ) - - def on_train_start(self) -> None: - param = {} - param["storage"] = f"{get_model_size_mb(self)} MB" if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( - Namespace(**param), - { - "hp/val_nll": 0, - "hp/val_acc": 0, - "hp/test_acc": 0, - "hp/test_nll": 0, - "hp/test_ece": 0, - "hp/test_brier": 0, - "hp/test_entropy_id": 0, - "hp/test_entropy_ood": 0, - "hp/test_aupr": 0, - "hp/test_auroc": 0, - "hp/test_fpr95": 0, - "hp/test_id_ens_disagreement": 0, - "hp/test_id_ens_mi": 0, - "hp/test_id_ens_entropy": 0, - "hp/test_ood_ens_disagreement": 0, - "hp/test_ood_ens_mi": 0, - "hp/test_ood_ens_entropy": 0, - }, + self.hparams, + init_metrics, + ) + + def on_test_start(self) -> None: + if self.calibration_set is not None: + self.scaler = TemperatureScaler(device=self.device).fit( + model=self.model, calibration_set=self.calibration_set() ) + self.cal_model = torch.nn.Sequential(self.model, self.scaler) + else: + self.scaler = None + self.cal_model = None + + @property + def criterion(self) -> nn.Module: + if self.is_elbo: + self.loss = partial(self.loss, model=self.model) + return self.loss() + + def forward(self, inputs: Tensor) -> Tensor: + return self.model.forward(inputs) def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: - batch = self.mixup(*batch) - # eventual input repeat is done in the model + # Mixup only for single models + if self.num_estimators == 1: + if self.mixtype == "kernel_warping": + if self.dist_sim == "emb": + with torch.no_grad(): + feats = self.model.feats_forward(batch[0]).detach() + + batch = self.mixup(*batch, feats) + elif self.dist_sim == "inp": + batch = self.mixup(*batch, batch[0]) + else: + batch = self.mixup(*batch) + inputs, targets = self.format_batch_fn(batch) if self.is_elbo: @@ -667,33 +1096,36 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch - logits = self.forward(inputs) - logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) + logits = self.forward(inputs) # (b, c) or (m, b, c) + if logits.ndim == 2: + logits = logits.unsqueeze(0) + if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) else: probs_per_est = F.softmax(logits, dim=-1) - probs = probs_per_est.mean(dim=1) + probs = probs_per_est.mean(dim=0) self.val_cls_metrics.update(probs, targets) def test_step( self, batch: tuple[Tensor, Tensor], batch_idx: int, - dataloader_idx: int | None = 0, - ) -> Tensor: + dataloader_idx: int = 0, + ) -> None: inputs, targets = batch logits = self.forward(inputs) - logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators) + if logits.ndim == 2: + logits = logits.unsqueeze(0) if self.binary_cls: - probs_per_est = torch.sigmoid(logits) + probs_per_est = torch.sigmoid(logits).squeeze(-1) else: probs_per_est = F.softmax(logits, dim=-1) - probs = probs_per_est.mean(dim=1) - # self.cal_plot.update(probs, targets) + probs = probs_per_est.mean(dim=0) + confs = probs.max(-1)[0] if self.use_logits: @@ -711,6 +1143,17 @@ def test_step( else: ood_scores = -confs + # Scaling for single models + if ( + self.num_estimators == 1 + and self.calibration_set is not None + and self.scaler is not None + and self.cal_model is not None + ): + cal_logits = self.cal_model(inputs) + cal_probs = F.softmax(cal_logits, dim=-1) + self.ts_cls_metrics.update(cal_probs, targets) + if dataloader_idx == 0: # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( @@ -718,8 +1161,6 @@ def test_step( targets, ) self.test_entropy_id(probs) - - self.test_id_ens_metrics.update(probs_per_est) self.log( "hp/test_entropy_id", self.test_entropy_id, @@ -727,116 +1168,66 @@ def test_step( add_dataloader_idx=False, ) - if self.evaluate_ood: - self.test_ood_metrics.update( - ood_scores, torch.zeros_like(targets) - ) + if self.num_estimators > 1: + self.test_id_ens_metrics.update(probs_per_est) + + if self.evaluate_ood: + self.test_ood_metrics.update( + ood_scores, torch.zeros_like(targets) + ) + elif self.evaluate_ood and dataloader_idx == 1: self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_entropy_ood(probs) - self.test_ood_ens_metrics.update(probs_per_est) self.log( "hp/test_entropy_ood", self.test_entropy_ood, on_epoch=True, add_dataloader_idx=False, ) - return logits + if self.num_estimators > 1: + self.test_ood_ens_metrics.update(probs_per_est) - def test_epoch_end( - self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT] - ) -> None: + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_cls_metrics.compute()) + self.val_cls_metrics.reset() + + def on_test_epoch_end(self) -> None: self.log_dict( self.test_cls_metrics.compute(), ) - self.log_dict( - self.test_id_ens_metrics.compute(), - ) + if ( + self.num_estimators == 1 + and self.calibration_set is not None + and self.scaler is not None + and self.cal_model is not None + ): + self.log_dict(self.ts_cls_metrics.compute()) + self.ts_cls_metrics.reset() - if self.evaluate_ood: + if self.num_estimators > 1: self.log_dict( - self.test_ood_metrics.compute(), + self.test_id_ens_metrics.compute(), ) + self.test_id_ens_metrics.reset() + + if self.evaluate_ood: self.log_dict( - self.test_ood_ens_metrics.compute(), + self.test_ood_metrics.compute(), ) - self.test_ood_metrics.reset() - self.test_ood_ens_metrics.reset() + if self.num_estimators > 1: + self.log_dict( + self.test_ood_ens_metrics.compute(), + ) + self.test_ood_ens_metrics.reset() if isinstance(self.logger, TensorBoardLogger) and self.log_plots: self.logger.experiment.add_figure( "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] ) - if self.evaluate_ood: - id_logits = torch.cat(outputs[0], 0).float().cpu() - ood_logits = torch.cat(outputs[1], 0).float().cpu() - - id_probs = F.softmax(id_logits, dim=-1) - ood_probs = F.softmax(ood_logits, dim=-1) - - logits_fig = plot_hist( - [ - id_logits.mean(1).max(-1).values, - ood_logits.mean(1).max(-1).values, - ], - 20, - "Histogram of the logits", - )[0] - probs_fig = plot_hist( - [ - id_probs.mean(1).max(-1).values, - ood_probs.mean(1).max(-1).values, - ], - 20, - "Histogram of the likelihoods", - )[0] - self.logger.experiment.add_figure("Logit Histogram", logits_fig) - self.logger.experiment.add_figure( - "Likelihood Histogram", probs_fig - ) + # TODO: plot histograms of logits and likelihoods self.test_cls_metrics.reset() - self.test_id_ens_metrics.reset() - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the routine's attributes via command-line options. - - Adds: - - ``--entropy``: sets :attr:`use_entropy` to ``True``. - - ``--logits``: sets :attr:`use_logits` to ``True``. - - ``--mutual_information``: sets :attr:`use_mi` to ``True``. - - ``--variation_ratio``: sets :attr:`use_variation_ratio` to ``True``. - - ``--num_estimators``: sets :attr:`num_estimators`. - """ - parent_parser = ClassificationSingle.add_model_specific_args( - parent_parser - ) - # FIXME: should be a str to choose among the available OOD criteria - # rather than a boolean, but it is not possible since - # ClassificationSingle and ClassificationEnsemble have different OOD - # criteria. - parent_parser.add_argument( - "--mutual_information", - dest="use_mi", - action="store_true", - default=False, - ) - parent_parser.add_argument( - "--variation_ratio", - dest="use_variation_ratio", - action="store_true", - default=False, - ) - parent_parser.add_argument( - "--num_estimators", - type=int, - default=None, - help="Number of estimators for ensemble", - ) - return parent_parser diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py new file mode 100644 index 00000000..30cc9e80 --- /dev/null +++ b/torch_uncertainty/routines/segmentation.py @@ -0,0 +1,70 @@ +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import Tensor, nn +from torchmetrics import MetricCollection + +from torch_uncertainty.metrics import IntersectionOverUnion + + +class SegmentationRoutine(LightningModule): + def __init__( + self, + num_classes: int, + model: nn.Module, + loss: nn.Module | None, + ) -> None: + super().__init__() + + self.num_classes = num_classes + self.model = model + self.loss = loss + + self.metric_to_monitor = "hp/val_iou" + + # metrics + seg_metrics = MetricCollection( + { + "iou": IntersectionOverUnion(num_classes=num_classes), + } + ) + + self.val_seg_metrics = seg_metrics.clone(prefix="hp/val_") + self.test_seg_metrics = seg_metrics.clone(prefix="hp/test_") + + def forward(self, img: Tensor) -> Tensor: + return self.model(img) + + def on_train_start(self) -> None: + init_metrics = {k: 0 for k in self.val_seg_metrics} + init_metrics.update({k: 0 for k in self.test_seg_metrics}) + + self.logger.log_hyperparams(self.hparams, init_metrics) + + def training_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> STEP_OUTPUT: + img, target = batch + pred = self.forward(img) + loss = self.loss(pred, target) + self.log("train_loss", loss) + return loss + + def validation_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> None: + img, target = batch + pred = self.forward(img) + self.val_seg_metrics.update(pred, target) + + def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + img, target = batch + pred = self.forward(img) + self.test_seg_metrics.update(pred, target) + + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_seg_metrics.compute()) + self.val_seg_metrics.reset() + + def on_test_epoch_end(self) -> None: + self.log_dict(self.test_seg_metrics.compute()) + self.test_seg_metrics.reset() From 508c12deecc9807802d2ae47fcb14483f37001a2 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 27 Jan 2024 16:34:01 +0100 Subject: [PATCH 008/148] :construction: Massive changes to upgrade Lightning 2.0 - :hammer: Rework classification experiment folder - :hammer: Only one Classification routine for single and ensemble models --- .../cifar10/configs/resnet.yaml | 34 + .../cifar10/configs/resnet18.yaml | 36 - .../cifar10/configs/resnet18/batched.yaml | 49 + .../cifar10/configs/resnet18/masked.yaml | 50 + .../cifar10/configs/resnet18/mimo.yaml | 50 + .../cifar10/configs/resnet18/packed.yaml | 51 + .../cifar10/configs/resnet18/standard.yaml | 48 + .../cifar10/configs/resnet50/batched.yaml | 50 + .../cifar10/configs/resnet50/masked.yaml | 51 + .../cifar10/configs/resnet50/mimo.yaml | 51 + .../cifar10/configs/resnet50/packed.yaml | 52 + .../cifar10/configs/resnet50/standard.yaml | 49 + .../cifar10/configs/wideresnet28x10.yaml | 47 + .../configs/wideresnet28x10/batched.yaml | 49 + .../configs/wideresnet28x10/masked.yaml | 50 + .../cifar10/configs/wideresnet28x10/mimo.yaml | 50 + .../configs/wideresnet28x10/packed.yaml | 51 + .../configs/wideresnet28x10/standard.yaml | 48 + experiments/classification/cifar10/readme.md | 55 +- experiments/classification/cifar10/resnet.py | 19 +- experiments/classification/cifar10/vgg.py | 45 +- .../classification/cifar10/wideresnet.py | 41 +- experiments/classification/cifar100/readme.md | 3 + pyproject.toml | 1 + tests/_dummies/baseline.py | 32 +- tests/baselines/test_batched.py | 19 +- tests/baselines/test_deep_ensembles.py | 14 +- tests/baselines/test_masked.py | 21 +- tests/baselines/test_mc_dropout.py | 18 +- tests/baselines/test_mimo.py | 19 +- tests/baselines/test_others.py | 6 +- tests/baselines/test_packed.py | 14 +- tests/baselines/test_standard.py | 9 +- tests/datamodules/test_cifar10_datamodule.py | 120 ++- tests/routines/test_classification.py | 699 ++++++------- tests/routines/test_regression.py | 320 +++--- tests/test_cli.py | 751 +++++++------- torch_uncertainty/__init__.py | 502 +++++----- torch_uncertainty/baselines/__init__.py | 6 +- .../baselines/classification/__init__.py | 6 +- .../classification/deep_ensembles.py | 64 ++ .../baselines/classification/resnet.py | 118 +-- .../baselines/classification/vgg.py | 420 ++++---- .../baselines/classification/wideresnet.py | 494 +++++---- torch_uncertainty/baselines/deep_ensembles.py | 110 -- torch_uncertainty/datamodules/cifar10.py | 23 +- torch_uncertainty/datamodules/cifar100.py | 24 +- torch_uncertainty/lightning_cli.py | 47 - torch_uncertainty/metrics/calibration.py | 8 +- torch_uncertainty/models/wideresnet/mimo.py | 1 - torch_uncertainty/routines/classification.py | 946 +++--------------- torch_uncertainty/routines/segmentation.py | 4 +- torch_uncertainty/utils/__init__.py | 1 + torch_uncertainty/utils/cli.py | 128 +++ 54 files changed, 2930 insertions(+), 3044 deletions(-) create mode 100644 experiments/classification/cifar10/configs/resnet.yaml delete mode 100644 experiments/classification/cifar10/configs/resnet18.yaml create mode 100644 experiments/classification/cifar10/configs/resnet18/batched.yaml create mode 100644 experiments/classification/cifar10/configs/resnet18/masked.yaml create mode 100644 experiments/classification/cifar10/configs/resnet18/mimo.yaml create mode 100644 experiments/classification/cifar10/configs/resnet18/packed.yaml create mode 100644 experiments/classification/cifar10/configs/resnet18/standard.yaml create mode 100644 experiments/classification/cifar10/configs/resnet50/batched.yaml create mode 100644 experiments/classification/cifar10/configs/resnet50/masked.yaml create mode 100644 experiments/classification/cifar10/configs/resnet50/mimo.yaml create mode 100644 experiments/classification/cifar10/configs/resnet50/packed.yaml create mode 100644 experiments/classification/cifar10/configs/resnet50/standard.yaml create mode 100644 experiments/classification/cifar10/configs/wideresnet28x10.yaml create mode 100644 experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml create mode 100644 experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml create mode 100644 experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml create mode 100644 experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml create mode 100644 experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml create mode 100644 experiments/classification/cifar100/readme.md create mode 100644 torch_uncertainty/baselines/classification/deep_ensembles.py delete mode 100644 torch_uncertainty/baselines/deep_ensembles.py delete mode 100644 torch_uncertainty/lightning_cli.py create mode 100644 torch_uncertainty/utils/cli.py diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml new file mode 100644 index 00000000..21352497 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -0,0 +1,34 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/ + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + style: cifar +data: + root: ./data + batch_size: 128 diff --git a/experiments/classification/cifar10/configs/resnet18.yaml b/experiments/classification/cifar10/configs/resnet18.yaml deleted file mode 100644 index 7caa52e6..00000000 --- a/experiments/classification/cifar10/configs/resnet18.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# lightning.pytorch==2.1.3 -seed_everything: true -trainer: - precision: 16 - max_epochs: 75 -model: - num_classes: 10 - in_channels: 3 - loss: - - class_path: torch.nn.CrossEntropyLoss - version: "vanilla" - arch: 18 - style: cifar -data: - root: null - evaluate_ood: null - batch_size: null - val_split: 0.0 - num_workers: 1 - cutout: null - auto_augment: null - test_alt: null - corruption_severity: 1 - num_dataloaders: 1 - pin_memory: true - persistent_workers: true -optimizer: - lr: 0.05 - momentum: 0.9 - weight_decay: 5e-4 - nesterov: true -lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml new file mode 100644 index 00000000..aa2c92e1 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: batched + arch: 18 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml new file mode 100644 index 00000000..9d92ef08 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: masked + arch: 18 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml new file mode 100644 index 00000000..92414eb4 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: mimo + arch: 18 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml new file mode 100644 index 00000000..a9e9479e --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: packed + arch: 18 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml new file mode 100644 index 00000000..e6cea671 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: vanilla + arch: 18 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml new file mode 100644 index 00000000..ec519698 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: batched + arch: 50 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.08 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml new file mode 100644 index 00000000..42efac31 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: masked + arch: 50 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml new file mode 100644 index 00000000..906207ba --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: mimo + arch: 50 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml new file mode 100644 index 00000000..4a8b057e --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -0,0 +1,52 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: packed + arch: 50 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml new file mode 100644 index 00000000..f57e9d07 --- /dev/null +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: vanilla + arch: 50 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml new file mode 100644 index 00000000..a8c05f36 --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -0,0 +1,47 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml new file mode 100644 index 00000000..48a9d817 --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: batched + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml new file mode 100644 index 00000000..3fb765dd --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: masked + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml new file mode 100644 index 00000000..e9149c0e --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: mimo + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.1 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml new file mode 100644 index 00000000..10c757b7 --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: packed + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml new file mode 100644 index 00000000..78b8a74b --- /dev/null +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: hp/val_acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: hp/val_acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: vanilla + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar10/readme.md b/experiments/classification/cifar10/readme.md index bb5b06b4..286b73a1 100644 --- a/experiments/classification/cifar10/readme.md +++ b/experiments/classification/cifar10/readme.md @@ -4,14 +4,63 @@ This folder contains the code to train models on the CIFAR10 dataset. The task i ## ResNet-backbone models -`torch-uncertainty` leverages [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html#lightning.pytorch.cli.LightningCLI) the configurable command line tool for pytorch-lightning. To ease the train of models, we provide a set of predefined configurations for the CIFAR10 dataset. The configurations are located in the `configs` folder. +`torch-uncertainty` leverages [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html#lightning.pytorch.cli.LightningCLI) the configurable command line tool for pytorch-lightning. To ease the train of models, we provide a set of predefined configurations for the CIFAR10 dataset (corresponding to the experiments reported in [Packed-Ensembles for Efficient Uncertainty Estimation](https://arxiv.org/abs/2210.09184)). The configurations are located in the `configs` folder. -**Train** +*Examples:* + +* Training a standard ResNet18 model as in [Packed-Ensembles for Efficient Uncertainty Estimation](https://arxiv.org/abs/2210.09184): + +```bash +python resnet.py fit --config configs/resnet18/standard.yaml +``` + +* Training Packed-Ensembles ResNet50 model as in [Packed-Ensembles for Efficient Uncertainty Estimation](https://arxiv.org/abs/2210.09184): ```bash +python resnet.py fit --config configs/resnet50/packed.yaml ``` -**Evaluate** + +**Note:** In addition we provide a default resnet config file (`configs/resnet.yaml`) to enable the training of any ResNet model. Here a basic example to train a MIMO ResNet101 model with $4$ estimators and $\rho=1.0$: ```bash +python resnet.py fit --config configs/resnet.yaml --model.arch 101 --model.version mimo --model.num_estimators 4 --model.rho 1.0 ``` + + + +## Available configurations: + +### ResNet + +||ResNet18|ResNet34|ResNet50|ResNet101|ResNet152| +|---|---|---|---|---|---| +|Standard|✅|✅|✅|✅|✅| +|Packed-Ensembles|✅|✅|✅|✅|✅| +|BatchEnsemble|✅|✅|✅|✅|✅| +|Masked-Ensembles|✅|✅|✅|✅|✅| +|MIMO|✅|✅|✅|✅|✅| +|MC Dropout|✅|✅|✅|✅|✅| + + +### WideResNet + +||WideResNet28-10| +|---|---| +|Standard|✅| +|Packed-Ensembles|✅| +|BatchEnsemble|✅| +|Masked-Ensembles|✅| +|MIMO|✅| +|MC Dropout|✅| + +### VGG + +||VGG11|VGG13|VGG16|VGG19| +|---|---|---|---|---| +|Standard|✅|✅|✅|✅| +|Packed-Ensembles|✅|✅|✅|✅| +|BatchEnsemble||||| +|Masked-Ensembles||||| +|MIMO||||| +|MC Dropout|✅|✅|✅|✅| diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 7fcf7e41..d8cdb179 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -1,24 +1,25 @@ import torch -from lightning.pytorch.cli import LightningArgumentParser, LightningCLI +from lightning.pytorch.cli import LightningArgumentParser from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 -from torch_uncertainty.baselines import ResNet +from torch_uncertainty.baselines.classification import ResNet from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.lightning_cli import MySaveConfigCallback +from torch_uncertainty.utils import TULightningCLI -class MyLightningCLI(LightningCLI): +class ResNetCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_optimizer_args(torch.optim.SGD) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) -def cli_main(): - _ = MyLightningCLI( - ResNet, CIFAR10DataModule, save_config_callback=MySaveConfigCallback - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(ResNet, CIFAR10DataModule) if __name__ == "__main__": torch.set_float32_matmul_precision("medium") - cli_main() + cli = cli_main() + + if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar10/vgg.py b/experiments/classification/cifar10/vgg.py index a6103551..606e3b36 100644 --- a/experiments/classification/cifar10/vgg.py +++ b/experiments/classification/cifar10/vgg.py @@ -1,35 +1,24 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser +from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import VGG +from torch_uncertainty.baselines.classification import VGG from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(VGG, CIFAR10DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - net_name = f"{args.version}-vgg{args.arch}-cifar10" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.Adam) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - # model - model = VGG( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"vgg{args.arch}", "cifar10", args.version - ), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(VGG, CIFAR10DataModule) - cli_main(model, dm, root, net_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar10/wideresnet.py b/experiments/classification/cifar10/wideresnet.py index 50efd480..0573ee48 100644 --- a/experiments/classification/cifar10/wideresnet.py +++ b/experiments/classification/cifar10/wideresnet.py @@ -1,33 +1,24 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser +from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import WideResNet +from torch_uncertainty.baselines.classification import WideResNet from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[2] - args = init_args(WideResNet, CIFAR10DataModule) +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - net_name = f"{args.version}-wideresnet28x10-cifar10" - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) +def cli_main() -> ResNetCLI: + return ResNetCLI(WideResNet, CIFAR10DataModule) - # model - model = WideResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - "wideresnet28x10", "cifar10", args.version - ), - style="cifar", - **vars(args), - ) - cli_main(model, dm, root, net_name, args) +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar100/readme.md b/experiments/classification/cifar100/readme.md new file mode 100644 index 00000000..5bbe475a --- /dev/null +++ b/experiments/classification/cifar100/readme.md @@ -0,0 +1,3 @@ +# CIFAR100 - Benchmark + +TODO diff --git a/pyproject.toml b/pyproject.toml index 4f2acc47..f9af6d7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ dependencies = [ "timm", "lightning[pytorch-extra]", + "torchvision>=0.16", "tensorboard", "einops", "torchinfo", diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index dd1edae0..48f3753e 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -1,12 +1,10 @@ -from argparse import ArgumentParser from typing import Any from pytorch_lightning import LightningModule from torch import nn from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, + ClassificationRoutine, ) from torch_uncertainty.routines.regression import ( RegressionEnsemble, @@ -34,7 +32,7 @@ def __new__( ) if baseline_type == "single": - return ClassificationSingle( + return ClassificationRoutine( num_classes=num_classes, model=model, loss=loss, @@ -45,7 +43,7 @@ def __new__( ) # baseline_type == "ensemble": kwargs["num_estimators"] = 2 - return ClassificationEnsemble( + return ClassificationRoutine( num_classes=num_classes, model=model, loss=loss, @@ -55,12 +53,12 @@ def __new__( **kwargs, ) - @classmethod - def add_model_specific_args( - cls, - parser: ArgumentParser, - ) -> ArgumentParser: - return ClassificationEnsemble.add_model_specific_args(parser) + # @classmethod + # def add_model_specific_args( + # cls, + # parser: ArgumentParser, + # ) -> ArgumentParser: + # return ClassificationEnsemble.add_model_specific_args(parser) class DummyRegressionBaseline: @@ -101,9 +99,9 @@ def __new__( **kwargs, ) - @classmethod - def add_model_specific_args( - cls, - parser: ArgumentParser, - ) -> ArgumentParser: - return ClassificationEnsemble.add_model_specific_args(parser) + # @classmethod + # def add_model_specific_args( + # cls, + # parser: ArgumentParser, + # ) -> ArgumentParser: + # return ClassificationEnsemble.add_model_specific_args(parser) diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index b6cd1fae..12980409 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -2,12 +2,13 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_wideresnet, - optim_cifar100_resnet18, - optim_cifar100_resnet50, -) +from torch_uncertainty.baselines.classification import ResNet, WideResNet + +# from torch_uncertainty.optimization_procedures import ( +# optim_cifar10_wideresnet, +# optim_cifar100_resnet18, +# optim_cifar100_resnet50, +# ) class TestBatchedBaseline: @@ -18,7 +19,6 @@ def test_batched_18(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, version="batched", arch=18, style="cifar", @@ -29,7 +29,6 @@ def test_batched_18(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_batched_50(self): @@ -37,7 +36,6 @@ def test_batched_50(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet50, version="batched", arch=50, style="imagenet", @@ -48,7 +46,6 @@ def test_batched_50(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) @@ -60,7 +57,6 @@ def test_batched(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, version="batched", style="cifar", num_estimators=4, @@ -70,5 +66,4 @@ def test_batched(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index d5521847..8ecbdd81 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -1,6 +1,4 @@ -from argparse import ArgumentParser - -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines.classification import DeepEnsembles class TestDeepEnsembles: @@ -8,16 +6,10 @@ class TestDeepEnsembles: def test_standard(self): DeepEnsembles( - task="classification", log_path=".", checkpoint_ids=[], backbone="resnet", - in_channels=3, num_classes=10, - version="vanilla", - arch=18, - style="cifar", - groups=1, ) - parser = ArgumentParser() - DeepEnsembles.add_model_specific_args(parser) + # parser = ArgumentParser() + # DeepEnsembles.add_model_specific_args(parser) diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index e992bdea..4fc82060 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -3,12 +3,13 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_wideresnet, - optim_cifar100_resnet18, - optim_cifar100_resnet50, -) +from torch_uncertainty.baselines.classification import ResNet, WideResNet + +# from torch_uncertainty.optimization_procedures import ( +# optim_cifar10_wideresnet, +# optim_cifar100_resnet18, +# optim_cifar100_resnet50, +# ) class TestMaskedBaseline: @@ -19,7 +20,6 @@ def test_masked_18(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, version="masked", arch=18, style="cifar", @@ -31,7 +31,6 @@ def test_masked_18(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_masked_50(self): @@ -39,7 +38,6 @@ def test_masked_50(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet50, version="masked", arch=50, style="imagenet", @@ -51,7 +49,6 @@ def test_masked_50(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) def test_masked_scale_lt_1(self): @@ -60,7 +57,6 @@ def test_masked_scale_lt_1(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, version="masked", arch=18, style="cifar", @@ -75,7 +71,6 @@ def test_masked_groups_lt_1(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, version="masked", arch=18, style="cifar", @@ -93,7 +88,6 @@ def test_masked(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, version="masked", style="cifar", num_estimators=4, @@ -104,5 +98,4 @@ def test_masked(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_mc_dropout.py b/tests/baselines/test_mc_dropout.py index 80056857..e18f708b 100644 --- a/tests/baselines/test_mc_dropout.py +++ b/tests/baselines/test_mc_dropout.py @@ -2,11 +2,12 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import VGG, ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_wideresnet, -) +from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet + +# from torch_uncertainty.optimization_procedures import ( +# optim_cifar10_resnet18, +# optim_cifar10_wideresnet, +# ) class TestStandardBaseline: @@ -17,7 +18,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="mc-dropout", num_estimators=4, arch=18, @@ -27,7 +27,6 @@ def test_standard(self): summary(net) _ = net.criterion - net.configure_optimizers() net(torch.rand(1, 3, 32, 32)) @@ -39,7 +38,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, version="mc-dropout", num_estimators=4, style="cifar", @@ -48,7 +46,6 @@ def test_standard(self): summary(net) _ = net.criterion - net.configure_optimizers() net(torch.rand(1, 3, 32, 32)) @@ -60,7 +57,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="mc-dropout", num_estimators=4, arch=11, @@ -70,14 +66,12 @@ def test_standard(self): summary(net) _ = net.criterion - net.configure_optimizers() net(torch.rand(1, 3, 32, 32)) net = VGG( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="mc-dropout", num_estimators=4, arch=11, diff --git a/tests/baselines/test_mimo.py b/tests/baselines/test_mimo.py index cf4a29cc..3adf8bd6 100644 --- a/tests/baselines/test_mimo.py +++ b/tests/baselines/test_mimo.py @@ -2,12 +2,13 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import ResNet, WideResNet -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_resnet50, - optim_cifar10_wideresnet, -) +from torch_uncertainty.baselines.classification import ResNet, WideResNet + +# from torch_uncertainty.optimization_procedures import ( +# optim_cifar10_resnet18, +# optim_cifar10_resnet50, +# optim_cifar10_wideresnet, +# ) class TestMIMOBaseline: @@ -18,7 +19,6 @@ def test_mimo_50(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, version="mimo", arch=50, style="cifar", @@ -31,7 +31,6 @@ def test_mimo_50(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_mimo_18(self): @@ -39,7 +38,6 @@ def test_mimo_18(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="mimo", arch=18, style="imagenet", @@ -52,7 +50,6 @@ def test_mimo_18(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) @@ -64,7 +61,6 @@ def test_mimo(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, version="mimo", style="cifar", num_estimators=4, @@ -76,5 +72,4 @@ def test_mimo(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_others.py b/tests/baselines/test_others.py index 06f497cf..42544023 100644 --- a/tests/baselines/test_others.py +++ b/tests/baselines/test_others.py @@ -1,11 +1,10 @@ import pytest from torch import nn -from torch_uncertainty.baselines import VGG, ResNet, WideResNet +from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet from torch_uncertainty.baselines.regression import MLP from torch_uncertainty.optimization_procedures import ( optim_cifar10_resnet18, - optim_cifar10_wideresnet, ) @@ -18,7 +17,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="prior", arch=18, style="cifar", @@ -35,7 +33,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, version="prior", style="cifar", groups=1, @@ -51,7 +48,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="prior", arch=11, groups=1, diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index d7ec73bc..b32f31bc 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -3,12 +3,10 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import VGG, ResNet, WideResNet +from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet from torch_uncertainty.baselines.regression import MLP from torch_uncertainty.optimization_procedures import ( optim_cifar10_resnet18, - optim_cifar10_resnet50, - optim_cifar10_wideresnet, ) @@ -20,7 +18,6 @@ def test_packed_50(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, version="packed", arch=50, style="cifar", @@ -33,7 +30,6 @@ def test_packed_50(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) def test_packed_18(self): @@ -41,7 +37,6 @@ def test_packed_18(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="packed", arch=18, style="imagenet", @@ -54,7 +49,6 @@ def test_packed_18(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 40, 40)) def test_packed_alpha_lt_0(self): @@ -63,7 +57,6 @@ def test_packed_alpha_lt_0(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, version="packed", arch=50, style="cifar", @@ -79,7 +72,6 @@ def test_packed_gamma_lt_1(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, version="packed", arch=50, style="cifar", @@ -98,7 +90,6 @@ def test_packed(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, version="packed", style="cifar", num_estimators=4, @@ -110,7 +101,6 @@ def test_packed(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) @@ -123,7 +113,6 @@ def test_packed(self): in_channels=3, arch=13, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet50, version="packed", num_estimators=4, alpha=2, @@ -134,7 +123,6 @@ def test_packed(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(2, 3, 32, 32)) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index 2da3ef14..ed1f9382 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -2,11 +2,10 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines import VGG, ResNet, WideResNet +from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet from torch_uncertainty.baselines.regression import MLP from torch_uncertainty.optimization_procedures import ( optim_cifar10_resnet18, - optim_cifar10_wideresnet, ) @@ -18,7 +17,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="vanilla", arch=18, style="cifar", @@ -27,7 +25,6 @@ def test_standard(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) @@ -39,7 +36,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, version="vanilla", style="cifar", groups=1, @@ -47,7 +43,6 @@ def test_standard(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) @@ -59,7 +54,6 @@ def test_standard(self): num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, version="vanilla", arch=11, groups=1, @@ -67,7 +61,6 @@ def test_standard(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/test_cifar10_datamodule.py index 439a44ac..09f6ff0f 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/test_cifar10_datamodule.py @@ -1,5 +1,3 @@ -from argparse import ArgumentParser - import pytest from torchvision.datasets import CIFAR10 @@ -12,14 +10,11 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" def test_cifar10_main(self): - parser = ArgumentParser() - parser = CIFAR10DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - args.cutout = 16 + # parser = ArgumentParser() + # parser = CIFAR10DataModule.add_argparse_args(parser) - dm = CIFAR10DataModule(**vars(args)) + # Simulate that cutout is set to 16 + dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR10 assert isinstance(dm.transform_train.transforms[2], Cutout) @@ -48,64 +43,81 @@ def test_cifar10_main(self): dm.setup("test") dm.test_dataloader() - args.test_alt = "c" - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", batch_size=128, cutout=16, test_alt="c" + ) dm.dataset = DummyClassificationDataset with pytest.raises(ValueError): dm.setup() - args.test_alt = "h" - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", batch_size=128, cutout=16, test_alt="h" + ) dm.dataset = DummyClassificationDataset dm.setup("test") - args.test_alt = None - args.num_dataloaders = 2 - args.val_split = 0.1 - dm = CIFAR10DataModule(**vars(args)) + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + cutout=16, + num_dataloaders=2, + val_split=0.1, + ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.setup() dm.setup("test") dm.train_dataloader() - args.cutout = 8 - args.auto_augment = "rand-m9-n2-mstd0.5" + # args.cutout = 8 + # args.auto_augment = "rand-m9-n2-mstd0.5" with pytest.raises(ValueError): - dm = CIFAR10DataModule(**vars(args)) - - args.cutout = None - args.auto_augment = "rand-m9-n2-mstd0.5" - dm = CIFAR10DataModule(**vars(args)) - - def test_cifar10_cv(self): - parser = ArgumentParser() - parser = CIFAR10DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - - dm = CIFAR10DataModule(**vars(args)) - dm.dataset = ( - lambda root, train, download, transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - dm.make_cross_val_splits(2, 1) - - args.val_split = 0.1 - dm = CIFAR10DataModule(**vars(args)) - dm.dataset = ( - lambda root, train, download, transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + cutout=8, + num_dataloaders=2, + val_split=0.1, + auto_augment="rand-m9-n2-mstd0.5", ) + + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + cutout=None, + num_dataloaders=2, + val_split=0.1, + auto_augment="rand-m9-n2-mstd0.5", ) - dm.make_cross_val_splits(2, 1) + + # def test_cifar10_cv(self): + # parser = ArgumentParser() + # parser = CIFAR10DataModule.add_argparse_args(parser) + + # # Simulate that cutout is set to 8 + # args = parser.parse_args("") + + # dm = CIFAR10DataModule(**vars(args)) + # dm.dataset = ( + # lambda root, train, download, transform: DummyClassificationDataset( + # root, + # train=train, + # download=download, + # transform=transform, + # num_images=20, + # ) + # ) + # dm.make_cross_val_splits(2, 1) + + # args.val_split = 0.1 + # dm = CIFAR10DataModule(**vars(args)) + # dm.dataset = ( + # lambda root, train, download, transform: DummyClassificationDataset( + # root, + # train=train, + # download=download, + # transform=transform, + # num_images=20, + # ) + # ) + # dm.make_cross_val_splits(2, 1) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 2c793fbf..db5c58ff 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -1,368 +1,331 @@ -from functools import partial -from pathlib import Path - -import pytest -from cli_test_helpers import ArgvContext -from torch import nn - -from tests._dummies import ( - DummyClassificationBaseline, - DummyClassificationDataModule, - DummyClassificationDataset, -) -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.losses import DECLoss, ELBOLoss -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 -from torch_uncertainty.routines.classification import ( - ClassificationEnsemble, - ClassificationSingle, -) - - -class TestClassificationSingle: - """Testing the classification routine with a single model.""" - - def test_cli_main_dummy_binary(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--logits"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_ood(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py", "--fast_dev_run"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - loss = partial( - ELBOLoss, - criterion=nn.CrossEntropyLoss(), - kl_weight=1e-5, - num_samples=2, - ) - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext( - "file.py", - "--evaluate_ood", - "--entropy", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=DECLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext( - "file.py", - "--evaluate_ood", - "--entropy", - "--cutmix_alpha", - "0.5", - "--mixtype", - "timm", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=DECLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - with pytest.raises(NotImplementedError): - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_mixup_ts_cv(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext( - "file.py", - "--mixtype", - "kernel_warping", - "--mixup_alpha", - "1.", - "--dist_sim", - "inp", - "--val_temp_scaling", - "--use_cv", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=10, **vars(args)) - dm.dataset = ( - lambda root, - num_channels, - num_classes, - image_size, - transform: DummyClassificationDataset( - root, - num_channels=num_channels, - num_classes=num_classes, - image_size=image_size, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [ - DummyClassificationBaseline( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - calibration_set=dm.get_val_set, - **vars(args), - ) - for i in range(len(list_dm)) - ] - - cli_main(list_model, list_dm, root, "logs/dummy", args) - - with ArgvContext( - "file.py", - "--mixtype", - "kernel_warping", - "--mixup_alpha", - "1.", - "--dist_sim", - "emb", - "--val_temp_scaling", - "--use_cv", - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=10, **vars(args)) - dm.dataset = ( - lambda root, - num_channels, - num_classes, - image_size, - transform: DummyClassificationDataset( - root, - num_channels=num_channels, - num_classes=num_classes, - image_size=image_size, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - DummyClassificationBaseline( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - calibration_set=dm.get_val_set, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "logs/dummy", args) - - def test_classification_failures(self): - with pytest.raises(ValueError): - ClassificationSingle( - 10, nn.Module(), None, None, use_entropy=True, use_logits=True - ) - - with pytest.raises(ValueError): - ClassificationSingle(10, nn.Module(), None, None, cutmix_alpha=-1) - - -class TestClassificationEnsemble: - """Testing the classification routine with an ensemble model.""" - - def test_cli_main_dummy_binary(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - loss = partial( - ELBOLoss, - criterion=nn.CrossEntropyLoss(), - kl_weight=1e-5, - num_samples=1, - ) - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--mutual_information"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_ood(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py", "--logits"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--evaluate_ood", "--entropy"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=DECLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - # datamodule - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - def test_classification_failures(self): - with pytest.raises(ValueError): - ClassificationEnsemble( - 10, - nn.Module(), - None, - None, - 2, - use_entropy=True, - use_variation_ratio=True, - ) +# from functools import partial +# from pathlib import Path + +# import pytest +# from cli_test_helpers import ArgvContext +# from torch import nn + +# from tests._dummies import ( +# DummyClassificationBaseline, +# DummyClassificationDataModule, +# DummyClassificationDataset, +# ) +# from torch_uncertainty import cli_main, init_args +# from torch_uncertainty.losses import DECLoss, ELBOLoss +# from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +# from torch_uncertainty.routines.classification import ( +# ClassificationRoutine, +# ) + + +# class TestClassificationSingle: +# """Testing the classification routine with a single model.""" + +# def test_cli_main_dummy_binary(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(num_classes=1, **vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=nn.BCEWithLogitsLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# **vars(args), +# ) +# cli_main(model, dm, root, "logs/dummy", args) + +# with ArgvContext("file.py", "--logits"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(num_classes=1, **vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=nn.BCEWithLogitsLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# **vars(args), +# ) +# cli_main(model, dm, root, "logs/dummy", args) + +# def test_cli_main_dummy_ood(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py", "--fast_dev_run"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(**vars(args)) +# loss = partial( +# ELBOLoss, +# criterion=nn.CrossEntropyLoss(), +# kl_weight=1e-5, +# num_samples=2, +# ) +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=loss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# **vars(args), +# ) +# cli_main(model, dm, root, "logs/dummy", args) + +# with ArgvContext( +# "file.py", +# "--evaluate_ood", +# "--entropy", +# ): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(**vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=DECLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# **vars(args), +# ) +# cli_main(model, dm, root, "logs/dummy", args) + +# with ArgvContext( +# "file.py", +# "--evaluate_ood", +# "--entropy", +# "--cutmix_alpha", +# "0.5", +# "--mixtype", +# "timm", +# ): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(**vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=DECLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# **vars(args), +# ) +# with pytest.raises(NotImplementedError): +# cli_main(model, dm, root, "logs/dummy", args) + +# def test_cli_main_dummy_mixup_ts_cv(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext( +# "file.py", +# "--mixtype", +# "kernel_warping", +# "--mixup_alpha", +# "1.", +# "--dist_sim", +# "inp", +# "--val_temp_scaling", +# "--use_cv", +# ): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(num_classes=10, **vars(args)) +# dm.dataset = lambda root, num_channels, num_classes, image_size, transform: DummyClassificationDataset( +# root, +# num_channels=num_channels, +# num_classes=num_classes, +# image_size=image_size, +# transform=transform, +# num_images=20, +# ) + +# list_dm = dm.make_cross_val_splits(2, 1) +# list_model = [ +# DummyClassificationBaseline( +# num_classes=list_dm[i].dm.num_classes, +# in_channels=list_dm[i].dm.num_channels, +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# calibration_set=dm.get_val_set, +# **vars(args), +# ) +# for i in range(len(list_dm)) +# ] + +# cli_main(list_model, list_dm, root, "logs/dummy", args) + +# with ArgvContext( +# "file.py", +# "--mixtype", +# "kernel_warping", +# "--mixup_alpha", +# "1.", +# "--dist_sim", +# "emb", +# "--val_temp_scaling", +# "--use_cv", +# ): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(num_classes=10, **vars(args)) +# dm.dataset = lambda root, num_channels, num_classes, image_size, transform: DummyClassificationDataset( +# root, +# num_channels=num_channels, +# num_classes=num_classes, +# image_size=image_size, +# transform=transform, +# num_images=20, +# ) + +# list_dm = dm.make_cross_val_splits(2, 1) +# list_model = [] +# for i in range(len(list_dm)): +# list_model.append( +# DummyClassificationBaseline( +# num_classes=list_dm[i].dm.num_classes, +# in_channels=list_dm[i].dm.num_channels, +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# calibration_set=dm.get_val_set, +# **vars(args), +# ) +# ) + +# cli_main(list_model, list_dm, root, "logs/dummy", args) + +# def test_classification_failures(self): +# with pytest.raises(ValueError): +# ClassificationRoutine( +# 10, nn.Module(), None, None, use_entropy=True, use_logits=True +# ) + +# with pytest.raises(ValueError): +# ClassificationRoutine(10, nn.Module(), None, None, cutmix_alpha=-1) + + +# class TestClassificationEnsemble: +# """Testing the classification routine with an ensemble model.""" + +# def test_cli_main_dummy_binary(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(num_classes=1, **vars(args)) +# loss = partial( +# ELBOLoss, +# criterion=nn.CrossEntropyLoss(), +# kl_weight=1e-5, +# num_samples=1, +# ) +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=loss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="ensemble", +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) + +# with ArgvContext("file.py", "--mutual_information"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(num_classes=1, **vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=nn.BCEWithLogitsLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="ensemble", +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) + +# def test_cli_main_dummy_ood(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py", "--logits"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(**vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="ensemble", +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) + +# with ArgvContext("file.py", "--evaluate_ood", "--entropy"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(**vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=DECLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="ensemble", +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) + +# with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"): +# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyClassificationDataModule(**vars(args)) + +# model = DummyClassificationBaseline( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="ensemble", +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) + +# def test_classification_failures(self): +# with pytest.raises(ValueError): +# ClassificationRoutine( +# 10, +# nn.Module(), +# None, +# None, +# 2, +# use_entropy=True, +# use_variation_ratio=True, +# ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 193e01f5..70ca4b48 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -1,160 +1,160 @@ -from functools import partial -from pathlib import Path - -import pytest -from cli_test_helpers import ArgvContext -from torch import nn - -from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.losses import BetaNLL, NIGLoss -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 - - -class TestRegressionSingle: - """Testing the Regression routine with a single model.""" - - def test_cli_main_dummy_dist(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=2, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - dist_estimation=2, - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - def test_cli_main_dummy_dist_der(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - loss = partial( - NIGLoss, - reg_weight=1e-2, - ) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=4, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - dist_estimation=4, - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy_der", args) - - def test_cli_main_dummy_dist_betanll(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - loss = partial( - BetaNLL, - beta=0.5, - ) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=2, - loss=loss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - dist_estimation=2, - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy_betanll", args) - - def test_cli_main_dummy(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=2, **vars(args)) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=dm.out_features, - loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) - - def test_regression_failures(self): - with pytest.raises(ValueError): - DummyRegressionBaseline( - in_features=10, - out_features=3, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - dist_estimation=4, - ) - - with pytest.raises(ValueError): - DummyRegressionBaseline( - in_features=10, - out_features=3, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - dist_estimation=-4, - ) - - with pytest.raises(TypeError): - DummyRegressionBaseline( - in_features=10, - out_features=4, - loss=nn.GaussianNLLLoss, - optimization_procedure=optim_cifar10_resnet18, - dist_estimation=4.2, - ) - - -class TestRegressionEnsemble: - """Testing the Regression routine with an ensemble model.""" - - def test_cli_main_dummy(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - - # datamodule - args.root = str(root / "data") - dm = DummyRegressionDataModule(out_features=1, **vars(args)) - - model = DummyRegressionBaseline( - in_features=dm.in_features, - out_features=dm.out_features, - loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="ensemble", - **vars(args), - ) - - cli_main(model, dm, root, "logs/dummy", args) +# from functools import partial +# from pathlib import Path + +# import pytest +# from cli_test_helpers import ArgvContext +# from torch import nn + +# from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule +# from torch_uncertainty import cli_main, init_args +# from torch_uncertainty.losses import BetaNLL, NIGLoss +# from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 + + +# class TestRegressionSingle: +# """Testing the Regression routine with a single model.""" + +# def test_cli_main_dummy_dist(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyRegressionDataModule(out_features=1, **vars(args)) + +# model = DummyRegressionBaseline( +# in_features=dm.in_features, +# out_features=2, +# loss=nn.GaussianNLLLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# dist_estimation=2, +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) + +# def test_cli_main_dummy_dist_der(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyRegressionDataModule(out_features=1, **vars(args)) + +# loss = partial( +# NIGLoss, +# reg_weight=1e-2, +# ) + +# model = DummyRegressionBaseline( +# in_features=dm.in_features, +# out_features=4, +# loss=loss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# dist_estimation=4, +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy_der", args) + +# def test_cli_main_dummy_dist_betanll(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyRegressionDataModule(out_features=1, **vars(args)) + +# loss = partial( +# BetaNLL, +# beta=0.5, +# ) + +# model = DummyRegressionBaseline( +# in_features=dm.in_features, +# out_features=2, +# loss=loss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# dist_estimation=2, +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy_betanll", args) + +# def test_cli_main_dummy(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyRegressionDataModule(out_features=2, **vars(args)) + +# model = DummyRegressionBaseline( +# in_features=dm.in_features, +# out_features=dm.out_features, +# loss=nn.MSELoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="single", +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) + +# def test_regression_failures(self): +# with pytest.raises(ValueError): +# DummyRegressionBaseline( +# in_features=10, +# out_features=3, +# loss=nn.GaussianNLLLoss, +# optimization_procedure=optim_cifar10_resnet18, +# dist_estimation=4, +# ) + +# with pytest.raises(ValueError): +# DummyRegressionBaseline( +# in_features=10, +# out_features=3, +# loss=nn.GaussianNLLLoss, +# optimization_procedure=optim_cifar10_resnet18, +# dist_estimation=-4, +# ) + +# with pytest.raises(TypeError): +# DummyRegressionBaseline( +# in_features=10, +# out_features=4, +# loss=nn.GaussianNLLLoss, +# optimization_procedure=optim_cifar10_resnet18, +# dist_estimation=4.2, +# ) + + +# class TestRegressionEnsemble: +# """Testing the Regression routine with an ensemble model.""" + +# def test_cli_main_dummy(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = DummyRegressionDataModule(out_features=1, **vars(args)) + +# model = DummyRegressionBaseline( +# in_features=dm.in_features, +# out_features=dm.out_features, +# loss=nn.MSELoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="ensemble", +# **vars(args), +# ) + +# cli_main(model, dm, root, "logs/dummy", args) diff --git a/tests/test_cli.py b/tests/test_cli.py index 24349c5f..5f4c3a01 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,386 +1,365 @@ -import sys -from pathlib import Path - -import pytest -from cli_test_helpers import ArgvContext -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import VGG, ResNet, WideResNet -from torch_uncertainty.baselines.regression import MLP -from torch_uncertainty.datamodules import CIFAR10DataModule, UCIDataModule -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, - optim_cifar10_vgg16, - optim_cifar10_wideresnet, - optim_regression, -) -from torch_uncertainty.utils.misc import csv_writer - -from ._dummies.dataset import DummyClassificationDataset - - -class TestCLI: - """Testing the CLI function.""" - - def test_cli_main_resnet(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - - results = cli_main(model, dm, root, "std", args) - results_path = root / "tests" / "logs" - if not results_path.exists(): - results_path.mkdir(parents=True) - for dict_result in results: - csv_writer( - results_path / "results.csv", - dict_result, - ) - # Test if file already exists - for dict_result in results: - csv_writer( - results_path / "results.csv", - dict_result, - ) - - def test_cli_main_other_arguments(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext( - "file.py", "--seed", "42", "--max_epochs", "1", "--channels_last" - ): - print(sys.orig_argv, sys.argv) - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = root / "data" - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - def test_cli_main_wideresnet(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(WideResNet, CIFAR10DataModule) - - # datamodule - args.root = root / "data" - dm = CIFAR10DataModule(**vars(args)) - - args.summary = True - - model = WideResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_wideresnet, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - def test_cli_main_vgg(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(VGG, CIFAR10DataModule) - - # datamodule - args.root = root / "data" - dm = CIFAR10DataModule(**vars(args)) - - args.summary = True - - model = VGG( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_vgg16, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - def test_cli_main_mlp(self): - root = str(Path(__file__).parent.absolute().parents[0]) - with ArgvContext("file.py"): - args = init_args(MLP, UCIDataModule) - - # datamodule - args.root = root + "/data" - dm = UCIDataModule( - dataset_name="kin8nm", input_shape=(1, 5), **vars(args) - ) - - args.summary = True - - model = MLP( - num_outputs=1, - in_features=5, - hidden_dims=[], - dist_estimation=1, - loss=nn.MSELoss, - optimization_procedure=optim_regression, - **vars(args), - ) - - cli_main(model, dm, root, "std", args) - - args.test = True - cli_main(model, dm, root, "std", args) - - def test_cli_other_training_task(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py"): - args = init_args(MLP, UCIDataModule) - - # datamodule - args.root = root / "data" - dm = UCIDataModule( - dataset_name="kin8nm", input_shape=(1, 5), **vars(args) - ) - - dm.training_task = "time-series-regression" - - args.summary = True - - model = MLP( - num_outputs=1, - in_features=5, - hidden_dims=[], - dist_estimation=1, - loss=nn.MSELoss, - optimization_procedure=optim_regression, - **vars(args), - ) - with pytest.raises(ValueError): - cli_main(model, dm, root, "std", args) - - def test_cli_cv_ts(self): - root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py", "--use_cv", "--channels_last"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [ - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - for i in range(len(list_dm)) - ] - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup_io"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext("file.py", "--use_cv", "--mixtype", "regmixup"): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - with ArgvContext( - "file.py", "--use_cv", "--mixtype", "kernel_warping" - ): - args = init_args(ResNet, CIFAR10DataModule) - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # Simulate that summary is True & the only argument - args.summary = True - - dm.dataset = ( - lambda root, - train, - download, - transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) - ) - - list_dm = dm.make_cross_val_splits(2, 1) - list_model = [] - for i in range(len(list_dm)): - list_model.append( - ResNet( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - style="cifar", - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - **vars(args), - ) - ) - - cli_main(list_model, list_dm, root, "std", args) - - def test_init_args_void(self): - with ArgvContext("file.py"): - init_args() +# import sys +# from pathlib import Path + +# import pytest +# from cli_test_helpers import ArgvContext +# from torch import nn + +# from torch_uncertainty import cli_main, init_args +# from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet +# from torch_uncertainty.baselines.regression import MLP +# from torch_uncertainty.datamodules import CIFAR10DataModule, UCIDataModule +# from torch_uncertainty.optimization_procedures import ( +# optim_cifar10_resnet18, +# optim_cifar10_vgg16, +# optim_cifar10_wideresnet, +# optim_regression, +# ) +# from torch_uncertainty.utils.misc import csv_writer + +# from ._dummies.dataset import DummyClassificationDataset + + +# class TestCLI: +# """Testing the CLI function.""" + +# def test_cli_main_resnet(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(ResNet, CIFAR10DataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = CIFAR10DataModule(**vars(args)) + +# # Simulate that summary is True & the only argument +# args.summary = True + +# model = ResNet( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# style="cifar", +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# **vars(args), +# ) + +# results = cli_main(model, dm, root, "std", args) +# results_path = root / "tests" / "logs" +# if not results_path.exists(): +# results_path.mkdir(parents=True) +# for dict_result in results: +# csv_writer( +# results_path / "results.csv", +# dict_result, +# ) +# # Test if file already exists +# for dict_result in results: +# csv_writer( +# results_path / "results.csv", +# dict_result, +# ) + +# def test_cli_main_other_arguments(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext( +# "file.py", "--seed", "42", "--max_epochs", "1", "--channels_last" +# ): +# print(sys.orig_argv, sys.argv) +# args = init_args(ResNet, CIFAR10DataModule) + +# # datamodule +# args.root = root / "data" +# dm = CIFAR10DataModule(**vars(args)) + +# # Simulate that summary is True & the only argument +# args.summary = True + +# model = ResNet( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# style="cifar", +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# **vars(args), +# ) + +# cli_main(model, dm, root, "std", args) + +# def test_cli_main_wideresnet(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(WideResNet, CIFAR10DataModule) + +# # datamodule +# args.root = root / "data" +# dm = CIFAR10DataModule(**vars(args)) + +# args.summary = True + +# model = WideResNet( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_wideresnet, +# **vars(args), +# ) + +# cli_main(model, dm, root, "std", args) + +# def test_cli_main_vgg(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(VGG, CIFAR10DataModule) + +# # datamodule +# args.root = root / "data" +# dm = CIFAR10DataModule(**vars(args)) + +# args.summary = True + +# model = VGG( +# num_classes=dm.num_classes, +# in_channels=dm.num_channels, +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_vgg16, +# **vars(args), +# ) + +# cli_main(model, dm, root, "std", args) + +# def test_cli_main_mlp(self): +# root = str(Path(__file__).parent.absolute().parents[0]) +# with ArgvContext("file.py"): +# args = init_args(MLP, UCIDataModule) + +# # datamodule +# args.root = root + "/data" +# dm = UCIDataModule(dataset_name="kin8nm", input_shape=(1, 5), **vars(args)) + +# args.summary = True + +# model = MLP( +# num_outputs=1, +# in_features=5, +# hidden_dims=[], +# dist_estimation=1, +# loss=nn.MSELoss, +# optimization_procedure=optim_regression, +# **vars(args), +# ) + +# cli_main(model, dm, root, "std", args) + +# args.test = True +# cli_main(model, dm, root, "std", args) + +# def test_cli_other_training_task(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py"): +# args = init_args(MLP, UCIDataModule) + +# # datamodule +# args.root = root / "data" +# dm = UCIDataModule(dataset_name="kin8nm", input_shape=(1, 5), **vars(args)) + +# dm.training_task = "time-series-regression" + +# args.summary = True + +# model = MLP( +# num_outputs=1, +# in_features=5, +# hidden_dims=[], +# dist_estimation=1, +# loss=nn.MSELoss, +# optimization_procedure=optim_regression, +# **vars(args), +# ) +# with pytest.raises(ValueError): +# cli_main(model, dm, root, "std", args) + +# def test_cli_cv_ts(self): +# root = Path(__file__).parent.absolute().parents[0] +# with ArgvContext("file.py", "--use_cv", "--channels_last"): +# args = init_args(ResNet, CIFAR10DataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = CIFAR10DataModule(**vars(args)) + +# # Simulate that summary is True & the only argument +# args.summary = True + +# dm.dataset = ( +# lambda root, train, download, transform: DummyClassificationDataset( +# root, +# train=train, +# download=download, +# transform=transform, +# num_images=20, +# ) +# ) + +# list_dm = dm.make_cross_val_splits(2, 1) +# list_model = [ +# ResNet( +# num_classes=list_dm[i].dm.num_classes, +# in_channels=list_dm[i].dm.num_channels, +# style="cifar", +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# **vars(args), +# ) +# for i in range(len(list_dm)) +# ] + +# cli_main(list_model, list_dm, root, "std", args) + +# with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup"): +# args = init_args(ResNet, CIFAR10DataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = CIFAR10DataModule(**vars(args)) + +# # Simulate that summary is True & the only argument +# args.summary = True + +# dm.dataset = ( +# lambda root, train, download, transform: DummyClassificationDataset( +# root, +# train=train, +# download=download, +# transform=transform, +# num_images=20, +# ) +# ) + +# list_dm = dm.make_cross_val_splits(2, 1) +# list_model = [] +# for i in range(len(list_dm)): +# list_model.append( +# ResNet( +# num_classes=list_dm[i].dm.num_classes, +# in_channels=list_dm[i].dm.num_channels, +# style="cifar", +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# **vars(args), +# ) +# ) + +# cli_main(list_model, list_dm, root, "std", args) + +# with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup_io"): +# args = init_args(ResNet, CIFAR10DataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = CIFAR10DataModule(**vars(args)) + +# # Simulate that summary is True & the only argument +# args.summary = True + +# dm.dataset = ( +# lambda root, train, download, transform: DummyClassificationDataset( +# root, +# train=train, +# download=download, +# transform=transform, +# num_images=20, +# ) +# ) + +# list_dm = dm.make_cross_val_splits(2, 1) +# list_model = [] +# for i in range(len(list_dm)): +# list_model.append( +# ResNet( +# num_classes=list_dm[i].dm.num_classes, +# in_channels=list_dm[i].dm.num_channels, +# style="cifar", +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# **vars(args), +# ) +# ) + +# cli_main(list_model, list_dm, root, "std", args) + +# with ArgvContext("file.py", "--use_cv", "--mixtype", "regmixup"): +# args = init_args(ResNet, CIFAR10DataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = CIFAR10DataModule(**vars(args)) + +# # Simulate that summary is True & the only argument +# args.summary = True + +# dm.dataset = ( +# lambda root, train, download, transform: DummyClassificationDataset( +# root, +# train=train, +# download=download, +# transform=transform, +# num_images=20, +# ) +# ) + +# list_dm = dm.make_cross_val_splits(2, 1) +# list_model = [] +# for i in range(len(list_dm)): +# list_model.append( +# ResNet( +# num_classes=list_dm[i].dm.num_classes, +# in_channels=list_dm[i].dm.num_channels, +# style="cifar", +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# **vars(args), +# ) +# ) + +# cli_main(list_model, list_dm, root, "std", args) + +# with ArgvContext("file.py", "--use_cv", "--mixtype", "kernel_warping"): +# args = init_args(ResNet, CIFAR10DataModule) + +# # datamodule +# args.root = str(root / "data") +# dm = CIFAR10DataModule(**vars(args)) + +# # Simulate that summary is True & the only argument +# args.summary = True + +# dm.dataset = ( +# lambda root, train, download, transform: DummyClassificationDataset( +# root, +# train=train, +# download=download, +# transform=transform, +# num_images=20, +# ) +# ) + +# list_dm = dm.make_cross_val_splits(2, 1) +# list_model = [] +# for i in range(len(list_dm)): +# list_model.append( +# ResNet( +# num_classes=list_dm[i].dm.num_classes, +# in_channels=list_dm[i].dm.num_channels, +# style="cifar", +# loss=nn.CrossEntropyLoss, +# optimization_procedure=optim_cifar10_resnet18, +# **vars(args), +# ) +# ) + +# cli_main(list_model, list_dm, root, "std", args) + +# def test_init_args_void(self): +# with ArgvContext("file.py"): +# init_args() diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 50b41a5b..0fbfcea8 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -1,254 +1,250 @@ # ruff: noqa: F401 -from argparse import ArgumentParser, Namespace -from collections import defaultdict -from pathlib import Path -from typing import Any, Optional - -import numpy as np -import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks import LearningRateMonitor -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from torchinfo import summary - -from .datamodules.abstract import AbstractDataModule -from .utils import get_version - - -def init_args( - network: Any = None, - datamodule: type[pl.LightningDataModule] | None = None, -) -> Namespace: - parser = ArgumentParser("torch-uncertainty") - parser.add_argument( - "--seed", - type=int, - default=None, - help="Random seed to make the training deterministic.", - ) - parser.add_argument( - "--test", - type=int, - default=None, - help="Run in test mode. Set to the checkpoint version number to test.", - ) - parser.add_argument( - "--ckpt", type=int, default=None, help="The number of the checkpoint" - ) - parser.add_argument( - "--summary", - dest="summary", - action="store_true", - help="Print model summary", - ) - parser.add_argument("--log_graph", dest="log_graph", action="store_true") - parser.add_argument( - "--channels_last", - action="store_true", - help="Use channels last memory format", - ) - parser.add_argument( - "--enable_resume", - action="store_true", - help="Allow resuming the training (save optimizer's states)", - ) - parser.add_argument( - "--exp_dir", - type=str, - default="logs/", - help="Directory to store experiment files", - ) - parser.add_argument( - "--exp_name", - type=str, - default="", - help="Name of the experiment folder", - ) - parser.add_argument( - "--opt_temp_scaling", - action="store_true", - default=False, - help="Compute optimal temperature on the test set", - ) - parser.add_argument( - "--val_temp_scaling", - action="store_true", - default=False, - help="Compute temperature on the validation set", - ) - parser = pl.Trainer.add_argparse_args(parser) - if network is not None: - parser = network.add_model_specific_args(parser) - - if datamodule is not None: - parser = datamodule.add_argparse_args(parser) - - return parser.parse_args() - - -def cli_main( - network: pl.LightningModule | list[pl.LightningModule], - datamodule: AbstractDataModule | list[AbstractDataModule], - root: Path | str, - net_name: str, - args: Namespace, -) -> list[dict]: - if isinstance(root, str): - root = Path(root) - - if isinstance(datamodule, list): - training_task = datamodule[0].dm.training_task - else: - training_task = datamodule.training_task - if training_task == "classification": - monitor = "hp/val_acc" - mode = "max" - elif training_task == "regression": - monitor = "hp/val_mse" - mode = "min" - else: - raise ValueError("Unknown problem type.") - - if args.test is None and args.max_epochs is None: - print( - "Setting max_epochs to 1 for testing purposes. Set max_epochs " - "manually to train the model." - ) - args.max_epochs = 1 - - if isinstance(args.seed, int): - pl.seed_everything(args.seed, workers=True) - - if args.channels_last: - if isinstance(network, list): - for i in range(len(network)): - network[i] = network[i].to(memory_format=torch.channels_last) - else: - network = network.to(memory_format=torch.channels_last) - - if hasattr(args, "use_cv") and args.use_cv: - test_values = [] - for i in range(len(datamodule)): - print( - f"Starting fold {i} out of {args.train_over} of a {args.n_splits}-fold CV." - ) - - # logger - tb_logger = TensorBoardLogger( - str(root), - name=net_name, - default_hp_metric=False, - log_graph=args.log_graph, - version=f"fold_{i}", - ) - - # callbacks - save_checkpoints = ModelCheckpoint( - dirpath=tb_logger.log_dir, - monitor=monitor, - mode=mode, - save_last=True, - save_weights_only=not args.enable_resume, - ) - - # Select the best model, monitor the lr and stop if NaN - callbacks = [ - save_checkpoints, - LearningRateMonitor(logging_interval="step"), - EarlyStopping( - monitor=monitor, patience=np.inf, check_finite=True - ), - ] - - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - logger=tb_logger, - deterministic=(args.seed is not None), - inference_mode=not ( - args.opt_temp_scaling or args.val_temp_scaling - ), - ) - if args.summary: - summary( - network[i], - input_size=list(datamodule[i].dm.input_shape).insert(0, 1), - ) - test_values.append({}) - else: - trainer.fit(network[i], datamodule[i]) - test_values.append( - trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] - ) - - all_test_values = defaultdict(list) - for test_value in test_values: - for key in test_value: - all_test_values[key].append(test_value[key]) - - avg_test_values = {} - for key in all_test_values: - avg_test_values[key] = np.mean(all_test_values[key]) - - return [avg_test_values] - - # logger - tb_logger = TensorBoardLogger( - str(root), - name=net_name, - default_hp_metric=False, - log_graph=args.log_graph, - version=args.test, - ) - - # callbacks - save_checkpoints = ModelCheckpoint( - monitor=monitor, - mode=mode, - save_last=True, - save_weights_only=not args.enable_resume, - ) - - # Select the best model, monitor the lr and stop if NaN - callbacks = [ - save_checkpoints, - LearningRateMonitor(logging_interval="step"), - EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), - ] - - # trainer - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - logger=tb_logger, - deterministic=(args.seed is not None), - inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), - ) - if args.summary: - summary( - network, - input_size=list(datamodule.input_shape).insert(0, 1), - ) - test_values = [{}] - elif args.test is not None: - if args.test >= 0: - ckpt_file, _ = get_version( - root=(root / net_name), - version=args.test, - checkpoint=args.ckpt, - ) - test_values = trainer.test( - network, datamodule=datamodule, ckpt_path=str(ckpt_file) - ) - else: - test_values = trainer.test(network, datamodule=datamodule) - else: - # training and testing - trainer.fit(network, datamodule) - if args.fast_dev_run is False: - test_values = trainer.test(datamodule=datamodule, ckpt_path="best") - else: - test_values = [{}] - return test_values +# from argparse import ArgumentParser, Namespace +# from collections import defaultdict +# from pathlib import Path +# from typing import Any + +# import numpy as np +# import pytorch_lightning as pl +# import torch +# from pytorch_lightning.callbacks import LearningRateMonitor +# from pytorch_lightning.callbacks.early_stopping import EarlyStopping +# from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +# from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +# from torchinfo import summary + +# from .datamodules.abstract import AbstractDataModule +# from .utils import get_version + + +# def init_args( +# network: Any = None, +# datamodule: type[pl.LightningDataModule] | None = None, +# ) -> Namespace: +# parser = ArgumentParser("torch-uncertainty") +# parser.add_argument( +# "--seed", +# type=int, +# default=None, +# help="Random seed to make the training deterministic.", +# ) +# parser.add_argument( +# "--test", +# type=int, +# default=None, +# help="Run in test mode. Set to the checkpoint version number to test.", +# ) +# parser.add_argument( +# "--ckpt", type=int, default=None, help="The number of the checkpoint" +# ) +# parser.add_argument( +# "--summary", +# dest="summary", +# action="store_true", +# help="Print model summary", +# ) +# parser.add_argument("--log_graph", dest="log_graph", action="store_true") +# parser.add_argument( +# "--channels_last", +# action="store_true", +# help="Use channels last memory format", +# ) +# parser.add_argument( +# "--enable_resume", +# action="store_true", +# help="Allow resuming the training (save optimizer's states)", +# ) +# parser.add_argument( +# "--exp_dir", +# type=str, +# default="logs/", +# help="Directory to store experiment files", +# ) +# parser.add_argument( +# "--exp_name", +# type=str, +# default="", +# help="Name of the experiment folder", +# ) +# parser.add_argument( +# "--opt_temp_scaling", +# action="store_true", +# default=False, +# help="Compute optimal temperature on the test set", +# ) +# parser.add_argument( +# "--val_temp_scaling", +# action="store_true", +# default=False, +# help="Compute temperature on the validation set", +# ) +# parser = pl.Trainer.add_argparse_args(parser) +# if network is not None: +# parser = network.add_model_specific_args(parser) + +# if datamodule is not None: +# parser = datamodule.add_argparse_args(parser) + +# return parser.parse_args() + + +# def cli_main( +# network: pl.LightningModule | list[pl.LightningModule], +# datamodule: AbstractDataModule | list[AbstractDataModule], +# root: Path | str, +# net_name: str, +# args: Namespace, +# ) -> list[dict]: +# if isinstance(root, str): +# root = Path(root) + +# if isinstance(datamodule, list): +# training_task = datamodule[0].dm.training_task +# else: +# training_task = datamodule.training_task +# if training_task == "classification": +# monitor = "hp/val_acc" +# mode = "max" +# elif training_task == "regression": +# monitor = "hp/val_mse" +# mode = "min" +# else: +# raise ValueError("Unknown problem type.") + +# if args.test is None and args.max_epochs is None: +# print( +# "Setting max_epochs to 1 for testing purposes. Set max_epochs " +# "manually to train the model." +# ) +# args.max_epochs = 1 + +# if isinstance(args.seed, int): +# pl.seed_everything(args.seed, workers=True) + +# if args.channels_last: +# if isinstance(network, list): +# for i in range(len(network)): +# network[i] = network[i].to(memory_format=torch.channels_last) +# else: +# network = network.to(memory_format=torch.channels_last) + +# if hasattr(args, "use_cv") and args.use_cv: +# test_values = [] +# for i in range(len(datamodule)): +# print( +# f"Starting fold {i} out of {args.train_over} of a {args.n_splits}-fold CV." +# ) + +# # logger +# tb_logger = TensorBoardLogger( +# str(root), +# name=net_name, +# default_hp_metric=False, +# log_graph=args.log_graph, +# version=f"fold_{i}", +# ) + +# # callbacks +# save_checkpoints = ModelCheckpoint( +# dirpath=tb_logger.log_dir, +# monitor=monitor, +# mode=mode, +# save_last=True, +# save_weights_only=not args.enable_resume, +# ) + +# # Select the best model, monitor the lr and stop if NaN +# callbacks = [ +# save_checkpoints, +# LearningRateMonitor(logging_interval="step"), +# EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), +# ] + +# trainer = pl.Trainer.from_argparse_args( +# args, +# callbacks=callbacks, +# logger=tb_logger, +# deterministic=(args.seed is not None), +# inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), +# ) +# if args.summary: +# summary( +# network[i], +# input_size=list(datamodule[i].dm.input_shape).insert(0, 1), +# ) +# test_values.append({}) +# else: +# trainer.fit(network[i], datamodule[i]) +# test_values.append( +# trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] +# ) + +# all_test_values = defaultdict(list) +# for test_value in test_values: +# for key in test_value: +# all_test_values[key].append(test_value[key]) + +# avg_test_values = {} +# for key in all_test_values: +# avg_test_values[key] = np.mean(all_test_values[key]) + +# return [avg_test_values] + +# # logger +# tb_logger = TensorBoardLogger( +# str(root), +# name=net_name, +# default_hp_metric=False, +# log_graph=args.log_graph, +# version=args.test, +# ) + +# # callbacks +# save_checkpoints = ModelCheckpoint( +# monitor=monitor, +# mode=mode, +# save_last=True, +# save_weights_only=not args.enable_resume, +# ) + +# # Select the best model, monitor the lr and stop if NaN +# callbacks = [ +# save_checkpoints, +# LearningRateMonitor(logging_interval="step"), +# EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), +# ] + +# # trainer +# trainer = pl.Trainer.from_argparse_args( +# args, +# callbacks=callbacks, +# logger=tb_logger, +# deterministic=(args.seed is not None), +# inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), +# ) +# if args.summary: +# summary( +# network, +# input_size=list(datamodule.input_shape).insert(0, 1), +# ) +# test_values = [{}] +# elif args.test is not None: +# if args.test >= 0: +# ckpt_file, _ = get_version( +# root=(root / net_name), +# version=args.test, +# checkpoint=args.ckpt, +# ) +# test_values = trainer.test( +# network, datamodule=datamodule, ckpt_path=str(ckpt_file) +# ) +# else: +# test_values = trainer.test(network, datamodule=datamodule) +# else: +# # training and testing +# trainer.fit(network, datamodule) +# if args.fast_dev_run is False: +# test_values = trainer.test(datamodule=datamodule, ckpt_path="best") +# else: +# test_values = [{}] +# return test_values diff --git a/torch_uncertainty/baselines/__init__.py b/torch_uncertainty/baselines/__init__.py index 4bfa4e54..c44ec481 100644 --- a/torch_uncertainty/baselines/__init__.py +++ b/torch_uncertainty/baselines/__init__.py @@ -1,6 +1,6 @@ # ruff: noqa: F401 -from .classification.resnet import ResNet - +# from .classification import ResNet # from .classification.vgg import VGG -# from .classification.wideresnet import WideResNet +# from .classification import WideResNet + # from .deep_ensembles import DeepEnsembles diff --git a/torch_uncertainty/baselines/classification/__init__.py b/torch_uncertainty/baselines/classification/__init__.py index 65873c97..bc3deec4 100644 --- a/torch_uncertainty/baselines/classification/__init__.py +++ b/torch_uncertainty/baselines/classification/__init__.py @@ -1,5 +1,5 @@ # ruff: noqa: F401 +from .deep_ensembles import DeepEnsembles from .resnet import ResNet - -# from .vgg import VGG -# from .wideresnet import WideResNet +from .vgg import VGG +from .wideresnet import WideResNet diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py new file mode 100644 index 00000000..866c39af --- /dev/null +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -0,0 +1,64 @@ +from pathlib import Path +from typing import Literal + +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.routines.classification import ClassificationRoutine +from torch_uncertainty.utils import get_version + +from . import VGG, ResNet, WideResNet + + +class DeepEnsembles(ClassificationRoutine): + backbones = { + "resnet": ResNet, + "vgg": VGG, + "wideresnet": WideResNet, + } + + def __init__( + self, + num_classes: int, + log_path: str | Path, + checkpoint_ids: list[int], + backbone: Literal["resnet", "vgg", "wideresnet"], + evaluate_ood: bool = False, + use_entropy: bool = False, + use_logits: bool = False, + use_mi: bool = False, + use_variation_ratio: bool = False, + log_plots: bool = False, + calibration_set: Literal["val", "test"] | None = None, + ) -> None: + if isinstance(log_path, str): + log_path = Path(log_path) + + backbone_cls = self.backbones[backbone] + + models = [] + for version in checkpoint_ids: # coverage: ignore + ckpt_file, hparams_file = get_version( + root=log_path, version=version + ) + trained_model = backbone_cls.load_from_checkpoint( + checkpoint_path=ckpt_file, + hparams_file=hparams_file, + loss=None, + optimization_procedure=None, + ).eval() + models.append(trained_model.model) + + de = deep_ensembles(models=models) + + super().__init__( + num_classes=num_classes, + model=de, + loss=None, + num_estimators=de.num_estimators, + evaluate_ood=evaluate_ood, + use_entropy=use_entropy, + use_logits=use_logits, + use_mi=use_mi, + use_variation_ratio=use_variation_ratio, + log_plots=log_plots, + calibration_set=calibration_set, + ) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index fb4e381d..211f8169 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from typing import Literal from torch import nn @@ -30,11 +29,7 @@ resnet101, resnet152, ) -from torch_uncertainty.routines.classification import ( - # ClassificationEnsemble, - # ClassificationSingle, - ClassificationRoutine, -) +from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget @@ -80,7 +75,6 @@ def __init__( num_classes: int, in_channels: int, loss: type[nn.Module], - # optimization_procedure: Any, version: Literal[ "vanilla", "mc-dropout", @@ -102,7 +96,7 @@ def __init__( cutmix_alpha: float = 0, groups: int = 1, scale: float | None = None, - alpha: float | None = None, + alpha: int | None = None, gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, @@ -111,7 +105,8 @@ def __init__( use_mi: bool = False, use_variation_ratio: bool = False, log_plots: bool = False, - calibration_set: Callable | None = None, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, evaluate_ood: bool = False, pretrained: bool = False, ) -> None: @@ -151,21 +146,25 @@ def __init__( Only used if :attr:`version` is either ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - mixtype (str, optional): _description_ - mixmode (str, optional): _description_ - dist_sim (str, optional): _description_ - kernel_tau_max (float, optional): _description_ - kernel_tau_std (float, optional): _description_ - mixup_alpha (float, optional): _description_ - cutmix_alpha (float, optional): _description_ - groups (int, optional): Number of groups in convolutions. Defaults to - ``1``. - scale (float, optional): Expansion factor affecting the width of the - estimators. Only used if :attr:`version` is ``"masked"``. Defaults - to ``None``. - alpha (float, optional): Expansion factor affecting the width of the - estimators. Only used if :attr:`version` is ``"packed"``. Defaults - to ``None``. + mixtype (str, optional): Mixup type. Defaults to ``"erm"``. + mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. + dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. + kernel_tau_max (float, optional): Maximum value for the kernel tau. + Defaults to ``1.0``. + kernel_tau_std (float, optional): Standard deviation for the kernel + tau. Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults + to ``0``. + cutmix_alpha (float, optional): Alpha parameter for CutMix. + Defaults to ``0``. + groups (int, optional): Number of groups in convolutions. Defaults + to ``1``. + scale (float, optional): Expansion factor affecting the width of + the estimators. Only used if :attr:`version` is ``"masked"``. + Defaults to ``None``. + alpha (int, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"packed"``. + Defaults to ``None``. gamma (int, optional): Number of groups within each estimator. Only used if :attr:`version` is ``"packed"`` and scales with :attr:`groups`. Defaults to ``1``. @@ -184,10 +183,12 @@ def __init__( variation ratio as the OOD criterion or not. Defaults to ``False``. log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. + save_in_csv (bool, optional): Indicates whether to save the results in + a csv file or not. Defaults to ``False``. calibration_set (Callable, optional): Calibration set. Defaults to ``None``. - evaluate_ood (bool, optional): Indicates whether to evaluate the OOD - detection or not. Defaults to ``False``. + evaluate_ood (bool, optional): Indicates whether to evaluate the + OOD detection or not. Defaults to ``False``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. Defaults to ``False``. @@ -262,20 +263,6 @@ def __init__( ) model = self.versions[version][self.archs.index(arch)](**params) - # kwargs.update(params) - # kwargs.update({"version": version, "arch": arch}) - # routine specific parameters - # if version in cls.single: - # return ClassificationSingle( - # model=model, - # loss=loss, - # # optimization_procedure=optimization_procedure, - # format_batch_fn=format_batch_fn, - # use_entropy=use_entropy, - # use_logits=use_logits, - # # **kwargs, - # ) - # # version in cls.ensemble super().__init__( num_classes=num_classes, model=model, @@ -295,55 +282,6 @@ def __init__( use_mi=use_mi, use_variation_ratio=use_variation_ratio, log_plots=log_plots, + save_in_csv=save_in_csv, calibration_set=calibration_set, ) - - self.save_hyperparameters( - ignore=[ - "log_plots", - ] - ) - - # @classmethod - # def load_from_checkpoint( - # cls, - # checkpoint_path: str | Path, - # hparams_file: str | Path, - # **kwargs, - # ) -> LightningModule: # coverage: ignore - # if hparams_file is not None: - # extension = str(hparams_file).split(".")[-1] - # if extension.lower() == "csv": - # hparams = load_hparams_from_tags_csv(hparams_file) - # elif extension.lower() in ("yml", "yaml"): - # hparams = load_hparams_from_yaml(hparams_file) - # else: - # raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") - - # hparams.update(kwargs) - # checkpoint = torch.load(checkpoint_path) - # obj = cls(**hparams) - # obj.load_state_dict(checkpoint["state_dict"]) - # return obj - - # @classmethod - # def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - # parser = ClassificationEnsemble.add_model_specific_args(parser) - # parser = add_resnet_specific_args(parser) - # parser = add_packed_specific_args(parser) - # parser = add_masked_specific_args(parser) - # parser = add_mimo_specific_args(parser) - # parser.add_argument( - # "--version", - # type=str, - # choices=cls.versions.keys(), - # default="vanilla", - # help=f"Variation of ResNet. Choose among: {cls.versions.keys()}", - # ) - # parser.add_argument( - # "--pretrained", - # dest="pretrained", - # action=BooleanOptionalAction, - # default=False, - # ) - # return parser diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 7f6adea7..9637dc64 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -1,223 +1,197 @@ -# from argparse import ArgumentParser -# from pathlib import Path -# from typing import Any, Literal - -# import torch -# from pytorch_lightning import LightningModule -# from pytorch_lightning.core.saving import ( -# load_hparams_from_tags_csv, -# load_hparams_from_yaml, -# ) -# from torch import nn - -# from torch_uncertainty.baselines.utils.parser_addons import ( -# add_packed_specific_args, -# add_vgg_specific_args, -# ) -# from torch_uncertainty.models.vgg import ( -# packed_vgg11, -# packed_vgg13, -# packed_vgg16, -# packed_vgg19, -# vgg11, -# vgg13, -# vgg16, -# vgg19, -# ) -# from torch_uncertainty.routines.classification import ( -# ClassificationEnsemble, -# ClassificationSingle, -# ) -# from torch_uncertainty.transforms import RepeatTarget - - -# class VGG: -# single = ["vanilla"] -# ensemble = ["mc-dropout", "packed"] -# versions = { -# "vanilla": [vgg11, vgg13, vgg16, vgg19], -# "mc-dropout": [vgg11, vgg13, vgg16, vgg19], -# "packed": [ -# packed_vgg11, -# packed_vgg13, -# packed_vgg16, -# packed_vgg19, -# ], -# } -# archs = [11, 13, 16, 19] - -# def __new__( -# cls, -# num_classes: int, -# in_channels: int, -# loss: type[nn.Module], -# optimization_procedure: Any, -# version: Literal["vanilla", "mc-dropout", "packed"], -# arch: int, -# num_estimators: int | None = None, -# dropout_rate: float = 0.0, -# style: str = "imagenet", -# groups: int = 1, -# alpha: float | None = None, -# gamma: int = 1, -# use_entropy: bool = False, -# use_logits: bool = False, -# use_mi: bool = False, -# use_variation_ratio: bool = False, -# **kwargs, -# ) -> LightningModule: -# r"""VGG backbone baseline for classification providing support for -# various versions and architectures. - -# Args: -# num_classes (int): Number of classes to predict. -# in_channels (int): Number of input channels. -# loss (nn.Module): Training loss. -# optimization_procedure (Any): Optimization procedure, corresponds to -# what expect the `LightningModule.configure_optimizers() -# `_ -# method. -# version (str): -# Determines which VGG version to use: - -# - ``"vanilla"``: original VGG -# - ``"mc-dropout"``: Monte Carlo Dropout VGG -# - ``"packed"``: Packed-Ensembles VGG - -# arch (int): -# Determines which VGG architecture to use: - -# - ``11``: VGG-11 -# - ``13``: VGG-13 -# - ``16``: VGG-16 -# - ``19``: VGG-19 - -# style (str, optional): Which VGG style to use. Defaults to -# ``imagenet``. -# num_estimators (int, optional): Number of estimators in the ensemble. -# Only used if :attr:`version` is either ``"packed"``, ``"batched"`` -# or ``"masked"`` Defaults to ``None``. -# dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. -# groups (int, optional): Number of groups in convolutions. Defaults to -# ``1``. -# alpha (float, optional): Expansion factor affecting the width of the -# estimators. Only used if :attr:`version` is ``"packed"``. Defaults -# to ``None``. -# gamma (int, optional): Number of groups within each estimator. Only -# used if :attr:`version` is ``"packed"`` and scales with -# :attr:`groups`. Defaults to ``1s``. -# use_entropy (bool, optional): Indicates whether to use the entropy -# values as the OOD criterion or not. Defaults to ``False``. -# use_logits (bool, optional): Indicates whether to use the logits as the -# OOD criterion or not. Defaults to ``False``. -# use_mi (bool, optional): Indicates whether to use the mutual -# information as the OOD criterion or not. Defaults to ``False``. -# use_variation_ratio (bool, optional): Indicates whether to use the -# variation ratio as the OOD criterion or not. Defaults to ``False``. -# **kwargs: Additional arguments to be passed to the -# Raises: -# ValueError: If :attr:`version` is not either ``"vanilla"``, -# ``"packed"``, ``"batched"`` or ``"masked"``. - -# Returns: -# LightningModule: VGG baseline ready for training and evaluation. -# """ -# params = { -# "in_channels": in_channels, -# "num_classes": num_classes, -# "style": style, -# "groups": groups, -# } - -# if version not in cls.versions: -# raise ValueError(f"Unknown version: {version}") - -# format_batch_fn = nn.Identity() - -# if version == "vanilla": -# params.update( -# { -# "dropout_rate": dropout_rate, -# } -# ) -# elif version == "mc-dropout": -# params.update( -# { -# "dropout_rate": dropout_rate, -# "num_estimators": num_estimators, -# } -# ) -# elif version == "packed": -# params.update( -# { -# "num_estimators": num_estimators, -# "alpha": alpha, -# "style": style, -# "gamma": gamma, -# } -# ) -# format_batch_fn = RepeatTarget(num_repeats=num_estimators) - -# model = cls.versions[version][cls.archs.index(arch)](**params) -# kwargs.update(params) -# # routine specific parameters -# if version in cls.single: -# return ClassificationSingle( -# model=model, -# loss=loss, -# optimization_procedure=optimization_procedure, -# format_batch_fn=format_batch_fn, -# use_entropy=use_entropy, -# use_logits=use_logits, -# **kwargs, -# ) -# # version in cls.ensemble -# return ClassificationEnsemble( -# model=model, -# loss=loss, -# optimization_procedure=optimization_procedure, -# format_batch_fn=format_batch_fn, -# use_entropy=use_entropy, -# use_logits=use_logits, -# use_mi=use_mi, -# use_variation_ratio=use_variation_ratio, -# **kwargs, -# ) - -# @classmethod -# def load_from_checkpoint( -# cls, -# checkpoint_path: str | Path, -# hparams_file: str | Path, -# **kwargs, -# ) -> LightningModule: # coverage: ignore -# if hparams_file is not None: -# extension = str(hparams_file).split(".")[-1] -# if extension.lower() == "csv": -# hparams = load_hparams_from_tags_csv(hparams_file) -# elif extension.lower() in ("yml", "yaml"): -# hparams = load_hparams_from_yaml(hparams_file) -# else: -# raise ValueError( -# ".csv, .yml or .yaml is required for `hparams_file`" -# ) - -# hparams.update(kwargs) -# checkpoint = torch.load(checkpoint_path) -# obj = cls(**hparams) -# obj.load_state_dict(checkpoint["state_dict"]) -# return obj - -# @classmethod -# def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: -# parser = ClassificationEnsemble.add_model_specific_args(parser) -# parser = add_vgg_specific_args(parser) -# parser = add_packed_specific_args(parser) -# parser.add_argument( -# "--version", -# type=str, -# choices=cls.versions.keys(), -# default="vanilla", -# help=f"Variation of VGG. Choose among: {cls.versions.keys()}", -# ) -# return parser +from typing import Literal + +from torch import nn +from torch.nn.modules import Module + +from torch_uncertainty.models.vgg import ( + packed_vgg11, + packed_vgg13, + packed_vgg16, + packed_vgg19, + vgg11, + vgg13, + vgg16, + vgg19, +) +from torch_uncertainty.routines.classification import ClassificationRoutine +from torch_uncertainty.transforms import RepeatTarget + + +class VGG(ClassificationRoutine): + single = ["vanilla"] + ensemble = ["mc-dropout", "packed"] + versions = { + "vanilla": [vgg11, vgg13, vgg16, vgg19], + "mc-dropout": [vgg11, vgg13, vgg16, vgg19], + "packed": [ + packed_vgg11, + packed_vgg13, + packed_vgg16, + packed_vgg19, + ], + } + archs = [11, 13, 16, 19] + + def __init__( + self, + num_classes: int, + in_channels: int, + loss: type[Module], + version: Literal["vanilla", "mc-dropout", "packed"], + arch: int, + style: str = "imagenet", + num_estimators: int = 1, + dropout_rate: float = 0.0, + last_layer_dropout: bool = False, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, + groups: int = 1, + alpha: int | None = None, + gamma: int = 1, + use_entropy: bool = False, + use_logits: bool = False, + use_mi: bool = False, + use_variation_ratio: bool = False, + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + evaluate_ood: bool = False, + ) -> None: + r"""VGG backbone baseline for classification providing support for + various versions and architectures. + + Args: + num_classes (int): Number of classes to predict. + in_channels (int): Number of input channels. + loss (nn.Module): Training loss. + version (str): + Determines which VGG version to use: + + - ``"vanilla"``: original VGG + - ``"mc-dropout"``: Monte Carlo Dropout VGG + - ``"packed"``: Packed-Ensembles VGG + + arch (int): + Determines which VGG architecture to use: + + - ``11``: VGG-11 + - ``13``: VGG-13 + - ``16``: VGG-16 + - ``19``: VGG-19 + + style (str, optional): Which VGG style to use. Defaults to + ``imagenet``. + num_estimators (int, optional): Number of estimators in the ensemble. + Only used if :attr:`version` is either ``"packed"``, ``"batched"`` + or ``"masked"`` Defaults to ``None``. + dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. + last_layer_dropout (bool, optional): Indicates whether to apply dropout + to the last layer or not. Defaults to ``False``. + mixtype (str, optional): Mixup type. Defaults to ``"erm"``. + mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. + dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. + kernel_tau_max (float, optional): Maximum value for the kernel tau. + Defaults to ``1.0``. + kernel_tau_std (float, optional): Standard deviation for the kernel + tau. Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults + to ``0``. + cutmix_alpha (float, optional): Alpha parameter for CutMix. + Defaults to ``0``. + groups (int, optional): Number of groups in convolutions. Defaults to + ``1``. + alpha (float, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"packed"``. Defaults + to ``None``. + gamma (int, optional): Number of groups within each estimator. Only + used if :attr:`version` is ``"packed"`` and scales with + :attr:`groups`. Defaults to ``1s``. + use_entropy (bool, optional): Indicates whether to use the entropy + values as the OOD criterion or not. Defaults to ``False``. + use_logits (bool, optional): Indicates whether to use the logits as the + OOD criterion or not. Defaults to ``False``. + use_mi (bool, optional): Indicates whether to use the mutual + information as the OOD criterion or not. Defaults to ``False``. + use_variation_ratio (bool, optional): Indicates whether to use the + variation ratio as the OOD criterion or not. Defaults to ``False``. + log_plots (bool, optional): Indicates whether to log the plots or not. + Defaults to ``False``. + save_in_csv (bool, optional): Indicates whether to save the results in + a csv file or not. Defaults to ``False``. + calibration_set (Callable, optional): Calibration set. Defaults to + ``None``. + evaluate_ood (bool, optional): Indicates whether to evaluate the + OOD detection or not. Defaults to ``False``. + + Raises: + ValueError: If :attr:`version` is not either ``"vanilla"``, + ``"packed"``, ``"batched"`` or ``"masked"``. + + Returns: + LightningModule: VGG baseline ready for training and evaluation. + """ + params = { + "in_channels": in_channels, + "num_classes": num_classes, + "style": style, + "groups": groups, + } + + if version not in self.versions: + raise ValueError(f"Unknown version: {version}") + + format_batch_fn = nn.Identity() + + if version == "vanilla": + params.update( + { + "dropout_rate": dropout_rate, + } + ) + elif version == "mc-dropout": + params.update( + { + "dropout_rate": dropout_rate, + "num_estimators": num_estimators, + "last_layer_dropout": last_layer_dropout, + } + ) + elif version == "packed": + params.update( + { + "num_estimators": num_estimators, + "alpha": alpha, + "style": style, + "gamma": gamma, + } + ) + format_batch_fn = RepeatTarget(num_repeats=num_estimators) + + model = self.versions[version][self.archs.index(arch)](**params) + super().__init__( + num_classes=num_classes, + model=model, + loss=loss, + num_estimators=num_estimators, + format_batch_fn=format_batch_fn, + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + evaluate_ood=evaluate_ood, + use_entropy=use_entropy, + use_logits=use_logits, + use_mi=use_mi, + use_variation_ratio=use_variation_ratio, + log_plots=log_plots, + save_in_csv=save_in_csv, + calibration_set=calibration_set, + ) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 37ad16f4..4d7f0ab9 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -1,267 +1,227 @@ -# from argparse import ArgumentParser, BooleanOptionalAction -# from pathlib import Path -# from typing import Any, Literal - -# import torch -# from pytorch_lightning import LightningModule -# from pytorch_lightning.core.saving import ( -# load_hparams_from_tags_csv, -# load_hparams_from_yaml, -# ) -# from torch import nn - -# from torch_uncertainty.baselines.utils.parser_addons import ( -# add_masked_specific_args, -# add_mimo_specific_args, -# add_packed_specific_args, -# add_wideresnet_specific_args, -# ) -# from torch_uncertainty.models.wideresnet import ( -# batched_wideresnet28x10, -# masked_wideresnet28x10, -# mimo_wideresnet28x10, -# packed_wideresnet28x10, -# wideresnet28x10, -# ) -# from torch_uncertainty.routines.classification import ( -# ClassificationEnsemble, -# ClassificationSingle, -# ) -# from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget - - -# class WideResNet: -# single = ["vanilla"] -# ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] -# versions = { -# "vanilla": [wideresnet28x10], -# "mc-dropout": [wideresnet28x10], -# "packed": [packed_wideresnet28x10], -# "batched": [batched_wideresnet28x10], -# "masked": [masked_wideresnet28x10], -# "mimo": [mimo_wideresnet28x10], -# } - -# def __new__( -# cls, -# num_classes: int, -# in_channels: int, -# loss: type[nn.Module], -# optimization_procedure: Any, -# version: Literal[ -# "vanilla", "mc-dropout", "packed", "batched", "masked", "mimo" -# ], -# style: str = "imagenet", -# num_estimators: int | None = None, -# dropout_rate: float = 0.0, -# groups: int | None = None, -# scale: float | None = None, -# alpha: int | None = None, -# gamma: int | None = None, -# rho: float = 1.0, -# batch_repeat: int = 1, -# use_entropy: bool = False, -# use_logits: bool = False, -# use_mi: bool = False, -# use_variation_ratio: bool = False, -# # pretrained: bool = False, -# **kwargs, -# ) -> LightningModule: -# r"""Wide-ResNet28x10 backbone baseline for classification providing support -# for various versions. - -# Args: -# num_classes (int): Number of classes to predict. -# in_channels (int): Number of input channels. -# loss (nn.Module): Training loss. -# optimization_procedure (Any): Optimization procedure, corresponds to -# what expect the `LightningModule.configure_optimizers() -# `_ -# method. -# version (str): -# Determines which Wide-ResNet version to use: - -# - ``"vanilla"``: original Wide-ResNet -# - ``"mc-dropout"``: Monte Carlo Dropout Wide-ResNet -# - ``"packed"``: Packed-Ensembles Wide-ResNet -# - ``"batched"``: BatchEnsemble Wide-ResNet -# - ``"masked"``: Masksemble Wide-ResNet -# - ``"mimo"``: MIMO Wide-ResNet - -# style (bool, optional): (str, optional): Which ResNet style to use. -# Defaults to ``imagenet``. -# num_estimators (int, optional): Number of estimators in the ensemble. -# Only used if :attr:`version` is either ``"packed"``, ``"batched"`` -# or ``"masked"`` Defaults to ``None``. -# dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. -# groups (int, optional): Number of groups in convolutions. Defaults to -# ``1``. -# scale (float, optional): Expansion factor affecting the width of the -# estimators. Only used if :attr:`version` is ``"masked"``. Defaults -# to ``None``. -# alpha (float, optional): Expansion factor affecting the width of the -# estimators. Only used if :attr:`version` is ``"packed"``. Defaults -# to ``None``. -# gamma (int, optional): Number of groups within each estimator. Only -# used if :attr:`version` is ``"packed"`` and scales with -# :attr:`groups`. Defaults to ``1s``. -# rho (float, optional): Probability that all estimators share the same -# input. Only used if :attr:`version` is ``"mimo"``. Defaults to -# ``1``. -# batch_repeat (int, optional): Number of times to repeat the batch. Only -# used if :attr:`version` is ``"mimo"``. Defaults to ``1``. -# use_entropy (bool, optional): Indicates whether to use the entropy -# values as the OOD criterion or not. Defaults to ``False``. -# use_logits (bool, optional): Indicates whether to use the logits as the -# OOD criterion or not. Defaults to ``False``. -# use_mi (bool, optional): Indicates whether to use the mutual -# information as the OOD criterion or not. Defaults to ``False``. -# use_variation_ratio (bool, optional): Indicates whether to use the -# variation ratio as the OOD criterion or not. Defaults to ``False``. -# pretrained (bool, optional): Indicates whether to use the pretrained -# weights or not. Only used if :attr:`version` is ``"packed"``. -# Defaults to ``False``. -# **kwargs: Additional arguments. - -# Raises: -# ValueError: If :attr:`version` is not either ``"vanilla"``, -# ``"packed"``, ``"batched"`` or ``"masked"``. - -# Returns: -# LightningModule: Wide-ResNet baseline ready for training and -# evaluation. -# """ -# params = { -# "in_channels": in_channels, -# "num_classes": num_classes, -# "style": style, -# "groups": groups, -# } - -# format_batch_fn = nn.Identity() - -# if version not in cls.versions: -# raise ValueError(f"Unknown version: {version}") - -# # version specific params -# if version == "vanilla": -# params.update( -# { -# "dropout_rate": dropout_rate, -# } -# ) -# elif version == "mc-dropout": -# params.update( -# { -# "dropout_rate": dropout_rate, -# "num_estimators": num_estimators, -# } -# ) -# elif version == "packed": -# params.update( -# { -# "num_estimators": num_estimators, -# "alpha": alpha, -# "gamma": gamma, -# } -# ) -# format_batch_fn = RepeatTarget(num_repeats=num_estimators) -# elif version == "batched": -# params.update( -# { -# "num_estimators": num_estimators, -# } -# ) -# format_batch_fn = RepeatTarget(num_repeats=num_estimators) -# elif version == "masked": -# params.update( -# { -# "num_estimators": num_estimators, -# "scale": scale, -# } -# ) -# format_batch_fn = RepeatTarget(num_repeats=num_estimators) -# elif version == "mimo": -# params.update( -# { -# "num_estimators": num_estimators, -# } -# ) -# format_batch_fn = MIMOBatchFormat( -# num_estimators=num_estimators, -# rho=rho, -# batch_repeat=batch_repeat, -# ) - -# model = cls.versions[version][0](**params) -# kwargs.update(params) -# # routine specific parameters -# if version in cls.single: -# return ClassificationSingle( -# model=model, -# loss=loss, -# optimization_procedure=optimization_procedure, -# format_batch_fn=format_batch_fn, -# use_entropy=use_entropy, -# use_logits=use_logits, -# **kwargs, -# ) -# # version in cls.ensemble -# return ClassificationEnsemble( -# model=model, -# loss=loss, -# optimization_procedure=optimization_procedure, -# format_batch_fn=format_batch_fn, -# use_entropy=use_entropy, -# use_logits=use_logits, -# use_mi=use_mi, -# use_variation_ratio=use_variation_ratio, -# **kwargs, -# ) - -# @classmethod -# def load_from_checkpoint( -# cls, -# checkpoint_path: str | Path, -# hparams_file: str | Path, -# **kwargs, -# ) -> LightningModule: # coverage: ignore -# if hparams_file is not None: -# extension = str(hparams_file).split(".")[-1] -# if extension.lower() == "csv": -# hparams = load_hparams_from_tags_csv(hparams_file) -# elif extension.lower() in ("yml", "yaml"): -# hparams = load_hparams_from_yaml(hparams_file) -# else: -# raise ValueError( -# ".csv, .yml or .yaml is required for `hparams_file`" -# ) - -# hparams.update(kwargs) -# checkpoint = torch.load(checkpoint_path) -# obj = cls(**hparams) -# obj.load_state_dict(checkpoint["state_dict"]) -# return obj - -# @classmethod -# def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: -# parser = ClassificationEnsemble.add_model_specific_args(parser) -# parser = add_wideresnet_specific_args(parser) -# parser = add_packed_specific_args(parser) -# parser = add_masked_specific_args(parser) -# parser = add_mimo_specific_args(parser) -# parser.add_argument( -# "--version", -# type=str, -# choices=cls.versions.keys(), -# default="vanilla", -# help=f"Variation of WideResNet. Choose among: {cls.versions.keys()}", -# ) -# parser.add_argument( -# "--pretrained", -# dest="pretrained", -# action=BooleanOptionalAction, -# default=False, -# ) - -# return parser +from typing import Literal + +from torch import nn + +from torch_uncertainty.models.wideresnet import ( + batched_wideresnet28x10, + masked_wideresnet28x10, + mimo_wideresnet28x10, + packed_wideresnet28x10, + wideresnet28x10, +) +from torch_uncertainty.routines.classification import ( + ClassificationRoutine, +) +from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget + + +class WideResNet(ClassificationRoutine): + single = ["vanilla"] + ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] + versions = { + "vanilla": [wideresnet28x10], + "mc-dropout": [wideresnet28x10], + "packed": [packed_wideresnet28x10], + "batched": [batched_wideresnet28x10], + "masked": [masked_wideresnet28x10], + "mimo": [mimo_wideresnet28x10], + } + + def __init__( + self, + num_classes: int, + in_channels: int, + loss: type[nn.Module], + version: Literal[ + "vanilla", "mc-dropout", "packed", "batched", "masked", "mimo" + ], + style: str = "imagenet", + num_estimators: int = 1, + dropout_rate: float = 0.0, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, + groups: int = 1, + scale: float | None = None, + alpha: int | None = None, + gamma: int = 1, + rho: float = 1.0, + batch_repeat: int = 1, + use_entropy: bool = False, + use_logits: bool = False, + use_mi: bool = False, + use_variation_ratio: bool = False, + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + evaluate_ood: bool = False, + # pretrained: bool = False, + ) -> None: + r"""Wide-ResNet28x10 backbone baseline for classification providing support + for various versions. + + Args: + num_classes (int): Number of classes to predict. + in_channels (int): Number of input channels. + loss (nn.Module): Training loss. + optimization_procedure (Any): Optimization procedure, corresponds to + what expect the `LightningModule.configure_optimizers() + `_ + method. + version (str): + Determines which Wide-ResNet version to use: + + - ``"vanilla"``: original Wide-ResNet + - ``"mc-dropout"``: Monte Carlo Dropout Wide-ResNet + - ``"packed"``: Packed-Ensembles Wide-ResNet + - ``"batched"``: BatchEnsemble Wide-ResNet + - ``"masked"``: Masksemble Wide-ResNet + - ``"mimo"``: MIMO Wide-ResNet + + style (bool, optional): (str, optional): Which ResNet style to use. + Defaults to ``imagenet``. + num_estimators (int, optional): Number of estimators in the ensemble. + Only used if :attr:`version` is either ``"packed"``, ``"batched"`` + or ``"masked"`` Defaults to ``None``. + dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. + mixtype (str, optional): Mixup type. Defaults to ``"erm"``. + mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. + dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. + kernel_tau_max (float, optional): Maximum value for the kernel tau. + Defaults to ``1.0``. + kernel_tau_std (float, optional): Standard deviation for the kernel + tau. Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults + to ``0``. + cutmix_alpha (float, optional): Alpha parameter for CutMix. + Defaults to ``0``. + groups (int, optional): Number of groups in convolutions. Defaults to + ``1``. + scale (float, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"masked"``. Defaults + to ``None``. + alpha (float, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"packed"``. Defaults + to ``None``. + gamma (int, optional): Number of groups within each estimator. Only + used if :attr:`version` is ``"packed"`` and scales with + :attr:`groups`. Defaults to ``1s``. + rho (float, optional): Probability that all estimators share the same + input. Only used if :attr:`version` is ``"mimo"``. Defaults to + ``1``. + batch_repeat (int, optional): Number of times to repeat the batch. Only + used if :attr:`version` is ``"mimo"``. Defaults to ``1``. + use_entropy (bool, optional): Indicates whether to use the entropy + values as the OOD criterion or not. Defaults to ``False``. + use_logits (bool, optional): Indicates whether to use the logits as the + OOD criterion or not. Defaults to ``False``. + use_mi (bool, optional): Indicates whether to use the mutual + information as the OOD criterion or not. Defaults to ``False``. + use_variation_ratio (bool, optional): Indicates whether to use the + variation ratio as the OOD criterion or not. Defaults to ``False``. + log_plots (bool, optional): Indicates whether to log the plots or not. + Defaults to ``False``. + save_in_csv (bool, optional): Indicates whether to save the results in + a csv file or not. Defaults to ``False``. + calibration_set (Callable, optional): Calibration set. Defaults to + ``None``. + evaluate_ood (bool, optional): Indicates whether to evaluate the + OOD detection or not. Defaults to ``False``. + + Raises: + ValueError: If :attr:`version` is not either ``"vanilla"``, + ``"packed"``, ``"batched"`` or ``"masked"``. + + Returns: + LightningModule: Wide-ResNet baseline ready for training and + evaluation. + """ + params = { + "in_channels": in_channels, + "num_classes": num_classes, + "style": style, + "groups": groups, + } + + format_batch_fn = nn.Identity() + + if version not in self.versions: + raise ValueError(f"Unknown version: {version}") + + # version specific params + if version == "vanilla": + params.update( + { + "dropout_rate": dropout_rate, + } + ) + elif version == "mc-dropout": + params.update( + { + "dropout_rate": dropout_rate, + "num_estimators": num_estimators, + } + ) + elif version == "packed": + params.update( + { + "num_estimators": num_estimators, + "alpha": alpha, + "gamma": gamma, + } + ) + format_batch_fn = RepeatTarget(num_repeats=num_estimators) + elif version == "batched": + params.update( + { + "num_estimators": num_estimators, + } + ) + format_batch_fn = RepeatTarget(num_repeats=num_estimators) + elif version == "masked": + params.update( + { + "num_estimators": num_estimators, + "scale": scale, + } + ) + format_batch_fn = RepeatTarget(num_repeats=num_estimators) + elif version == "mimo": + params.update( + { + "num_estimators": num_estimators, + } + ) + format_batch_fn = MIMOBatchFormat( + num_estimators=num_estimators, + rho=rho, + batch_repeat=batch_repeat, + ) + + model = self.versions[version][0](**params) + super().__init__( + num_classes=num_classes, + model=model, + loss=loss, + num_estimators=num_estimators, + format_batch_fn=format_batch_fn, + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + evaluate_ood=evaluate_ood, + use_entropy=use_entropy, + use_logits=use_logits, + use_mi=use_mi, + use_variation_ratio=use_variation_ratio, + log_plots=log_plots, + save_in_csv=save_in_csv, + calibration_set=calibration_set, + ) diff --git a/torch_uncertainty/baselines/deep_ensembles.py b/torch_uncertainty/baselines/deep_ensembles.py deleted file mode 100644 index a7ddbaa1..00000000 --- a/torch_uncertainty/baselines/deep_ensembles.py +++ /dev/null @@ -1,110 +0,0 @@ -from argparse import ArgumentParser -from pathlib import Path -from typing import Literal - -from pytorch_lightning import LightningModule - -from torch_uncertainty.models import deep_ensembles -from torch_uncertainty.routines.classification import ClassificationEnsemble -from torch_uncertainty.routines.regression import RegressionEnsemble -from torch_uncertainty.utils import get_version - -from .classification import VGG, ResNet, WideResNet -from .regression import MLP - - -class DeepEnsembles: - backbones = { - "mlp": MLP, - "resnet": ResNet, - "vgg": VGG, - "wideresnet": WideResNet, - } - - def __new__( - cls, - task: Literal["classification", "regression"], - log_path: str | Path, - checkpoint_ids: list[int], - backbone: Literal["mlp", "resnet", "vgg", "wideresnet"], - # num_estimators: int, - in_channels: int | None = None, - num_classes: int | None = None, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - **kwargs, - ) -> LightningModule: - if isinstance(log_path, str): - log_path = Path(log_path) - - backbone_cls = cls.backbones[backbone] - - models = [] - for version in checkpoint_ids: # coverage: ignore - ckpt_file, hparams_file = get_version( - root=log_path, version=version - ) - trained_model = backbone_cls.load_from_checkpoint( - checkpoint_path=ckpt_file, - hparams_file=hparams_file, - loss=None, - optimization_procedure=None, - ).eval() - models.append(trained_model.model) - - de = deep_ensembles(models=models) - - if task == "classification": - return ClassificationEnsemble( - in_channels=in_channels, - num_classes=num_classes, - model=de, - loss=None, - optimization_procedure=None, - num_estimators=de.num_estimators, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - ) - # task == "regression": - return RegressionEnsemble( - model=de, - loss=None, - optimization_procedure=None, - dist_estimation=2, - num_estimators=de.num_estimators, - mode="mean", - ) - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = ClassificationEnsemble.add_model_specific_args(parser) - parser.add_argument( - "--task", - type=str, - choices=["classification", "regression"], - help="Task to be performed", - ) - parser.add_argument( - "--backbone", - type=str, - choices=cls.backbones.keys(), - help="Backbone architecture", - required=True, - ) - parser.add_argument( - "--versions", - type=int, - nargs="+", - help="Versions of the model to be ensembled", - ) - parser.add_argument( - "--log_path", - type=str, - help="Root directory of the models", - required=True, - ) - return parser diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index 2cf268c8..755a84ef 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -1,6 +1,5 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import numpy as np import torchvision.transforms as T @@ -26,8 +25,8 @@ class CIFAR10DataModule(AbstractDataModule): def __init__( self, root: str | Path, - evaluate_ood: bool, batch_size: int, + evaluate_ood: bool = False, val_split: float = 0.0, num_workers: int = 1, cutout: int | None = None, @@ -222,21 +221,3 @@ def _get_train_targets(self) -> ArrayLike: if self.val_split: return np.array(self.train.dataset.targets)[self.train.indices] return np.array(self.train.targets) - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for CIFAR10 - p.add_argument("--cutout", type=int, default=0) - p.add_argument("--auto_augment", type=str) - p.add_argument("--test_alt", choices=["c", "h"], default=None) - p.add_argument( - "--severity", dest="corruption_severity", type=int, default=None - ) - p.add_argument("--evaluate_ood", action="store_true") - return parent_parser diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index 631b671b..080a16ab 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -1,6 +1,5 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import numpy as np import torch @@ -27,8 +26,8 @@ class CIFAR100DataModule(AbstractDataModule): def __init__( self, root: str | Path, - evaluate_ood: bool, batch_size: int, + evaluate_ood: bool = False, val_split: float = 0.0, num_workers: int = 1, cutout: int | None = None, @@ -224,22 +223,3 @@ def _get_train_targets(self) -> ArrayLike: if self.val_split: return np.array(self.train.dataset.targets)[self.train.indices] return np.array(self.train.targets) - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for CIFAR100 - p.add_argument("--cutout", type=int, default=0) - p.add_argument("--randaugment", dest="randaugment", action="store_true") - p.add_argument("--auto_augment", type=str) - p.add_argument("--test_alt", choices=["c"], default=None) - p.add_argument( - "--severity", dest="corruption_severity", type=int, default=1 - ) - p.add_argument("--evaluate_ood", action="store_true") - return parent_parser diff --git a/torch_uncertainty/lightning_cli.py b/torch_uncertainty/lightning_cli.py deleted file mode 100644 index 242d54f5..00000000 --- a/torch_uncertainty/lightning_cli.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path - -from lightning.fabric.utilities.cloud_io import get_filesystem -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.cli import SaveConfigCallback -from typing_extensions import override - - -class MySaveConfigCallback(SaveConfigCallback): - @override - def setup( - self, trainer: Trainer, pl_module: LightningModule, stage: str - ) -> None: - if self.already_saved: - return - - if self.save_to_log_dir and stage == "fit": - log_dir = trainer.log_dir # this broadcasts the directory - assert log_dir is not None - config_path = Path(log_dir) / self.config_filename - fs = get_filesystem(log_dir) - - if not self.overwrite: - # check if the file exists on rank 0 - file_exists = ( - fs.isfile(config_path) if trainer.is_global_zero else False - ) - # broadcast whether to fail to all ranks - file_exists = trainer.strategy.broadcast(file_exists) - if file_exists: - # TODO: complete error description - raise RuntimeError("TODO") - - if trainer.is_global_zero: - fs.makedirs(log_dir, exist_ok=True) - self.parser.save( - self.config, - config_path, - skip_none=False, - overwrite=self.overwrite, - multifile=self.multifile, - ) - if trainer.is_global_zero: - self.save_config(trainer, pl_module, stage) - self.already_saved = True - - self.already_saved = trainer.strategy.broadcast(self.already_saved) diff --git a/torch_uncertainty/metrics/calibration.py b/torch_uncertainty/metrics/calibration.py index 55fc443c..c32787f4 100644 --- a/torch_uncertainty/metrics/calibration.py +++ b/torch_uncertainty/metrics/calibration.py @@ -76,8 +76,8 @@ class MulticlassCE(MulticlassCalibrationError): # noqa: N818 def plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: fig, ax = plt.subplots() if ax is None else (None, ax) - conf = dim_zero_cat(self.confidences) - acc = dim_zero_cat(self.accuracies) + conf = dim_zero_cat(self.confidences).cpu() + acc = dim_zero_cat(self.accuracies).cpu() bin_width = 1 / self.n_bins @@ -98,9 +98,9 @@ def plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: acc.unsqueeze(1) * torch.nn.functional.one_hot(inverse).float(), 0, ) - / (val_oh.T @ counts + 1e-6).float() + / (val_oh.T.float() @ counts.float() + 1e-6) ) - counts_all = (val_oh.T @ counts).float() + counts_all = val_oh.T.float() @ counts.float() total = torch.sum(counts) plt.rc("axes", axisbelow=True) diff --git a/torch_uncertainty/models/wideresnet/mimo.py b/torch_uncertainty/models/wideresnet/mimo.py index 5906e885..e9881472 100644 --- a/torch_uncertainty/models/wideresnet/mimo.py +++ b/torch_uncertainty/models/wideresnet/mimo.py @@ -36,7 +36,6 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.training: x = x.repeat(self.num_estimators, 1, 1, 1) - out = rearrange(x, "(m b) c h w -> b (m c) h w", m=self.num_estimators) out = super().forward(out) return rearrange(out, "b (m d) -> (m b) d", m=self.num_estimators) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 8f72e5d7..bb68f3b7 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,20 +1,14 @@ -# from argparse import ArgumentParser, Namespace from collections.abc import Callable from functools import partial +from pathlib import Path +from typing import Literal import torch import torch.nn.functional as F - -# import pytorch_lightning as pl +from einops import rearrange from lightning.pytorch import LightningModule +from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.utilities.types import STEP_OUTPUT - -# from einops import rearrange -from pytorch_lightning.loggers import TensorBoardLogger - -# from pytorch_lightning.utilities.memory import get_model_size_mb -# from lightning.pytorch.utilities import get_model_size_mb -# from pytorch_lightning.utilities.types import STEP_OUTPUT from timm.data import Mixup as timm_Mixup from torch import Tensor, nn from torchmetrics import Accuracy, MetricCollection @@ -34,780 +28,10 @@ NegativeLogLikelihood, VariationRatio, ) - -# from torch_uncertainty.plotting_utils import plot_hist +from torch_uncertainty.plotting_utils import plot_hist from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup - -# class ClassificationSingle(pl.LightningModule): -# def __init__( -# self, -# num_classes: int, -# model: nn.Module, -# loss: type[nn.Module], -# optimization_procedure: Any, -# format_batch_fn: nn.Module | None = None, -# mixtype: str = "erm", -# mixmode: str = "elem", -# dist_sim: str = "emb", -# kernel_tau_max: float = 1.0, -# kernel_tau_std: float = 0.5, -# mixup_alpha: float = 0, -# cutmix_alpha: float = 0, -# evaluate_ood: bool = False, -# use_entropy: bool = False, -# use_logits: bool = False, -# log_plots: bool = False, -# calibration_set: Callable | None = None, -# **kwargs, -# ) -> None: -# """Classification routine for single models. - -# Args: -# num_classes (int): Number of classes. -# model (nn.Module): Model to train. -# loss (type[nn.Module]): Loss function. -# optimization_procedure (Any): Optimization procedure. -# format_batch_fn (nn.Module, optional): Function to format the batch. -# Defaults to :class:`torch.nn.Identity()`. -# mixtype (str, optional): Mixup type. Defaults to ``"erm"``. -# mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. -# dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. -# kernel_tau_max (float, optional): Maximum value for the kernel tau. -# Defaults to 1.0. -# kernel_tau_std (float, optional): Standard deviation for the kernel tau. -# Defaults to 0.5. -# mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. -# cutmix_alpha (float, optional): Alpha parameter for Cutmix. -# Defaults to 0. -# evaluate_ood (bool, optional): Indicates whether to evaluate the OOD -# detection performance or not. Defaults to ``False``. -# use_entropy (bool, optional): Indicates whether to use the entropy -# values as the OOD criterion or not. Defaults to ``False``. -# use_logits (bool, optional): Indicates whether to use the logits as the -# OOD criterion or not. Defaults to ``False``. -# log_plots (bool, optional): Indicates whether to log plots from -# metrics. Defaults to ``False``. -# calibration_set (Callable, optional): Function to get the calibration -# set. Defaults to ``None``. -# kwargs (Any): Additional arguments. - -# Note: -# The default OOD criterion is the softmax confidence score. - -# Warning: -# Make sure at most only one of :attr:`use_entropy` and :attr:`use_logits` -# attributes is set to ``True``. Otherwise a :class:`ValueError()` will -# be raised. -# """ -# super().__init__() - -# if format_batch_fn is None: -# format_batch_fn = nn.Identity() - -# self.save_hyperparameters( -# ignore=[ -# "model", -# "loss", -# "optimization_procedure", -# "format_batch_fn", -# "calibration_set", -# ] -# ) - -# if (use_logits + use_entropy) > 1: -# raise ValueError("You cannot choose more than one OOD criterion.") - -# self.num_classes = num_classes -# self.evaluate_ood = evaluate_ood -# self.use_logits = use_logits -# self.use_entropy = use_entropy -# self.log_plots = log_plots - -# self.calibration_set = calibration_set - -# self.binary_cls = num_classes == 1 - -# self.model = model -# self.loss = loss -# self.optimization_procedure = optimization_procedure -# # batch format -# self.format_batch_fn = format_batch_fn - -# # metrics -# if self.binary_cls: -# cls_metrics = MetricCollection( -# { -# "acc": Accuracy(task="binary"), -# "ece": CE(task="binary"), -# "brier": BrierScore(num_classes=1), -# }, -# compute_groups=False, -# ) -# else: -# cls_metrics = MetricCollection( -# { -# "nll": NegativeLogLikelihood(), -# "acc": Accuracy(task="multiclass", num_classes=self.num_classes), -# "ece": CE(task="multiclass", num_classes=self.num_classes), -# "brier": BrierScore(num_classes=self.num_classes), -# }, -# compute_groups=False, -# ) - -# self.val_cls_metrics = cls_metrics.clone(prefix="hp/val_") -# self.test_cls_metrics = cls_metrics.clone(prefix="hp/test_") - -# if self.calibration_set is not None: -# self.ts_cls_metrics = cls_metrics.clone(prefix="hp/ts_") - -# self.test_entropy_id = Entropy() - -# if self.evaluate_ood: -# ood_metrics = MetricCollection( -# { -# "fpr95": FPR95(pos_label=1), -# "auroc": BinaryAUROC(), -# "aupr": BinaryAveragePrecision(), -# }, -# compute_groups=[["auroc", "aupr"], ["fpr95"]], -# ) -# self.test_ood_metrics = ood_metrics.clone(prefix="hp/test_") -# self.test_entropy_ood = Entropy() - -# if mixup_alpha < 0 or cutmix_alpha < 0: -# raise ValueError( -# "Cutmix alpha and Mixup alpha must be positive." -# f"Got {mixup_alpha} and {cutmix_alpha}." -# ) - -# self.mixtype = mixtype -# self.mixmode = mixmode -# self.dist_sim = dist_sim - -# self.mixup = self.init_mixup( -# mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std -# ) - -# # Handle ELBO special cases -# self.is_elbo = isinstance(self.loss, partial) and self.loss.func == ELBOLoss - -# # DEC -# self.is_dec = self.loss == DECLoss or ( -# isinstance(self.loss, partial) and self.loss.func == DECLoss -# ) - -# def configure_optimizers(self) -> Any: -# return self.optimization_procedure(self) - -# @property -# def criterion(self) -> nn.Module: -# if self.is_elbo: -# self.loss = partial(self.loss, model=self.model) -# return self.loss() - -# def forward(self, inputs: Tensor) -> Tensor: -# return self.model.forward(inputs) - -# def on_train_start(self) -> None: -# # hyperparameters for performances -# param = {} -# param["storage"] = f"{get_model_size_mb(self)} MB" -# if self.logger is not None: # coverage: ignore -# self.logger.log_hyperparams( -# Namespace(**param), -# { -# "hp/val_nll": 0, -# "hp/val_acc": 0, -# "hp/test_acc": 0, -# "hp/test_nll": 0, -# "hp/test_ece": 0, -# "hp/test_brier": 0, -# "hp/test_entropy_id": 0, -# "hp/test_entropy_ood": 0, -# "hp/test_aupr": 0, -# "hp/test_auroc": 0, -# "hp/test_fpr95": 0, -# "hp/ts_test_nll": 0, -# "hp/ts_test_ece": 0, -# "hp/ts_test_brier": 0, -# }, -# ) - -# def training_step( -# self, batch: tuple[Tensor, Tensor], batch_idx: int -# ) -> STEP_OUTPUT: -# if self.mixtype == "kernel_warping": -# if self.dist_sim == "emb": -# with torch.no_grad(): -# feats = self.model.feats_forward(batch[0]).detach() - -# batch = self.mixup(*batch, feats) -# elif self.dist_sim == "inp": -# batch = self.mixup(*batch, batch[0]) -# else: -# batch = self.mixup(*batch) - -# inputs, targets = self.format_batch_fn(batch) - -# if self.is_elbo: -# loss = self.criterion(inputs, targets) -# else: -# logits = self.forward(inputs) -# # BCEWithLogitsLoss expects float targets -# if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: -# logits = logits.squeeze(-1) -# targets = targets.float() - -# if not self.is_dec: -# loss = self.criterion(logits, targets) -# else: -# loss = self.criterion(logits, targets, self.current_epoch) -# self.log("train_loss", loss) -# return loss - -# def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: -# inputs, targets = batch -# logits = self.forward(inputs) - -# if self.binary_cls: -# probs = torch.sigmoid(logits).squeeze(-1) -# else: -# probs = F.softmax(logits, dim=-1) - -# self.val_cls_metrics.update(probs, targets) - -# def validation_epoch_end(self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT]) -> None: -# self.log_dict(self.val_cls_metrics.compute()) -# self.val_cls_metrics.reset() - -# def on_test_start(self) -> None: -# if self.calibration_set is not None: -# self.scaler = TemperatureScaler(device=self.device).fit( -# model=self.model, calibration_set=self.calibration_set() -# ) -# self.cal_model = torch.nn.Sequential(self.model, self.scaler) -# else: -# self.scaler = None -# self.cal_model = None - -# def test_step( -# self, -# batch: tuple[Tensor, Tensor], -# batch_idx: int, -# dataloader_idx: int | None = 0, -# ) -> Tensor: -# inputs, targets = batch -# logits = self.forward(inputs) - -# if self.binary_cls: -# probs = torch.sigmoid(logits).squeeze(-1) -# else: -# probs = F.softmax(logits, dim=-1) - -# # self.cal_plot.update(probs, targets) -# confs = probs.max(dim=-1)[0] - -# if self.use_logits: -# ood_scores = -logits.max(dim=-1)[0] -# elif self.use_entropy: -# ood_scores = torch.special.entr(probs).sum(dim=-1) -# else: -# ood_scores = -confs - -# if ( -# self.calibration_set is not None -# and self.scaler is not None -# and self.cal_model is not None -# ): -# cal_logits = self.cal_model(inputs) -# cal_probs = F.softmax(cal_logits, dim=-1) -# self.ts_cls_metrics.update(cal_probs, targets) - -# if dataloader_idx == 0: -# self.test_cls_metrics.update(probs, targets) -# self.test_entropy_id(probs) -# self.log( -# "hp/test_entropy_id", -# self.test_entropy_id, -# on_epoch=True, -# add_dataloader_idx=False, -# ) -# if self.evaluate_ood: -# self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) -# elif self.evaluate_ood and dataloader_idx == 1: -# self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) -# self.test_entropy_ood(probs) -# self.log( -# "hp/test_entropy_ood", -# self.test_entropy_ood, -# on_epoch=True, -# add_dataloader_idx=False, -# ) -# return logits - -# def test_epoch_end(self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT]) -> None: -# self.log_dict( -# self.test_cls_metrics.compute(), -# ) - -# if ( -# self.calibration_set is not None -# and self.scaler is not None -# and self.cal_model is not None -# ): -# self.log_dict(self.ts_cls_metrics.compute()) -# self.ts_cls_metrics.reset() - -# if self.evaluate_ood: -# self.log_dict( -# self.test_ood_metrics.compute(), -# ) -# self.test_ood_metrics.reset() - -# if isinstance(self.logger, TensorBoardLogger) and self.log_plots: -# self.logger.experiment.add_figure( -# "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] -# ) - -# if self.evaluate_ood: -# id_logits = torch.cat(outputs[0], 0).float().cpu() -# ood_logits = torch.cat(outputs[1], 0).float().cpu() - -# id_probs = F.softmax(id_logits, dim=-1) -# ood_probs = F.softmax(ood_logits, dim=-1) - -# logits_fig = plot_hist( -# [id_logits.max(-1).values, ood_logits.max(-1).values], -# 20, -# "Histogram of the logits", -# )[0] -# probs_fig = plot_hist( -# [id_probs.max(-1).values, ood_probs.max(-1).values], -# 20, -# "Histogram of the likelihoods", -# )[0] -# self.logger.experiment.add_figure("Logit Histogram", logits_fig) -# self.logger.experiment.add_figure("Likelihood Histogram", probs_fig) - -# self.test_cls_metrics.reset() - -# def init_mixup( -# self, -# mixup_alpha: float, -# cutmix_alpha: float, -# kernel_tau_max: float, -# kernel_tau_std: float, -# ) -> Callable: -# if self.mixtype == "timm": -# return timm_Mixup( -# mixup_alpha=mixup_alpha, -# cutmix_alpha=cutmix_alpha, -# mode=self.mixmode, -# num_classes=self.num_classes, -# ) -# if self.mixtype == "mixup": -# return Mixup( -# alpha=mixup_alpha, -# mode=self.mixmode, -# num_classes=self.num_classes, -# ) -# if self.mixtype == "mixup_io": -# return MixupIO( -# alpha=mixup_alpha, -# mode=self.mixmode, -# num_classes=self.num_classes, -# ) -# if self.mixtype == "regmixup": -# return RegMixup( -# alpha=mixup_alpha, -# mode=self.mixmode, -# num_classes=self.num_classes, -# ) -# if self.mixtype == "kernel_warping": -# return WarpingMixup( -# alpha=mixup_alpha, -# mode=self.mixmode, -# num_classes=self.num_classes, -# apply_kernel=True, -# tau_max=kernel_tau_max, -# tau_std=kernel_tau_std, -# ) -# return lambda x, y: (x, y) - -# @staticmethod -# def add_model_specific_args( -# parent_parser: ArgumentParser, -# ) -> ArgumentParser: -# """Defines the routine's attributes via command-line options. - -# Args: -# parent_parser (ArgumentParser): Parent parser to be completed. - -# Adds: -# - ``--entropy``: sets :attr:`use_entropy` to ``True``. -# - ``--logits``: sets :attr:`use_logits` to ``True``. -# - ``--mixup_alpha``: sets :attr:`mixup_alpha` for Mixup -# - ``--cutmix_alpha``: sets :attr:`cutmix_alpha` for Cutmix -# - ``--mixtype``: sets :attr:`mixtype` for Mixup -# - ``--mixmode``: sets :attr:`mixmode` for Mixup -# - ``--dist_sim``: sets :attr:`dist_sim` for Mixup -# - ``--kernel_tau_max``: sets :attr:`kernel_tau_max` for Mixup -# - ``--kernel_tau_std``: sets :attr:`kernel_tau_std` for Mixup -# """ -# parent_parser.add_argument("--entropy", dest="use_entropy", action="store_true") -# parent_parser.add_argument("--logits", dest="use_logits", action="store_true") - -# # Mixup args -# parent_parser.add_argument( -# "--mixup_alpha", dest="mixup_alpha", type=float, default=0 -# ) -# parent_parser.add_argument( -# "--cutmix_alpha", dest="cutmix_alpha", type=float, default=0 -# ) -# parent_parser.add_argument("--mixtype", dest="mixtype", type=str, default="erm") -# parent_parser.add_argument( -# "--mixmode", dest="mixmode", type=str, default="elem" -# ) -# parent_parser.add_argument( -# "--dist_sim", dest="dist_sim", type=str, default="emb" -# ) -# parent_parser.add_argument( -# "--kernel_tau_max", dest="kernel_tau_max", type=float, default=1.0 -# ) -# parent_parser.add_argument( -# "--kernel_tau_std", dest="kernel_tau_std", type=float, default=0.5 -# ) -# return parent_parser - - -# class ClassificationEnsemble(ClassificationSingle): -# def __init__( -# self, -# num_classes: int, -# model: nn.Module, -# loss: type[nn.Module], -# optimization_procedure: Any, -# num_estimators: int, -# format_batch_fn: nn.Module | None = None, -# mixtype: str = "erm", -# mixmode: str = "elem", -# dist_sim: str = "emb", -# kernel_tau_max: float = 1.0, -# kernel_tau_std: float = 0.5, -# mixup_alpha: float = 0, -# cutmix_alpha: float = 0, -# evaluate_ood: bool = False, -# use_entropy: bool = False, -# use_logits: bool = False, -# use_mi: bool = False, -# use_variation_ratio: bool = False, -# log_plots: bool = False, -# **kwargs, -# ) -> None: -# """Classification routine for ensemble models. - -# Args: -# num_classes (int): Number of classes. -# model (nn.Module): Model to train. -# loss (type[nn.Module]): Loss function. -# optimization_procedure (Any): Optimization procedure. -# num_estimators (int): Number of estimators in the ensemble. -# format_batch_fn (nn.Module, optional): Function to format the batch. -# Defaults to :class:`torch.nn.Identity()`. -# mixtype (str, optional): Mixup type. Defaults to ``"erm"``. -# mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. -# dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. -# kernel_tau_max (float, optional): Maximum value for the kernel tau. -# Defaults to 1.0. -# kernel_tau_std (float, optional): Standard deviation for the kernel tau. -# Defaults to 0.5. -# mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. -# cutmix_alpha (float, optional): Alpha parameter for Cutmix. -# Defaults to 0. -# evaluate_ood (bool, optional): Indicates whether to evaluate the OOD -# detection performance or not. Defaults to ``False``. -# use_entropy (bool, optional): Indicates whether to use the entropy -# values as the OOD criterion or not. Defaults to ``False``. -# use_logits (bool, optional): Indicates whether to use the logits as the -# OOD criterion or not. Defaults to ``False``. -# use_mi (bool, optional): Indicates whether to use the mutual -# information as the OOD criterion or not. Defaults to ``False``. -# use_variation_ratio (bool, optional): Indicates whether to use the -# variation ratio as the OOD criterion or not. Defaults to ``False``. -# log_plots (bool, optional): Indicates whether to log plots from -# metrics. Defaults to ``False``. -# calibration_set (Callable, optional): Function to get the calibration -# set. Defaults to ``None``. -# kwargs (Any): Additional arguments. - -# Note: -# The default OOD criterion is the averaged softmax confidence score. - -# Warning: -# Make sure at most only one of :attr:`use_entropy`, :attr:`use_logits` -# , :attr:`use_mi`, and :attr:`use_variation_ratio` attributes is set to -# ``True``. Otherwise a :class:`ValueError()` will be raised. -# """ -# super().__init__( -# num_classes=num_classes, -# model=model, -# loss=loss, -# optimization_procedure=optimization_procedure, -# format_batch_fn=format_batch_fn, -# mixtype=mixtype, -# mixmode=mixmode, -# dist_sim=dist_sim, -# kernel_tau_max=kernel_tau_max, -# kernel_tau_std=kernel_tau_std, -# mixup_alpha=mixup_alpha, -# cutmix_alpha=cutmix_alpha, -# evaluate_ood=evaluate_ood, -# use_entropy=use_entropy, -# use_logits=use_logits, -# **kwargs, -# ) - -# self.num_estimators = num_estimators - -# self.use_mi = use_mi -# self.use_variation_ratio = use_variation_ratio -# self.log_plots = log_plots - -# if ( -# self.use_logits + self.use_entropy + self.use_mi + self.use_variation_ratio -# ) > 1: -# raise ValueError("You cannot choose more than one OOD criterion.") - -# # metrics for ensembles only -# ens_metrics = MetricCollection( -# { -# "disagreement": Disagreement(), -# "mi": MutualInformation(), -# "entropy": Entropy(), -# } -# ) -# self.test_id_ens_metrics = ens_metrics.clone(prefix="hp/test_id_ens_") - -# if self.evaluate_ood: -# self.test_ood_ens_metrics = ens_metrics.clone(prefix="hp/test_ood_ens_") - -# def on_train_start(self) -> None: -# param = {} -# param["storage"] = f"{get_model_size_mb(self)} MB" -# if self.logger is not None: # coverage: ignore -# self.logger.log_hyperparams( -# Namespace(**param), -# { -# "hp/val_nll": 0, -# "hp/val_acc": 0, -# "hp/test_acc": 0, -# "hp/test_nll": 0, -# "hp/test_ece": 0, -# "hp/test_brier": 0, -# "hp/test_entropy_id": 0, -# "hp/test_entropy_ood": 0, -# "hp/test_aupr": 0, -# "hp/test_auroc": 0, -# "hp/test_fpr95": 0, -# "hp/test_id_ens_disagreement": 0, -# "hp/test_id_ens_mi": 0, -# "hp/test_id_ens_entropy": 0, -# "hp/test_ood_ens_disagreement": 0, -# "hp/test_ood_ens_mi": 0, -# "hp/test_ood_ens_entropy": 0, -# }, -# ) - -# def training_step( -# self, batch: tuple[Tensor, Tensor], batch_idx: int -# ) -> STEP_OUTPUT: -# batch = self.mixup(*batch) -# # eventual input repeat is done in the model -# inputs, targets = self.format_batch_fn(batch) - -# if self.is_elbo: -# loss = self.criterion(inputs, targets) -# else: -# logits = self.forward(inputs) -# # BCEWithLogitsLoss expects float targets -# if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: -# logits = logits.squeeze(-1) -# targets = targets.float() - -# if not self.is_dec: -# loss = self.criterion(logits, targets) -# else: -# loss = self.criterion(logits, targets, self.current_epoch) - -# self.log("train_loss", loss) -# return loss - -# def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: -# inputs, targets = batch -# logits = self.forward(inputs) -# logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) -# if self.binary_cls: -# probs_per_est = torch.sigmoid(logits).squeeze(-1) -# else: -# probs_per_est = F.softmax(logits, dim=-1) - -# probs = probs_per_est.mean(dim=1) -# self.val_cls_metrics.update(probs, targets) - -# def test_step( -# self, -# batch: tuple[Tensor, Tensor], -# batch_idx: int, -# dataloader_idx: int | None = 0, -# ) -> Tensor: -# inputs, targets = batch -# logits = self.forward(inputs) -# logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators) - -# if self.binary_cls: -# probs_per_est = torch.sigmoid(logits) -# else: -# probs_per_est = F.softmax(logits, dim=-1) - -# probs = probs_per_est.mean(dim=1) -# # self.cal_plot.update(probs, targets) -# confs = probs.max(-1)[0] - -# if self.use_logits: -# ood_scores = -logits.mean(dim=1).max(dim=-1)[0] -# elif self.use_entropy: -# ood_scores = torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) -# elif self.use_mi: -# mi_metric = MutualInformation(reduction="none") -# ood_scores = mi_metric(probs_per_est) -# elif self.use_variation_ratio: -# vr_metric = VariationRatio(reduction="none", probabilistic=False) -# ood_scores = vr_metric(probs_per_est.transpose(0, 1)) -# else: -# ood_scores = -confs - -# if dataloader_idx == 0: -# # squeeze if binary classification only for binary metrics -# self.test_cls_metrics.update( -# probs.squeeze(-1) if self.binary_cls else probs, -# targets, -# ) -# self.test_entropy_id(probs) - -# self.test_id_ens_metrics.update(probs_per_est) -# self.log( -# "hp/test_entropy_id", -# self.test_entropy_id, -# on_epoch=True, -# add_dataloader_idx=False, -# ) - -# if self.evaluate_ood: -# self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) -# elif self.evaluate_ood and dataloader_idx == 1: -# self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) -# self.test_entropy_ood(probs) -# self.test_ood_ens_metrics.update(probs_per_est) -# self.log( -# "hp/test_entropy_ood", -# self.test_entropy_ood, -# on_epoch=True, -# add_dataloader_idx=False, -# ) -# return logits - -# def test_epoch_end(self, outputs: EPOCH_OUTPUT | list[EPOCH_OUTPUT]) -> None: -# self.log_dict( -# self.test_cls_metrics.compute(), -# ) - -# self.log_dict( -# self.test_id_ens_metrics.compute(), -# ) - -# if self.evaluate_ood: -# self.log_dict( -# self.test_ood_metrics.compute(), -# ) -# self.log_dict( -# self.test_ood_ens_metrics.compute(), -# ) - -# self.test_ood_metrics.reset() -# self.test_ood_ens_metrics.reset() - -# if isinstance(self.logger, TensorBoardLogger) and self.log_plots: -# self.logger.experiment.add_figure( -# "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] -# ) - -# if self.evaluate_ood: -# id_logits = torch.cat(outputs[0], 0).float().cpu() -# ood_logits = torch.cat(outputs[1], 0).float().cpu() - -# id_probs = F.softmax(id_logits, dim=-1) -# ood_probs = F.softmax(ood_logits, dim=-1) - -# logits_fig = plot_hist( -# [ -# id_logits.mean(1).max(-1).values, -# ood_logits.mean(1).max(-1).values, -# ], -# 20, -# "Histogram of the logits", -# )[0] -# probs_fig = plot_hist( -# [ -# id_probs.mean(1).max(-1).values, -# ood_probs.mean(1).max(-1).values, -# ], -# 20, -# "Histogram of the likelihoods", -# )[0] -# self.logger.experiment.add_figure("Logit Histogram", logits_fig) -# self.logger.experiment.add_figure("Likelihood Histogram", probs_fig) - -# self.test_cls_metrics.reset() -# self.test_id_ens_metrics.reset() - -# @staticmethod -# def add_model_specific_args( -# parent_parser: ArgumentParser, -# ) -> ArgumentParser: -# """Defines the routine's attributes via command-line options. - -# Adds: -# - ``--entropy``: sets :attr:`use_entropy` to ``True``. -# - ``--logits``: sets :attr:`use_logits` to ``True``. -# - ``--mutual_information``: sets :attr:`use_mi` to ``True``. -# - ``--variation_ratio``: sets :attr:`use_variation_ratio` to ``True``. -# - ``--num_estimators``: sets :attr:`num_estimators`. -# """ -# parent_parser = ClassificationSingle.add_model_specific_args(parent_parser) -# # FIXME: should be a str to choose among the available OOD criteria -# # rather than a boolean, but it is not possible since -# # ClassificationSingle and ClassificationEnsemble have different OOD -# # criteria. -# parent_parser.add_argument( -# "--mutual_information", -# dest="use_mi", -# action="store_true", -# default=False, -# ) -# parent_parser.add_argument( -# "--variation_ratio", -# dest="use_variation_ratio", -# action="store_true", -# default=False, -# ) -# parent_parser.add_argument( -# "--num_estimators", -# type=int, -# default=None, -# help="Number of estimators for ensemble", -# ) -# return parent_parser +from torch_uncertainty.utils import csv_writer class ClassificationRoutine(LightningModule): @@ -831,7 +55,8 @@ def __init__( use_mi: bool = False, use_variation_ratio: bool = False, log_plots: bool = False, - calibration_set: Callable | None = None, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, ) -> None: """Classification routine. @@ -854,7 +79,8 @@ def __init__( use_mi (bool, optional): _description_. Defaults to False. use_variation_ratio (bool, optional): _description_. Defaults to False. log_plots (bool, optional): _description_. Defaults to False. - calibration_set (Callable | None, optional): _description_. Defaults to None. + save_in_csv (bool, optional): _description_. Defaults to False. + calibration_set (str | None, optional): _description_. Defaults to None. Raises: ValueError: _description_ @@ -890,6 +116,7 @@ def __init__( self.use_mi = use_mi self.use_variation_ratio = use_variation_ratio self.log_plots = log_plots + self.save_in_csv = save_in_csv self.calibration_set = calibration_set self.binary_cls = num_classes == 1 @@ -921,11 +148,11 @@ def __init__( compute_groups=False, ) - self.val_cls_metrics = cls_metrics.clone(prefix="hp/val_") - self.test_cls_metrics = cls_metrics.clone(prefix="hp/test_") + self.val_cls_metrics = cls_metrics.clone(prefix="val_") + self.test_cls_metrics = cls_metrics.clone(prefix="test_") if self.calibration_set is not None: - self.ts_cls_metrics = cls_metrics.clone(prefix="hp/ts_") + self.ts_cls_metrics = cls_metrics.clone(prefix="ts_") self.test_entropy_id = Entropy() @@ -938,7 +165,7 @@ def __init__( }, compute_groups=[["auroc", "aupr"], ["fpr95"]], ) - self.test_ood_metrics = ood_metrics.clone(prefix="hp/test_") + self.test_ood_metrics = ood_metrics.clone(prefix="test_") self.test_entropy_ood = Entropy() self.mixtype = mixtype @@ -974,15 +201,16 @@ def __init__( "entropy": Entropy(), } ) - self.test_id_ens_metrics = ens_metrics.clone( - prefix="hp/test_id_ens_" - ) + self.test_id_ens_metrics = ens_metrics.clone(prefix="test_id_ens_") if self.evaluate_ood: self.test_ood_ens_metrics = ens_metrics.clone( - prefix="hp/test_ood_ens_" + prefix="test_ood_ens_" ) + self.id_logit_storage = None + self.ood_logit_storage = None + def init_mixup( self, mixup_alpha: float, @@ -1030,8 +258,6 @@ def on_train_start(self) -> None: init_metrics = {k: 0 for k in self.val_cls_metrics} init_metrics.update({k: 0 for k in self.test_cls_metrics}) - # self.hparams.storage = f"{get_model_size_mb(self)} MB" - if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( self.hparams, @@ -1039,15 +265,31 @@ def on_train_start(self) -> None: ) def on_test_start(self) -> None: - if self.calibration_set is not None: + if isinstance(self.calibration_set, str) and self.calibration_set in [ + "val", + "test", + ]: + dataset = ( + self.trainer.datamodule.val_dataloader().dataset + if self.calibration_set == "val" + else self.trainer.datamodule.test_dataloader().dataset + ) self.scaler = TemperatureScaler(device=self.device).fit( - model=self.model, calibration_set=self.calibration_set() + model=self.model, calibration_set=dataset ) self.cal_model = torch.nn.Sequential(self.model, self.scaler) else: self.scaler = None self.cal_model = None + if ( + self.evaluate_ood + and self.log_plots + and isinstance(self.logger, TensorBoardLogger) + ): + self.id_logit_storage = [] + self.ood_logit_storage = [] + @property def criterion(self) -> nn.Module: if self.is_elbo: @@ -1096,16 +338,15 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch - logits = self.forward(inputs) # (b, c) or (m, b, c) - if logits.ndim == 2: - logits = logits.unsqueeze(0) + logits = self.forward(inputs) # (m*b, c) + logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) else: probs_per_est = F.softmax(logits, dim=-1) - probs = probs_per_est.mean(dim=0) + probs = probs_per_est.mean(dim=1) self.val_cls_metrics.update(probs, targets) def test_step( @@ -1115,16 +356,15 @@ def test_step( dataloader_idx: int = 0, ) -> None: inputs, targets = batch - logits = self.forward(inputs) - if logits.ndim == 2: - logits = logits.unsqueeze(0) + logits = self.forward(inputs) # (m*b, c) + logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) else: probs_per_est = F.softmax(logits, dim=-1) - probs = probs_per_est.mean(dim=0) + probs = probs_per_est.mean(dim=1) confs = probs.max(-1)[0] @@ -1156,13 +396,16 @@ def test_step( if dataloader_idx == 0: # squeeze if binary classification only for binary metrics - self.test_cls_metrics.update( + self.test_cls_metrics( probs.squeeze(-1) if self.binary_cls else probs, targets, ) + self.log_dict( + self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False + ) self.test_entropy_id(probs) self.log( - "hp/test_entropy_id", + "test_entropy_id", self.test_entropy_id, on_epoch=True, add_dataloader_idx=False, @@ -1171,16 +414,19 @@ def test_step( if self.num_estimators > 1: self.test_id_ens_metrics.update(probs_per_est) - if self.evaluate_ood: - self.test_ood_metrics.update( - ood_scores, torch.zeros_like(targets) - ) + if self.evaluate_ood: + self.test_ood_metrics.update( + ood_scores, torch.zeros_like(targets) + ) + + if self.id_logit_storage is not None: + self.id_logit_storage.append(logits.detach().cpu()) elif self.evaluate_ood and dataloader_idx == 1: self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_entropy_ood(probs) self.log( - "hp/test_entropy_ood", + "test_entropy_ood", self.test_entropy_ood, on_epoch=True, add_dataloader_idx=False, @@ -1188,14 +434,19 @@ def test_step( if self.num_estimators > 1: self.test_ood_ens_metrics.update(probs_per_est) + if self.ood_logit_storage is not None: + self.ood_logit_storage.append(logits.detach().cpu()) + def on_validation_epoch_end(self) -> None: self.log_dict(self.val_cls_metrics.compute()) self.val_cls_metrics.reset() def on_test_epoch_end(self) -> None: - self.log_dict( - self.test_cls_metrics.compute(), - ) + # already logged + result_dict = self.test_cls_metrics.compute() + + # already logged + result_dict.update({"test_entropy_id": self.test_entropy_id.compute()}) if ( self.num_estimators == 1 @@ -1203,24 +454,32 @@ def on_test_epoch_end(self) -> None: and self.scaler is not None and self.cal_model is not None ): - self.log_dict(self.ts_cls_metrics.compute()) + tmp_metrics = self.ts_cls_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) self.ts_cls_metrics.reset() if self.num_estimators > 1: - self.log_dict( - self.test_id_ens_metrics.compute(), - ) + tmp_metrics = self.test_id_ens_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) self.test_id_ens_metrics.reset() if self.evaluate_ood: - self.log_dict( - self.test_ood_metrics.compute(), - ) + tmp_metrics = self.test_ood_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) self.test_ood_metrics.reset() + + # already logged + result_dict.update( + {"test_entropy_ood": self.test_entropy_ood.compute()} + ) + if self.num_estimators > 1: - self.log_dict( - self.test_ood_ens_metrics.compute(), - ) + tmp_metrics = self.test_ood_ens_metrics.compute() + self.log_dict(tmp_metrics) + result_dict.update(tmp_metrics) self.test_ood_ens_metrics.reset() if isinstance(self.logger, TensorBoardLogger) and self.log_plots: @@ -1228,6 +487,41 @@ def on_test_epoch_end(self) -> None: "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] ) - # TODO: plot histograms of logits and likelihoods + # plot histograms of logits and likelihoods + if self.evaluate_ood: + id_logits = torch.cat(self.id_logit_storage, dim=0) + ood_logits = torch.cat(self.ood_logit_storage, dim=0) + + id_probs = F.softmax(id_logits, dim=-1) + ood_probs = F.softmax(ood_logits, dim=-1) + + logits_fig = plot_hist( + [ + id_logits.mean(1).max(-1).values, + ood_logits.mean(1).max(-1).values, + ], + 20, + "Histogram of the logits", + )[0] + probs_fig = plot_hist( + [ + id_probs.mean(1).max(-1).values, + ood_probs.mean(1).max(-1).values, + ], + 20, + "Histogram of the likelihoods", + )[0] + self.logger.experiment.add_figure("Logit Histogram", logits_fig) + self.logger.experiment.add_figure( + "Likelihood Histogram", probs_fig + ) - self.test_cls_metrics.reset() + if self.save_in_csv: + self.save_results_to_csv(result_dict) + + def save_results_to_csv(self, results: dict[str, float]) -> None: + if self.logger is not None: + csv_writer( + Path(self.logger.log_dir) / "results.csv", + results, + ) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 30cc9e80..c160c3e0 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -11,7 +11,9 @@ def __init__( self, num_classes: int, model: nn.Module, - loss: nn.Module | None, + loss: nn.Module, + num_estimators: int, + format_batch_fn: nn.Module | None = None, ) -> None: super().__init__() diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index e6b70312..af7813b4 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 from .checkpoints import get_version +from .cli import TULightningCLI from .hub import load_hf from .misc import csv_writer diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py new file mode 100644 index 00000000..695b3b32 --- /dev/null +++ b/torch_uncertainty/utils/cli.py @@ -0,0 +1,128 @@ +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.cli import ( + ArgsType, + LightningArgumentParser, + LightningCLI, + SaveConfigCallback, +) +from typing_extensions import override + + +class TUSaveConfigCallback(SaveConfigCallback): + @override + def setup( + self, trainer: Trainer, pl_module: LightningModule, stage: str + ) -> None: + if self.already_saved: + return + + if self.save_to_log_dir and stage == "fit": + log_dir = trainer.log_dir # this broadcasts the directory + assert log_dir is not None + config_path = Path(log_dir) / self.config_filename + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = ( + fs.isfile(config_path) if trainer.is_global_zero else False + ) + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' + ) + + if trainer.is_global_zero: + fs.makedirs(log_dir, exist_ok=True) + self.parser.save( + self.config, + config_path, + skip_none=False, + overwrite=self.overwrite, + multifile=self.multifile, + ) + if trainer.is_global_zero: + self.save_config(trainer, pl_module, stage) + self.already_saved = True + + self.already_saved = trainer.strategy.broadcast(self.already_saved) + + +class TULightningCLI(LightningCLI): + def __init__( + self, + model_class: type[LightningModule] + | Callable[..., LightningModule] + | None = None, + datamodule_class: type[LightningDataModule] + | Callable[..., LightningDataModule] + | None = None, + save_config_callback: type[SaveConfigCallback] + | None = TUSaveConfigCallback, + save_config_kwargs: dict[str, Any] | None = None, + trainer_class: type[Trainer] | Callable[..., Trainer] = Trainer, + trainer_defaults: dict[str, Any] | None = None, + seed_everything_default: bool | int = True, + parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, + subclass_mode_model: bool = False, + subclass_mode_data: bool = False, + args: ArgsType = None, + run: bool = True, + auto_configure_optimizers: bool = True, + eval_after_fit_default: bool = False, + ) -> None: + """Custom LightningCLI for torch-uncertainty. + + Args: + model_class (type[LightningModule] | Callable[..., LightningModule] | None, optional): _description_. Defaults to None. + datamodule_class (type[LightningDataModule] | Callable[..., LightningDataModule] | None, optional): _description_. Defaults to None. + save_config_callback (type[SaveConfigCallback] | None, optional): _description_. Defaults to TUSaveConfigCallback. + save_config_kwargs (dict[str, Any] | None, optional): _description_. Defaults to None. + trainer_class (type[Trainer] | Callable[..., Trainer], optional): _description_. Defaults to Trainer. + trainer_defaults (dict[str, Any] | None, optional): _description_. Defaults to None. + seed_everything_default (bool | int, optional): _description_. Defaults to True. + parser_kwargs (dict[str, Any] | dict[str, dict[str, Any]] | None, optional): _description_. Defaults to None. + subclass_mode_model (bool, optional): _description_. Defaults to False. + subclass_mode_data (bool, optional): _description_. Defaults to False. + args (ArgsType, optional): _description_. Defaults to None. + run (bool, optional): _description_. Defaults to True. + auto_configure_optimizers (bool, optional): _description_. Defaults to True. + eval_after_fit_default (bool, optional): _description_. Defaults to False. + """ + self.eval_after_fit_default = eval_after_fit_default + super().__init__( + model_class, + datamodule_class, + save_config_callback, + save_config_kwargs, + trainer_class, + trainer_defaults, + seed_everything_default, + parser_kwargs, + subclass_mode_model, + subclass_mode_data, + args, + run, + auto_configure_optimizers, + ) + + def add_default_arguments_to_parser( + self, parser: LightningArgumentParser + ) -> None: + """Adds default arguments to the parser.""" + parser.add_argument( + "--eval_after_fit", + action="store_true", + default=self.eval_after_fit_default, + ) + super().add_default_arguments_to_parser(parser) From 798e610a6cc5727d71f4fe48e008fea6d8c1f3ba Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 Feb 2024 16:34:52 +0100 Subject: [PATCH 009/148] :zap: Update ruff parameters to comply with deprecation --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e5018677..9f7ad239 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ name = "torch_uncertainty" [tool.ruff] line-length = 80 target-version = "py310" -extend-select = [ +lint.extend-select = [ "A", "B", "C4", @@ -100,7 +100,7 @@ extend-select = [ "TRY", "YTT", ] -ignore = [ +lint.ignore = [ "B017", "D100", "D101", @@ -143,7 +143,7 @@ exclude = [ "venv", ] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" [tool.coverage.run] From 649ae0a6d2c31d1da2f6749e685cb47dcfba0a72 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 Feb 2024 17:02:55 +0100 Subject: [PATCH 010/148] :hammer: Rename the return_features argument --- torch_uncertainty/routines/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 0e597f97..c4552aac 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -340,18 +340,18 @@ def criterion(self) -> nn.Module: self.loss = partial(self.loss, model=self.model) return self.loss() - def forward(self, inputs: Tensor, return_features: bool = False) -> Tensor: + def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: """Forward pass of the model. Args: inputs (Tensor): Input tensor. - return_features (bool, optional): Whether to store the features or + save_feats (bool, optional): Whether to store the features or not. Defaults to ``False``. Note: The features are stored in the :attr:`features` attribute. """ - if return_features: + if save_feats: self.features = self.model.feats_forward(inputs) if hasattr(self.model, "classification_head"): # coverage: ignore logits = self.model.classification_head(self.features) @@ -423,7 +423,7 @@ def test_step( ) -> None: inputs, targets = batch logits = self.forward( - inputs, return_features=self.eval_grouping_loss + inputs, save_feats=self.eval_grouping_loss ) # (m*b, c) if logits.size(0) % self.num_estimators != 0: # coverage: ignore raise ValueError( From 3b10897966a81c1710f5c2182da9680785fbca39 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 Feb 2024 17:03:57 +0100 Subject: [PATCH 011/148] :shirt: Improve baselines codes and routines --- tests/_dummies/baseline.py | 29 +++++-------------- .../baselines/classification/resnet.py | 15 ++++------ .../baselines/classification/vgg.py | 20 +++++-------- .../baselines/classification/wideresnet.py | 15 ++++------ torch_uncertainty/routines/__init__.py | 3 ++ 5 files changed, 30 insertions(+), 52 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 576b12bd..93b2175b 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -1,15 +1,8 @@ -from typing import Any from pytorch_lightning import LightningModule from torch import nn -from torch_uncertainty.routines.classification import ( - ClassificationRoutine, -) -from torch_uncertainty.routines.regression import ( - RegressionEnsemble, - RegressionSingle, -) +from torch_uncertainty.routines import ClassificationRoutine, RegressionRoutine from torch_uncertainty.transforms import RepeatTarget from .model import dummy_model @@ -21,11 +14,9 @@ def __new__( num_classes: int, in_channels: int, loss: type[nn.Module], - optimization_procedure: Any, baseline_type: str = "single", with_feats: bool = True, with_linear: bool = True, - **kwargs, ) -> LightningModule: model = dummy_model( in_channels=in_channels, @@ -40,21 +31,18 @@ def __new__( num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, format_batch_fn=nn.Identity(), log_plots=True, - **kwargs, + num_estimators = 1 ) # baseline_type == "ensemble": - kwargs["num_estimators"] = 2 return ClassificationRoutine( num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, format_batch_fn=RepeatTarget(2), log_plots=True, - **kwargs, + num_estimators = 2 ) # @classmethod @@ -71,7 +59,6 @@ def __new__( in_features: int, out_features: int, loss: type[nn.Module], - optimization_procedure: Any, baseline_type: str = "single", dist_estimation: int = 1, **kwargs, @@ -83,24 +70,22 @@ def __new__( ) if baseline_type == "single": - return RegressionSingle( + return RegressionRoutine( out_features=out_features, model=model, loss=loss, - optimization_procedure=optimization_procedure, dist_estimation=dist_estimation, - **kwargs, + num_estimators=1 ) # baseline_type == "ensemble": kwargs["num_estimators"] = 2 - return RegressionEnsemble( + return RegressionRoutine( model=model, loss=loss, - optimization_procedure=optimization_procedure, dist_estimation=dist_estimation, mode="mean", out_features=out_features, - **kwargs, + num_estimators=2 ) # @classmethod diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 18fe67e5..cb8cafc9 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -246,28 +246,25 @@ def __init__( raise ValueError(f"Unknown version: {version}") if version in self.ensemble: - params.update( - { + params |= { "num_estimators": num_estimators, } - ) + if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": - params.update( - { + params |= { "alpha": alpha, "gamma": gamma, "pretrained": pretrained, } - ) + elif version == "masked": - params.update( - { + params |= { "scale": scale, } - ) + elif version == "mimo": format_batch_fn = MIMOBatchFormat( num_estimators=num_estimators, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 42732396..b41200d0 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -148,37 +148,33 @@ def __init__( format_batch_fn = nn.Identity() if version == "std": - params.update( - { + params |= { "dropout_rate": dropout_rate, } - ) + elif version == "mc-dropout": - params.update( - { + params |= { "dropout_rate": dropout_rate, "num_estimators": num_estimators, } - ) + if version in self.ensemble: - params.update( - { + params |= { "num_estimators": num_estimators, } - ) + if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": - params.update( - { + params |= { "alpha": alpha, "style": style, "gamma": gamma, } - ) + if version == "mc-dropout": # std VGGs don't have `num_estimators` del params["num_estimators"] diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index d7cf5a56..7eea2a3f 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -158,27 +158,24 @@ def __init__( raise ValueError(f"Unknown version: {version}") if version in self.ensemble: - params.update( - { + params |= { "num_estimators": num_estimators, } - ) + if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": - params.update( - { + params |= { "alpha": alpha, "gamma": gamma, } - ) + elif version == "masked": - params.update( - { + params |= { "scale": scale, } - ) + elif version == "mimo": format_batch_fn = MIMOBatchFormat( num_estimators=num_estimators, diff --git a/torch_uncertainty/routines/__init__.py b/torch_uncertainty/routines/__init__.py index e69de29b..f850e325 100644 --- a/torch_uncertainty/routines/__init__.py +++ b/torch_uncertainty/routines/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa: F401 +from .classification import ClassificationRoutine +from .regression import RegressionRoutine \ No newline at end of file From 24b93fa9c38ee794c2b73f12e5d8b90a085384f8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 Feb 2024 17:04:29 +0100 Subject: [PATCH 012/148] :bug: Fix regression routine and MLP baseline --- tests/baselines/test_deep_ensembles.py | 4 +- .../baselines/classification/__init__.py | 2 +- torch_uncertainty/baselines/regression/mlp.py | 89 ++------ torch_uncertainty/routines/regression.py | 199 ++++-------------- 4 files changed, 61 insertions(+), 233 deletions(-) diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index 060877d4..bd989cdc 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -1,4 +1,6 @@ -from torch_uncertainty.baselines.classification.deep_ensembles import DeepEnsembles +from torch_uncertainty.baselines.classification.deep_ensembles import ( + DeepEnsembles, +) class TestDeepEnsembles: diff --git a/torch_uncertainty/baselines/classification/__init__.py b/torch_uncertainty/baselines/classification/__init__.py index 466f3686..1326c2e3 100644 --- a/torch_uncertainty/baselines/classification/__init__.py +++ b/torch_uncertainty/baselines/classification/__init__.py @@ -1,4 +1,4 @@ # ruff: noqa: F401 from .resnet import ResNet from .vgg import VGG -from .wideresnet import WideResNet \ No newline at end of file +from .wideresnet import WideResNet diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 0cf0c6c6..16da62bb 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -1,44 +1,30 @@ -from argparse import ArgumentParser -from pathlib import Path -from typing import Any, Literal +from typing import Literal -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.core.saving import ( - load_hparams_from_tags_csv, - load_hparams_from_yaml, -) from torch import nn -from torch_uncertainty.baselines.utils.parser_addons import ( - add_packed_specific_args, -) from torch_uncertainty.models.mlp import mlp, packed_mlp from torch_uncertainty.routines.regression import ( - RegressionEnsemble, - RegressionSingle, + RegressionRoutine, ) -class MLP: +class MLP(RegressionRoutine): single = ["std"] ensemble = ["packed"] versions = {"std": mlp, "packed": packed_mlp} - def __new__( - cls, + def __init__( + self, num_outputs: int, in_features: int, loss: type[nn.Module], - optimization_procedure: Any, version: Literal["std", "packed"], hidden_dims: list[int], - dist_estimation: int, num_estimators: int | None = 1, alpha: float | None = None, gamma: int = 1, **kwargs, - ) -> LightningModule: + ) -> None: r"""MLP baseline for regression providing support for various versions.""" params = { "in_features": in_features, @@ -53,66 +39,17 @@ def __new__( "gamma": gamma, } - if version not in cls.versions: + if version not in self.versions: raise ValueError(f"Unknown version: {version}") - model = cls.versions[version](**params) + model = self.versions[version](**params) - kwargs.update(params) - kwargs.update({"version": version}) - # routine specific parameters - if version in cls.single: - return RegressionSingle( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, - **kwargs, - ) - # version in cls.versions.keys(): - return RegressionEnsemble( + # version in self.versions: + super().__init__( model=model, loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, + num_estimators=num_estimators, + dist_estimation=num_outputs, mode="mean", - **kwargs, - ) - return None - - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str | Path, - hparams_file: str | Path, - **kwargs, - ) -> LightningModule: # coverage: ignore - if hparams_file is not None: - extension = str(hparams_file).split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError( - ".csv, .yml or .yaml is required for `hparams_file`" - ) - - hparams.update(kwargs) - checkpoint = torch.load(checkpoint_path) - obj = cls(**hparams) - obj.load_state_dict(checkpoint["state_dict"]) - return obj - - @classmethod - def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: - parser = RegressionEnsemble.add_model_specific_args(parser) - parser = add_packed_specific_args(parser) - parser.add_argument( - "--version", - type=str, - choices=cls.versions.keys(), - default="std", - help=f"Variation of MLP. Choose among: {cls.versions.keys()}", ) - return parser + self.save_hyperparameters() diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index b210497b..a30da3cb 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -1,31 +1,30 @@ -from argparse import ArgumentParser -from typing import Any, Literal +from typing import Literal -import pytorch_lightning as pl -import torch import torch.nn.functional as F from einops import rearrange -from torch import nn +from lightning.pytorch.utilities.types import STEP_OUTPUT +from pytorch_lightning import LightningModule +from torch import Tensor, nn from torchmetrics import MeanSquaredError, MetricCollection from torch_uncertainty.metrics.nll import GaussianNegativeLogLikelihood -class RegressionSingle(pl.LightningModule): +class RegressionRoutine(LightningModule): def __init__( self, + dist_estimation: int, model: nn.Module, loss: type[nn.Module], - optimization_procedure: Any, - dist_estimation: int, - **kwargs, + num_estimators: int, + mode: Literal["mean", "mixture"], + out_features: int | None = 1, ) -> None: + print("Regression is Work in progress. Raise an issue if interested.") super().__init__() self.model = model self.loss = loss - self.optimization_procedure = optimization_procedure - # metrics if isinstance(dist_estimation, int): if dist_estimation <= 0: @@ -68,25 +67,36 @@ def __init__( self.val_metrics = reg_metrics.clone(prefix="reg_val/") self.test_metrics = reg_metrics.clone(prefix="reg_test/") - def configure_optimizers(self) -> Any: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - return self.loss() + if mode == "mixture": + raise NotImplementedError( + "Mixture of gaussians not implemented yet. Raise an issue if " + "needed." + ) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - return self.model.forward(inputs) + self.mode = mode + self.num_estimators = num_estimators + self.out_features = out_features def on_train_start(self) -> None: # hyperparameters for performances - param = {} - param["storage"] = f"{get_model_size_mb(self)} MB" + init_metrics = {k: 0 for k in self.val_metrics} + init_metrics.update({k: 0 for k in self.test_metrics}) + + def forward(self, inputs: Tensor) -> Tensor: + return self.model.forward(inputs) + + @property + def criterion(self) -> nn.Module: + return self.loss() def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int - ): + self, batch: tuple[Tensor, Tensor], batch_idx: int + )-> STEP_OUTPUT: inputs, targets = batch + + # eventual input repeat is done in the model + targets = targets.repeat((self.num_estimators, 1)) + logits = self.forward(inputs) if self.dist_estimation == 4: @@ -106,120 +116,7 @@ def training_step( return loss def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int - ) -> None: - inputs, targets = batch - logits = self.forward(inputs) - if self.dist_estimation == 4: - means = logits[..., 0] - alpha = 1 + F.softplus(logits[..., 2]) - beta = F.softplus(logits[..., 3]) - variances = beta / (alpha - 1) - self.val_metrics.gnll.update(means, targets, variances) - - targets = targets.view(means.size()) - elif self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.val_metrics.gnll.update(means, targets, variances) - - if means.ndim == 1: - means = means.unsqueeze(-1) - else: - means = logits.squeeze(-1) - - self.val_metrics.mse.update(means, targets) - - def validation_epoch_end( - self, outputs - ) -> None: - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step( - self, - batch: tuple[torch.Tensor, torch.Tensor], - batch_idx: int, - ) -> None: - inputs, targets = batch - logits = self.forward(inputs) - - if self.dist_estimation == 4: - means = logits[..., 0] - alpha = 1 + F.softplus(logits[..., 2]) - beta = F.softplus(logits[..., 3]) - variances = beta / (alpha - 1) - self.test_metrics.gnll.update(means, targets, variances) - - targets = targets.view(means.size()) - elif self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.test_metrics.gnll.update(means, targets, variances) - - if means.ndim == 1: - means = means.unsqueeze(-1) - else: - means = logits.squeeze(-1) - - self.test_metrics.mse.update(means, targets) - - def test_epoch_end( - self, outputs - ) -> None: - self.log_dict( - self.test_metrics.compute(), - ) - self.test_metrics.reset() - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - return parent_parser - - -class RegressionEnsemble(RegressionSingle): - def __init__( - self, - model: nn.Module, - loss: type[nn.Module], - optimization_procedure: Any, - dist_estimation: int, - num_estimators: int, - mode: Literal["mean", "mixture"], - out_features: int | None = 1, - **kwargs, - ) -> None: - super().__init__( - model=model, - loss=loss, - optimization_procedure=optimization_procedure, - dist_estimation=dist_estimation, - **kwargs, - ) - - if mode == "mixture": - raise NotImplementedError( - "Mixture of gaussians not implemented yet. Raise an issue if " - "needed." - ) - - self.mode = mode - self.num_estimators = num_estimators - self.out_features = out_features - - def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int - ): - inputs, targets = batch - - # eventual input repeat is done in the model - targets = targets.repeat((self.num_estimators, 1)) - return super().training_step((inputs, targets), batch_idx) - - def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch logits = self.forward(inputs) @@ -250,9 +147,9 @@ def validation_step( def test_step( self, - batch: tuple[torch.Tensor, torch.Tensor], + batch: tuple[Tensor, Tensor], batch_idx: int, - dataloader_idx: int | None = 0, + dataloader_idx: int = 0, ) -> None: if dataloader_idx != 0: raise NotImplementedError( @@ -287,20 +184,12 @@ def test_step( self.test_metrics.mse.update(means, targets) - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the routine's attributes via command-line options. - - Adds: - - ``--num_estimators``: sets :attr:`num_estimators`. - """ - parent_parser = RegressionSingle.add_model_specific_args(parent_parser) - parent_parser.add_argument( - "--num_estimators", - type=int, - default=None, - help="Number of estimators for ensemble", + def validation_epoch_end(self, outputs) -> None: + self.log_dict(self.val_metrics.compute()) + self.val_metrics.reset() + + def test_epoch_end(self, outputs) -> None: + self.log_dict( + self.test_metrics.compute(), ) - return parent_parser + self.test_metrics.reset() From 42d3835ce2fb347dc77da045a82ab81dc280dbf4 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 Feb 2024 17:12:17 +0100 Subject: [PATCH 013/148] :shirt: Lint --- tests/_dummies/baseline.py | 9 +++---- .../baselines/classification/resnet.py | 22 +++++++-------- .../baselines/classification/vgg.py | 27 +++++++++---------- .../baselines/classification/wideresnet.py | 20 +++++++------- torch_uncertainty/routines/__init__.py | 2 +- torch_uncertainty/routines/regression.py | 4 +-- 6 files changed, 40 insertions(+), 44 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 93b2175b..fcffee6e 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -1,4 +1,3 @@ - from pytorch_lightning import LightningModule from torch import nn @@ -33,7 +32,7 @@ def __new__( loss=loss, format_batch_fn=nn.Identity(), log_plots=True, - num_estimators = 1 + num_estimators=1, ) # baseline_type == "ensemble": return ClassificationRoutine( @@ -42,7 +41,7 @@ def __new__( loss=loss, format_batch_fn=RepeatTarget(2), log_plots=True, - num_estimators = 2 + num_estimators=2, ) # @classmethod @@ -75,7 +74,7 @@ def __new__( model=model, loss=loss, dist_estimation=dist_estimation, - num_estimators=1 + num_estimators=1, ) # baseline_type == "ensemble": kwargs["num_estimators"] = 2 @@ -85,7 +84,7 @@ def __new__( dist_estimation=dist_estimation, mode="mean", out_features=out_features, - num_estimators=2 + num_estimators=2, ) # @classmethod diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index cb8cafc9..0147aa31 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -247,24 +247,24 @@ def __init__( if version in self.ensemble: params |= { - "num_estimators": num_estimators, - } - + "num_estimators": num_estimators, + } + if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": params |= { - "alpha": alpha, - "gamma": gamma, - "pretrained": pretrained, - } - + "alpha": alpha, + "gamma": gamma, + "pretrained": pretrained, + } + elif version == "masked": params |= { - "scale": scale, - } - + "scale": scale, + } + elif version == "mimo": format_batch_fn = MIMOBatchFormat( num_estimators=num_estimators, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index b41200d0..de1ca9e4 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -149,32 +149,29 @@ def __init__( if version == "std": params |= { - "dropout_rate": dropout_rate, - } - + "dropout_rate": dropout_rate, + } + elif version == "mc-dropout": params |= { - "dropout_rate": dropout_rate, - "num_estimators": num_estimators, - } - + "dropout_rate": dropout_rate, + "num_estimators": num_estimators, + } if version in self.ensemble: params |= { - "num_estimators": num_estimators, - } - + "num_estimators": num_estimators, + } if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": params |= { - "alpha": alpha, - "style": style, - "gamma": gamma, - } - + "alpha": alpha, + "style": style, + "gamma": gamma, + } if version == "mc-dropout": # std VGGs don't have `num_estimators` del params["num_estimators"] diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 7eea2a3f..b7496158 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -159,23 +159,23 @@ def __init__( if version in self.ensemble: params |= { - "num_estimators": num_estimators, - } - + "num_estimators": num_estimators, + } + if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": params |= { - "alpha": alpha, - "gamma": gamma, - } - + "alpha": alpha, + "gamma": gamma, + } + elif version == "masked": params |= { - "scale": scale, - } - + "scale": scale, + } + elif version == "mimo": format_batch_fn = MIMOBatchFormat( num_estimators=num_estimators, diff --git a/torch_uncertainty/routines/__init__.py b/torch_uncertainty/routines/__init__.py index f850e325..2513e15e 100644 --- a/torch_uncertainty/routines/__init__.py +++ b/torch_uncertainty/routines/__init__.py @@ -1,3 +1,3 @@ # ruff: noqa: F401 from .classification import ClassificationRoutine -from .regression import RegressionRoutine \ No newline at end of file +from .regression import RegressionRoutine diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index a30da3cb..bc11a60d 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -91,7 +91,7 @@ def criterion(self) -> nn.Module: def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int - )-> STEP_OUTPUT: + ) -> STEP_OUTPUT: inputs, targets = batch # eventual input repeat is done in the model @@ -187,7 +187,7 @@ def test_step( def validation_epoch_end(self, outputs) -> None: self.log_dict(self.val_metrics.compute()) self.val_metrics.reset() - + def test_epoch_end(self, outputs) -> None: self.log_dict( self.test_metrics.compute(), From 50f59a081dc6bf4fcbb54047418a67ede7ea091d Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 Feb 2024 17:32:41 +0100 Subject: [PATCH 014/148] :sparkles: Add & fix configs --- .../cifar10/configs/resnet.yaml | 4 +- .../cifar10/configs/resnet18/batched.yaml | 4 +- .../cifar10/configs/resnet18/masked.yaml | 4 +- .../cifar10/configs/resnet18/mimo.yaml | 4 +- .../cifar10/configs/resnet18/packed.yaml | 4 +- .../cifar10/configs/resnet18/standard.yaml | 6 +-- .../cifar10/configs/resnet50/batched.yaml | 4 +- .../cifar10/configs/resnet50/masked.yaml | 4 +- .../cifar10/configs/resnet50/mimo.yaml | 4 +- .../cifar10/configs/resnet50/packed.yaml | 4 +- .../cifar10/configs/resnet50/standard.yaml | 6 +-- .../cifar10/configs/wideresnet28x10.yaml | 15 +----- .../configs/wideresnet28x10/batched.yaml | 4 +- .../configs/wideresnet28x10/masked.yaml | 4 +- .../cifar10/configs/wideresnet28x10/mimo.yaml | 4 +- .../configs/wideresnet28x10/packed.yaml | 4 +- .../configs/wideresnet28x10/standard.yaml | 6 +-- .../cifar100/configs/resnet.yaml | 34 ++++++++++++ .../cifar100/configs/resnet18/batched.yaml | 49 +++++++++++++++++ .../cifar100/configs/resnet18/masked.yaml | 50 ++++++++++++++++++ .../cifar100/configs/resnet18/mimo.yaml | 50 ++++++++++++++++++ .../cifar100/configs/resnet18/packed.yaml | 51 ++++++++++++++++++ .../cifar100/configs/resnet18/standard.yaml | 48 +++++++++++++++++ .../cifar100/configs/resnet50/batched.yaml | 50 ++++++++++++++++++ .../cifar100/configs/resnet50/masked.yaml | 51 ++++++++++++++++++ .../cifar100/configs/resnet50/mimo.yaml | 51 ++++++++++++++++++ .../cifar100/configs/resnet50/packed.yaml | 52 +++++++++++++++++++ .../cifar100/configs/resnet50/standard.yaml | 49 +++++++++++++++++ 28 files changed, 572 insertions(+), 48 deletions(-) create mode 100644 experiments/classification/cifar100/configs/resnet.yaml create mode 100644 experiments/classification/cifar100/configs/resnet18/batched.yaml create mode 100644 experiments/classification/cifar100/configs/resnet18/masked.yaml create mode 100644 experiments/classification/cifar100/configs/resnet18/mimo.yaml create mode 100644 experiments/classification/cifar100/configs/resnet18/packed.yaml create mode 100644 experiments/classification/cifar100/configs/resnet18/standard.yaml create mode 100644 experiments/classification/cifar100/configs/resnet50/batched.yaml create mode 100644 experiments/classification/cifar100/configs/resnet50/masked.yaml create mode 100644 experiments/classification/cifar100/configs/resnet50/mimo.yaml create mode 100644 experiments/classification/cifar100/configs/resnet50/packed.yaml create mode 100644 experiments/classification/cifar100/configs/resnet50/standard.yaml diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index 21352497..fc1197d7 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index aa2c92e1..34ab9fc5 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index 9d92ef08..f84533ba 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index 92414eb4..8577ab3b 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index a9e9479e..1e6853df 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index e6cea671..2d835a6a 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,14 +23,14 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: num_classes: 10 in_channels: 3 loss: torch.nn.CrossEntropyLoss - version: vanilla + version: standard arch: 18 style: cifar data: diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index ec519698..60841124 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index 42efac31..eea7ac81 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index 906207ba..cd6681a3 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index 4a8b057e..816ddb33 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index f57e9d07..e3ff7407 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,14 +23,14 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: num_classes: 10 in_channels: 3 loss: torch.nn.CrossEntropyLoss - version: vanilla + version: standard arch: 50 style: cifar data: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index a8c05f36..f9423cfb 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: @@ -34,14 +34,3 @@ model: data: root: ./data batch_size: 128 -optimizer: - lr: 0.1 - momentum: 0.9 - weight_decay: 5e-4 - nesterov: true -lr_scheduler: - milestones: - - 60 - - 120 - - 160 - gamma: 0.2 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index 48a9d817..59bb6213 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 3fb765dd..363cf464 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index e9149c0e..b186909b 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index 10c757b7..0ae5c3ca 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index 78b8a74b..200c571c 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: hp/val_acc + monitor: cls_val/acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,14 +23,14 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: hp/val_acc + monitor: cls_val/acc patience: 1000 check_finite: true model: num_classes: 10 in_channels: 3 loss: torch.nn.CrossEntropyLoss - version: vanilla + version: standard style: cifar data: root: ./data diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml new file mode 100644 index 00000000..fc1197d7 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -0,0 +1,34 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/ + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 10 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + style: cifar +data: + root: ./data + batch_size: 128 diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml new file mode 100644 index 00000000..358e0f62 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: batched + arch: 18 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 1e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml new file mode 100644 index 00000000..93c57ae4 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: masked + arch: 18 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml new file mode 100644 index 00000000..3e6da8e8 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: mimo + arch: 18 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml new file mode 100644 index 00000000..39f384db --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: packed + arch: 18 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml new file mode 100644 index 00000000..41452d61 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: standard + arch: 18 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml new file mode 100644 index 00000000..1a344716 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: batched + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: batched + arch: 50 + style: cifar + num_estimators: 4 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.08 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml new file mode 100644 index 00000000..be065061 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: masked + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: masked + arch: 50 + style: cifar + num_estimators: 4 + scale: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml new file mode 100644 index 00000000..76f4f8b8 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -0,0 +1,51 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: mimo + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: mimo + arch: 50 + style: cifar + num_estimators: 4 + rho: 1.0 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml new file mode 100644 index 00000000..6bc068e7 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -0,0 +1,52 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: packed + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: packed + arch: 50 + style: cifar + num_estimators: 4 + alpha: 2 + gamma: 2 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml new file mode 100644 index 00000000..d7580437 --- /dev/null +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -0,0 +1,49 @@ +# lightning.pytorch==2.1.3 +seed_everything: true +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet50 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: cls_val/acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: cls_val/acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: torch.nn.CrossEntropyLoss + version: standard + arch: 50 + style: cifar +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 From 42fd89b31ef7cc3774f34df71eda6eed9ff65355 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 Feb 2024 17:39:51 +0100 Subject: [PATCH 015/148] :hammer: Refactor datamodules --- torch_uncertainty/datamodules/__init__.py | 11 ++++++----- .../datamodules/{ => classification}/cifar10.py | 3 +-- .../datamodules/{ => classification}/cifar100.py | 3 +-- .../datamodules/{ => classification}/imagenet.py | 0 .../datamodules/{ => classification}/mnist.py | 0 .../datamodules/{ => classification}/tiny_imagenet.py | 3 +-- .../datamodules/segmentation/__init__.py | 2 ++ 7 files changed, 11 insertions(+), 11 deletions(-) rename torch_uncertainty/datamodules/{ => classification}/cifar10.py (99%) rename torch_uncertainty/datamodules/{ => classification}/cifar100.py (99%) rename torch_uncertainty/datamodules/{ => classification}/imagenet.py (100%) rename torch_uncertainty/datamodules/{ => classification}/mnist.py (100%) rename torch_uncertainty/datamodules/{ => classification}/tiny_imagenet.py (99%) diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index 90b9f0eb..4c236ef5 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -1,7 +1,8 @@ # ruff: noqa: F401 -from .cifar10 import CIFAR10DataModule -from .cifar100 import CIFAR100DataModule -from .imagenet import ImageNetDataModule -from .mnist import MNISTDataModule -from .tiny_imagenet import TinyImageNetDataModule +from .classification.cifar10 import CIFAR10DataModule +from .classification.cifar100 import CIFAR100DataModule +from .classification.imagenet import ImageNetDataModule +from .classification.mnist import MNISTDataModule +from .classification.tiny_imagenet import TinyImageNetDataModule +from .segmentation import CamVidDataModule from .uci_regression import UCIDataModule diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py similarity index 99% rename from torch_uncertainty/datamodules/cifar10.py rename to torch_uncertainty/datamodules/classification/cifar10.py index 0f0766fd..0ab6948c 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -9,12 +9,11 @@ from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR10, SVHN +from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H from torch_uncertainty.transforms import Cutout -from .abstract import AbstractDataModule - class CIFAR10DataModule(AbstractDataModule): num_classes = 10 diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py similarity index 99% rename from torch_uncertainty/datamodules/cifar100.py rename to torch_uncertainty/datamodules/classification/cifar100.py index ede0da5e..b4ae095c 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -10,12 +10,11 @@ from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR100, SVHN +from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C from torch_uncertainty.transforms import Cutout -from .abstract import AbstractDataModule - class CIFAR100DataModule(AbstractDataModule): num_classes = 100 diff --git a/torch_uncertainty/datamodules/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py similarity index 100% rename from torch_uncertainty/datamodules/imagenet.py rename to torch_uncertainty/datamodules/classification/imagenet.py diff --git a/torch_uncertainty/datamodules/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py similarity index 100% rename from torch_uncertainty/datamodules/mnist.py rename to torch_uncertainty/datamodules/classification/mnist.py diff --git a/torch_uncertainty/datamodules/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py similarity index 99% rename from torch_uncertainty/datamodules/tiny_imagenet.py rename to torch_uncertainty/datamodules/classification/tiny_imagenet.py index a6ad7d08..a537efff 100644 --- a/torch_uncertainty/datamodules/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -9,10 +9,9 @@ from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN +from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet -from .abstract import AbstractDataModule - class TinyImageNetDataModule(AbstractDataModule): num_classes = 200 diff --git a/torch_uncertainty/datamodules/segmentation/__init__.py b/torch_uncertainty/datamodules/segmentation/__init__.py index e69de29b..008b7252 100644 --- a/torch_uncertainty/datamodules/segmentation/__init__.py +++ b/torch_uncertainty/datamodules/segmentation/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .camvid import CamVidDataModule From dbcf99d35dd30768ef6e9e82189a50eafc7530fd Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 15 Feb 2024 10:35:56 +0100 Subject: [PATCH 016/148] :zap: Update all datamodules & tests --- tests/datamodules/test_cifar100_datamodule.py | 64 ++++++++----------- tests/datamodules/test_imagenet_datamodule.py | 56 ++++++++-------- tests/datamodules/test_mnist_datamodule.py | 25 ++------ .../test_tiny_imagenet_datamodule.py | 40 +++++------- .../test_uci_regression_datamodule.py | 10 +-- torch_uncertainty/datamodules/abstract.py | 19 +----- .../datamodules/classification/imagenet.py | 23 +------ .../datamodules/classification/mnist.py | 18 +----- .../classification/tiny_imagenet.py | 52 +++++++-------- .../datamodules/uci_regression.py | 12 ---- 10 files changed, 108 insertions(+), 211 deletions(-) diff --git a/tests/datamodules/test_cifar100_datamodule.py b/tests/datamodules/test_cifar100_datamodule.py index 5ca47529..e032d332 100644 --- a/tests/datamodules/test_cifar100_datamodule.py +++ b/tests/datamodules/test_cifar100_datamodule.py @@ -1,6 +1,3 @@ -from argparse import ArgumentParser -from pathlib import Path - import pytest from torchvision.datasets import CIFAR100 @@ -13,14 +10,7 @@ class TestCIFAR100DataModule: """Testing the CIFAR100DataModule datamodule class.""" def test_cifar100(self): - parser = ArgumentParser() - parser = CIFAR100DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - args.cutout = 8 - - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR100 assert isinstance(dm.train_transform.transforms[2], Cutout) @@ -41,18 +31,21 @@ def test_cifar100(self): dm.setup("test") dm.test_dataloader() - args.test_alt = "c" - args.cutout = 0 - args.root = Path(args.root) - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", batch_size=128, cutout=0, test_alt="c" + ) dm.dataset = DummyClassificationDataset with pytest.raises(ValueError): dm.setup() - args.test_alt = None - args.num_dataloaders = 2 - args.val_split = 0.1 - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", + batch_size=128, + cutout=0, + test_alt=None, + val_split=0.1, + num_dataloaders=2, + ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset @@ -62,27 +55,25 @@ def test_cifar100(self): with pytest.raises(ValueError): dm.setup("other") - args.num_dataloaders = 1 - args.cutout = 8 - args.randaugment = True with pytest.raises(ValueError): - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", + batch_size=128, + num_dataloaders=1, + cutout=8, + randaugment=True, + ) - args.cutout = None - dm = CIFAR100DataModule(**vars(args)) - args.randaugment = False + dm = CIFAR100DataModule( + root="./data/", batch_size=128, randaugment=True + ) - args.auto_augment = "rand-m9-n2-mstd0.5" - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule( + root="./data/", batch_size=128, auto_augment="rand-m9-n2-mstd0.5" + ) def test_cifar100_cv(self): - parser = ArgumentParser() - parser = CIFAR100DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule(root="./data/", batch_size=128) dm.dataset = ( lambda root, train, download, transform: DummyClassificationDataset( root, @@ -94,8 +85,7 @@ def test_cifar100_cv(self): ) dm.make_cross_val_splits(2, 1) - args.val_split = 0.1 - dm = CIFAR100DataModule(**vars(args)) + dm = CIFAR100DataModule(root="./data/", batch_size=128, val_split=0.1) dm.dataset = ( lambda root, train, download, transform: DummyClassificationDataset( root, diff --git a/tests/datamodules/test_imagenet_datamodule.py b/tests/datamodules/test_imagenet_datamodule.py index 23af2e63..e37ea442 100644 --- a/tests/datamodules/test_imagenet_datamodule.py +++ b/tests/datamodules/test_imagenet_datamodule.py @@ -1,5 +1,4 @@ -import pathlib -from argparse import ArgumentParser +from pathlib import Path import pytest from torchvision.datasets import ImageNet @@ -12,12 +11,7 @@ class TestImageNetDataModule: """Testing the ImageNetDataModule datamodule class.""" def test_imagenet(self): - parser = ArgumentParser() - parser = ImageNetDataModule.add_argparse_args(parser) - - args = parser.parse_args("") - args.val_split = 0.1 - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=0.1) assert dm.dataset == ImageNet @@ -27,11 +21,8 @@ def test_imagenet(self): dm.setup() dm.setup("test") - args.val_split = ( - pathlib.Path(__file__).parent.resolve() - / "../assets/dummy_indices.yaml" - ) - dm = ImageNetDataModule(**vars(args)) + path = Path(__file__).parent.resolve() / "../assets/dummy_indices.yaml" + dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=path) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.setup("fit") @@ -55,21 +46,22 @@ def test_imagenet(self): dm.setup("other") for test_alt in ["r", "o", "a"]: - args.test_alt = test_alt - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", batch_size=128, test_alt=test_alt + ) with pytest.raises(ValueError): dm.setup() - args.test_alt = "x" with pytest.raises(ValueError): - dm = ImageNetDataModule(**vars(args)) - - args.test_alt = None + dm = ImageNetDataModule( + root="./data/", batch_size=128, test_alt="x" + ) for ood_ds in ["inaturalist", "imagenet-o", "textures", "openimage-o"]: - args.ood_ds = ood_ds - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", batch_size=128, ood_ds=ood_ds + ) if ood_ds == "inaturalist": dm.eval_ood = True dm.dataset = DummyClassificationDataset @@ -78,21 +70,23 @@ def test_imagenet(self): dm.setup("test") dm.test_dataloader() - args.ood_ds = "other" with pytest.raises(ValueError): - dm = ImageNetDataModule(**vars(args)) - - args.ood_ds = "svhn" + dm = ImageNetDataModule( + root="./data/", batch_size=128, ood_ds="other" + ) for procedure in ["ViT", "A3"]: - args.procedure = procedure - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", + batch_size=128, + ood_ds="svhn", + procedure=procedure, + ) - args.procedure = "A2" with pytest.raises(ValueError): - dm = ImageNetDataModule(**vars(args)) + dm = ImageNetDataModule( + root="./data/", batch_size=128, procedure="A2" + ) - args.procedure = None - args.rand_augment_opt = "rand-m9-n2-mstd0.5" with pytest.raises(FileNotFoundError): dm._verify_splits(split="test") diff --git a/tests/datamodules/test_mnist_datamodule.py b/tests/datamodules/test_mnist_datamodule.py index 36255381..234eb613 100644 --- a/tests/datamodules/test_mnist_datamodule.py +++ b/tests/datamodules/test_mnist_datamodule.py @@ -1,6 +1,3 @@ -from argparse import ArgumentParser -from pathlib import Path - import pytest from torch import nn from torchvision.datasets import MNIST @@ -14,28 +11,20 @@ class TestMNISTDataModule: """Testing the MNISTDataModule datamodule class.""" def test_mnist_cutout(self): - parser = ArgumentParser() - parser = MNISTDataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 16 - args = parser.parse_args("") - args.cutout = 16 - args.val_split = 0.1 - dm = MNISTDataModule(**vars(args)) + dm = MNISTDataModule( + root="./data/", batch_size=128, cutout=16, val_split=0.1 + ) assert dm.dataset == MNIST assert isinstance(dm.train_transform.transforms[0], Cutout) - args.root = Path(args.root) - args.ood_ds = "not" - args.cutout = 0 - args.val_split = 0 - dm = MNISTDataModule(**vars(args)) + dm = MNISTDataModule( + root="./data/", batch_size=128, ood_ds="not", cutout=0, val_split=0 + ) assert isinstance(dm.train_transform.transforms[0], nn.Identity) - args.ood_ds = "other" with pytest.raises(ValueError): - MNISTDataModule(**vars(args)) + MNISTDataModule(root="./data/", batch_size=128, ood_ds="other") dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset diff --git a/tests/datamodules/test_tiny_imagenet_datamodule.py b/tests/datamodules/test_tiny_imagenet_datamodule.py index 954e07ef..ca519347 100644 --- a/tests/datamodules/test_tiny_imagenet_datamodule.py +++ b/tests/datamodules/test_tiny_imagenet_datamodule.py @@ -1,5 +1,3 @@ -from argparse import ArgumentParser - import pytest from tests._dummies.dataset import DummyClassificationDataset @@ -11,24 +9,25 @@ class TestTinyImageNetDataModule: """Testing the TinyImageNetDataModule datamodule class.""" def test_tiny_imagenet(self): - parser = ArgumentParser() - parser = TinyImageNetDataModule.add_argparse_args(parser) - - args = parser.parse_args("") - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule(root="./data/", batch_size=128) assert dm.dataset == TinyImageNet - args.rand_augment_opt = "rand-m9-n3-mstd0.5" - args.ood_ds = "imagenet-o" - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule( + root="./data/", + batch_size=128, + rand_augment_opt="rand-m9-n3-mstd0.5", + ood_ds="imagenet-o", + ) - args.ood_ds = "textures" - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule( + root="./data/", batch_size=128, ood_ds="textures" + ) - args.ood_ds = "other" with pytest.raises(ValueError): - TinyImageNetDataModule(**vars(args)) + TinyImageNetDataModule( + root="./data/", batch_size=128, ood_ds="other" + ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset @@ -50,20 +49,15 @@ def test_tiny_imagenet(self): dm.test_dataloader() def test_tiny_imagenet_cv(self): - parser = ArgumentParser() - parser = TinyImageNetDataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 8 - args = parser.parse_args("") - - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule(root="./data/", batch_size=128) dm.dataset = lambda root, split, transform: DummyClassificationDataset( root, split=split, transform=transform, num_images=20 ) dm.make_cross_val_splits(2, 1) - args.val_split = 0.1 - dm = TinyImageNetDataModule(**vars(args)) + dm = TinyImageNetDataModule( + root="./data/", batch_size=128, val_split=0.1 + ) dm.dataset = lambda root, split, transform: DummyClassificationDataset( root, split=split, transform=transform, num_images=20 ) diff --git a/tests/datamodules/test_uci_regression_datamodule.py b/tests/datamodules/test_uci_regression_datamodule.py index 7094ce31..9c8155fa 100644 --- a/tests/datamodules/test_uci_regression_datamodule.py +++ b/tests/datamodules/test_uci_regression_datamodule.py @@ -1,4 +1,3 @@ -from argparse import ArgumentParser from functools import partial from tests._dummies.dataset import DummyRegressionDataset @@ -9,12 +8,9 @@ class TestUCIDataModule: """Testing the UCIDataModule datamodule class.""" def test_uci_regression(self): - parser = ArgumentParser() - parser = UCIDataModule.add_argparse_args(parser) - - args = parser.parse_args("") - - dm = UCIDataModule(dataset_name="kin8nm", **vars(args)) + dm = UCIDataModule( + dataset_name="kin8nm", root="./data/", batch_size=128 + ) dm.dataset = partial(DummyRegressionDataset, num_samples=64) dm.prepare_data() diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index bada3daf..c5c95616 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -1,6 +1,5 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal from lightning.pytorch.core import LightningDataModule from numpy.typing import ArrayLike @@ -148,22 +147,6 @@ def make_cross_val_splits( return cv_dm - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=None) - p.add_argument("--num_workers", type=int, default=4) - p.add_argument("--use_cv", action="store_true") - p.add_argument("--n_splits", type=int, default=10) - p.add_argument("--train_over", type=int, default=4) - return parent_parser - class CrossValDataModule(AbstractDataModule): def __init__( diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index f8935d44..e756b49d 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -1,7 +1,6 @@ import copy -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import torchvision.transforms as T import yaml @@ -49,7 +48,6 @@ def __init__( num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for ImageNet. @@ -257,25 +255,6 @@ def test_dataloader(self) -> list[DataLoader]: dataloader.append(self._data_loader(self.ood)) return dataloader - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for ImageNet - p.add_argument("--eval-ood", action="store_true") - p.add_argument("--ood_ds", choices=cls.ood_datasets, default="svhn") - p.add_argument("--test_alt", choices=cls.test_datasets, default=None) - p.add_argument("--procedure", choices=["ViT", "A3"], default=None) - p.add_argument("--train_size", type=int, default=224) - p.add_argument( - "--rand_augment", dest="rand_augment_opt", type=str, default=None - ) - return parent_parser - def read_indices(path: Path) -> list[str]: # coverage: ignore """Read a file and return its lines as a list. diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 19e5154a..e4ef8a1f 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -1,6 +1,5 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal import torchvision.transforms as T from torch import nn @@ -31,7 +30,6 @@ def __init__( test_alt: Literal["c"] | None = None, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for MNIST. @@ -158,17 +156,3 @@ def test_dataloader(self) -> list[DataLoader]: if self.eval_ood: dataloader.append(self._data_loader(self.ood)) return dataloader - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for MNIST - p.add_argument("--eval-ood", action="store_true") - p.add_argument("--ood_ds", choices=cls.ood_datasets, default="fashion") - p.add_argument("--test_alt", choices=["c"], default=None) - return parent_parser diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index a537efff..35f4894b 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -1,12 +1,12 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any, Literal +from typing import Literal +import numpy as np import torchvision.transforms as T from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import ConcatDataset, DataLoader, random_split from torchvision.datasets import DTD, SVHN from torch_uncertainty.datamodules.abstract import AbstractDataModule @@ -23,12 +23,12 @@ def __init__( root: str | Path, batch_size: int, eval_ood: bool = False, + val_split: float | None = None, ood_ds: str = "svhn", rand_augment_opt: str | None = None, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: super().__init__( root=root, @@ -40,6 +40,7 @@ def __init__( # TODO: COMPUTE STATS self.eval_ood = eval_ood + self.val_split = val_split self.ood_ds = ood_ds self.dataset = TinyImageNet @@ -120,16 +121,26 @@ def prepare_data(self) -> None: # coverage: ignore def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: - self.train = self.dataset( + full = self.dataset( self.root, split="train", transform=self.train_transform, ) - self.val = self.dataset( - self.root, - split="val", - transform=self.test_transform, - ) + if self.val_split: + self.train, self.val = random_split( + full, + [ + 1 - self.val_split, + self.val_split, + ], + ) + else: + self.train = full + self.val = self.dataset( + self.root, + split="val", + transform=self.test_transform, + ) elif stage == "test": self.test = self.dataset( self.root, @@ -199,22 +210,11 @@ def test_dataloader(self) -> list[DataLoader]: return dataloader def _get_train_data(self) -> ArrayLike: + if self.val_split: + return self.train.dataset.samples[self.train.indices] return self.train.samples def _get_train_targets(self) -> ArrayLike: - return self.train.label_data - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - - # Arguments for Tiny Imagenet - p.add_argument( - "--rand_augment", dest="rand_augment_opt", type=str, default=None - ) - p.add_argument("--eval-ood", action="store_true") - return parent_parser + if self.val_split: + return np.array(self.train.dataset.label_data)[self.train.indices] + return np.array(self.train.label_data) diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 1fe028c7..69f49b85 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -1,7 +1,5 @@ -from argparse import ArgumentParser from functools import partial from pathlib import Path -from typing import Any from torch import Generator from torch.utils.data import random_split @@ -96,13 +94,3 @@ def setup(self, stage: str | None = None) -> None: # DataLoader: UCI Regression test dataloader. # """ # return self._data_loader(self.test) - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - super().add_argparse_args(parent_parser) - - return parent_parser From fcc3799c7eefd1e9ac7d7b9eca20496ad1a5c9df Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 15 Feb 2024 13:41:04 +0100 Subject: [PATCH 017/148] :hammer: Rework ood_criterion as a string --- tests/datamodules/test_cifar10_datamodule.py | 55 ++++++++----------- .../classification/deep_ensembles.py | 11 +--- .../baselines/classification/resnet.py | 27 ++++----- .../baselines/classification/vgg.py | 26 ++++----- .../baselines/classification/wideresnet.py | 27 ++++----- torch_uncertainty/metrics/iou.py | 6 +- torch_uncertainty/routines/classification.py | 42 ++++++-------- 7 files changed, 80 insertions(+), 114 deletions(-) diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/test_cifar10_datamodule.py index c9697b7e..a66357a2 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/test_cifar10_datamodule.py @@ -90,34 +90,27 @@ def test_cifar10_main(self): auto_augment="rand-m9-n2-mstd0.5", ) - # def test_cifar10_cv(self): - # parser = ArgumentParser() - # parser = CIFAR10DataModule.add_argparse_args(parser) - - # # Simulate that cutout is set to 8 - # args = parser.parse_args("") - - # dm = CIFAR10DataModule(**vars(args)) - # dm.dataset = ( - # lambda root, train, download, transform: DummyClassificationDataset( - # root, - # train=train, - # download=download, - # transform=transform, - # num_images=20, - # ) - # ) - # dm.make_cross_val_splits(2, 1) - - # args.val_split = 0.1 - # dm = CIFAR10DataModule(**vars(args)) - # dm.dataset = ( - # lambda root, train, download, transform: DummyClassificationDataset( - # root, - # train=train, - # download=download, - # transform=transform, - # num_images=20, - # ) - # ) - # dm.make_cross_val_splits(2, 1) + def test_cifar100_cv(self): + dm = CIFAR10DataModule(root="./data/", batch_size=128) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) + + dm = CIFAR10DataModule(root="./data/", batch_size=128, val_split=0.1) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index 6cbd42bf..a353c1db 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -22,10 +22,8 @@ def __init__( checkpoint_ids: list[int], backbone: Literal["resnet", "vgg", "wideresnet"], eval_ood: bool = False, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, + eval_grouping_loss: bool = False, + ood_criterion: Literal["msp", "logits", "entropy", "mi", "VR"] = "msp", log_plots: bool = False, calibration_set: Literal["val", "test"] | None = None, ) -> None: @@ -54,10 +52,7 @@ def __init__( loss=None, num_estimators=de.num_estimators, eval_ood=eval_ood, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, + ood_criterion=ood_criterion, log_plots=log_plots, calibration_set=calibration_set, ) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 0147aa31..784904b7 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -125,14 +125,12 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, + ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, eval_ood: bool = False, + eval_grouping_loss: bool = False, pretrained: bool = False, ) -> None: r"""ResNet backbone baseline for classification providing support for @@ -204,14 +202,11 @@ def __init__( ``1``. batch_repeat (int, optional): Number of times to repeat the batch. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. + ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. + MSP is the maximum softmax probability, logit is the maximum + logit, entropy is the entropy of the mean prediction, mi is the + mutual information of the ensemble and vr is the variation ratio + of the ensemble. log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in @@ -220,6 +215,8 @@ def __init__( ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. + eval_grouping_loss (bool, optional): Indicates whether to evaluate the + grouping loss or not. Defaults to ``False``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. Defaults to ``False``. @@ -298,10 +295,8 @@ def __init__( mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, eval_ood=eval_ood, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, + eval_grouping_loss=eval_grouping_loss, + ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, calibration_set=calibration_set, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index de1ca9e4..cc6e2655 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -53,14 +53,12 @@ def __init__( groups: int = 1, alpha: int | None = None, gamma: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, + ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, eval_ood: bool = False, + eval_grouping_loss: bool = False, ) -> None: r"""VGG backbone baseline for classification providing support for various versions and architectures. @@ -110,14 +108,11 @@ def __init__( gamma (int, optional): Number of groups within each estimator. Only used if :attr:`version` is ``"packed"`` and scales with :attr:`groups`. Defaults to ``1s``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. + ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. + MSP is the maximum softmax probability, logit is the maximum + logit, entropy is the entropy of the mean prediction, mi is the + mutual information of the ensemble and vr is the variation ratio + of the ensemble. log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in @@ -126,6 +121,8 @@ def __init__( ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. + eval_grouping_loss (bool, optional): Indicates whether to evaluate the + grouping loss or not. Defaults to ``False``. Raises: ValueError: If :attr:`version` is not either ``"std"``, @@ -196,10 +193,7 @@ def __init__( mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, eval_ood=eval_ood, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, + ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, calibration_set=calibration_set, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index b7496158..5ce5026a 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -53,14 +53,12 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, + ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, eval_ood: bool = False, + eval_grouping_loss: bool = False, # pretrained: bool = False, ) -> None: r"""Wide-ResNet28x10 backbone baseline for classification providing support @@ -118,14 +116,11 @@ def __init__( ``1``. batch_repeat (int, optional): Number of times to repeat the batch. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. + ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. + MSP is the maximum softmax probability, logit is the maximum + logit, entropy is the entropy of the mean prediction, mi is the + mutual information of the ensemble and vr is the variation ratio + of the ensemble. log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in @@ -134,6 +129,8 @@ def __init__( ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. + eval_grouping_loss (bool, optional): Indicates whether to evaluate the + grouping loss or not. Defaults to ``False``. Raises: ValueError: If :attr:`version` is not either ``"std"``, @@ -209,10 +206,8 @@ def __init__( mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, eval_ood=eval_ood, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, + eval_grouping_loss=eval_grouping_loss, + ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, calibration_set=calibration_set, diff --git a/torch_uncertainty/metrics/iou.py b/torch_uncertainty/metrics/iou.py index c6259ed1..a6695b1f 100644 --- a/torch_uncertainty/metrics/iou.py +++ b/torch_uncertainty/metrics/iou.py @@ -5,6 +5,8 @@ class IntersectionOverUnion(MulticlassStatScores): + """Compute the Intersection over Union (IoU) score.""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False @@ -27,8 +29,6 @@ def update(self, preds: Tensor, target: Tensor) -> None: super().update(preds, target) def compute(self) -> Tensor: - """Compute the Intersection over Union (IoU) based on inputs passed to - ``update``. - """ + """Compute the Intersection over Union (IoU) based on saved inputs.""" tp, fp, _, fn = self._final_state() return _safe_divide(tp, tp + fp + fn) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index c4552aac..5bf25f2b 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -52,10 +52,7 @@ def __init__( cutmix_alpha: float = 0, eval_ood: bool = False, eval_grouping_loss: bool = False, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, + ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, @@ -84,14 +81,11 @@ def __init__( detection performance or not. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. - use_entropy (bool, optional): Indicates whether to use the entropy - values as the OOD criterion or not. Defaults to ``False``. - use_logits (bool, optional): Indicates whether to use the logits as the - OOD criterion or not. Defaults to ``False``. - use_mi (bool, optional): Indicates whether to use the mutual - information as the OOD criterion or not. Defaults to ``False``. - use_variation_ratio (bool, optional): Indicates whether to use the - variation ratio as the OOD criterion or not. Defaults to ``False``. + ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. + MSP is the maximum softmax probability, logit is the maximum + logit, entropy is the entropy of the mean prediction, mi is the + mutual information of the ensemble and vr is the variation ratio + of the ensemble. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. save_in_csv(bool, optional): __TODO__ @@ -109,16 +103,19 @@ def __init__( if format_batch_fn is None: format_batch_fn = nn.Identity() - if (use_logits + use_entropy + use_mi + use_variation_ratio) > 1: - raise ValueError("You cannot choose more than one OOD criterion.") - if not isinstance(num_estimators, int) and num_estimators < 1: raise ValueError( "The number of estimators must be a positive integer >= 1." f"Got {num_estimators}." ) - if num_estimators == 1 and (use_mi or use_variation_ratio): + if ood_criterion not in ["msp", "logit", "entropy", "mi", "vr"]: + raise ValueError( + "The OOD criterion must be one of 'msp', 'logit', 'entropy'," + f" 'mi' or 'vr'. Got {ood_criterion}." + ) + + if num_estimators == 1 and ood_criterion in ["mi", "vr"]: raise ValueError( "You cannot use mutual information or variation ratio with a single" " model." @@ -147,10 +144,7 @@ def __init__( self.num_estimators = num_estimators self.eval_ood = eval_ood self.eval_grouping_loss = eval_grouping_loss - self.use_logits = use_logits - self.use_entropy = use_entropy - self.use_mi = use_mi - self.use_variation_ratio = use_variation_ratio + self.ood_criterion = ood_criterion self.log_plots = log_plots self.save_in_csv = save_in_csv self.calibration_set = calibration_set @@ -443,16 +437,16 @@ def test_step( confs = probs.max(-1)[0] - if self.use_logits: + if self.criterion == "logit": ood_scores = -logits.mean(dim=1).max(dim=-1)[0] - elif self.use_entropy: + elif self.criterion == "entropy": ood_scores = ( torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) ) - elif self.use_mi: + elif self.criterion == "mi": mi_metric = MutualInformation(reduction="none") ood_scores = mi_metric(probs_per_est) - elif self.use_variation_ratio: + elif self.criterion == "vr": vr_metric = VariationRatio(reduction="none", probabilistic=False) ood_scores = vr_metric(probs_per_est.transpose(0, 1)) else: From 91cfc31bcae2e44261e3cad41f75be8b89aff1c1 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 15 Feb 2024 14:24:34 +0100 Subject: [PATCH 018/148] :hammer: Refactor datamodules tests --- tests/datamodules/classification/__init__.py | 0 .../test_abstract_datamodule.py | 14 +++++++++----- .../test_cifar100_datamodule.py | 0 .../test_cifar10_datamodule.py | 2 +- .../test_imagenet_datamodule.py | 4 +++- .../{ => classification}/test_mnist_datamodule.py | 0 .../test_tiny_imagenet_datamodule.py | 0 .../test_uci_regression_datamodule.py | 0 8 files changed, 13 insertions(+), 7 deletions(-) create mode 100644 tests/datamodules/classification/__init__.py rename tests/datamodules/{ => classification}/test_abstract_datamodule.py (75%) rename tests/datamodules/{ => classification}/test_cifar100_datamodule.py (100%) rename tests/datamodules/{ => classification}/test_cifar10_datamodule.py (99%) rename tests/datamodules/{ => classification}/test_imagenet_datamodule.py (96%) rename tests/datamodules/{ => classification}/test_mnist_datamodule.py (100%) rename tests/datamodules/{ => classification}/test_tiny_imagenet_datamodule.py (100%) rename tests/datamodules/{ => classification}/test_uci_regression_datamodule.py (100%) diff --git a/tests/datamodules/classification/__init__.py b/tests/datamodules/classification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/classification/test_abstract_datamodule.py similarity index 75% rename from tests/datamodules/test_abstract_datamodule.py rename to tests/datamodules/classification/test_abstract_datamodule.py index 02c5b8e8..c756d7be 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/classification/test_abstract_datamodule.py @@ -13,7 +13,7 @@ class TestAbstractDataModule: """Testing the AbstractDataModule class.""" def test_errors(self): - dm = AbstractDataModule("root", 128, 4, True, True) + dm = AbstractDataModule("root", 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): dm.setup() dm._get_train_data() @@ -24,12 +24,14 @@ class TestCrossValDataModule: """Testing the CrossValDataModule class.""" def test_cv_main(self): - dm = AbstractDataModule("root", 128, 4, True, True) + dm = AbstractDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + cv_dm = CrossValDataModule( + "root", [0], [1], dm, 128, 0.0, 4, True, True + ) cv_dm.setup() cv_dm.setup("test") @@ -44,12 +46,14 @@ def test_cv_main(self): cv_dm.test_dataloader() def test_errors(self): - dm = AbstractDataModule("root", 128, 4, True, True) + dm = AbstractDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + cv_dm = CrossValDataModule( + "root", [0], [1], dm, 128, 0.0, 4, True, True + ) with pytest.raises(NotImplementedError): cv_dm.setup() cv_dm._get_train_data() diff --git a/tests/datamodules/test_cifar100_datamodule.py b/tests/datamodules/classification/test_cifar100_datamodule.py similarity index 100% rename from tests/datamodules/test_cifar100_datamodule.py rename to tests/datamodules/classification/test_cifar100_datamodule.py diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/classification/test_cifar10_datamodule.py similarity index 99% rename from tests/datamodules/test_cifar10_datamodule.py rename to tests/datamodules/classification/test_cifar10_datamodule.py index a66357a2..3bc5931f 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/classification/test_cifar10_datamodule.py @@ -90,7 +90,7 @@ def test_cifar10_main(self): auto_augment="rand-m9-n2-mstd0.5", ) - def test_cifar100_cv(self): + def test_cifar10_cv(self): dm = CIFAR10DataModule(root="./data/", batch_size=128) dm.dataset = ( lambda root, train, download, transform: DummyClassificationDataset( diff --git a/tests/datamodules/test_imagenet_datamodule.py b/tests/datamodules/classification/test_imagenet_datamodule.py similarity index 96% rename from tests/datamodules/test_imagenet_datamodule.py rename to tests/datamodules/classification/test_imagenet_datamodule.py index e37ea442..4689c2d9 100644 --- a/tests/datamodules/test_imagenet_datamodule.py +++ b/tests/datamodules/classification/test_imagenet_datamodule.py @@ -21,7 +21,9 @@ def test_imagenet(self): dm.setup() dm.setup("test") - path = Path(__file__).parent.resolve() / "../assets/dummy_indices.yaml" + path = ( + Path(__file__).parent.resolve() / "../../assets/dummy_indices.yaml" + ) dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=path) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset diff --git a/tests/datamodules/test_mnist_datamodule.py b/tests/datamodules/classification/test_mnist_datamodule.py similarity index 100% rename from tests/datamodules/test_mnist_datamodule.py rename to tests/datamodules/classification/test_mnist_datamodule.py diff --git a/tests/datamodules/test_tiny_imagenet_datamodule.py b/tests/datamodules/classification/test_tiny_imagenet_datamodule.py similarity index 100% rename from tests/datamodules/test_tiny_imagenet_datamodule.py rename to tests/datamodules/classification/test_tiny_imagenet_datamodule.py diff --git a/tests/datamodules/test_uci_regression_datamodule.py b/tests/datamodules/classification/test_uci_regression_datamodule.py similarity index 100% rename from tests/datamodules/test_uci_regression_datamodule.py rename to tests/datamodules/classification/test_uci_regression_datamodule.py From 08ec11da957213733f07090ba1f565eff052fdf1 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 15 Feb 2024 14:24:57 +0100 Subject: [PATCH 019/148] :zap: Update CamVid dm & add test --- tests/datamodules/test_camvid.py | 27 +++++++++++++++++++ .../datamodules/segmentation/camvid.py | 17 ++---------- 2 files changed, 29 insertions(+), 15 deletions(-) create mode 100644 tests/datamodules/test_camvid.py diff --git a/tests/datamodules/test_camvid.py b/tests/datamodules/test_camvid.py new file mode 100644 index 00000000..e9cf05ff --- /dev/null +++ b/tests/datamodules/test_camvid.py @@ -0,0 +1,27 @@ +import pytest + +from tests._dummies.dataset import DummyClassificationDataset +from torch_uncertainty.datamodules.segmentation import CamVidDataModule +from torch_uncertainty.datasets.segmentation import CamVid + + +class TestCamVidDataModule: + """Testing the CamVidDataModule datamodule.""" + + def test_camvid_main(self): + # parser = ArgumentParser() + # parser = CIFAR10DataModule.add_argparse_args(parser) + + # Simulate that cutout is set to 16 + dm = CamVidDataModule(root="./data/", batch_size=128) + + assert dm.dataset == CamVid + + dm.dataset = DummyClassificationDataset + + dm.prepare_data() + dm.setup() + dm.setup("test") + + with pytest.raises(ValueError): + dm.setup("xxx") diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index fb483c75..b8186dd1 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -1,6 +1,4 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any from torchvision.transforms import v2 @@ -13,6 +11,7 @@ def __init__( self, root: str | Path, batch_size: int, + val_split: float = 0.0, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, @@ -20,11 +19,11 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) - self.dataset = CamVid self.transform_train = v2.Compose( @@ -60,15 +59,3 @@ def setup(self, stage: str | None = None) -> None: ) else: raise ValueError(f"Stage {stage} is not supported.") - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--num_workers", type=int, default=4) - return parent_parser From bd4cb60f8eddec18042b3f1d626b83047179a6db Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 15 Feb 2024 14:25:24 +0100 Subject: [PATCH 020/148] :hammer: Refactor datamodules' val_split arg. --- torch_uncertainty/datamodules/abstract.py | 40 +++++++++---------- .../datamodules/classification/cifar10.py | 1 + .../datamodules/classification/cifar100.py | 4 +- .../datamodules/classification/imagenet.py | 1 + .../datamodules/classification/mnist.py | 2 +- .../classification/tiny_imagenet.py | 4 +- .../datamodules/uci_regression.py | 5 +-- 7 files changed, 27 insertions(+), 30 deletions(-) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index c5c95616..9925f6e2 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -18,10 +18,10 @@ def __init__( self, root: str | Path, batch_size: int, - num_workers: int = 1, - pin_memory: bool = True, - persistent_workers: bool = True, - **kwargs, + val_split: float, + num_workers: int, + pin_memory: bool, + persistent_workers: bool, ) -> None: """Abstract DataModule class. @@ -32,17 +32,16 @@ def __init__( Args: root (str): Root directory of the datasets. batch_size (int): Number of samples per batch. - num_workers (int): Number of workers to use for data loading. Defaults - to ``1``. - pin_memory (bool): Whether to pin memory. Defaults to ``True``. - persistent_workers (bool): Whether to use persistent workers. Defaults - to ``True``. - kwargs (Any): Other arguments. + val_split (float): Share of samples to use for validation. + num_workers (int): Number of workers to use for data loading. + pin_memory (bool): Whether to pin memory. + persistent_workers (bool): Whether to use persistent workers. """ super().__init__() self.root = Path(root) self.batch_size = batch_size + self.val_split = val_split self.num_workers = num_workers self.pin_memory = pin_memory @@ -139,6 +138,7 @@ def make_cross_val_splits( val_idx=val_idx, datamodule=self, batch_size=self.batch_size, + val_split=self.val_split, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -156,18 +156,18 @@ def __init__( val_idx: ArrayLike, datamodule: AbstractDataModule, batch_size: int, - num_workers: int = 1, - pin_memory: bool = True, - persistent_workers: bool = True, - **kwargs, + val_split: float, + num_workers: int, + pin_memory: bool, + persistent_workers: bool, ) -> None: super().__init__( - root, - batch_size, - num_workers, - pin_memory, - persistent_workers, - **kwargs, + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, ) self.train_idx = train_idx diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 0ab6948c..01e50dfc 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -61,6 +61,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index b4ae095c..a03b48bc 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -37,7 +37,6 @@ def __init__( num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: """DataModule for CIFAR100. @@ -61,18 +60,17 @@ def __init__( pin_memory (bool): Whether to pin memory. Defaults to ``True``. persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. - kwargs: Additional arguments. """ super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) self.eval_ood = eval_ood - self.val_split = val_split self.num_dataloaders = num_dataloaders if test_alt == "c": diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index e756b49d..dc6775f2 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -75,6 +75,7 @@ def __init__( super().__init__( root=Path(root), batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index e4ef8a1f..24f15a9f 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -54,6 +54,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -61,7 +62,6 @@ def __init__( self.eval_ood = eval_ood self.batch_size = batch_size - self.val_split = val_split if test_alt == "c": self.dataset = MNISTC diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 35f4894b..91929005 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -23,7 +23,7 @@ def __init__( root: str | Path, batch_size: int, eval_ood: bool = False, - val_split: float | None = None, + val_split: float = 0.0, ood_ds: str = "svhn", rand_augment_opt: str | None = None, num_workers: int = 1, @@ -33,6 +33,7 @@ def __init__( super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -40,7 +41,6 @@ def __init__( # TODO: COMPUTE STATS self.eval_ood = eval_ood - self.val_split = val_split self.ood_ds = ood_ds self.dataset = TinyImageNet diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 69f49b85..0d8f96dc 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -23,7 +23,6 @@ def __init__( persistent_workers: bool = True, input_shape: tuple[int, ...] | None = None, split_seed: int = 42, - **kwargs, ) -> None: """The UCI regression datasets. @@ -46,18 +45,16 @@ def __init__( ``None``. split_seed (int, optional): The seed to use for splitting the dataset. Defaults to ``42``. - **kwargs: Additional arguments. """ super().__init__( root=root, batch_size=batch_size, + val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) - self.val_split = val_split - self.dataset = partial( UCIRegression, dataset_name=dataset_name, seed=split_seed ) From c8d3e389c0be8bd2c857d319f45ea3e7dd500cb9 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 4 Mar 2024 15:14:01 +0100 Subject: [PATCH 021/148] :sparkles: Add energy --- .../classification/deep_ensembles.py | 4 ++- .../baselines/classification/resnet.py | 4 ++- .../baselines/classification/vgg.py | 4 ++- .../baselines/classification/wideresnet.py | 4 ++- torch_uncertainty/routines/classification.py | 33 ++++++++++++------- 5 files changed, 34 insertions(+), 15 deletions(-) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index a353c1db..4ffd8405 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -23,7 +23,9 @@ def __init__( backbone: Literal["resnet", "vgg", "wideresnet"], eval_ood: bool = False, eval_grouping_loss: bool = False, - ood_criterion: Literal["msp", "logits", "entropy", "mi", "VR"] = "msp", + ood_criterion: Literal[ + "msp", "logits", "energy", "entropy", "mi", "VR" + ] = "msp", log_plots: bool = False, calibration_set: Literal["val", "test"] | None = None, ) -> None: diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 784904b7..7efda160 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -125,7 +125,9 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index cc6e2655..39f048f3 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -53,7 +53,9 @@ def __init__( groups: int = 1, alpha: int | None = None, gamma: int = 1, - ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 5ce5026a..c1734a4a 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -53,7 +53,9 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 5bf25f2b..13ab8091 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -52,7 +52,9 @@ def __init__( cutmix_alpha: float = 0, eval_ood: bool = False, eval_grouping_loss: bool = False, - ood_criterion: Literal["msp", "logit", "entropy", "mi", "vr"] = "msp", + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, @@ -83,9 +85,9 @@ def __init__( grouping loss or not. Defaults to ``False``. ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. MSP is the maximum softmax probability, logit is the maximum - logit, entropy is the entropy of the mean prediction, mi is the - mutual information of the ensemble and vr is the variation ratio - of the ensemble. + logit, energy the logsumexp of the mean logits, entropy the + entropy of the mean prediction, mi is the mutual information of + the ensemble and vr is the variation ratio of the ensemble. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. save_in_csv(bool, optional): __TODO__ @@ -109,9 +111,16 @@ def __init__( f"Got {num_estimators}." ) - if ood_criterion not in ["msp", "logit", "entropy", "mi", "vr"]: + if ood_criterion not in [ + "msp", + "logit", + "energy", + "entropy", + "mi", + "vr", + ]: raise ValueError( - "The OOD criterion must be one of 'msp', 'logit', 'entropy'," + "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," f" 'mi' or 'vr'. Got {ood_criterion}." ) @@ -437,16 +446,18 @@ def test_step( confs = probs.max(-1)[0] - if self.criterion == "logit": - ood_scores = -logits.mean(dim=1).max(dim=-1)[0] - elif self.criterion == "entropy": + if self.ood_criterion == "logit": + ood_scores = -logits.mean(dim=1).max(dim=-1).values + elif self.ood_criterion == "energy": + ood_scores = -logits.mean(dim=1).logsumexp(dim=-1) + elif self.ood_criterion == "entropy": ood_scores = ( torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) ) - elif self.criterion == "mi": + elif self.ood_criterion == "mi": mi_metric = MutualInformation(reduction="none") ood_scores = mi_metric(probs_per_est) - elif self.criterion == "vr": + elif self.ood_criterion == "vr": vr_metric = VariationRatio(reduction="none", probabilistic=False) ood_scores = vr_metric(probs_per_est.transpose(0, 1)) else: From b38bd39a2d4d3a2c6dfc4b99bbd2100f1531916b Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 13 Mar 2024 09:54:54 +0100 Subject: [PATCH 022/148] :fire: Finish removing ArgvContext in tutorials --- .../tutorial_evidential_classification.py | 12 +++--------- .../tutorial_mc_batch_norm.py | 16 +++------------- auto_tutorials_source/tutorial_mc_dropout.py | 18 +++--------------- 3 files changed, 9 insertions(+), 37 deletions(-) diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index a3a44e17..e7e72bbf 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -39,11 +39,11 @@ # We also import sys to override the command line arguments. import os +import sys from functools import partial from pathlib import Path import torch -from cli_test_helpers import ArgvContext from torch import nn, optim @@ -71,14 +71,8 @@ def optim_lenet(model: nn.Module) -> dict: root = Path(os.path.abspath("")) # We mock the arguments for the trainer. Replace with 25 epochs on your machine. -with ArgvContext( - "file.py", - "--max_epochs", - "5", - "--enable_progress_bar", - "True", -): - args = init_args(datamodule=MNISTDataModule) +sys.argv = ["file.py", "--max_epochs", "5", "--enable_progress_bar", "True"] +args = init_args(datamodule=MNISTDataModule) net_name = "logs/dec-lenet-mnist" diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index b2ed7d4e..3fa05908 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -36,10 +36,10 @@ # We also import sys to override the command line arguments. import os +import sys from pathlib import Path from torch import nn -from cli_test_helpers import ArgvContext # %% # 2. Creating the necessary variables @@ -57,18 +57,8 @@ root = Path(os.path.abspath("")) # We mock the arguments for the trainer -with ArgvContext( - "file.py", - "--max_epochs", - "1", - "--enable_progress_bar", - "False", - "--num_estimators", - "8", - "--max_epochs", - "2" -): - args = init_args(network=ResNet, datamodule=MNISTDataModule) +sys.argv = ["file.py", "--max_epochs", "1", "--enable_progress_bar", "False", "--num_estimators", "8", "--max_epochs", "2"] +args = init_args(network=ResNet, datamodule=MNISTDataModule) net_name = "logs/lenet-mnist" diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 02b291e2..91514492 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -42,10 +42,10 @@ # We also import sys to override the command line arguments. import os +import sys from pathlib import Path from torch import nn -from cli_test_helpers import ArgvContext # %% # 2. Creating the necessary variables @@ -63,20 +63,8 @@ root = Path(os.path.abspath("")) # We mock the arguments for the trainer -with ArgvContext( - "file.py", - "--max_epochs", - "1", - "--enable_progress_bar", - "False", - "--dropout_rate", - "0.6", - "--num_estimators", - "16", - "--max_epochs", - "2" -): - args = init_args(network=ResNet, datamodule=MNISTDataModule) +sys.argv = ["file.py", "--max_epochs", "1", "--enable_progress_bar", "False", "--dropout_rate", "0.6", "--num_estimators", "16", "--max_epochs", "2"] +args = init_args(network=ResNet, datamodule=MNISTDataModule) net_name = "logs/mc-dropout-lenet-mnist" From 333f7cf17e92f53ffb38a99541be45d7757e7096 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 13 Mar 2024 22:44:50 +0100 Subject: [PATCH 023/148] :bug: Remove double max_epochs --- auto_tutorials_source/tutorial_mc_batch_norm.py | 2 +- auto_tutorials_source/tutorial_mc_dropout.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 3fa05908..d4a9e2bd 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -57,7 +57,7 @@ root = Path(os.path.abspath("")) # We mock the arguments for the trainer -sys.argv = ["file.py", "--max_epochs", "1", "--enable_progress_bar", "False", "--num_estimators", "8", "--max_epochs", "2"] +sys.argv = ["file.py", "--enable_progress_bar", "False", "--num_estimators", "8", "--max_epochs", "2"] args = init_args(network=ResNet, datamodule=MNISTDataModule) net_name = "logs/lenet-mnist" diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 91514492..62915574 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -63,7 +63,7 @@ root = Path(os.path.abspath("")) # We mock the arguments for the trainer -sys.argv = ["file.py", "--max_epochs", "1", "--enable_progress_bar", "False", "--dropout_rate", "0.6", "--num_estimators", "16", "--max_epochs", "2"] +sys.argv = ["file.py", "--enable_progress_bar", "False", "--dropout_rate", "0.6", "--num_estimators", "16", "--max_epochs", "2"] args = init_args(network=ResNet, datamodule=MNISTDataModule) net_name = "logs/mc-dropout-lenet-mnist" From 0613d44ff1503d6d2b086809a55fb632637ade4a Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 15 Mar 2024 16:37:21 +0100 Subject: [PATCH 024/148] :shirt: Update the Quickstart page --- docs/source/quickstart.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index e27461f2..eef966d5 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -68,16 +68,16 @@ trains any ResNet architecture on CIFAR10: # model model = ResNet( num_classes=dm.num_classes, - in_channels=dm.in_channels, + in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), optimization_procedure=get_procedure( f"resnet{args.arch}", "cifar10", args.version ), - imagenet_structure=False, + style="cifar", **vars(args), ) - cli_main(model, dm, root, net_name, args) + cli_main(model, dm, args.exp_dir, args.exp_name, args) Run this model with, for instance: From 45ac5114ae1eeca7098ccd8e7049a99e18eb3bca Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 15 Mar 2024 21:45:48 +0100 Subject: [PATCH 025/148] :hammer: Rework the regression routine --- torch_uncertainty/layers/distributions.py | 61 ++++++++ torch_uncertainty/losses.py | 7 +- torch_uncertainty/metrics/nll.py | 34 ++++- torch_uncertainty/routines/classification.py | 1 - torch_uncertainty/routines/regression.py | 150 ++++--------------- 5 files changed, 128 insertions(+), 125 deletions(-) create mode 100644 torch_uncertainty/layers/distributions.py diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py new file mode 100644 index 00000000..74e782bd --- /dev/null +++ b/torch_uncertainty/layers/distributions.py @@ -0,0 +1,61 @@ +import torch.nn.functional as F +from torch import Tensor, distributions, nn + + +class AbstractDistLayer(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + if dim < 1: + raise ValueError(f"dim must be positive, got {dim}.") + self.dim = dim + + def forward(self, x: Tensor) -> distributions.Distribution: + raise NotImplementedError + + +class IndptNormalDistLayer(AbstractDistLayer): + def __init__(self, dim: int, min_scale: float = 1e-3) -> None: + super().__init__(dim) + if min_scale <= 0: + raise ValueError(f"min_scale must be positive, got {min_scale}.") + self.min_scale = min_scale + + def forward(self, x: Tensor) -> distributions.Normal: + """Forward pass of the independent normal distribution layer. + + Args: + x (Tensor): The input tensor of shape (dx2). + + Returns: + distributions.Normal: The independent normal distribution. + """ + loc = x[:, : self.dim] + scale = F.softplus(x[:, self.dim :]) + self.min_scale + if self.dim == 1: + loc = loc.squeeze(1) + scale = scale.squeeze(1) + return distributions.Normal(loc, scale) + + +class IndptLaplaceDistLayer(AbstractDistLayer): + def __init__(self, dim: int, min_scale: float = 1e-3) -> None: + super().__init__(dim) + if min_scale <= 0: + raise ValueError(f"min_scale must be positive, got {min_scale}.") + self.min_scale = min_scale + + def forward(self, x: Tensor) -> distributions.Laplace: + """Forward pass of the independent normal distribution layer. + + Args: + x (Tensor): The input tensor of shape (dx2). + + Returns: + distributions.Laplace: The independent Laplace distribution. + """ + loc = x[:, : self.dim] + scale = F.softplus(x[:, self.dim :]) + self.min_scale + if self.dim == 1: + loc = loc.squeeze(1) + scale = scale.squeeze(1) + return distributions.Laplace(loc, scale) diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index 8afe9661..6e60153a 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -1,5 +1,5 @@ import torch -from torch import Tensor, nn +from torch import Tensor, distributions, nn from torch.nn import functional as F from .layers.bayesian import bayesian_modules @@ -376,3 +376,8 @@ def forward( elif self.reduction == "sum": loss = loss.sum() return loss + + +class DistributionNLL(nn.Module): + def forward(self, dist: distributions.Distribution, target: Tensor): + return -dist.log_prob(target).mean() diff --git a/torch_uncertainty/metrics/nll.py b/torch_uncertainty/metrics/nll.py index 08c98bb9..14dac386 100644 --- a/torch_uncertainty/metrics/nll.py +++ b/torch_uncertainty/metrics/nll.py @@ -7,9 +7,9 @@ class NegativeLogLikelihood(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = False - full_state_update: bool = False + is_differentiabled = False + higher_is_better = False + full_state_update = False def __init__( self, @@ -140,3 +140,31 @@ def update( mean, target, var, reduction="sum" ) self.total += target.size(0) + + +class DistributionNLL(NegativeLogLikelihood): + def update( + self, dists: torch.distributions.Distribution, target: torch.Tensor + ) -> None: + """Update state with the predicted distributions and the targets. + + Args: + dists (torch.distributions.Distribution): Predicted distributions. + target (torch.Tensor): Ground truth labels. + """ + if self.reduction is None or self.reduction == "none": + self.values.append(-dists.log_prob(target)) + else: + self.values += -dists.log_prob(target).sum() + self.total += target.size(0) + + def compute(self) -> torch.Tensor: + """Computes NLL based on inputs passed in to ``update`` previously.""" + values = dim_zero_cat(self.values) + + if self.reduction == "sum": + return values.sum(dim=-1) + if self.reduction == "mean": + return values.sum(dim=-1) / self.total + # reduction is None or "none" + return values diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 13ab8091..086aaf08 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -66,7 +66,6 @@ def __init__( model (nn.Module): Model to train. loss (type[nn.Module]): Loss function. num_estimators (int): _description_ - optimization_procedure (Any): Optimization procedure. format_batch_fn (nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. mixtype (str, optional): Mixup type. Defaults to ``"erm"``. diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index bc11a60d..6a5fe8da 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -1,87 +1,49 @@ -from typing import Literal - -import torch.nn.functional as F -from einops import rearrange +from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT -from pytorch_lightning import LightningModule from torch import Tensor, nn from torchmetrics import MeanSquaredError, MetricCollection -from torch_uncertainty.metrics.nll import GaussianNegativeLogLikelihood +from torch_uncertainty.metrics.nll import DistributionNLL class RegressionRoutine(LightningModule): def __init__( self, - dist_estimation: int, + num_features: int, model: nn.Module, loss: type[nn.Module], num_estimators: int, - mode: Literal["mean", "mixture"], - out_features: int | None = 1, + format_batch_fn: nn.Module | None = None, ) -> None: - print("Regression is Work in progress. Raise an issue if interested.") super().__init__() self.model = model self.loss = loss - # metrics - if isinstance(dist_estimation, int): - if dist_estimation <= 0: - raise ValueError( - "Expected the argument ``dist_estimation`` to be integer " - f" larger than 0, but got {dist_estimation}." - ) - else: - raise TypeError( - "Expected the argument ``dist_estimation`` to be integer, but " - f"got {type(dist_estimation)}" - ) - - out_features = list(self.model.parameters())[-1].size(0) - if dist_estimation > out_features: - raise ValueError( - "Expected argument ``dist_estimation`` to be an int lower or " - f"equal than the size of the output layer, but got " - f"{dist_estimation} and {out_features}." - ) - - self.dist_estimation = dist_estimation - - if dist_estimation in (4, 2): - reg_metrics = MetricCollection( - { - "mse": MeanSquaredError(squared=True), - "gnll": GaussianNegativeLogLikelihood(), - }, - compute_groups=False, - ) - else: - reg_metrics = MetricCollection( - { - "mse": MeanSquaredError(squared=True), - }, - compute_groups=False, - ) - + self.format_batch_fn = format_batch_fn + + reg_metrics = MetricCollection( + { + "mse": MeanSquaredError(squared=True), + "nll": DistributionNLL(), + }, + compute_groups=False, + ) self.val_metrics = reg_metrics.clone(prefix="reg_val/") self.test_metrics = reg_metrics.clone(prefix="reg_test/") - if mode == "mixture": - raise NotImplementedError( - "Mixture of gaussians not implemented yet. Raise an issue if " - "needed." - ) - - self.mode = mode self.num_estimators = num_estimators - self.out_features = out_features def on_train_start(self) -> None: # hyperparameters for performances init_metrics = {k: 0 for k in self.val_metrics} init_metrics.update({k: 0 for k in self.test_metrics}) + if self.logger is not None: # coverage: ignore + self.logger.log_hyperparams( + self.hparams, + init_metrics, + ) + def forward(self, inputs: Tensor) -> Tensor: return self.model.forward(inputs) @@ -99,18 +61,7 @@ def training_step( logits = self.forward(inputs) - if self.dist_estimation == 4: - means, v, alpha, beta = logits.split(1, dim=-1) - v = F.softplus(v) - alpha = 1 + F.softplus(alpha) - beta = F.softplus(beta) - loss = self.criterion(means, v, alpha, beta, targets) - elif self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - loss = self.criterion(means, targets, variances) - else: - loss = self.criterion(logits, targets) + loss = self.criterion(logits, targets) self.log("train_loss", loss) return loss @@ -119,31 +70,15 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch - logits = self.forward(inputs) - - if self.out_features == 1: - logits = rearrange( - logits, "(m b) dist -> b m dist", m=self.num_estimators - ) - else: - logits = rearrange( - logits, - "(m b) (f dist) -> b f m dist", - m=self.num_estimators, - f=self.out_features, - ) - if self.mode == "mean": - logits = logits.mean(dim=1) + dists = self.forward(inputs) - if self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.val_metrics.gnll.update(means, targets, variances) - else: - means = logits + self.val_metrics.mse.update(dists.loc, targets) + self.val_metrics.nll.update(dists, targets) - self.val_metrics.mse.update(means, targets) + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_metrics.compute()) + self.val_metrics.reset() def test_step( self, @@ -158,37 +93,12 @@ def test_step( ) inputs, targets = batch - logits = self.forward(inputs) + dists = self.forward(inputs) - if self.out_features == 1: - logits = rearrange( - logits, "(m b) dist -> b m dist", m=self.num_estimators - ) - else: - logits = rearrange( - logits, - "(m b) (f dist) -> b f m dist", - m=self.num_estimators, - f=self.out_features, - ) - - if self.mode == "mean": - logits = logits.mean(dim=1) - - if self.dist_estimation == 2: - means = logits[..., 0] - variances = F.softplus(logits[..., 1]) - self.test_metrics.gnll.update(means, targets, variances) - else: - means = logits - - self.test_metrics.mse.update(means, targets) - - def validation_epoch_end(self, outputs) -> None: - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() + self.test_metrics.mse.update(dists.loc, targets) + self.test_metrics.nll.update(dists, targets) - def test_epoch_end(self, outputs) -> None: + def on_test_epoch_end(self) -> None: self.log_dict( self.test_metrics.compute(), ) From 904881ba40931d1a346d533111e28fa1c08f7269 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 15 Mar 2024 21:46:07 +0100 Subject: [PATCH 026/148] :zap: Ruff with updated rules --- torch_uncertainty/layers/bayesian/bayes_conv.py | 4 +--- torch_uncertainty/layers/bayesian/bayes_linear.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index 450d9f63..95060560 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -84,9 +84,7 @@ def __init__( valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} if padding_mode not in valid_padding_modes: raise ValueError( - "padding_mode must be one of {}, but got '{}'".format( - valid_padding_modes, padding_mode - ) + f"padding_mode must be one of {valid_padding_modes}, but got '{padding_mode}'" ) if transposed: diff --git a/torch_uncertainty/layers/bayesian/bayes_linear.py b/torch_uncertainty/layers/bayesian/bayes_linear.py index 9dd1d06e..f722f842 100644 --- a/torch_uncertainty/layers/bayesian/bayes_linear.py +++ b/torch_uncertainty/layers/bayesian/bayes_linear.py @@ -146,6 +146,4 @@ def sample(self) -> tuple[Tensor, Tensor | None]: return weight, bias def extra_repr(self) -> str: - return "in_features={}, out_features={}, bias={}".format( - self.in_features, self.out_features, self.bias_mu is not None - ) + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias_mu is not None}" From 1a7fd29c770172d06deeaf677e949756fc21ceb1 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 01:33:00 +0100 Subject: [PATCH 027/148] :bug: Remove duplicate model initialization --- torch_uncertainty/baselines/classification/resnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 7efda160..21e5c825 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -282,7 +282,6 @@ def __init__( last_layer=last_layer_dropout, ) - model = self.versions[version][self.archs.index(arch)](**params) super().__init__( num_classes=num_classes, model=model, From b8e83d9e411b58fec96940e2f4366402dcc8a24f Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 01:35:57 +0100 Subject: [PATCH 028/148] :hammer: Change seed_everything to ``False`` in classification config files --- experiments/classification/cifar10/configs/resnet.yaml | 2 +- .../classification/cifar10/configs/resnet18/batched.yaml | 2 +- experiments/classification/cifar10/configs/resnet18/masked.yaml | 2 +- experiments/classification/cifar10/configs/resnet18/mimo.yaml | 2 +- experiments/classification/cifar10/configs/resnet18/packed.yaml | 2 +- .../classification/cifar10/configs/resnet18/standard.yaml | 2 +- .../classification/cifar10/configs/resnet50/batched.yaml | 2 +- experiments/classification/cifar10/configs/resnet50/masked.yaml | 2 +- experiments/classification/cifar10/configs/resnet50/mimo.yaml | 2 +- experiments/classification/cifar10/configs/resnet50/packed.yaml | 2 +- .../classification/cifar10/configs/resnet50/standard.yaml | 2 +- experiments/classification/cifar10/configs/wideresnet28x10.yaml | 2 +- .../classification/cifar10/configs/wideresnet28x10/batched.yaml | 2 +- .../classification/cifar10/configs/wideresnet28x10/masked.yaml | 2 +- .../classification/cifar10/configs/wideresnet28x10/mimo.yaml | 2 +- .../classification/cifar10/configs/wideresnet28x10/packed.yaml | 2 +- .../cifar10/configs/wideresnet28x10/standard.yaml | 2 +- experiments/classification/cifar100/configs/resnet.yaml | 2 +- .../classification/cifar100/configs/resnet18/batched.yaml | 2 +- .../classification/cifar100/configs/resnet18/masked.yaml | 2 +- experiments/classification/cifar100/configs/resnet18/mimo.yaml | 2 +- .../classification/cifar100/configs/resnet18/packed.yaml | 2 +- .../classification/cifar100/configs/resnet18/standard.yaml | 2 +- .../classification/cifar100/configs/resnet50/batched.yaml | 2 +- .../classification/cifar100/configs/resnet50/masked.yaml | 2 +- experiments/classification/cifar100/configs/resnet50/mimo.yaml | 2 +- .../classification/cifar100/configs/resnet50/packed.yaml | 2 +- .../classification/cifar100/configs/resnet50/standard.yaml | 2 +- 28 files changed, 28 insertions(+), 28 deletions(-) diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index fc1197d7..fb396273 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index 34ab9fc5..9596dc65 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index f84533ba..958c8c25 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index 8577ab3b..c642e877 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index 1e6853df..c6c4ecd4 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index 2d835a6a..e930813d 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index 60841124..396a268f 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index eea7ac81..195c8338 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index cd6681a3..939f2897 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index 816ddb33..ac99f2f0 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index e3ff7407..6e0e719e 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index f9423cfb..8d88ad09 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index 59bb6213..eeca402d 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 363cf464..74a0b950 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index b186909b..782c1202 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index 0ae5c3ca..e3af37c5 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index 200c571c..ebfd0f2f 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index fc1197d7..fb396273 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index 358e0f62..4410892d 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index 93c57ae4..3bd98005 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index 3e6da8e8..ee3efcb9 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 39f384db..2a0d7c47 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index 41452d61..235a6382 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index 1a344716..38331ee5 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index be065061..09168144 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index 76f4f8b8..387e6dd9 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index 6bc068e7..3813fba7 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index d7580437..7f236d4f 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -1,5 +1,5 @@ # lightning.pytorch==2.1.3 -seed_everything: true +seed_everything: false eval_after_fit: true trainer: accelerator: gpu From 97090cc4944ae591778564405260d3c0f88b7ad4 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 01:36:59 +0100 Subject: [PATCH 029/148] :hammer: Change IOU metric to Mean_IOU - slight modifications in ``SegmentationRoutine`` --- torch_uncertainty/metrics/__init__.py | 2 +- .../metrics/{iou.py => mean_iou.py} | 8 +++--- torch_uncertainty/routines/segmentation.py | 27 ++++++++++++------- 3 files changed, 22 insertions(+), 15 deletions(-) rename torch_uncertainty/metrics/{iou.py => mean_iou.py} (78%) diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 9236190f..3d132fbe 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -5,7 +5,7 @@ from .entropy import Entropy from .fpr95 import FPR95 from .grouping_loss import GroupingLoss -from .iou import IntersectionOverUnion +from .mean_iou import MeanIntersectionOverUnion from .mutual_information import MutualInformation from .nll import GaussianNegativeLogLikelihood, NegativeLogLikelihood from .sparsification import AUSE diff --git a/torch_uncertainty/metrics/iou.py b/torch_uncertainty/metrics/mean_iou.py similarity index 78% rename from torch_uncertainty/metrics/iou.py rename to torch_uncertainty/metrics/mean_iou.py index a6695b1f..1a3dec82 100644 --- a/torch_uncertainty/metrics/iou.py +++ b/torch_uncertainty/metrics/mean_iou.py @@ -4,8 +4,8 @@ from torchmetrics.utilities.compute import _safe_divide -class IntersectionOverUnion(MulticlassStatScores): - """Compute the Intersection over Union (IoU) score.""" +class MeanIntersectionOverUnion(MulticlassStatScores): + """Compute the MeanIntersection over Union (IoU) score.""" is_differentiable: bool = False higher_is_better: bool = True @@ -29,6 +29,6 @@ def update(self, preds: Tensor, target: Tensor) -> None: super().update(preds, target) def compute(self) -> Tensor: - """Compute the Intersection over Union (IoU) based on saved inputs.""" + """Compute the Means Intersection over Union (MIoU) based on saved inputs.""" tp, fp, _, fn = self._final_state() - return _safe_divide(tp, tp + fp + fn) + return _safe_divide(tp, tp + fp + fn).mean() diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index c160c3e0..2d060fa4 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -3,7 +3,7 @@ from torch import Tensor, nn from torchmetrics import MetricCollection -from torch_uncertainty.metrics import IntersectionOverUnion +from torch_uncertainty.metrics import MeanIntersectionOverUnion class SegmentationRoutine(LightningModule): @@ -11,27 +11,34 @@ def __init__( self, num_classes: int, model: nn.Module, - loss: nn.Module, + loss: type[nn.Module], num_estimators: int, format_batch_fn: nn.Module | None = None, ) -> None: super().__init__() + if format_batch_fn is None: + format_batch_fn = nn.Identity() + self.num_classes = num_classes self.model = model self.loss = loss - self.metric_to_monitor = "hp/val_iou" + self.metric_to_monitor = "val/mean_iou" # metrics seg_metrics = MetricCollection( { - "iou": IntersectionOverUnion(num_classes=num_classes), + "mean_iou": MeanIntersectionOverUnion(num_classes=num_classes), } ) - self.val_seg_metrics = seg_metrics.clone(prefix="hp/val_") - self.test_seg_metrics = seg_metrics.clone(prefix="hp/test_") + self.val_seg_metrics = seg_metrics.clone(prefix="val/") + self.test_seg_metrics = seg_metrics.clone(prefix="test/") + + @property + def criterion(self) -> nn.Module: + return self.loss() def forward(self, img: Tensor) -> Tensor: return self.model(img) @@ -47,16 +54,16 @@ def training_step( ) -> STEP_OUTPUT: img, target = batch pred = self.forward(img) - loss = self.loss(pred, target) + loss = self.criterion(pred, target) self.log("train_loss", loss) return loss def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - img, target = batch - pred = self.forward(img) - self.val_seg_metrics.update(pred, target) + img, targets = batch + logits = self.forward(img) + self.val_seg_metrics.update(logits, targets) def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: img, target = batch From 05c2ff22abcc0205469587ce6216862b37d74446 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 01:40:02 +0100 Subject: [PATCH 030/148] :bug: Fix CamVid Dataset to work with ``torchvision.transforms.v2`` --- .../datasets/segmentation/camvid.py | 66 ++++++++++++++++--- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index a6a1ddbd..211f0ec5 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Literal, NamedTuple +import torch +from einops import rearrange, repeat from PIL import Image from torchvision import tv_tensors from torchvision.datasets import VisionDataset @@ -11,6 +13,7 @@ download_and_extract_archive, download_url, ) +from torchvision.transforms.v2 import functional as F class CamVidClass(NamedTuple): @@ -61,7 +64,7 @@ def __init__( self, root: str, split: Literal["train", "val", "test"] | None = None, - transforms: Callable | None = None, + transform: Callable | None = None, download: bool = False, ) -> None: """`CamVid `_ Dataset. @@ -71,7 +74,7 @@ def __init__( will be saved to if download is set to ``True``. split (str, optional): The dataset split, supports ``train``, ``val`` and ``test``. Default: ``None``. - transforms (callable, optional): A function/transform that takes + transform (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. download (bool, optional): If true, downloads the dataset from the @@ -84,7 +87,7 @@ def __init__( "Supported splits are ['train', 'val', 'test', None]" ) - super().__init__(root, transforms, None, None) + super().__init__(root, transform, None, None) if download: self.download() @@ -122,12 +125,57 @@ def __init__( for path in (Path(self.root) / "camvid" / "label").glob( "*.png" ) - if path.stem in filenames + if path.stem[:-2] in filenames ] ) + self.transform = transform self.split = split if split is not None else "all" + def encode_target(self, target: Image.Image) -> torch.Tensor: + """Encode target image to tensor. + + Args: + target (Image.Image): Target image. + + Returns: + torch.Tensor: Encoded target. + """ + colored_target = F.pil_to_tensor(target) + colored_target = rearrange(colored_target, "c h w -> h w c") + target = torch.zeros_like(colored_target[..., :1]) + # convert target color to index + for camvid_class in self.classes: + target[ + ( + colored_target + == torch.tensor(camvid_class.color, dtype=target.dtype) + ).all(dim=-1) + ] = camvid_class.index + + return rearrange(target, "h w c -> c h w").squeeze(0) + + def decode_target(self, target: torch.Tensor) -> Image.Image: + """Decode target tensor to image. + + Args: + target (torch.Tensor): Target tensor. + + Returns: + Image.Image: Decoded target. + """ + colored_target = repeat(target.clone(), "h w -> h w 3", c=3) + + for camvid_class in self.classes: + colored_target[ + ( + target + == torch.tensor(camvid_class.index, dtype=target.dtype) + ).all(dim=0) + ] = torch.tensor(camvid_class.color, dtype=target.dtype) + + return F.to_pil_image(rearrange(colored_target, "h w c -> c h w")) + def __getitem__(self, index: int) -> tuple: """Get image and target at index. @@ -138,10 +186,12 @@ def __getitem__(self, index: int) -> tuple: tuple: (image, target) where target is the segmentation mask. """ image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) - target = tv_tensors.Mask(Image.open(self.targets[index])) + target = tv_tensors.Mask( + self.encode_target(Image.open(self.targets[index])) + ) - if self.transforms is not None: - image, target = self.transforms(image, target) + if self.transform is not None: + image, target = self.transform(image, target) return image, target @@ -173,7 +223,7 @@ def download(self) -> None: print("Files already downloaded and verified") return - if Path(self.root) / self.base_folder: + if (Path(self.root) / self.base_folder).exists(): shutil.rmtree(Path(self.root) / self.base_folder) download_and_extract_archive( From 6c68e01b4c0458a5b4add2d0de76d8e8f13b6ad3 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 01:41:46 +0100 Subject: [PATCH 031/148] :hammer: Update transform methods in CamVid DataModule --- .../datamodules/segmentation/camvid.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index b8186dd1..bfcd4bb8 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -1,5 +1,7 @@ from pathlib import Path +import torch +from torchvision import tv_tensors from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import AbstractDataModule @@ -27,10 +29,34 @@ def __init__( self.dataset = CamVid self.transform_train = v2.Compose( - [v2.Resize((360, 480), interpolation=v2.InterpolationMode.NEAREST)] + [ + v2.Resize( + (360, 480), interpolation=v2.InterpolationMode.NEAREST + ), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + ] ) self.transform_test = v2.Compose( - [v2.Resize((360, 480), interpolation=v2.InterpolationMode.NEAREST)] + [ + v2.Resize( + (360, 480), interpolation=v2.InterpolationMode.NEAREST + ), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + ] ) def prepare_data(self) -> None: # coverage: ignore From 682705efaec2ebbe7119aecb6265970cf2995923 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 01:53:12 +0100 Subject: [PATCH 032/148] :sparkles: Add SegFormer models --- .../baselines/segmentation/__init__.py | 2 + .../baselines/segmentation/segformer.py | 110 +++ .../models/segmentation/__init__.py | 0 .../models/segmentation/segformer/__init__.py | 2 + .../models/segmentation/segformer/std.py | 880 ++++++++++++++++++ 5 files changed, 994 insertions(+) create mode 100644 torch_uncertainty/baselines/segmentation/segformer.py create mode 100644 torch_uncertainty/models/segmentation/__init__.py create mode 100644 torch_uncertainty/models/segmentation/segformer/__init__.py create mode 100644 torch_uncertainty/models/segmentation/segformer/std.py diff --git a/torch_uncertainty/baselines/segmentation/__init__.py b/torch_uncertainty/baselines/segmentation/__init__.py index e69de29b..d9e05601 100644 --- a/torch_uncertainty/baselines/segmentation/__init__.py +++ b/torch_uncertainty/baselines/segmentation/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .segformer import SegFormer diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py new file mode 100644 index 00000000..4c95bdaf --- /dev/null +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -0,0 +1,110 @@ +from typing import Literal + +from torch import Tensor, nn +from torchvision.transforms.v2 import functional as F + +from torch_uncertainty.models.segmentation.segformer import ( + segformer_b0, + segformer_b1, + segformer_b2, + segformer_b3, + segformer_b4, + segformer_b5, +) +from torch_uncertainty.routines.segmentation import SegmentationRoutine + + +class SegFormer(SegmentationRoutine): + single = ["std"] + versions = { + "std": [ + segformer_b0, + segformer_b1, + segformer_b2, + segformer_b3, + segformer_b4, + segformer_b5, + ] + } + archs = [0, 1, 2, 3, 4, 5] + + def __init__( + self, + num_classes: int, + loss: type[nn.Module], + version: Literal["std"], + arch: int, + num_estimators: int = 1, + ) -> None: + r"""SegFormer backbone baseline for segmentation providing support for + various versions and architectures. + + Args: + num_classes (int): Number of classes to predict. + loss (type[Module]): Training loss. + version (str): + Determines which SegFormer version to use. Options are: + + - ``"std"``: original SegFormer + + arch (int): + Determines which architecture to use. Options are: + + - ``0``: SegFormer-B0 + - ``1``: SegFormer-B1 + - ``2``: SegFormer-B2 + - ``3``: SegFormer-B3 + - ``4``: SegFormer-B4 + - ``5``: SegFormer-B5 + + num_estimators (int, optional): _description_. Defaults to 1. + """ + params = { + "num_classes": num_classes, + } + + format_batch_fn = nn.Identity() + + if version not in self.versions: + raise ValueError(f"Unknown version {version}") + + model = self.versions[version][self.archs.index(arch)](**params) + + super().__init__( + num_classes=num_classes, + model=model, + loss=loss, + num_estimators=num_estimators, + format_batch_fn=format_batch_fn, + ) + self.save_hyperparameters() + + def training_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> Tensor: + img, target = batch + logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) + loss = self.criterion(logits, target) + self.log("train_loss", loss) + return loss + + def validation_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> None: + img, target = batch + logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) + self.val_seg_metrics.update(logits, target) + + def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + img, target = batch + logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) + self.test_seg_metrics.update(logits, target) diff --git a/torch_uncertainty/models/segmentation/__init__.py b/torch_uncertainty/models/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/models/segmentation/segformer/__init__.py b/torch_uncertainty/models/segmentation/segformer/__init__.py new file mode 100644 index 00000000..dc3fb2ee --- /dev/null +++ b/torch_uncertainty/models/segmentation/segformer/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401, F403 +from .std import * diff --git a/torch_uncertainty/models/segmentation/segformer/std.py b/torch_uncertainty/models/segmentation/segformer/std.py new file mode 100644 index 00000000..1b735760 --- /dev/null +++ b/torch_uncertainty/models/segmentation/segformer/std.py @@ -0,0 +1,880 @@ +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- + +import math +import warnings +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, h, w): + b, _, c = x.shape + x = x.transpose(1, 2).view(b, c, h, w) + x = self.dwconv(x) + return x.flatten(2).transpose(1, 2) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, h, w): + x = self.fc1(x) + x = self.dwconv(x, h, w) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, h, w): + b, n, c = x.shape + q = ( + self.q(x) + .reshape(b, n, self.num_heads, c // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(b, c, h, w) + x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = ( + self.kv(x_) + .reshape(b, -1, 2, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + else: + kv = ( + self.kv(x) + .reshape(b, -1, 2, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + return self.proj_drop(x) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better + # than dropout here + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, h, w): + x = x + self.drop_path(self.attn(self.norm1(x), h, w)) + return x + self.drop_path(self.mlp(self.norm2(x), h, w)) + + +class OverlapPatchEmbed(nn.Module): + """Image to Patch Embedding.""" + + def __init__( + self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.h, self.w = ( + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ) + self.num_patches = self.h * self.w + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, h, w = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, h, w + + +class SegFormerSegmentationHead(nn.Module): + def __init__(self, channels: int, num_classes: int, num_features: int = 4): + super().__init__() + self.fuse = nn.Sequential( + nn.Conv2d( + channels * num_features, channels, kernel_size=1, bias=False + ), + nn.ReLU(), + nn.BatchNorm2d(channels), + ) + self.predict = nn.Conv2d(channels, num_classes, kernel_size=1) + + def forward(self, features): + x = torch.cat(features, dim=1) + x = self.fuse(x) + return self.predict(x) + + +class MixVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=None, + num_heads=None, + mlp_ratios=None, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=None, + sr_ratios=None, + ): + if sr_ratios is None: + sr_ratios = [8, 4, 2, 1] + if depths is None: + depths = [3, 4, 6, 3] + if mlp_ratios is None: + mlp_ratios = [4, 4, 4, 4] + if num_heads is None: + num_heads = [1, 2, 4, 8] + if embed_dims is None: + embed_dims = [64, 128, 256, 512] + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed( + img_size=img_size, + patch_size=7, + stride=4, + in_chans=in_chans, + embed_dim=embed_dims[0], + ) + self.patch_embed2 = OverlapPatchEmbed( + img_size=img_size // 4, + patch_size=3, + stride=2, + in_chans=embed_dims[0], + embed_dim=embed_dims[1], + ) + self.patch_embed3 = OverlapPatchEmbed( + img_size=img_size // 8, + patch_size=3, + stride=2, + in_chans=embed_dims[1], + embed_dim=embed_dims[2], + ) + self.patch_embed4 = OverlapPatchEmbed( + img_size=img_size // 16, + patch_size=3, + stride=2, + in_chans=embed_dims[2], + embed_dim=embed_dims[3], + ) + + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList( + [ + Block( + dim=embed_dims[0], + num_heads=num_heads[0], + mlp_ratio=mlp_ratios[0], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[0], + ) + for i in range(depths[0]) + ] + ) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList( + [ + Block( + dim=embed_dims[1], + num_heads=num_heads[1], + mlp_ratio=mlp_ratios[1], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[1], + ) + for i in range(depths[1]) + ] + ) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList( + [ + Block( + dim=embed_dims[2], + num_heads=num_heads[2], + mlp_ratio=mlp_ratios[2], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[2], + ) + for i in range(depths[2]) + ] + ) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList( + [ + Block( + dim=embed_dims[3], + num_heads=num_heads[3], + mlp_ratio=mlp_ratios[3], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[3], + ) + for i in range(depths[3]) + ] + ) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 + # else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def reset_drop_path(self, drop_path_rate): + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return { + "pos_embed1", + "pos_embed2", + "pos_embed3", + "pos_embed4", + "cls_token", + } # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = ( + nn.Linear(self.embed_dim, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + def forward_features(self, x): + b = x.shape[0] + outs = [] + + # stage 1 + x, h, w = self.patch_embed1(x) + for _i, blk in enumerate(self.block1): + x = blk(x, h, w) + x = self.norm1(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, h, w = self.patch_embed2(x) + for _i, blk in enumerate(self.block2): + x = blk(x, h, w) + x = self.norm2(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, h, w = self.patch_embed3(x) + for _i, blk in enumerate(self.block3): + x = blk(x, h, w) + x = self.norm3(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, h, w = self.patch_embed4(x) + for _i, blk in enumerate(self.block4): + x = blk(x, h, w) + x = self.norm4(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs + + def forward(self, x): + return self.forward_features(x) + # x = self.head(x) + + +class MitB0(MixVisionTransformer): + def __init__(self, **kwargs): + super().__init__( + patch_size=4, + embed_dims=[32, 64, 160, 256], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB1(MixVisionTransformer): + def __init__(self, **kwargs): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB2(MixVisionTransformer): + def __init__(self, **kwargs): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB3(MixVisionTransformer): + def __init__(self, **kwargs): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 18, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB4(MixVisionTransformer): + def __init__(self, **kwargs): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 8, 27, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MitB5(MixVisionTransformer): + def __init__(self, **kwargs): + super().__init__( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 6, 40, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) + + +class MLPHead(nn.Module): + """Linear Embedding.""" + + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + return self.proj(x) + + +def resize( + inputs, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + warning=True, +): + if warning and size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in inputs.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if (output_h > input_h or output_w > output_h) and ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`", + stacklevel=2, + ) + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) + return F.interpolate(inputs, size, scale_factor, mode, align_corners) + + +class SegFormerHead(nn.Module): + """SegFormer: Simple and Efficient Design for Semantic Segmentation with + Transformers. + """ + + def __init__( + self, + in_channels, + feature_strides, + decoder_params, + num_classes, + dropout_ratio=0.1, + ): + super().__init__() + self.in_channels = in_channels + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + self.num_classes = num_classes + # self.in_index = [0, 1, 2, 3], + + ( + c1_in_channels, + c2_in_channels, + c3_in_channels, + c4_in_channels, + ) = self.in_channels + + embedding_dim = decoder_params["embed_dim"] + + self.linear_c4 = MLPHead( + input_dim=c4_in_channels, embed_dim=embedding_dim + ) + self.linear_c3 = MLPHead( + input_dim=c3_in_channels, embed_dim=embedding_dim + ) + self.linear_c2 = MLPHead( + input_dim=c2_in_channels, embed_dim=embedding_dim + ) + self.linear_c1 = MLPHead( + input_dim=c1_in_channels, embed_dim=embedding_dim + ) + + self.fuse = nn.Sequential( + nn.Conv2d( + embedding_dim * 4, embedding_dim, kernel_size=1, bias=False + ), + nn.ReLU(), + nn.BatchNorm2d(embedding_dim), + ) + + self.linear_pred = nn.Conv2d( + embedding_dim, self.num_classes, kernel_size=1 + ) + + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + + def forward(self, inputs): + # x = [inputs[i] for i in self.in_index] # len=4, 1/4,1/8,1/16,1/32 + c1, c2, c3, c4 = inputs[0], inputs[1], inputs[2], inputs[3] + + n, _, h, w = c4.shape + + _c4 = ( + self.linear_c4(c4) + .permute(0, 2, 1) + .reshape(n, -1, c4.shape[2], c4.shape[3]) + ) + _c4 = resize( + _c4, size=c1.size()[2:], mode="bilinear", align_corners=False + ) + + _c3 = ( + self.linear_c3(c3) + .permute(0, 2, 1) + .reshape(n, -1, c3.shape[2], c3.shape[3]) + ) + _c3 = resize( + _c3, size=c1.size()[2:], mode="bilinear", align_corners=False + ) + + _c2 = ( + self.linear_c2(c2) + .permute(0, 2, 1) + .reshape(n, -1, c2.shape[2], c2.shape[3]) + ) + _c2 = resize( + _c2, size=c1.size()[2:], mode="bilinear", align_corners=False + ) + + _c1 = ( + self.linear_c1(c1) + .permute(0, 2, 1) + .reshape(n, -1, c1.shape[2], c1.shape[3]) + ) + + _c = self.fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + + x = self.dropout(_c) + return self.linear_pred(x) + + +class _SegFormer(nn.Module): + def __init__( + self, + in_channels, + feature_strides, + decoder_params, + num_classes, + dropout_ratio, + mit: nn.Module, + ): + super().__init__() + + self.encoder = mit() + self.head = SegFormerHead( + in_channels, + feature_strides, + decoder_params, + num_classes, + dropout_ratio, + ) + + def forward(self, x): + features = self.encoder(x) + return self.head(features) + + +def segformer_b0(num_classes: int): + return _SegFormer( + in_channels=[32, 64, 160, 256], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 256}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB0, + ) + + +def segformer_b1(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB1, + ) + + +def segformer_b2(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB2, + ) + + +def segformer_b3(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB3, + ) + + +def segformer_b4(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB4, + ) + + +def segformer_b5(num_classes: int): + return _SegFormer( + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + decoder_params={"embed_dim": 512}, + num_classes=num_classes, + dropout_ratio=0.1, + mit=MitB5, + ) + + +if __name__ == "__main__": + x = torch.randn((1, 3, 224, 224)) + model = segformer_b0() + print(model(x).size()) # torch.Size([1, 19, 56, 56]) From 094cc371fa3923a08e6d42c526a6d746705828f1 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 01:53:57 +0100 Subject: [PATCH 033/148] :sparkles: Running SegFormerB0 on CamVid - init segmentation experiments --- .../camvid/configs/segformer.yaml | 23 +++++++++++++++++ experiments/segmentation/camvid/segformer.py | 25 +++++++++++++++++++ experiments/segmentation/readme.md | 1 + 3 files changed, 49 insertions(+) create mode 100644 experiments/segmentation/camvid/configs/segformer.yaml create mode 100644 experiments/segmentation/camvid/segformer.py create mode 100644 experiments/segmentation/readme.md diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml new file mode 100644 index 00000000..16767e87 --- /dev/null +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -0,0 +1,23 @@ +# lightning.pytorch==2.1.3 +eval_after_fit: true +seed_everything: true +trainer: + accelerator: gpu + devices: 1 + precision: bf16-mixed +model: + num_classes: 13 + loss: torch.nn.CrossEntropyLoss + version: std + arch: 0 + num_estimators: 1 +data: + root: ./data + batch_size: 16 + num_workers: 20 +optimizer: + lr: 0.01 +lr_scheduler: + milestones: + - 30 + gamma: 0.1 diff --git a/experiments/segmentation/camvid/segformer.py b/experiments/segmentation/camvid/segformer.py new file mode 100644 index 00000000..b38f0d32 --- /dev/null +++ b/experiments/segmentation/camvid/segformer.py @@ -0,0 +1,25 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser +from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 + +from torch_uncertainty.baselines.segmentation import SegFormer +from torch_uncertainty.datamodules import CamVidDataModule +from torch_uncertainty.utils import TULightningCLI + + +class SegFormerCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) + + +def cli_main() -> SegFormerCLI: + return SegFormerCLI(SegFormer, CamVidDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + + if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/segmentation/readme.md b/experiments/segmentation/readme.md new file mode 100644 index 00000000..e8ef0698 --- /dev/null +++ b/experiments/segmentation/readme.md @@ -0,0 +1 @@ +# Segmentation Benchmarks From 5b5e43360a82629de4a49f9ae029494162ca01c7 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 15:21:25 +0100 Subject: [PATCH 034/148] :bug: ``stage is None`` now sets up the test set for Classification DataModules --- torch_uncertainty/datamodules/classification/cifar10.py | 2 +- torch_uncertainty/datamodules/classification/cifar100.py | 2 +- torch_uncertainty/datamodules/classification/imagenet.py | 2 +- torch_uncertainty/datamodules/classification/mnist.py | 2 +- torch_uncertainty/datamodules/classification/tiny_imagenet.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 01e50dfc..73efc2ae 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -163,7 +163,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: if self.test_alt is None: self.test = self.dataset( self.root, diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index a03b48bc..3f75dde7 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -163,7 +163,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: if self.test_alt is None: self.test = self.dataset( self.root, diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index dc6775f2..9f19eec5 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -221,7 +221,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: split="val", transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: self.test = self.dataset( self.root, split="val", diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 24f15a9f..e4bd6107 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -127,7 +127,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: self.test = self.dataset( self.root, train=False, diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 91929005..f323fc31 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -141,7 +141,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: split="val", transform=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: self.test = self.dataset( self.root, split="val", From bbbcb7c8fcace7d0e285e1bd9c0e58a84c5f8385 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 16:19:23 +0100 Subject: [PATCH 035/148] :sparkles: Enable ensembles in regression routine --- torch_uncertainty/routines/regression.py | 74 +++++++++++++++++++----- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 6a5fe8da..11d869fd 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -1,38 +1,64 @@ +import torch from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, nn -from torchmetrics import MeanSquaredError, MetricCollection +from torch.distributions import ( + Categorical, + Independent, + MixtureSameFamily, +) +from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torch_uncertainty.metrics.nll import DistributionNLL +from torch_uncertainty.utils.distributions import to_ens_dist class RegressionRoutine(LightningModule): def __init__( self, - num_features: int, + num_outputs: int, model: nn.Module, loss: type[nn.Module], - num_estimators: int, + num_estimators: int = 1, format_batch_fn: nn.Module | None = None, ) -> None: super().__init__() self.model = model self.loss = loss + + if format_batch_fn is None: + format_batch_fn = nn.Identity() + self.format_batch_fn = format_batch_fn reg_metrics = MetricCollection( { - "mse": MeanSquaredError(squared=True), - "nll": DistributionNLL(), + "mae": MeanAbsoluteError(), + "mse": MeanSquaredError(squared=False), + "nll": DistributionNLL(reduction="mean"), }, compute_groups=False, ) self.val_metrics = reg_metrics.clone(prefix="reg_val/") self.test_metrics = reg_metrics.clone(prefix="reg_test/") + if num_estimators < 1: + raise ValueError( + f"num_estimators must be positive, got {num_estimators}." + ) self.num_estimators = num_estimators + if num_outputs < 1: + raise ValueError( + f"num_outputs must be positive, got {num_outputs}." + ) + self.num_outputs = num_outputs + + self.one_dim_regression = False + if num_outputs == 1: + self.one_dim_regression = True + def on_train_start(self) -> None: # hyperparameters for performances init_metrics = {k: 0 for k in self.val_metrics} @@ -54,14 +80,14 @@ def criterion(self) -> nn.Module: def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: - inputs, targets = batch + inputs, targets = self.format_batch_fn(batch) - # eventual input repeat is done in the model - targets = targets.repeat((self.num_estimators, 1)) + dists = self.forward(inputs) - logits = self.forward(inputs) + if self.one_dim_regression: + targets = targets.unsqueeze(-1) - loss = self.criterion(logits, targets) + loss = self.criterion(dists, targets) self.log("train_loss", loss) return loss @@ -73,8 +99,18 @@ def validation_step( dists = self.forward(inputs) - self.val_metrics.mse.update(dists.loc, targets) - self.val_metrics.nll.update(dists, targets) + ens_dist = Independent( + to_ens_dist(dists, num_estimators=self.num_estimators), 1 + ) + mix = Categorical(torch.ones(self.num_estimators, device=self.device)) + mixture = MixtureSameFamily(mix, ens_dist) + + if self.one_dim_regression: + targets = targets.unsqueeze(-1) + + self.val_metrics.mse.update(mixture.mean, targets) + self.val_metrics.mae.update(mixture.mean, targets) + self.val_metrics.nll.update(mixture, targets) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute()) @@ -94,9 +130,19 @@ def test_step( inputs, targets = batch dists = self.forward(inputs) + ens_dist = Independent( + to_ens_dist(dists, num_estimators=self.num_estimators), 1 + ) + + mix = Categorical(torch.ones(self.num_estimators, device=self.device)) + mixture = MixtureSameFamily(mix, ens_dist) + + if self.one_dim_regression: + targets = targets.unsqueeze(-1) - self.test_metrics.mse.update(dists.loc, targets) - self.test_metrics.nll.update(dists, targets) + self.test_metrics.mae.update(mixture.mean, targets) + self.test_metrics.mse.update(mixture.mean, targets) + self.test_metrics.nll.update(mixture, targets) def on_test_epoch_end(self) -> None: self.log_dict( From 1eadeb8af463ff741d16cf6b026bfeb7e3f594c0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 16:19:42 +0100 Subject: [PATCH 036/148] :sparkles: Add distribution utils --- torch_uncertainty/layers/distributions.py | 29 ++++++-------- torch_uncertainty/utils/distributions.py | 48 +++++++++++++++++++++++ 2 files changed, 60 insertions(+), 17 deletions(-) create mode 100644 torch_uncertainty/utils/distributions.py diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index 74e782bd..f7514c69 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -1,5 +1,6 @@ import torch.nn.functional as F -from torch import Tensor, distributions, nn +from torch import Tensor, nn +from torch.distributions import Distribution, Laplace, Normal class AbstractDistLayer(nn.Module): @@ -9,53 +10,47 @@ def __init__(self, dim: int) -> None: raise ValueError(f"dim must be positive, got {dim}.") self.dim = dim - def forward(self, x: Tensor) -> distributions.Distribution: + def forward(self, x: Tensor) -> Distribution: raise NotImplementedError -class IndptNormalDistLayer(AbstractDistLayer): +class IndptNormalLayer(AbstractDistLayer): def __init__(self, dim: int, min_scale: float = 1e-3) -> None: super().__init__(dim) if min_scale <= 0: raise ValueError(f"min_scale must be positive, got {min_scale}.") self.min_scale = min_scale - def forward(self, x: Tensor) -> distributions.Normal: + def forward(self, x: Tensor) -> Normal: """Forward pass of the independent normal distribution layer. Args: x (Tensor): The input tensor of shape (dx2). Returns: - distributions.Normal: The independent normal distribution. + Normal: The independent normal distribution. """ loc = x[:, : self.dim] scale = F.softplus(x[:, self.dim :]) + self.min_scale - if self.dim == 1: - loc = loc.squeeze(1) - scale = scale.squeeze(1) - return distributions.Normal(loc, scale) + return Normal(loc, scale) -class IndptLaplaceDistLayer(AbstractDistLayer): +class IndptLaplaceLayer(AbstractDistLayer): def __init__(self, dim: int, min_scale: float = 1e-3) -> None: super().__init__(dim) if min_scale <= 0: raise ValueError(f"min_scale must be positive, got {min_scale}.") self.min_scale = min_scale - def forward(self, x: Tensor) -> distributions.Laplace: - """Forward pass of the independent normal distribution layer. + def forward(self, x: Tensor) -> Laplace: + """Forward pass of the independent Laplace distribution layer. Args: x (Tensor): The input tensor of shape (dx2). Returns: - distributions.Laplace: The independent Laplace distribution. + Laplace: The independent Laplace distribution. """ loc = x[:, : self.dim] scale = F.softplus(x[:, self.dim :]) + self.min_scale - if self.dim == 1: - loc = loc.squeeze(1) - scale = scale.squeeze(1) - return distributions.Laplace(loc, scale) + return Laplace(loc, scale) diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py new file mode 100644 index 00000000..be68dee2 --- /dev/null +++ b/torch_uncertainty/utils/distributions.py @@ -0,0 +1,48 @@ +import torch +from einops import rearrange +from torch.distributions import Distribution, Laplace, Normal + + +def cat_dist(distributions: list[Distribution]) -> Distribution: + r"""Concatenate a list of distributions into a single distribution. + + Args: + distributions (list[Distribution]): The list of distributions. + + Returns: + Distribution: The concatenated distributions. + """ + dist_type = type(distributions[0]) + if not all( + isinstance(distribution, dist_type) for distribution in distributions + ): + raise ValueError("All distributions must have the same type.") + + if isinstance(distributions[0], Normal | Laplace): + locs = torch.cat( + [distribution.loc for distribution in distributions], dim=0 + ) + scales = torch.cat( + [distribution.scale for distribution in distributions], dim=0 + ) + return dist_type(loc=locs, scale=scales) + raise NotImplementedError( + f"Concatenation of {dist_type} distributions is not supported." + "Raise an issue if needed." + ) + + +def to_ens_dist( + distribution: Distribution, num_estimators: int = 1 +) -> Distribution: + dist_type = type(distribution) + if isinstance(distribution, Normal | Laplace): + loc = rearrange(distribution.loc, "(n b) c -> b n c", n=num_estimators) + scale = rearrange( + distribution.scale, "(n b) c -> b n c", n=num_estimators + ) + return dist_type(loc=loc, scale=scale) + raise NotImplementedError( + f"Ensemble distribution of {dist_type} is not supported." + "Raise an issue if needed." + ) From 0ba3ad93058a7285fa1e204e1a6158f7ddd0ca5d Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 16:20:13 +0100 Subject: [PATCH 037/148] hammer: Update MLP and Deep Ensembles for distribution regression --- torch_uncertainty/baselines/regression/mlp.py | 17 +++--- torch_uncertainty/models/deep_ensembles.py | 39 +++++++++++++- torch_uncertainty/models/mlp.py | 53 ++++++++++++++----- 3 files changed, 87 insertions(+), 22 deletions(-) diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 16da62bb..25d94e39 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -6,6 +6,7 @@ from torch_uncertainty.routines.regression import ( RegressionRoutine, ) +from torch_uncertainty.transforms.batch import RepeatTarget class MLP(RegressionRoutine): @@ -21,35 +22,39 @@ def __init__( version: Literal["std", "packed"], hidden_dims: list[int], num_estimators: int | None = 1, + dropout_rate: float = 0.0, alpha: float | None = None, gamma: int = 1, - **kwargs, ) -> None: r"""MLP baseline for regression providing support for various versions.""" params = { + "dropout_rate": dropout_rate, "in_features": in_features, "num_outputs": num_outputs, "hidden_dims": hidden_dims, } + format_batch_fn = nn.Identity() + + if version not in self.versions: + raise ValueError(f"Unknown version: {version}") + if version == "packed": params |= { "alpha": alpha, "num_estimators": num_estimators, "gamma": gamma, } - - if version not in self.versions: - raise ValueError(f"Unknown version: {version}") + format_batch_fn = RepeatTarget(num_repeats=num_estimators) model = self.versions[version](**params) # version in self.versions: super().__init__( + num_outputs=num_outputs, model=model, loss=loss, num_estimators=num_estimators, - dist_estimation=num_outputs, - mode="mean", + format_batch_fn=format_batch_fn, ) self.save_hyperparameters() diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index d82152df..1b8dc88f 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -1,10 +1,14 @@ import copy +from typing import Literal import torch from torch import nn +from torch.distributions import Distribution +from torch_uncertainty.utils.distributions import cat_dist -class _DeepEnsembles(nn.Module): + +class _ClsDeepEnsembles(nn.Module): def __init__( self, models: list[nn.Module], @@ -29,15 +33,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat([model.forward(x) for model in self.models], dim=0) +class _RegDeepEnsembles(nn.Module): + def __init__( + self, + models: list[nn.Module], + ) -> None: + """Create a deep ensembles from a list of models.""" + super().__init__() + + self.models = nn.ModuleList(models) + self.num_estimators = len(models) + + def forward(self, x: torch.Tensor) -> Distribution: + r"""Return the logits of the ensemble. + + Args: + x (Tensor): The input of the model. + + Returns: + Distribution: + """ + return cat_dist([model.forward(x) for model in self.models]) + + def deep_ensembles( models: list[nn.Module] | nn.Module, num_estimators: int | None = None, + task: Literal["classification", "regression"] = "classification", ) -> nn.Module: """Build a Deep Ensembles out of the original models. Args: models (list[nn.Module] | nn.Module): The model to be ensembled. num_estimators (int | None): The number of estimators in the ensemble. + task (Literal["classification", "regression"]): The model task. Returns: nn.Module: The ensembled model. @@ -82,4 +111,10 @@ def deep_ensembles( "num_estimators must be None if you provided a non-singleton list." ) - return _DeepEnsembles(models=models) + if task == "classification": + return _ClsDeepEnsembles(models=models) + if task == "regression": + return _RegDeepEnsembles(models=models) + raise ValueError( + f"task must be either 'classification' or 'regression'. Got {task}." + ) diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 8e0cefbc..c6e82ca3 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -19,7 +19,9 @@ def __init__( layer: type[nn.Module], activation: Callable, layer_args: dict, - dropout: float, + final_layer: nn.Module, + final_layer_args: dict, + dropout_rate: float, ) -> None: """Multi-layer perceptron class. @@ -30,11 +32,13 @@ def __init__( layer (nn.Module): Layer class. activation (Callable): Activation function. layer_args (Dict): Arguments for the layer class. - dropout (float): Dropout probability. + final_layer (nn.Module): Final layer class for distribution regression. + final_layer_args (Dict): Arguments for the final layer class. + dropout_rate (float): Dropout probability. """ super().__init__() self.activation = activation - self.dropout = dropout + self.dropout_rate = dropout_rate layers = nn.ModuleList() @@ -70,12 +74,12 @@ def __init__( ) else: layers.append(layer(hidden_dims[-1], num_outputs, **layer_args)) - self.layers = layers + self.final_layer = final_layer(**final_layer_args) def forward(self, x: Tensor) -> Tensor: for layer in self.layers[:-1]: - x = F.dropout(layer(x), p=self.dropout, training=self.training) + x = F.dropout(layer(x), p=self.dropout_rate, training=self.training) x = self.activation(x) return self.layers[-1](x) @@ -93,10 +97,14 @@ def _mlp( layer_args: dict | None = None, layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _MLP | _StochasticMLP: if layer_args is None: layer_args = {} + if final_layer_args is None: + final_layer_args = {} model = _MLP if not stochastic else _StochasticMLP return model( in_features=in_features, @@ -105,7 +113,9 @@ def _mlp( layer_args=layer_args, layer=layer, activation=activation, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) @@ -115,7 +125,9 @@ def mlp( hidden_dims: list[int], layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _MLP: """Multi-layer perceptron. @@ -126,7 +138,10 @@ def mlp( layer (nn.Module, optional): Layer type. Defaults to nn.Linear. activation (Callable, optional): Activation function. Defaults to F.relu. - dropout (float, optional): Dropout probability. Defaults to 0.0. + final_layer (nn.Module, optional): Final layer class for distribution + regression. Defaults to nn.Identity. + final_layer_args (Dict, optional): Arguments for the final layer class. + dropout_rate (float, optional): Dropout probability. Defaults to 0.0. Returns: _MLP: A Multi-Layer-Perceptron model. @@ -138,7 +153,9 @@ def mlp( hidden_dims=hidden_dims, layer=layer, activation=activation, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) @@ -150,7 +167,9 @@ def packed_mlp( alpha: float = 2, gamma: float = 1, activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _MLP: layer_args = { "num_estimators": num_estimators, @@ -165,7 +184,9 @@ def packed_mlp( layer=PackedLinear, activation=activation, layer_args=layer_args, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) @@ -174,7 +195,9 @@ def bayesian_mlp( num_outputs: int, hidden_dims: list[int], activation: Callable = F.relu, - dropout: float = 0.0, + final_layer: nn.Module = nn.Identity, + final_layer_args: dict | None = None, + dropout_rate: float = 0.0, ) -> _StochasticMLP: return _mlp( stochastic=True, @@ -183,5 +206,7 @@ def bayesian_mlp( hidden_dims=hidden_dims, layer=BayesLinear, activation=activation, - dropout=dropout, + final_layer=final_layer, + final_layer_args=final_layer_args, + dropout_rate=dropout_rate, ) From 7892984167af7bf2c0c904f8edf9a953f5d91a56 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 16:20:34 +0100 Subject: [PATCH 038/148] :heavy_check_mark: Update tests --- tests/baselines/test_packed.py | 5 ----- tests/baselines/test_standard.py | 2 -- 2 files changed, 7 deletions(-) diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index c74947d0..0404bd1b 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -5,9 +5,6 @@ from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet from torch_uncertainty.baselines.regression import MLP -from torch_uncertainty.optimization_procedures import ( - optim_cifar10_resnet18, -) class TestPackedBaseline: @@ -133,13 +130,11 @@ def test_packed(self): in_features=3, num_outputs=10, loss=nn.MSELoss, - optimization_procedure=optim_cifar10_resnet18, version="packed", hidden_dims=[1], num_estimators=2, alpha=2, gamma=1, - dist_estimation=1, ) summary(net) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index 60d853cb..a76ee35f 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -111,7 +111,6 @@ def test_standard(self): loss=nn.MSELoss, version="std", hidden_dims=[1], - dist_estimation=1, ) summary(net) @@ -130,5 +129,4 @@ def test_errors(self): loss=nn.MSELoss, version="test", hidden_dims=[1], - dist_estimation=1, ) From e6d4059c0464d33b220da8fd6b049f17e6853e80 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 16:26:09 +0100 Subject: [PATCH 039/148] :zap: Prepare version update --- docs/source/conf.py | 7 +++++-- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index c442fdbc..f86e98bd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,9 +10,12 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "TorchUncertainty" -copyright = f"{datetime.utcnow().year!s}, Adrien Lafage and Olivier Laurent" # noqa: A001 + +copyright = ( # noqa: A001 + f"{datetime.now(datetime.UTC).year!s}, Adrien Lafage and Olivier Laurent" +) author = "Adrien Lafage and Olivier Laurent" -release = "0.1.6" +release = "0.2.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 9f7ad239..6836a04e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.1.6" +version = "0.2.0" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, From 810c11dbc2271ba8a42125c3b501a380d2a2e6b7 Mon Sep 17 00:00:00 2001 From: alafage Date: Sat, 16 Mar 2024 17:51:55 +0100 Subject: [PATCH 040/148] :bug: Fix ``setup()`` in Classification DataModules --- torch_uncertainty/datamodules/classification/cifar10.py | 2 +- torch_uncertainty/datamodules/classification/cifar100.py | 2 +- torch_uncertainty/datamodules/classification/imagenet.py | 2 +- torch_uncertainty/datamodules/classification/mnist.py | 2 +- torch_uncertainty/datamodules/classification/tiny_imagenet.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 73efc2ae..b6536e90 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -184,7 +184,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") def train_dataloader(self) -> DataLoader: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 3f75dde7..a6b68e56 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -184,7 +184,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") def train_dataloader(self) -> DataLoader: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 9f19eec5..20196b1a 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -227,7 +227,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: split="val", transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index e4bd6107..cd4997c3 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -134,7 +134,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index f323fc31..f99144fb 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -147,7 +147,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: split="val", transform=self.test_transform, ) - else: + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: From 1c412bb47b405056fcacfbd5113f6802998ffc2e Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 17:20:51 +0100 Subject: [PATCH 041/148] :bug: Fix conf bug --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index f86e98bd..4dc73558 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,7 +12,7 @@ project = "TorchUncertainty" copyright = ( # noqa: A001 - f"{datetime.now(datetime.UTC).year!s}, Adrien Lafage and Olivier Laurent" + f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" release = "0.2.0" From 833920522068ce6aa8b9e4bba3a6255bb9cfadbf Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 19:59:31 +0100 Subject: [PATCH 042/148] :white_check_mark: Add tests for the regression routine --- tests/_dummies/baseline.py | 41 ++--- tests/_dummies/datamodule.py | 25 +-- tests/_dummies/model.py | 19 ++- tests/routines/test_regression.py | 189 ++++++----------------- torch_uncertainty/routines/regression.py | 6 + torch_uncertainty/transforms/batch.py | 2 +- torch_uncertainty/utils/distributions.py | 21 +++ 7 files changed, 104 insertions(+), 199 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index fcffee6e..20f1c6c5 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -1,6 +1,10 @@ +import copy + from pytorch_lightning import LightningModule from torch import nn +from torch_uncertainty.layers.distributions import IndptNormalLayer +from torch_uncertainty.models.deep_ensembles import deep_ensembles from torch_uncertainty.routines import ClassificationRoutine, RegressionRoutine from torch_uncertainty.transforms import RepeatTarget @@ -44,52 +48,39 @@ def __new__( num_estimators=2, ) - # @classmethod - # def add_model_specific_args( - # cls, - # parser: ArgumentParser, - # ) -> ArgumentParser: - # return ClassificationEnsemble.add_model_specific_args(parser) - class DummyRegressionBaseline: def __new__( cls, in_features: int, - out_features: int, + num_outputs: int, loss: type[nn.Module], baseline_type: str = "single", - dist_estimation: int = 1, - **kwargs, + optimization_procedure=None, ) -> LightningModule: + kwargs = {} model = dummy_model( in_channels=in_features, - num_classes=out_features, - num_estimators=1 + int(baseline_type == "ensemble"), + num_classes=num_outputs * 2, + num_estimators=1, + last_layer=IndptNormalLayer(num_outputs), ) - if baseline_type == "single": return RegressionRoutine( - out_features=out_features, + num_outputs=num_outputs * 2, model=model, loss=loss, - dist_estimation=dist_estimation, num_estimators=1, + optimization_procedure=optimization_procedure, ) # baseline_type == "ensemble": kwargs["num_estimators"] = 2 + model = deep_ensembles([model, copy.deepcopy(model)], task="regression") return RegressionRoutine( + num_outputs=num_outputs * 2, model=model, loss=loss, - dist_estimation=dist_estimation, - mode="mean", - out_features=out_features, num_estimators=2, + optimization_procedure=optimization_procedure, + format_batch_fn=RepeatTarget(2), ) - - # @classmethod - # def add_model_specific_args( - # cls, - # parser: ArgumentParser, - # ) -> ArgumentParser: - # return ClassificationEnsemble.add_model_specific_args(parser) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index a6b15d2a..040f2e35 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -116,13 +116,11 @@ class DummyRegressionDataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, out_features: int = 2, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, - **kwargs, ) -> None: super().__init__( root=root, @@ -130,9 +128,9 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, + val_split=0, ) - self.eval_ood = eval_ood self.out_features = out_features self.dataset = DummyRegressionDataset @@ -162,25 +160,6 @@ def setup(self, stage: str | None = None) -> None: out_features=self.out_features, transform=self.test_transform, ) - if self.eval_ood: - self.ood = self.ood_dataset( - self.root, - out_features=self.out_features, - transform=self.test_transform, - ) def test_dataloader(self) -> DataLoader | list[DataLoader]: - dataloader = [self._data_loader(self.test)] - if self.eval_ood: - dataloader.append(self._data_loader(self.ood)) - return dataloader - - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - p.add_argument("--eval-ood", action="store_true") - return parent_parser + return [self._data_loader(self.test)] diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 51b65a77..09a27806 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -14,6 +14,7 @@ def __init__( num_estimators: int, dropout_rate: float, with_linear: bool, + last_layer: nn.Module, ) -> None: super().__init__() self.dropout_rate = dropout_rate @@ -28,15 +29,19 @@ def __init__( 1, num_classes, ) + self.last_layer = last_layer self.dropout = nn.Dropout(p=dropout_rate) self.num_estimators = num_estimators def forward(self, x: Tensor) -> Tensor: - return self.dropout( - self.linear( - torch.ones( - (x.shape[0] * self.num_estimators, 1), dtype=torch.float32 + return self.last_layer( + self.dropout( + self.linear( + torch.ones( + (x.shape[0] * self.num_estimators, 1), + dtype=torch.float32, + ) ) ) ) @@ -54,6 +59,7 @@ def dummy_model( dropout_rate: float = 0.0, with_feats: bool = True, with_linear: bool = True, + last_layer=None, ) -> _Dummy: """Dummy model for testing purposes. @@ -65,10 +71,13 @@ def dummy_model( with_feats (bool, optional): Whether to include features. Defaults to True. with_linear (bool, optional): Whether to include a linear layer. Defaults to True. + last_layer ([type], optional): Last layer of the model. Defaults to None. Returns: _Dummy: Dummy model. """ + if last_layer is None: + last_layer = nn.Identity() if with_feats: return _DummyWithFeats( in_channels=in_channels, @@ -76,6 +85,7 @@ def dummy_model( num_estimators=num_estimators, dropout_rate=dropout_rate, with_linear=with_linear, + last_layer=last_layer, ) return _Dummy( in_channels=in_channels, @@ -83,4 +93,5 @@ def dummy_model( num_estimators=num_estimators, dropout_rate=dropout_rate, with_linear=with_linear, + last_layer=last_layer, ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 70ca4b48..1533c930 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -1,160 +1,57 @@ -# from functools import partial -# from pathlib import Path +from pathlib import Path -# import pytest -# from cli_test_helpers import ArgvContext -# from torch import nn +import pytest +from lightning.pytorch import Trainer +from torch import nn -# from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule -# from torch_uncertainty import cli_main, init_args -# from torch_uncertainty.losses import BetaNLL, NIGLoss -# from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule +from torch_uncertainty.losses import DistributionNLL +from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from torch_uncertainty.routines import RegressionRoutine -# class TestRegressionSingle: -# """Testing the Regression routine with a single model.""" +class TestRegression: + """Testing the Regression routine.""" -# def test_cli_main_dummy_dist(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) + def test_main_one_estimator(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) -# # datamodule -# args.root = str(root / "data") -# dm = DummyRegressionDataModule(out_features=1, **vars(args)) + root = Path(__file__).parent.absolute().parents[0] / "data" + # datamodule + dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) -# model = DummyRegressionBaseline( -# in_features=dm.in_features, -# out_features=2, -# loss=nn.GaussianNLLLoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# dist_estimation=2, -# **vars(args), -# ) + model = DummyRegressionBaseline( + in_features=dm.in_features, + num_outputs=1, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + ) -# cli_main(model, dm, root, "logs/dummy", args) + trainer.fit(model, dm) + trainer.test(model, dm) -# def test_cli_main_dummy_dist_der(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) + def test_main_two_estimators(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) -# # datamodule -# args.root = str(root / "data") -# dm = DummyRegressionDataModule(out_features=1, **vars(args)) + root = Path(__file__).parent.absolute().parents[0] / "data" + # datamodule + dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) -# loss = partial( -# NIGLoss, -# reg_weight=1e-2, -# ) + model = DummyRegressionBaseline( + in_features=dm.in_features, + num_outputs=2, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="ensemble", + ) -# model = DummyRegressionBaseline( -# in_features=dm.in_features, -# out_features=4, -# loss=loss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# dist_estimation=4, -# **vars(args), -# ) + trainer.fit(model, dm) + trainer.test(model, dm) -# cli_main(model, dm, root, "logs/dummy_der", args) + def test_regression_failures(self): + with pytest.raises(ValueError): + RegressionRoutine(1, nn.Identity(), nn.MSELoss, num_estimators=0) -# def test_cli_main_dummy_dist_betanll(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyRegressionDataModule(out_features=1, **vars(args)) - -# loss = partial( -# BetaNLL, -# beta=0.5, -# ) - -# model = DummyRegressionBaseline( -# in_features=dm.in_features, -# out_features=2, -# loss=loss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# dist_estimation=2, -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy_betanll", args) - -# def test_cli_main_dummy(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyRegressionDataModule(out_features=2, **vars(args)) - -# model = DummyRegressionBaseline( -# in_features=dm.in_features, -# out_features=dm.out_features, -# loss=nn.MSELoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy", args) - -# def test_regression_failures(self): -# with pytest.raises(ValueError): -# DummyRegressionBaseline( -# in_features=10, -# out_features=3, -# loss=nn.GaussianNLLLoss, -# optimization_procedure=optim_cifar10_resnet18, -# dist_estimation=4, -# ) - -# with pytest.raises(ValueError): -# DummyRegressionBaseline( -# in_features=10, -# out_features=3, -# loss=nn.GaussianNLLLoss, -# optimization_procedure=optim_cifar10_resnet18, -# dist_estimation=-4, -# ) - -# with pytest.raises(TypeError): -# DummyRegressionBaseline( -# in_features=10, -# out_features=4, -# loss=nn.GaussianNLLLoss, -# optimization_procedure=optim_cifar10_resnet18, -# dist_estimation=4.2, -# ) - - -# class TestRegressionEnsemble: -# """Testing the Regression routine with an ensemble model.""" - -# def test_cli_main_dummy(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(DummyRegressionBaseline, DummyRegressionDataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyRegressionDataModule(out_features=1, **vars(args)) - -# model = DummyRegressionBaseline( -# in_features=dm.in_features, -# out_features=dm.out_features, -# loss=nn.MSELoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="ensemble", -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy", args) + with pytest.raises(ValueError): + RegressionRoutine(0, nn.Identity(), nn.MSELoss, num_estimators=1) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 11d869fd..79dcd0b3 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -21,6 +21,7 @@ def __init__( loss: type[nn.Module], num_estimators: int = 1, format_batch_fn: nn.Module | None = None, + optimization_procedure=None, ) -> None: super().__init__() @@ -59,6 +60,11 @@ def __init__( if num_outputs == 1: self.one_dim_regression = True + self.optimization_procedure = optimization_procedure + + def configure_optimizers(self): + return self.optimization_procedure(self.model) + def on_train_start(self) -> None: # hyperparameters for performances init_metrics = {k: 0 for k in self.val_metrics} diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index 29e035e8..600cea3d 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -25,7 +25,7 @@ def __init__(self, num_repeats: int) -> None: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat(self.num_repeats) + return inputs, targets.repeat_interleave(self.num_repeats, dim=0) class MIMOBatchFormat(nn.Module): diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index be68dee2..d69b2f8e 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -32,6 +32,27 @@ def cat_dist(distributions: list[Distribution]) -> Distribution: ) +def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: + """Squeeze the distribution along a given dimension. + + Args: + distribution (Distribution): The distribution to squeeze. + dim (int): The dimension to squeeze. + + Returns: + Distribution: The squeezed distribution. + """ + dist_type = type(distribution) + if isinstance(distribution, Normal | Laplace): + loc = distribution.loc.squeeze(dim) + scale = distribution.scale.squeeze(dim) + return dist_type(loc=loc, scale=scale) + raise NotImplementedError( + f"Squeezing of {dist_type} distributions is not supported." + "Raise an issue if needed." + ) + + def to_ens_dist( distribution: Distribution, num_estimators: int = 1 ) -> Distribution: From 422e21cc7fcc0f50bbb3732376385113b4c3bd68 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 20:11:54 +0100 Subject: [PATCH 043/148] :white_check_mark: Fix tests --- tests/baselines/test_packed.py | 1 - tests/baselines/test_standard.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index 0404bd1b..250e38cc 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -139,5 +139,4 @@ def test_packed(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3)) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index a76ee35f..c52f438b 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -115,7 +115,6 @@ def test_standard(self): summary(net) _ = net.criterion - _ = net.configure_optimizers() _ = net(torch.rand(1, 3)) parser = ArgumentParser() From 16814302dd32823bd95a263c9b036fa7e9f370b6 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 16 Mar 2024 20:18:30 +0100 Subject: [PATCH 044/148] :zap: Allow sphinx6 --- .gitignore | 1 + auto_tutorials_source/tutorial_corruptions.py | 2 +- auto_tutorials_source/tutorial_pe_cifar10.py | 2 +- pyproject.toml | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 6ed77954..8e9d6307 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ docs/*/auto_tutorials/ *.pth *.ckpt *.out +sg_execution_times.rst # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/auto_tutorials_source/tutorial_corruptions.py b/auto_tutorials_source/tutorial_corruptions.py index 709fa0a7..6713834f 100644 --- a/auto_tutorials_source/tutorial_corruptions.py +++ b/auto_tutorials_source/tutorial_corruptions.py @@ -105,7 +105,7 @@ def show_images(transform): #%% # 10. Frost -# ~~~~~~~~ +# ~~~~~~~~~ from torch_uncertainty.transforms.corruptions import Frost show_images(Frost) diff --git a/auto_tutorials_source/tutorial_pe_cifar10.py b/auto_tutorials_source/tutorial_pe_cifar10.py index 8a4af0e7..52820064 100644 --- a/auto_tutorials_source/tutorial_pe_cifar10.py +++ b/auto_tutorials_source/tutorial_pe_cifar10.py @@ -119,7 +119,7 @@ def imshow(img): # %% # 2. Define a Packed-Ensemble from a standard classifier -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # First we define a standard classifier for CIFAR10 for reference. We will use a # convolutional neural network. diff --git a/pyproject.toml b/pyproject.toml index 6836a04e..80eb453f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ dev = [ "cli-test-helpers", ] docs = [ - "sphinx<6", + "sphinx<7", "tu_sphinx_theme", "sphinx-copybutton", "sphinx-gallery", From 5c428afbb837a1e29f6fe84d6a356e7560995cff Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 09:31:41 +0100 Subject: [PATCH 045/148] :hammer: Rename readmes --- experiments/classification/{README.md => readme.md} | 0 experiments/{README.md => readme.md} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename experiments/classification/{README.md => readme.md} (100%) rename experiments/{README.md => readme.md} (100%) diff --git a/experiments/classification/README.md b/experiments/classification/readme.md similarity index 100% rename from experiments/classification/README.md rename to experiments/classification/readme.md diff --git a/experiments/README.md b/experiments/readme.md similarity index 100% rename from experiments/README.md rename to experiments/readme.md From 0986bda27c6dc2dd281c5f7d46f6624423885a13 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 09:32:52 +0100 Subject: [PATCH 046/148] :bug: Add optimization proc. to cls routine --- .../cifar10/configs/resnet18/standard.yaml | 2 +- .../cifar10/configs/resnet50/standard.yaml | 2 +- .../cifar10/configs/wideresnet28x10/standard.yaml | 2 +- torch_uncertainty/routines/classification.py | 13 +++++++++---- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index e930813d..d6ab70f9 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -30,7 +30,7 @@ model: num_classes: 10 in_channels: 3 loss: torch.nn.CrossEntropyLoss - version: standard + version: std arch: 18 style: cifar data: diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index 6e0e719e..02743adb 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -30,7 +30,7 @@ model: num_classes: 10 in_channels: 3 loss: torch.nn.CrossEntropyLoss - version: standard + version: std arch: 50 style: cifar data: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index ebfd0f2f..875ec995 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -30,7 +30,7 @@ model: num_classes: 10 in_channels: 3 loss: torch.nn.CrossEntropyLoss - version: standard + version: std style: cifar data: root: ./data diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 086aaf08..b9a67005 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -43,6 +43,7 @@ def __init__( loss: type[nn.Module], num_estimators: int, format_batch_fn: nn.Module | None = None, + optimization_procedure=None, mixtype: str = "erm", mixmode: str = "elem", dist_sim: str = "emb", @@ -68,6 +69,7 @@ def __init__( num_estimators (int): _description_ format_batch_fn (nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. + optimization_procedure (optional): Training recipe. Defaults to None. mixtype (str, optional): Mixup type. Defaults to ``"erm"``. mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. @@ -160,8 +162,8 @@ def __init__( self.model = model self.loss = loss - # batch format self.format_batch_fn = format_batch_fn + self.optimization_procedure = optimization_procedure # metrics if self.binary_cls: @@ -300,6 +302,9 @@ def init_mixup( ) return nn.Identity() + def configure_optimizers(self): + return self.optimization_procedure(self.model) + def on_train_start(self) -> None: init_metrics = {k: 0 for k in self.val_cls_metrics} init_metrics.update({k: 0 for k in self.test_cls_metrics}) @@ -374,11 +379,11 @@ def training_step( with torch.no_grad(): feats = self.model.feats_forward(batch[0]).detach() - batch = self.mixup(*batch, feats) + batch = self.mixup(batch, feats) elif self.dist_sim == "inp": - batch = self.mixup(*batch, batch[0]) + batch = self.mixup(batch, batch[0]) else: - batch = self.mixup(*batch) + batch = self.mixup(batch) inputs, targets = self.format_batch_fn(batch) From b367f5ceb3dfd78aeacc5286abc4b3a9f0741d3a Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 10:16:32 +0100 Subject: [PATCH 047/148] :white_check_mark: Improve regression tests & routine --- tests/_dummies/baseline.py | 6 +- tests/_dummies/dataset.py | 5 +- tests/routines/test_regression.py | 48 +++++++++-- torch_uncertainty/baselines/regression/mlp.py | 1 + torch_uncertainty/routines/regression.py | 86 +++++++++++++------ torch_uncertainty/utils/distributions.py | 2 +- 6 files changed, 114 insertions(+), 34 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 20f1c6c5..74d56e08 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -67,7 +67,8 @@ def __new__( ) if baseline_type == "single": return RegressionRoutine( - num_outputs=num_outputs * 2, + probabilistic=True, + num_outputs=num_outputs, model=model, loss=loss, num_estimators=1, @@ -77,7 +78,8 @@ def __new__( kwargs["num_estimators"] = 2 model = deep_ensembles([model, copy.deepcopy(model)], task="regression") return RegressionRoutine( - num_outputs=num_outputs * 2, + probabilistic=True, + num_outputs=num_outputs, model=model, loss=loss, num_estimators=2, diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 9c35c7d2..59183811 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -122,7 +122,10 @@ def __init__( self.targets = [] input_shape = (num_samples, in_features) - output_shape = (num_samples, out_features) + if out_features != 1: + output_shape = (num_samples, out_features) + else: + output_shape = (num_samples,) self.data = torch.rand( size=input_shape, diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 1533c930..00c783c7 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -13,11 +13,10 @@ class TestRegression: """Testing the Regression routine.""" - def test_main_one_estimator(self): + def test_one_estimator_one_output(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" - # datamodule dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) model = DummyRegressionBaseline( @@ -31,11 +30,44 @@ def test_main_one_estimator(self): trainer.fit(model, dm) trainer.test(model, dm) - def test_main_two_estimators(self): + def test_one_estimator_two_outputs(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) + + model = DummyRegressionBaseline( + in_features=dm.in_features, + num_outputs=2, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + ) + + trainer.fit(model, dm) + trainer.test(model, dm) + + def test_two_estimators_one_output(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) + + model = DummyRegressionBaseline( + in_features=dm.in_features, + num_outputs=1, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="ensemble", + ) + + trainer.fit(model, dm) + trainer.test(model, dm) + + def test_two_estimators_two_outputs(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" - # datamodule dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) model = DummyRegressionBaseline( @@ -51,7 +83,11 @@ def test_main_two_estimators(self): def test_regression_failures(self): with pytest.raises(ValueError): - RegressionRoutine(1, nn.Identity(), nn.MSELoss, num_estimators=0) + RegressionRoutine( + True, 1, nn.Identity(), nn.MSELoss, num_estimators=0 + ) with pytest.raises(ValueError): - RegressionRoutine(0, nn.Identity(), nn.MSELoss, num_estimators=1) + RegressionRoutine( + True, 0, nn.Identity(), nn.MSELoss, num_estimators=1 + ) diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 25d94e39..1dc2ba80 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -51,6 +51,7 @@ def __init__( # version in self.versions: super().__init__( + probabilistic=False, num_outputs=num_outputs, model=model, loss=loss, diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 79dcd0b3..98edaf80 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -10,12 +10,13 @@ from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torch_uncertainty.metrics.nll import DistributionNLL -from torch_uncertainty.utils.distributions import to_ens_dist +from torch_uncertainty.utils.distributions import squeeze_dist, to_ensemble_dist class RegressionRoutine(LightningModule): def __init__( self, + probabilistic: bool, num_outputs: int, model: nn.Module, loss: type[nn.Module], @@ -23,8 +24,28 @@ def __init__( format_batch_fn: nn.Module | None = None, optimization_procedure=None, ) -> None: + """Regression routine for PyTorch Lightning. + + Args: + probabilistic (bool): Whether the model is probabilistic, i.e., + outputs a PyTorch distribution. + num_outputs (int): The number of outputs of the model. + model (nn.Module): The model to train. + loss (type[nn.Module]): The loss function to use. + num_estimators (int, optional): The number of estimators for the + ensemble. Defaults to 1. + format_batch_fn (nn.Module, optional): The function to format the + batch. Defaults to None. + optimization_procedure (optional): The optimization procedure + to use. Defaults to None. + + Warning: + If :attr:`probabilistic` is True, the model must output a `PyTorch + distribution _`. + """ super().__init__() + self.probabilistic = probabilistic self.model = model self.loss = loss @@ -37,10 +58,12 @@ def __init__( { "mae": MeanAbsoluteError(), "mse": MeanSquaredError(squared=False), - "nll": DistributionNLL(reduction="mean"), }, compute_groups=False, ) + if self.probabilistic: + reg_metrics["nll"] = DistributionNLL(reduction="mean") + self.val_metrics = reg_metrics.clone(prefix="reg_val/") self.test_metrics = reg_metrics.clone(prefix="reg_test/") @@ -77,7 +100,13 @@ def on_train_start(self) -> None: ) def forward(self, inputs: Tensor) -> Tensor: - return self.model.forward(inputs) + pred = self.model(inputs) + if self.probabilistic: + if self.one_dim_regression: + pred = squeeze_dist(pred, -1) + if self.num_estimators == 1: + pred = squeeze_dist(pred, -1) + return pred @property def criterion(self) -> nn.Module: @@ -87,8 +116,7 @@ def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: inputs, targets = self.format_batch_fn(batch) - - dists = self.forward(inputs) + dists = self.model(inputs) if self.one_dim_regression: targets = targets.unsqueeze(-1) @@ -102,21 +130,26 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch + pred = self.model(inputs) - dists = self.forward(inputs) - - ens_dist = Independent( - to_ens_dist(dists, num_estimators=self.num_estimators), 1 - ) - mix = Categorical(torch.ones(self.num_estimators, device=self.device)) - mixture = MixtureSameFamily(mix, ens_dist) + if self.probabilistic: + ens_dist = Independent( + to_ensemble_dist(pred, num_estimators=self.num_estimators), 1 + ) + mix = Categorical( + torch.ones(self.num_estimators, device=self.device) + ) + mixture = MixtureSameFamily(mix, ens_dist) + pred = mixture.mean if self.one_dim_regression: + print("one dim") targets = targets.unsqueeze(-1) - self.val_metrics.mse.update(mixture.mean, targets) - self.val_metrics.mae.update(mixture.mean, targets) - self.val_metrics.nll.update(mixture, targets) + self.val_metrics.mse.update(pred, targets) + self.val_metrics.mae.update(pred, targets) + if self.probabilistic: + self.val_metrics.nll.update(mixture, targets) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute()) @@ -135,20 +168,25 @@ def test_step( ) inputs, targets = batch - dists = self.forward(inputs) - ens_dist = Independent( - to_ens_dist(dists, num_estimators=self.num_estimators), 1 - ) + pred = self.model(inputs) - mix = Categorical(torch.ones(self.num_estimators, device=self.device)) - mixture = MixtureSameFamily(mix, ens_dist) + if self.probabilistic: + ens_dist = Independent( + to_ensemble_dist(pred, num_estimators=self.num_estimators), 1 + ) + mix = Categorical( + torch.ones(self.num_estimators, device=self.device) + ) + mixture = MixtureSameFamily(mix, ens_dist) + pred = mixture.mean if self.one_dim_regression: targets = targets.unsqueeze(-1) - self.test_metrics.mae.update(mixture.mean, targets) - self.test_metrics.mse.update(mixture.mean, targets) - self.test_metrics.nll.update(mixture, targets) + self.test_metrics.mse.update(pred, targets) + self.test_metrics.mae.update(pred, targets) + if self.probabilistic: + self.test_metrics.nll.update(mixture, targets) def on_test_epoch_end(self) -> None: self.log_dict( diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index d69b2f8e..cb8f7cfa 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -53,7 +53,7 @@ def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: ) -def to_ens_dist( +def to_ensemble_dist( distribution: Distribution, num_estimators: int = 1 ) -> Distribution: dist_type = type(distribution) From 1fa1ae29654d28af47bcc734b51f2fb17d3cfdcb Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 10:52:01 +0100 Subject: [PATCH 048/148] :sparkles: Enable pointwise reg. ensembles --- tests/_dummies/baseline.py | 19 ++++--- tests/routines/test_regression.py | 54 ++++++++++++++++++ torch_uncertainty/models/deep_ensembles.py | 32 ++++++----- torch_uncertainty/routines/regression.py | 64 +++++++++++++++------- torch_uncertainty/utils/distributions.py | 7 ++- 5 files changed, 134 insertions(+), 42 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 74d56e08..f556bdd6 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -52,22 +52,24 @@ def __new__( class DummyRegressionBaseline: def __new__( cls, + probabilistic: bool, in_features: int, num_outputs: int, loss: type[nn.Module], baseline_type: str = "single", optimization_procedure=None, ) -> LightningModule: - kwargs = {} model = dummy_model( in_channels=in_features, - num_classes=num_outputs * 2, + num_classes=num_outputs * 2 if probabilistic else num_outputs, num_estimators=1, - last_layer=IndptNormalLayer(num_outputs), + last_layer=IndptNormalLayer(num_outputs) + if probabilistic + else nn.Identity(), ) if baseline_type == "single": return RegressionRoutine( - probabilistic=True, + probabilistic=probabilistic, num_outputs=num_outputs, model=model, loss=loss, @@ -75,10 +77,13 @@ def __new__( optimization_procedure=optimization_procedure, ) # baseline_type == "ensemble": - kwargs["num_estimators"] = 2 - model = deep_ensembles([model, copy.deepcopy(model)], task="regression") + model = deep_ensembles( + [model, copy.deepcopy(model)], + task="regression", + probabilistic=probabilistic, + ) return RegressionRoutine( - probabilistic=True, + probabilistic=probabilistic, num_outputs=num_outputs, model=model, loss=loss, diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 00c783c7..613335b2 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -20,6 +20,21 @@ def test_one_estimator_one_output(self): dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + num_outputs=1, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + ) + + trainer.fit(model, dm) + trainer.test(model, dm) + + model(dm.get_test_set()[0][0]) + + model = DummyRegressionBaseline( + probabilistic=False, in_features=dm.in_features, num_outputs=1, loss=DistributionNLL, @@ -37,6 +52,19 @@ def test_one_estimator_two_outputs(self): dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + num_outputs=2, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + ) + + trainer.fit(model, dm) + trainer.test(model, dm) + + model = DummyRegressionBaseline( + probabilistic=False, in_features=dm.in_features, num_outputs=2, loss=DistributionNLL, @@ -54,6 +82,19 @@ def test_two_estimators_one_output(self): dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + num_outputs=1, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="ensemble", + ) + + trainer.fit(model, dm) + trainer.test(model, dm) + + model = DummyRegressionBaseline( + probabilistic=False, in_features=dm.in_features, num_outputs=1, loss=DistributionNLL, @@ -71,6 +112,19 @@ def test_two_estimators_two_outputs(self): dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) model = DummyRegressionBaseline( + probabilistic=True, + in_features=dm.in_features, + num_outputs=2, + loss=DistributionNLL, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="ensemble", + ) + + trainer.fit(model, dm) + trainer.test(model, dm) + + model = DummyRegressionBaseline( + probabilistic=False, in_features=dm.in_features, num_outputs=2, loss=DistributionNLL, diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index 1b8dc88f..c5b80eac 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -8,12 +8,12 @@ from torch_uncertainty.utils.distributions import cat_dist -class _ClsDeepEnsembles(nn.Module): +class _DeepEnsembles(nn.Module): def __init__( self, models: list[nn.Module], ) -> None: - """Create a deep ensembles from a list of models.""" + """Create a classification deep ensembles from a list of models.""" super().__init__() self.models = nn.ModuleList(models) @@ -33,16 +33,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat([model.forward(x) for model in self.models], dim=0) -class _RegDeepEnsembles(nn.Module): +class _RegDeepEnsembles(_DeepEnsembles): def __init__( self, + probabilistic: bool, models: list[nn.Module], ) -> None: - """Create a deep ensembles from a list of models.""" - super().__init__() + """Create a regression deep ensembles from a list of models.""" + super().__init__(models) - self.models = nn.ModuleList(models) - self.num_estimators = len(models) + self.probabilistic = probabilistic def forward(self, x: torch.Tensor) -> Distribution: r"""Return the logits of the ensemble. @@ -53,13 +53,16 @@ def forward(self, x: torch.Tensor) -> Distribution: Returns: Distribution: """ - return cat_dist([model.forward(x) for model in self.models]) + if self.probabilistic: + return cat_dist([model.forward(x) for model in self.models], dim=0) + return super().forward(x) def deep_ensembles( models: list[nn.Module] | nn.Module, num_estimators: int | None = None, task: Literal["classification", "regression"] = "classification", + probabilistic=None, ) -> nn.Module: """Build a Deep Ensembles out of the original models. @@ -67,6 +70,7 @@ def deep_ensembles( models (list[nn.Module] | nn.Module): The model to be ensembled. num_estimators (int | None): The number of estimators in the ensemble. task (Literal["classification", "regression"]): The model task. + probabilistic (bool): Whether the regression model is probabilistic. Returns: nn.Module: The ensembled model. @@ -112,9 +116,11 @@ def deep_ensembles( ) if task == "classification": - return _ClsDeepEnsembles(models=models) + return _DeepEnsembles(models=models) if task == "regression": - return _RegDeepEnsembles(models=models) - raise ValueError( - f"task must be either 'classification' or 'regression'. Got {task}." - ) + if probabilistic is None: + raise ValueError( + "probabilistic must be specified for regression models." + ) + return _RegDeepEnsembles(probabilistic=probabilistic, models=models) + raise ValueError(f"Unknown task: {task}.") diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 98edaf80..f846228a 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -1,4 +1,5 @@ import torch +from einops import rearrange from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, nn @@ -42,6 +43,10 @@ def __init__( Warning: If :attr:`probabilistic` is True, the model must output a `PyTorch distribution _`. + + Warning: + You must define :attr:`optimization_procedure` if you do not use + the CLI. """ super().__init__() @@ -100,16 +105,33 @@ def on_train_start(self) -> None: ) def forward(self, inputs: Tensor) -> Tensor: + """Forward pass of the routine. + + The forward pass automatically squeezes the output if the regression + is one-dimensional and if the routine contains a single model. + + Args: + inputs (Tensor): The input tensor. + + Returns: + Tensor: The output tensor. + """ pred = self.model(inputs) if self.probabilistic: if self.one_dim_regression: pred = squeeze_dist(pred, -1) if self.num_estimators == 1: pred = squeeze_dist(pred, -1) + else: + if self.one_dim_regression: + pred = pred.squeeze(-1) + if self.num_estimators == 1: + pred = pred.squeeze(-1) return pred @property def criterion(self) -> nn.Module: + """The loss function of the routine.""" return self.loss() def training_step( @@ -130,26 +152,27 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch - pred = self.model(inputs) + if self.one_dim_regression: + targets = targets.unsqueeze(-1) + preds = self.model(inputs) if self.probabilistic: ens_dist = Independent( - to_ensemble_dist(pred, num_estimators=self.num_estimators), 1 + to_ensemble_dist(preds, num_estimators=self.num_estimators), 1 ) mix = Categorical( torch.ones(self.num_estimators, device=self.device) ) mixture = MixtureSameFamily(mix, ens_dist) - pred = mixture.mean + self.val_metrics.nll.update(mixture, targets) - if self.one_dim_regression: - print("one dim") - targets = targets.unsqueeze(-1) + preds = mixture.mean + else: + preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = preds.mean(dim=1) - self.val_metrics.mse.update(pred, targets) - self.val_metrics.mae.update(pred, targets) - if self.probabilistic: - self.val_metrics.nll.update(mixture, targets) + self.val_metrics.mse.update(preds, targets) + self.val_metrics.mae.update(preds, targets) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute()) @@ -168,25 +191,28 @@ def test_step( ) inputs, targets = batch - pred = self.model(inputs) + if self.one_dim_regression: + targets = targets.unsqueeze(-1) + preds = self.model(inputs) if self.probabilistic: ens_dist = Independent( - to_ensemble_dist(pred, num_estimators=self.num_estimators), 1 + to_ensemble_dist(preds, num_estimators=self.num_estimators), 1 ) mix = Categorical( torch.ones(self.num_estimators, device=self.device) ) mixture = MixtureSameFamily(mix, ens_dist) - pred = mixture.mean + self.test_metrics.nll.update(mixture, targets) - if self.one_dim_regression: - targets = targets.unsqueeze(-1) + preds = mixture.mean - self.test_metrics.mse.update(pred, targets) - self.test_metrics.mae.update(pred, targets) - if self.probabilistic: - self.test_metrics.nll.update(mixture, targets) + else: + preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = preds.mean(dim=1) + + self.test_metrics.mse.update(preds, targets) + self.test_metrics.mae.update(preds, targets) def on_test_epoch_end(self) -> None: self.log_dict( diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index cb8f7cfa..c87617d2 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -3,11 +3,12 @@ from torch.distributions import Distribution, Laplace, Normal -def cat_dist(distributions: list[Distribution]) -> Distribution: +def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: r"""Concatenate a list of distributions into a single distribution. Args: distributions (list[Distribution]): The list of distributions. + dim (int): The dimension to concatenate. Returns: Distribution: The concatenated distributions. @@ -20,10 +21,10 @@ def cat_dist(distributions: list[Distribution]) -> Distribution: if isinstance(distributions[0], Normal | Laplace): locs = torch.cat( - [distribution.loc for distribution in distributions], dim=0 + [distribution.loc for distribution in distributions], dim=dim ) scales = torch.cat( - [distribution.scale for distribution in distributions], dim=0 + [distribution.scale for distribution in distributions], dim=dim ) return dist_type(loc=locs, scale=scales) raise NotImplementedError( From 8bfc3b835fd70eb2022052c68fe47a90b29215f9 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 11:06:28 +0100 Subject: [PATCH 049/148] :shirt: Make num_estimators default to 1 in routines --- torch_uncertainty/routines/classification.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index b9a67005..f4ff4655 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -41,7 +41,7 @@ def __init__( num_classes: int, model: nn.Module, loss: type[nn.Module], - num_estimators: int, + num_estimators: int = 1, format_batch_fn: nn.Module | None = None, optimization_procedure=None, mixtype: str = "erm", @@ -66,7 +66,8 @@ def __init__( num_classes (int): Number of classes. model (nn.Module): Model to train. loss (type[nn.Module]): Loss function. - num_estimators (int): _description_ + num_estimators (int, optional): Number of estimators for the + ensemble. Defaults to 1. format_batch_fn (nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. optimization_procedure (optional): Training recipe. Defaults to None. From bb75d94e4e21df039d8c0441603ddbdf17595109 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 11:52:47 +0100 Subject: [PATCH 050/148] :sparkles: Rework 4 tutorials --- auto_tutorials_source/tutorial_bayesian.py | 100 ++++++++---------- .../tutorial_evidential_classification.py | 70 +++++------- .../tutorial_mc_batch_norm.py | 94 +++++++--------- auto_tutorials_source/tutorial_mc_dropout.py | 71 ++++++------- 4 files changed, 138 insertions(+), 197 deletions(-) diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 64c8fce8..dee6c76b 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -2,58 +2,56 @@ Train a Bayesian Neural Network in Three Minutes ================================================ -In this tutorial, we will train a Bayesian Neural Network (BNN) LeNet classifier on the MNIST dataset. +In this tutorial, we will train a variational inference Bayesian Neural Network (BNN) LeNet classifier on the MNIST dataset. Foreword on Bayesian Neural Networks ------------------------------------ -Bayesian Neural Networks (BNNs) are a class of neural networks that can estimate the uncertainty of their predictions via uncertainty on their weights. This is achieved by considering the weights of the neural network as random variables, and by learning their posterior distribution. This is in contrast to standard neural networks, which only learn a single set of weights, which can be seen as Dirac distributions on the weights. +Bayesian Neural Networks (BNNs) are a class of neural networks that estimate the uncertainty on their predictions via uncertainty +on their weights. This is achieved by considering the weights of the neural network as random variables, and by learning their +posterior distribution. This is in contrast to standard neural networks, which only learn a single set of weights, which can be +seen as Dirac distributions on the weights. For more information on Bayesian Neural Networks, we refer the reader to the following resources: - Weight Uncertainty in Neural Networks `ICML2015 `_ - Hands-on Bayesian Neural Networks - a Tutorial for Deep Learning Users `IEEE Computational Intelligence Magazine `_ -Training a Bayesian LeNet using TorchUncertainty models and PyTorch Lightning ------------------------------------------------------------------------------ +Training a Bayesian LeNet using TorchUncertainty models and Lightning +--------------------------------------------------------------------- In this part, we train a bayesian LeNet, based on the model and routines already implemented in TU. 1. Loading the utilities ~~~~~~~~~~~~~~~~~~~~~~~~ -To train a BNN using TorchUncertainty, we have to load the following utilities from TorchUncertainty: +To train a BNN using TorchUncertainty, we have to load the following modules: -- the cli handler: cli_main and argument parser: init_args -- the model: bayesian_lenet, which lies in the torch_uncertainty.model module -- the classification training routine in the torch_uncertainty.training.classification module +- the Trainer from Lightning +- the model: bayesian_lenet, which lies in the torch_uncertainty.model +- the classification training routine from torch_uncertainty.routines - the bayesian objective: the ELBOLoss, which lies in the torch_uncertainty.losses file -- the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule -""" - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.losses import ELBOLoss -from torch_uncertainty.models.lenet import bayesian_lenet -from torch_uncertainty.routines.classification import ClassificationSingle +- the datamodule that handles dataloaders: MNISTDataModule from torch_uncertainty.datamodules +We will also need to define an optimizer using torch.optim, the +neural network utils from torch.nn, as well as the partial util to provide +the modified default arguments for the ELBO loss. +""" # %% -# We will also need to define an optimizer using torch.optim as well as the -# neural network utils withing torch.nn, as well as the partial util to provide -# the modified default arguments for the ELBO loss. -# -# We also import sys to override the command line arguments. - -import os from functools import partial from pathlib import Path -import sys +from lightning.pytorch import Trainer from torch import nn, optim +from torch_uncertainty.datamodules import MNISTDataModule +from torch_uncertainty.losses import ELBOLoss +from torch_uncertainty.models.lenet import bayesian_lenet +from torch_uncertainty.routines import ClassificationRoutine + # %% -# 2. Creating the Optimizer Wrapper -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 2. The Optimization Recipe +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # We will use the Adam optimizer with the default learning rate of 0.001. @@ -69,26 +67,19 @@ def optim_lenet(model: nn.Module) -> dict: # 3. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In the following, we will need to define the root of the datasets and the -# logs, and to fake-parse the arguments needed for using the PyTorch Lightning -# Trainer. We also create the datamodule that handles the MNIST dataset, -# dataloaders and transforms. Finally, we create the model using the -# blueprint from torch_uncertainty.models. - -root = Path(os.path.abspath("")) +# In the following, we define the Lightning trainer, the root of the datasets and the logs. +# We also create the datamodule that handles the MNIST dataset, dataloaders and transforms. +# Please note that the datamodules can also handle OOD detection by setting the eval_ood +# parameter to True. Finally, we create the model using the blueprint from torch_uncertainty.models. -# We mock the arguments for the trainer -sys.argv = ["file.py", "--max_epochs", "1", "--enable_progress_bar", "False"] -args = init_args(datamodule=MNISTDataModule) - -net_name = "logs/bayesian-lenet-mnist" +trainer = Trainer(accelerator="cpu", enable_progress_bar=False, max_epochs=1) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) +root = Path("") / "data" +datamodule = MNISTDataModule(root = root, batch_size=128, eval_ood=False) # model -model = bayesian_lenet(dm.num_channels, dm.num_classes) +model = bayesian_lenet(datamodule.num_channels, datamodule.num_classes) # %% # 4. The Loss and the Training Routine @@ -99,8 +90,8 @@ def optim_lenet(model: nn.Module) -> dict: # library. As we are train a classification model, we use the CrossEntropyLoss # as the likelihood. # We then define the training routine using the classification training routine -# from torch_uncertainty.training.classification. We provide the model, the ELBO -# loss and the optimizer, as well as all the default arguments. +# from torch_uncertainty.classification. We provide the model, the ELBO +# loss and the optimizer to the routine. loss = partial( ELBOLoss, @@ -110,13 +101,11 @@ def optim_lenet(model: nn.Module) -> dict: num_samples=3, ) -baseline = ClassificationSingle( +routine = ClassificationRoutine( model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, + num_classes=datamodule.num_classes, loss=loss, optimization_procedure=optim_lenet, - **vars(args), ) # %% @@ -124,14 +113,14 @@ def optim_lenet(model: nn.Module) -> dict: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Now that we have prepared all of this, we just have to gather everything in -# the main function and to train the model using the PyTorch Lightning Trainer. -# Specifically, it needs the baseline, that includes the model as well as the -# training routine, the datamodule, the root for the datasets and the logs, the -# name of the model for the logs and all the training arguments. +# the main function and to train the model using the Lightning Trainer. +# Specifically, it needs the routine, that includes the model as well as the +# training/eval logic and the datamodule # The dataset will be downloaded automatically in the root/data folder, and the # logs will be saved in the root/logs folder. -results = cli_main(baseline, dm, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% # 6. Testing the Model @@ -140,19 +129,20 @@ def optim_lenet(model: nn.Module) -> dict: # Now that the model is trained, let's test it on MNIST import matplotlib.pyplot as plt +import numpy as np import torch import torchvision -import numpy as np - def imshow(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.axis("off") + plt.tight_layout() plt.show() -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) # print images diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index e7e72bbf..5a551da8 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -16,36 +16,28 @@ To train a LeNet with the DEC loss function using TorchUncertainty, we have to load the following utilities from TorchUncertainty: -- the cli handler: cli_main and argument parser: init_args +- the Trainer from Lightning - the model: LeNet, which lies in torch_uncertainty.models -- the classification training routine in the torch_uncertainty.training.classification module -- the evidential objective: the DECLoss, which lies in the torch_uncertainty.losses file -- the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule -""" - -# %% -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.routines.classification import ClassificationSingle -from torch_uncertainty.losses import DECLoss -from torch_uncertainty.datamodules import MNISTDataModule +- the classification training routine in the torch_uncertainty.routines +- the evidential objective: the DECLoss from torch_uncertainty.losses +- the datamodule that handles dataloaders & transforms: MNISTDataModule from torch_uncertainty.datamodules +We also need to define an optimizer using torch.optim, the neural network utils within torch.nn, as well as the partial util to provide +the modified default arguments for the DEC loss. +""" # %% -# We also need to define an optimizer using torch.optim as well as the -# neural network utils withing torch.nn, as well as the partial util to provide -# the modified default arguments for the DEC loss. -# -# We also import sys to override the command line arguments. - -import os -import sys from functools import partial from pathlib import Path import torch +from lightning.pytorch import Trainer from torch import nn, optim +from torch_uncertainty.datamodules import MNISTDataModule +from torch_uncertainty.losses import DECLoss +from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.routines import ClassificationRoutine # %% # 2. Creating the Optimizer Wrapper @@ -59,7 +51,6 @@ def optim_lenet(model: nn.Module) -> dict: ) return {"optimizer": optimizer, "lr_scheduler": exp_lr_scheduler} - # %% # 3. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -68,22 +59,15 @@ def optim_lenet(model: nn.Module) -> dict: # fake-parse the arguments needed for using the PyTorch Lightning Trainer. We # also use the same MNIST classification example as that used in the # original DEC paper. We only train for 5 epochs for the sake of time. -root = Path(os.path.abspath("")) - -# We mock the arguments for the trainer. Replace with 25 epochs on your machine. -sys.argv = ["file.py", "--max_epochs", "5", "--enable_progress_bar", "True"] -args = init_args(datamodule=MNISTDataModule) - -net_name = "logs/dec-lenet-mnist" +trainer = Trainer(accelerator="cpu", max_epochs=5, enable_progress_bar=False) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) - +root = Path() / "data" +datamodule = MNISTDataModule(root=root, batch_size=128) model = lenet( - in_channels=dm.num_channels, - num_classes=dm.num_classes, + in_channels=datamodule.num_channels, + num_classes=datamodule.num_classes, ) # %% @@ -102,20 +86,19 @@ def optim_lenet(model: nn.Module) -> dict: reg_weight=1e-2, ) -baseline = ClassificationSingle( +routine = ClassificationRoutine( model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, + num_classes=datamodule.num_classes, loss=loss, optimization_procedure=optim_lenet, - **vars(args), ) # %% # 5. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -results = cli_main(baseline, dm, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% # 6. Testing the Model @@ -123,12 +106,10 @@ def optim_lenet(model: nn.Module) -> dict: # Now that the model is trained, let's test it on MNIST. import matplotlib.pyplot as plt -import torch +import numpy as np import torchvision import torchvision.transforms.functional as F -import numpy as np - def imshow(img) -> None: npimg = img.numpy() @@ -148,7 +129,7 @@ def rotated_mnist(angle: int) -> None: imshow(torchvision.utils.make_grid(rotated_images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) - evidence = baseline(rotated_images) + evidence = routine(rotated_images) alpha = torch.relu(evidence) + 1 strength = torch.sum(alpha, dim=1, keepdim=True) probs = alpha / strength @@ -161,16 +142,15 @@ def rotated_mnist(angle: int) -> None: ) -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) with torch.no_grad(): - baseline.eval() + routine.eval() rotated_mnist(0) rotated_mnist(45) rotated_mnist(90) - # %% # References # ---------- diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index d4a9e2bd..fec45261 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -13,93 +13,72 @@ First, we have to load the following utilities from TorchUncertainty: -- the cli handler: cli_main and argument parser: init_args +- the Trainer from Lightning - the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule - the model: LeNet, which lies in torch_uncertainty.models - the mc-batch-norm wrapper: mc_dropout, which lies in torch_uncertainty.models - a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module - the optimizer wrapper in the torch_uncertainty.optimization_procedures module. -""" -# %% -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.post_processing.mc_batch_norm import MCBatchNorm -from torch_uncertainty.baselines.classification import ResNet -from torch_uncertainty.routines.classification import ClassificationSingle -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +We also need import the neural network utils withing `torch.nn`. +""" # %% -# We will also need import the neural network utils withing `torch.nn`. -# -# We also import sys to override the command line arguments. - -import os -import sys from pathlib import Path +from lightning import Trainer from torch import nn +from torch_uncertainty.datamodules import MNISTDataModule +from torch_uncertainty.models.lenet import lenet +from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from torch_uncertainty.post_processing.mc_batch_norm import MCBatchNorm +from torch_uncertainty.routines import ClassificationRoutine + # %% # 2. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# -# In the following, we will need to define the root of the datasets and the -# logs, and to fake-parse the arguments needed for using the PyTorch Lightning -# Trainer. We also create the datamodule that handles the MNIST dataset, -# dataloaders and transforms. We create the model using the -# blueprint from torch_uncertainty.models and we wrap it into mc-dropout. -# -# It is important to specify the arguments ``version`` as ``mc-dropout``, -# ``num_estimators`` and the ``dropout_rate`` to use Monte Carlo dropout. - -root = Path(os.path.abspath("")) - -# We mock the arguments for the trainer -sys.argv = ["file.py", "--enable_progress_bar", "False", "--num_estimators", "8", "--max_epochs", "2"] -args = init_args(network=ResNet, datamodule=MNISTDataModule) +# In the following, we define the root of the datasets and the +# logs. We also create the datamodule that handles the MNIST dataset +# dataloaders and transforms. -net_name = "logs/lenet-mnist" +trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) +root = Path("") / "data" +datamodule = MNISTDataModule(root, batch_size=128) model = lenet( - in_channels=dm.num_channels, - num_classes=dm.num_classes, + in_channels=datamodule.num_channels, + num_classes=datamodule.num_classes, norm = nn.BatchNorm2d, ) # %% # 3. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# This is a classification problem, and we use CrossEntropyLoss as the likelihood. +# This is a classification problem, and we use CrossEntropyLoss as likelihood. # We define the training routine using the classification training routine from -# torch_uncertainty.training.classification. We provide the number of classes -# and channels, the optimizer wrapper, the dropout rate, and the number of -# forward passes to perform through the network, as well as all the default -# arguments. +# torch_uncertainty.training.classification. We provide the number of classes, +# and the optimization recipe. -baseline = ClassificationSingle( - num_classes=dm.num_classes, +routine = ClassificationRoutine( + num_classes=datamodule.num_classes, model=model, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_resnet18, - **vars(args), ) # %% -# 5. Gathering Everything and Training the Model +# 4. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -results = cli_main(baseline, dm, root, net_name, args) - +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% -# 6. Wrapping the Model in a MCBatchNorm +# 5. Wrapping the Model in a MCBatchNorm # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We can now wrap the model in a MCBatchNorm to add stochasticity to the # predictions. We specify that the BatchNorm layers are to be converted to @@ -109,12 +88,12 @@ # The authors suggest 32 as a good value for ``mc_batch_size`` but we use 4 here # to highlight the effect of stochasticity on the predictions. -baseline.model = MCBatchNorm(baseline.model, num_estimators=8, convert=True, mc_batch_size=32) -baseline.model.fit(dm.train) -baseline.eval() +routine.model = MCBatchNorm(routine.model, num_estimators=8, convert=True, mc_batch_size=4) +routine.model.fit(datamodule.train) +routine.eval() # %% -# 7. Testing the Model +# 6. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # Now that the model is trained, let's test it on MNIST. Don't forget to call # .eval() to enable Monte Carlo batch normalization at inference. @@ -122,27 +101,28 @@ # the variance of the predictions is the highest. import matplotlib.pyplot as plt +import numpy as np import torch import torchvision -import numpy as np - def imshow(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.axis("off") + plt.tight_layout() plt.show() -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) # print images imshow(torchvision.utils.make_grid(images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) -baseline.eval() -logits = baseline(images).reshape(8, 128, 10) +routine.eval() +logits = routine(images).reshape(8, 128, 10) probs = torch.nn.functional.softmax(logits, dim=-1) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 62915574..e3bf59c0 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -19,33 +19,27 @@ First, we have to load the following utilities from TorchUncertainty: -- the cli handler: cli_main and argument parser: init_args +- the Trainer from Lightning - the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule - the model: LeNet, which lies in torch_uncertainty.models - the mc-dropout wrapper: mc_dropout, which lies in torch_uncertainty.models - a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module - the optimizer wrapper in the torch_uncertainty.optimization_procedures module. + +We also need import the neural network utils within `torch.nn`. """ # %% -from torch_uncertainty import cli_main, init_args +from pathlib import Path + +from lightning.pytorch import Trainer +from torch import nn + from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.models.lenet import lenet from torch_uncertainty.models.mc_dropout import mc_dropout -from torch_uncertainty.baselines.classification import ResNet -from torch_uncertainty.routines.classification import ClassificationEnsemble from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 - -# %% -# We will also need import the neural network utils withing `torch.nn`. -# -# We also import sys to override the command line arguments. - -import os -import sys -from pathlib import Path - -from torch import nn +from torch_uncertainty.routines import ClassificationRoutine # %% # 2. Creating the necessary variables @@ -60,26 +54,20 @@ # It is important to specify the arguments ``version`` as ``mc-dropout``, # ``num_estimators`` and the ``dropout_rate`` to use Monte Carlo dropout. -root = Path(os.path.abspath("")) - -# We mock the arguments for the trainer -sys.argv = ["file.py", "--enable_progress_bar", "False", "--dropout_rate", "0.6", "--num_estimators", "16", "--max_epochs", "2"] -args = init_args(network=ResNet, datamodule=MNISTDataModule) - -net_name = "logs/mc-dropout-lenet-mnist" +trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -args.root = str(root / "data") -dm = MNISTDataModule(**vars(args)) +root = Path("") / "data" +datamodule = MNISTDataModule(root=root, batch_size=128) model = lenet( - in_channels=dm.num_channels, - num_classes=dm.num_classes, - dropout_rate=args.dropout_rate, + in_channels=datamodule.num_channels, + num_classes=datamodule.num_classes, + dropout_rate=0.6, ) -mc_model = mc_dropout(model, num_estimators=args.num_estimators, last_layer=0.0) +mc_model = mc_dropout(model, num_estimators=16, last_layer=False) # %% # 3. The Loss and the Training Routine @@ -91,48 +79,51 @@ # forward passes to perform through the network, as well as all the default # arguments. -baseline = ClassificationEnsemble( - num_classes=dm.num_classes, +routine = ClassificationRoutine( + num_classes=datamodule.num_classes, model=mc_model, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_resnet18, - **vars(args), + num_estimators=16, + ) # %% -# 5. Gathering Everything and Training the Model +# 4. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -results = cli_main(baseline, dm, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% -# 6. Testing the Model +# 5. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # Now that the model is trained, let's test it on MNIST. Don't forget to call # .eval() to enable dropout at inference. import matplotlib.pyplot as plt +import numpy as np import torch import torchvision -import numpy as np - def imshow(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.axis("off") + plt.tight_layout() plt.show() -dataiter = iter(dm.val_dataloader()) +dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) # print images imshow(torchvision.utils.make_grid(images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) -baseline.eval() -logits = baseline(images).reshape(16, 128, 10) +routine.eval() +logits = routine(images).reshape(16, 128, 10) probs = torch.nn.functional.softmax(logits, dim=-1) @@ -140,7 +131,7 @@ def imshow(img): for j in range(4): values, predicted = torch.max(probs[:, j], 1) print( - f"Predicted digits for the image {j}: ", + f"Predicted digits for the image {j+1}: ", " ".join([str(image_id.item()) for image_id in predicted]), ) From 856bed6484ba242f90a74a8f39494deccc526441 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 12:02:12 +0100 Subject: [PATCH 051/148] :white_check_mark: Improve regression coverage --- tests/_dummies/baseline.py | 12 ++++++++++-- tests/layers/test_distributions.py | 16 ++++++++++++++++ tests/routines/test_regression.py | 20 +++++++++++++------- 3 files changed, 39 insertions(+), 9 deletions(-) create mode 100644 tests/layers/test_distributions.py diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index f556bdd6..f52dfe87 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -3,7 +3,10 @@ from pytorch_lightning import LightningModule from torch import nn -from torch_uncertainty.layers.distributions import IndptNormalLayer +from torch_uncertainty.layers.distributions import ( + IndptLaplaceLayer, + IndptNormalLayer, +) from torch_uncertainty.models.deep_ensembles import deep_ensembles from torch_uncertainty.routines import ClassificationRoutine, RegressionRoutine from torch_uncertainty.transforms import RepeatTarget @@ -58,12 +61,17 @@ def __new__( loss: type[nn.Module], baseline_type: str = "single", optimization_procedure=None, + dist_type: str = "normal", ) -> LightningModule: model = dummy_model( in_channels=in_features, num_classes=num_outputs * 2 if probabilistic else num_outputs, num_estimators=1, - last_layer=IndptNormalLayer(num_outputs) + last_layer=( + IndptNormalLayer(num_outputs) + if dist_type == "normal" + else IndptLaplaceLayer(num_outputs) + ) if probabilistic else nn.Identity(), ) diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py new file mode 100644 index 00000000..46e5e0b4 --- /dev/null +++ b/tests/layers/test_distributions.py @@ -0,0 +1,16 @@ +import pytest + +from torch_uncertainty.layers.distributions import ( + IndptLaplaceLayer, + IndptNormalLayer, +) + + +class TestDistributions: + def test_errors(self): + with pytest.raises(ValueError): + IndptNormalLayer(-1, 1) + with pytest.raises(ValueError): + IndptNormalLayer(1, -1) + with pytest.raises(ValueError): + IndptLaplaceLayer(1, -1) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 613335b2..1ef3387d 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -29,8 +29,8 @@ def test_one_estimator_one_output(self): ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) - model(dm.get_test_set()[0][0]) model = DummyRegressionBaseline( @@ -43,7 +43,9 @@ def test_one_estimator_one_output(self): ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) + model(dm.get_test_set()[0][0]) def test_one_estimator_two_outputs(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) @@ -58,10 +60,11 @@ def test_one_estimator_two_outputs(self): loss=DistributionNLL, optimization_procedure=optim_cifar10_resnet18, baseline_type="single", + dist_type="laplace", ) - trainer.fit(model, dm) trainer.test(model, dm) + model(dm.get_test_set()[0][0]) model = DummyRegressionBaseline( probabilistic=False, @@ -71,9 +74,9 @@ def test_one_estimator_two_outputs(self): optimization_procedure=optim_cifar10_resnet18, baseline_type="single", ) - trainer.fit(model, dm) trainer.test(model, dm) + model(dm.get_test_set()[0][0]) def test_two_estimators_one_output(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) @@ -88,10 +91,11 @@ def test_two_estimators_one_output(self): loss=DistributionNLL, optimization_procedure=optim_cifar10_resnet18, baseline_type="ensemble", + dist_type="laplace", ) - trainer.fit(model, dm) trainer.test(model, dm) + model(dm.get_test_set()[0][0]) model = DummyRegressionBaseline( probabilistic=False, @@ -101,9 +105,9 @@ def test_two_estimators_one_output(self): optimization_procedure=optim_cifar10_resnet18, baseline_type="ensemble", ) - trainer.fit(model, dm) trainer.test(model, dm) + model(dm.get_test_set()[0][0]) def test_two_estimators_two_outputs(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) @@ -119,9 +123,10 @@ def test_two_estimators_two_outputs(self): optimization_procedure=optim_cifar10_resnet18, baseline_type="ensemble", ) - trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) + model(dm.get_test_set()[0][0]) model = DummyRegressionBaseline( probabilistic=False, @@ -131,9 +136,10 @@ def test_two_estimators_two_outputs(self): optimization_procedure=optim_cifar10_resnet18, baseline_type="ensemble", ) - trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) + model(dm.get_test_set()[0][0]) def test_regression_failures(self): with pytest.raises(ValueError): From 466e6e1cdb287395503d4b6c2f46208433ecae40 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 12:11:05 +0100 Subject: [PATCH 052/148] :fire: Remove cli_test_helpers --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 80eb453f..8baf0296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ dev = [ "pytest-cov", "pre-commit", "pre-commit-hooks", - "cli-test-helpers", ] docs = [ "sphinx<7", From 5ddf1b12f7a6693716916c2eb5071fb748974806 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 12:27:46 +0100 Subject: [PATCH 053/148] :white_check_mark: First classification tests --- tests/_dummies/baseline.py | 3 + tests/_dummies/datamodule.py | 16 +---- tests/routines/test_classification.py | 88 ++++++++++++++++++++++++--- 3 files changed, 84 insertions(+), 23 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index f52dfe87..fac8e15f 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -21,6 +21,7 @@ def __new__( in_channels: int, loss: type[nn.Module], baseline_type: str = "single", + optimization_procedure=None, with_feats: bool = True, with_linear: bool = True, ) -> LightningModule: @@ -39,6 +40,7 @@ def __new__( loss=loss, format_batch_fn=nn.Identity(), log_plots=True, + optimization_procedure=optimization_procedure, num_estimators=1, ) # baseline_type == "ensemble": @@ -46,6 +48,7 @@ def __new__( num_classes=num_classes, model=model, loss=loss, + optimization_procedure=optimization_procedure, format_batch_fn=RepeatTarget(2), log_plots=True, num_estimators=2, diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 040f2e35..0b81ccc0 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -1,6 +1,4 @@ -from argparse import ArgumentParser from pathlib import Path -from typing import Any import numpy as np import torchvision.transforms as T @@ -20,17 +18,17 @@ class DummyClassificationDataModule(AbstractDataModule): def __init__( self, root: str | Path, - eval_ood: bool, batch_size: int, num_classes: int = 2, num_workers: int = 1, + eval_ood: bool = False, pin_memory: bool = True, persistent_workers: bool = True, num_images: int = 2, - **kwargs, ) -> None: super().__init__( root=root, + val_split=None, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -98,16 +96,6 @@ def _get_train_data(self) -> ArrayLike: def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) - @classmethod - def add_argparse_args( - cls, - parent_parser: ArgumentParser, - **kwargs: Any, - ) -> ArgumentParser: - p = super().add_argparse_args(parent_parser) - p.add_argument("--eval-ood", action="store_true") - return parent_parser - class DummyRegressionDataModule(AbstractDataModule): in_features = 4 diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index a62f1fdf..f200ceba 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -1,3 +1,82 @@ +from pathlib import Path + +from lightning import Trainer +from torch import nn + +from tests._dummies import ( + DummyClassificationBaseline, + DummyClassificationDataModule, +) +from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 + + +class TestClassificationSingle: + """Testing the classification routine with a single model.""" + + def test_one_estimator_binary(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=1, + num_images=100, + ) + model = DummyClassificationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + loss=nn.BCEWithLogitsLoss, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + +# def test_two_estimators_binary(self): +# trainer = Trainer(accelerator="cpu", fast_dev_run=True) + +# dm = DummyClassificationDataModule( +# root=Path(""), batch_size=16, num_classes=1, num_images=100, +# ) +# model = DummyClassificationBaseline( +# in_channels=dm.num_channels, +# num_classes=dm.num_classes, +# loss=nn.BCEWithLogitsLoss, +# optimization_procedure=optim_cifar10_resnet18, +# baseline_type="ensemble", +# ) + +# trainer.fit(model, dm) +# trainer.validate(model, dm) +# trainer.test(model, dm) +# model(dm.get_test_set()[0][0]) + # from functools import partial # from pathlib import Path @@ -13,15 +92,6 @@ # ) # from torch_uncertainty import cli_main, init_args # from torch_uncertainty.losses import DECLoss, ELBOLoss -# from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 -# from torch_uncertainty.routines.classification import ( -# ClassificationEnsemble, -# ClassificationSingle, -# ) - - -# class TestClassificationSingle: -# """Testing the classification routine with a single model.""" # def test_cli_main_dummy_binary(self): # root = Path(__file__).parent.absolute().parents[0] From a801821bf22db01b9ee39b3724c218fd4c8cb0a8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 14:36:11 +0100 Subject: [PATCH 054/148] :white_check_mark: Improve cls coverage --- tests/_dummies/baseline.py | 11 +- tests/routines/test_classification.py | 279 ++++++------------- torch_uncertainty/routines/classification.py | 4 +- 3 files changed, 90 insertions(+), 204 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index fac8e15f..f2fce4b0 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -20,20 +20,21 @@ def __new__( num_classes: int, in_channels: int, loss: type[nn.Module], - baseline_type: str = "single", + ensemble=False, optimization_procedure=None, with_feats: bool = True, with_linear: bool = True, + ood_criterion: str = "msp", ) -> LightningModule: model = dummy_model( in_channels=in_channels, num_classes=num_classes, - num_estimators=1 + int(baseline_type == "ensemble"), + num_estimators=1 + int(ensemble), with_feats=with_feats, with_linear=with_linear, ) - if baseline_type == "single": + if not ensemble: return ClassificationRoutine( num_classes=num_classes, model=model, @@ -42,8 +43,9 @@ def __new__( log_plots=True, optimization_procedure=optimization_procedure, num_estimators=1, + ood_criterion=ood_criterion, ) - # baseline_type == "ensemble": + # ensemble return ClassificationRoutine( num_classes=num_classes, model=model, @@ -52,6 +54,7 @@ def __new__( format_batch_fn=RepeatTarget(2), log_plots=True, num_estimators=2, + ood_criterion=ood_criterion, ) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index f200ceba..849ad938 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -1,13 +1,16 @@ from pathlib import Path +import pytest from lightning import Trainer from torch import nn from tests._dummies import ( DummyClassificationBaseline, DummyClassificationDataModule, + dummy_model, ) from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from torch_uncertainty.routines import ClassificationRoutine class TestClassificationSingle: @@ -27,7 +30,31 @@ def test_one_estimator_binary(self): num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss, optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", + ensemble=False, + ood_criterion="msp", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_binary(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=1, + num_images=100, + ) + model = DummyClassificationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + loss=nn.BCEWithLogitsLoss, + optimization_procedure=optim_cifar10_resnet18, + ensemble=True, + ood_criterion="logit", ) trainer.fit(model, dm) @@ -38,6 +65,30 @@ def test_one_estimator_binary(self): def test_one_estimator_two_classes(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + ensemble=False, + ood_criterion="entropy", + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_two_classes(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + dm = DummyClassificationDataModule( root=Path(), batch_size=16, @@ -49,7 +100,8 @@ def test_one_estimator_two_classes(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", + ensemble=True, + ood_criterion="energy", ) trainer.fit(model, dm) @@ -57,25 +109,31 @@ def test_one_estimator_two_classes(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + def test_classification_failures(self): + with pytest.raises(ValueError): + ClassificationRoutine(10, nn.Module(), None, ood_criterion="other") -# def test_two_estimators_binary(self): -# trainer = Trainer(accelerator="cpu", fast_dev_run=True) + with pytest.raises(ValueError): + ClassificationRoutine(10, nn.Module(), None, cutmix_alpha=-1) -# dm = DummyClassificationDataModule( -# root=Path(""), batch_size=16, num_classes=1, num_images=100, -# ) -# model = DummyClassificationBaseline( -# in_channels=dm.num_channels, -# num_classes=dm.num_classes, -# loss=nn.BCEWithLogitsLoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="ensemble", -# ) + with pytest.raises(ValueError): + ClassificationRoutine( + 10, nn.Module(), None, eval_grouping_loss=True + ) + + with pytest.raises(NotImplementedError): + ClassificationRoutine( + 10, nn.Module(), None, 2, eval_grouping_loss=True + ) + + model = dummy_model(1, 1, 1, 0, with_feats=False, with_linear=True) + with pytest.raises(ValueError): + ClassificationRoutine(10, model, None, eval_grouping_loss=True) + + model = dummy_model(1, 1, 1, 0, with_feats=True, with_linear=False) + with pytest.raises(ValueError): + ClassificationRoutine(10, model, None, eval_grouping_loss=True) -# trainer.fit(model, dm) -# trainer.validate(model, dm) -# trainer.test(model, dm) -# model(dm.get_test_set()[0][0]) # from functools import partial # from pathlib import Path @@ -93,88 +151,6 @@ def test_one_estimator_two_classes(self): # from torch_uncertainty import cli_main, init_args # from torch_uncertainty.losses import DECLoss, ELBOLoss -# def test_cli_main_dummy_binary(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# args.root = str(root / "data") -# args.eval_grouping_loss = True -# dm = DummyClassificationDataModule( -# num_classes=1, num_images=100, **vars(args) -# ) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=nn.BCEWithLogitsLoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# **vars(args), -# ) -# cli_main(model, dm, root, "logs/dummy", args) - -# with ArgvContext("file.py", "--logits"): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=nn.BCEWithLogitsLoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# **vars(args), -# ) -# cli_main(model, dm, root, "logs/dummy", args) - -# def test_cli_main_dummy_ood(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py", "--fast_dev_run"): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(**vars(args)) -# loss = partial( -# ELBOLoss, -# criterion=nn.CrossEntropyLoss(), -# kl_weight=1e-5, -# num_samples=2, -# ) -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=loss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# **vars(args), -# ) -# cli_main(model, dm, root, "logs/dummy", args) - -# with ArgvContext( -# "file.py", -# "--eval-ood", -# "--entropy", -# ): -# args = init_args( -# DummyClassificationBaseline, DummyClassificationDataModule -# ) - -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(**vars(args)) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=DECLoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", -# **vars(args), -# ) -# cli_main(model, dm, root, "logs/dummy", args) - # with ArgvContext( # "file.py", # "--eval-ood", @@ -196,7 +172,7 @@ def test_one_estimator_two_classes(self): # in_channels=dm.num_channels, # loss=DECLoss, # optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", +# ensemble=False, # **vars(args), # ) # with pytest.raises(NotImplementedError): @@ -242,7 +218,7 @@ def test_one_estimator_two_classes(self): # in_channels=list_dm[i].dm.num_channels, # loss=nn.CrossEntropyLoss, # optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", +# ensemble=False, # calibration_set=dm.get_val_set, # **vars(args), # ) @@ -291,67 +267,13 @@ def test_one_estimator_two_classes(self): # in_channels=list_dm[i].dm.num_channels, # loss=nn.CrossEntropyLoss, # optimization_procedure=optim_cifar10_resnet18, -# baseline_type="single", +# ensemble=False, # calibration_set=dm.get_val_set, # **vars(args), # ) # ) # cli_main(list_model, list_dm, root, "logs/dummy", args) - -# def test_classification_failures(self): -# with pytest.raises(ValueError): -# ClassificationRoutine( -# 10, nn.Module(), None, None, use_entropy=True, use_logits=True -# ) - -# with pytest.raises(ValueError): -# ClassificationRoutine(10, nn.Module(), None, None, cutmix_alpha=-1) - -# with pytest.raises(ValueError): -# ClassificationSingle( -# 10, nn.Module(), None, None, eval_grouping_loss=True -# ) - -# model = dummy_model(1, 1, 1, 0, with_feats=False, with_linear=True) - -# with pytest.raises(ValueError): -# ClassificationSingle(10, model, None, None, eval_grouping_loss=True) - -# model = dummy_model(1, 1, 1, 0, with_feats=True, with_linear=False) - -# with pytest.raises(ValueError): -# ClassificationSingle(10, model, None, None, eval_grouping_loss=True) - - -# class TestClassificationEnsemble: -# """Testing the classification routine with an ensemble model.""" - -# def test_cli_main_dummy_binary(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(num_classes=1, **vars(args)) -# loss = partial( -# ELBOLoss, -# criterion=nn.CrossEntropyLoss(), -# kl_weight=1e-5, -# num_samples=1, -# ) -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=loss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="ensemble", -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy", args) - # with ArgvContext("file.py", "--mutual_information"): # args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) @@ -364,51 +286,12 @@ def test_one_estimator_two_classes(self): # in_channels=dm.num_channels, # loss=nn.BCEWithLogitsLoss, # optimization_procedure=optim_cifar10_resnet18, -# baseline_type="ensemble", +# ensemble=True, # **vars(args), # ) # cli_main(model, dm, root, "logs/dummy", args) -# def test_cli_main_dummy_ood(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py", "--logits"): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(**vars(args)) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="ensemble", -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy", args) - -# with ArgvContext("file.py", "--eval-ood", "--entropy"): -# args = init_args( -# DummyClassificationBaseline, DummyClassificationDataModule -# ) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(**vars(args)) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=DECLoss, -# optimization_procedure=optim_cifar10_resnet18, -# baseline_type="ensemble", -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy", args) # with ArgvContext("file.py", "--eval-ood", "--variation_ratio"): # args = init_args( @@ -424,7 +307,7 @@ def test_one_estimator_two_classes(self): # in_channels=dm.num_channels, # loss=nn.CrossEntropyLoss, # optimization_procedure=optim_cifar10_resnet18, -# baseline_type="ensemble", +# ensemble=True, # **vars(args), # ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index f4ff4655..eaaa6da3 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -132,7 +132,7 @@ def __init__( " model." ) - if num_estimators == 1 and eval_grouping_loss: + if num_estimators != 1 and eval_grouping_loss: raise NotImplementedError( "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." ) @@ -443,7 +443,7 @@ def test_step( logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators) if self.binary_cls: - probs_per_est = torch.sigmoid(logits).squeeze(-1) + probs_per_est = torch.sigmoid(logits) else: probs_per_est = F.softmax(logits, dim=-1) From ad857ccc5bf5bb39f624704179bf1a1d0cfe0532 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 14:55:06 +0100 Subject: [PATCH 055/148] :white_check_mark: slightly improve cls. cov. --- tests/_dummies/baseline.py | 6 ++++++ tests/routines/test_classification.py | 10 ++++++++++ torch_uncertainty/routines/classification.py | 2 +- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index f2fce4b0..b7172881 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -25,6 +25,8 @@ def __new__( with_feats: bool = True, with_linear: bool = True, ood_criterion: str = "msp", + eval_ood: bool = False, + eval_grouping_loss: bool = False, ) -> LightningModule: model = dummy_model( in_channels=in_channels, @@ -44,6 +46,8 @@ def __new__( optimization_procedure=optimization_procedure, num_estimators=1, ood_criterion=ood_criterion, + eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, ) # ensemble return ClassificationRoutine( @@ -55,6 +59,8 @@ def __new__( log_plots=True, num_estimators=2, ood_criterion=ood_criterion, + eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, ) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 849ad938..685b54eb 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -79,6 +79,7 @@ def test_one_estimator_two_classes(self): optimization_procedure=optim_cifar10_resnet18, ensemble=False, ood_criterion="entropy", + eval_ood=True, ) trainer.fit(model, dm) @@ -110,6 +111,15 @@ def test_two_estimators_two_classes(self): model(dm.get_test_set()[0][0]) def test_classification_failures(self): + # num_estimators + with pytest.raises(ValueError): + ClassificationRoutine(10, nn.Module(), None, num_estimators=-1) + # single & MI + with pytest.raises(ValueError): + ClassificationRoutine( + 10, nn.Module(), None, num_estimators=1, ood_criterion="mi" + ) + with pytest.raises(ValueError): ClassificationRoutine(10, nn.Module(), None, ood_criterion="other") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index eaaa6da3..6544d3a5 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -107,7 +107,7 @@ def __init__( if format_batch_fn is None: format_batch_fn = nn.Identity() - if not isinstance(num_estimators, int) and num_estimators < 1: + if not isinstance(num_estimators, int) or num_estimators < 1: raise ValueError( "The number of estimators must be a positive integer >= 1." f"Got {num_estimators}." From 3144fc4bbb4084fa4d40d97a92e0b40f9e9b4b46 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 15:10:14 +0100 Subject: [PATCH 056/148] :heavy_check_mark: Fix deep ensembles test --- tests/baselines/test_deep_ensembles.py | 17 ++++++++++------- .../baselines/classification/deep_ensembles.py | 1 + 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index bd989cdc..e559fe58 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -1,3 +1,5 @@ +import pytest + from torch_uncertainty.baselines.classification.deep_ensembles import ( DeepEnsembles, ) @@ -6,10 +8,11 @@ class TestDeepEnsembles: """Testing the Deep Ensembles baseline class.""" - def test_standard(self): - DeepEnsembles( - log_path=".", - checkpoint_ids=[], - backbone="resnet", - num_classes=10, - ) + def test_failure(self): + with pytest.raises(ValueError): + DeepEnsembles( + log_path=".", + checkpoint_ids=[], + backbone="resnet", + num_classes=10, + ) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index 4ffd8405..727cd696 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -54,6 +54,7 @@ def __init__( loss=None, num_estimators=de.num_estimators, eval_ood=eval_ood, + eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, log_plots=log_plots, calibration_set=calibration_set, From 9d3aff0a813642f53e2a49cecbfba594001b1441 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 18 Mar 2024 16:31:26 +0100 Subject: [PATCH 057/148] :bug: Fix fast_dev_run raising error --- experiments/classification/cifar10/resnet.py | 7 +++++-- experiments/classification/cifar10/vgg.py | 6 +++++- experiments/classification/cifar10/wideresnet.py | 6 +++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index d8cdb179..5610bcbe 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -20,6 +20,9 @@ def cli_main() -> ResNetCLI: if __name__ == "__main__": torch.set_float32_matmul_precision("medium") cli = cli_main() - - if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar10/vgg.py b/experiments/classification/cifar10/vgg.py index 606e3b36..90af58c4 100644 --- a/experiments/classification/cifar10/vgg.py +++ b/experiments/classification/cifar10/vgg.py @@ -20,5 +20,9 @@ def cli_main() -> ResNetCLI: if __name__ == "__main__": torch.set_float32_matmul_precision("medium") cli = cli_main() - if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar10/wideresnet.py b/experiments/classification/cifar10/wideresnet.py index 0573ee48..fdb47fa1 100644 --- a/experiments/classification/cifar10/wideresnet.py +++ b/experiments/classification/cifar10/wideresnet.py @@ -20,5 +20,9 @@ def cli_main() -> ResNetCLI: if __name__ == "__main__": torch.set_float32_matmul_precision("medium") cli = cli_main() - if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") From c35ddc50ef79ccdea09ba8f1ad8186c3afcb5125 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 18 Mar 2024 16:34:31 +0100 Subject: [PATCH 058/148] :hammer: Update Segmentation routine - :bug: CamVid setup bug --- .../camvid/configs/segformer.yaml | 7 ++-- experiments/segmentation/camvid/segformer.py | 7 ++-- .../baselines/segmentation/segformer.py | 16 +++++++-- .../datamodules/segmentation/camvid.py | 17 ++++----- .../datasets/segmentation/camvid.py | 36 +++++++++---------- torch_uncertainty/routines/segmentation.py | 29 ++++++++++----- 6 files changed, 68 insertions(+), 44 deletions(-) diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index 16767e87..9146c141 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -1,12 +1,11 @@ # lightning.pytorch==2.1.3 eval_after_fit: true -seed_everything: true +seed_everything: false trainer: accelerator: gpu devices: 1 - precision: bf16-mixed model: - num_classes: 13 + num_classes: 12 loss: torch.nn.CrossEntropyLoss version: std arch: 0 @@ -19,5 +18,5 @@ optimizer: lr: 0.01 lr_scheduler: milestones: - - 30 + - 30 gamma: 0.1 diff --git a/experiments/segmentation/camvid/segformer.py b/experiments/segmentation/camvid/segformer.py index b38f0d32..90369f08 100644 --- a/experiments/segmentation/camvid/segformer.py +++ b/experiments/segmentation/camvid/segformer.py @@ -20,6 +20,9 @@ def cli_main() -> SegFormerCLI: if __name__ == "__main__": torch.set_float32_matmul_precision("medium") cli = cli_main() - - if cli.subcommand == "fit" and cli._get(cli.config, "eval_after_fit"): + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 4c95bdaf..fe32a034 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -1,5 +1,6 @@ from typing import Literal +from einops import rearrange from torch import Tensor, nn from torchvision.transforms.v2 import functional as F @@ -87,7 +88,10 @@ def training_step( target = F.resize( target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST ) - loss = self.criterion(logits, target) + logits = rearrange(logits, "b c h w -> (b h w) c") + target = target.flatten() + valid_mask = target != 255 + loss = self.criterion(logits[valid_mask], target[valid_mask]) self.log("train_loss", loss) return loss @@ -99,7 +103,10 @@ def validation_step( target = F.resize( target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST ) - self.val_seg_metrics.update(logits, target) + logits = rearrange(logits, "b c h w -> (b h w) c") + target = target.flatten() + valid_mask = target != 255 + self.val_seg_metrics.update(logits[valid_mask], target[valid_mask]) def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: img, target = batch @@ -107,4 +114,7 @@ def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: target = F.resize( target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST ) - self.test_seg_metrics.update(logits, target) + logits = rearrange(logits, "b c h w -> (b h w) c") + target = target.flatten() + valid_mask = target != 255 + self.test_seg_metrics.update(logits[valid_mask], target[valid_mask]) diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index bfcd4bb8..83c36483 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -13,7 +13,7 @@ def __init__( self, root: str | Path, batch_size: int, - val_split: float = 0.0, + val_split: float = 0.0, # FIXME: not used for now num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, @@ -28,7 +28,7 @@ def __init__( ) self.dataset = CamVid - self.transform_train = v2.Compose( + self.train_transform = v2.Compose( [ v2.Resize( (360, 480), interpolation=v2.InterpolationMode.NEAREST @@ -43,7 +43,7 @@ def __init__( ), ] ) - self.transform_test = v2.Compose( + self.test_transform = v2.Compose( [ v2.Resize( (360, 480), interpolation=v2.InterpolationMode.NEAREST @@ -68,20 +68,21 @@ def setup(self, stage: str | None = None) -> None: root=self.root, split="train", download=False, - transform=self.transform_train, + transforms=self.train_transform, ) self.val = self.dataset( root=self.root, split="val", download=False, - transform=self.transform_test, + transforms=self.test_transform, ) - elif stage == "test": + if stage == "test" or stage is None: self.test = self.dataset( root=self.root, split="test", download=False, - transform=self.transform_test, + transforms=self.test_transform, ) - else: + + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 211f0ec5..e017fd1d 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -64,7 +64,7 @@ def __init__( self, root: str, split: Literal["train", "val", "test"] | None = None, - transform: Callable | None = None, + transforms: Callable | None = None, download: bool = False, ) -> None: """`CamVid `_ Dataset. @@ -74,9 +74,9 @@ def __init__( will be saved to if download is set to ``True``. split (str, optional): The dataset split, supports ``train``, ``val`` and ``test``. Default: ``None``. - transform (callable, optional): A function/transform that takes + transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed - version. + version. Default: ``None``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. @@ -87,7 +87,7 @@ def __init__( "Supported splits are ['train', 'val', 'test', None]" ) - super().__init__(root, transform, None, None) + super().__init__(root, transforms, None, None) if download: self.download() @@ -129,14 +129,14 @@ def __init__( ] ) - self.transform = transform + # self.transforms = transforms self.split = split if split is not None else "all" def encode_target(self, target: Image.Image) -> torch.Tensor: """Encode target image to tensor. Args: - target (Image.Image): Target image. + target (Image.Image): Target PIL image. Returns: torch.Tensor: Encoded target. @@ -146,12 +146,13 @@ def encode_target(self, target: Image.Image) -> torch.Tensor: target = torch.zeros_like(colored_target[..., :1]) # convert target color to index for camvid_class in self.classes: + index = camvid_class.index if camvid_class.index != 12 else 255 target[ ( colored_target == torch.tensor(camvid_class.color, dtype=target.dtype) ).all(dim=-1) - ] = camvid_class.index + ] = index return rearrange(target, "h w c -> c h w").squeeze(0) @@ -162,7 +163,7 @@ def decode_target(self, target: torch.Tensor) -> Image.Image: target (torch.Tensor): Target tensor. Returns: - Image.Image: Decoded target. + Image.Image: Decoded target as a PIL.Image. """ colored_target = repeat(target.clone(), "h w -> h w 3", c=3) @@ -176,22 +177,24 @@ def decode_target(self, target: torch.Tensor) -> Image.Image: return F.to_pil_image(rearrange(colored_target, "h w c -> c h w")) - def __getitem__(self, index: int) -> tuple: - """Get image and target at index. + def __getitem__( + self, index: int + ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: + """Get the image and target at the given index. Args: - index (int): Index + index (int): Sample index. Returns: - tuple: (image, target) where target is the segmentation mask. + tuple[tv_tensors.Image, tv_tensors.Mask]: Image and target. """ image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) target = tv_tensors.Mask( self.encode_target(Image.open(self.targets[index])) ) - if self.transform is not None: - image, target = self.transform(image, target) + if self.transforms is not None: + image, target = self.transforms(image, target) return image, target @@ -247,8 +250,3 @@ def download(self) -> None: filename="splits.json", md5=self.splits_md5, ) - - -if __name__ == "__main__": - dataset = CamVid("data", split=None, download=True) - print(dataset) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 2d060fa4..d2dd126d 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -1,7 +1,8 @@ +from einops import rearrange from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, nn -from torchmetrics import MetricCollection +from torchmetrics import Accuracy, MetricCollection from torch_uncertainty.metrics import MeanIntersectionOverUnion @@ -29,8 +30,10 @@ def __init__( # metrics seg_metrics = MetricCollection( { + "acc": Accuracy(task="multiclass", num_classes=num_classes), "mean_iou": MeanIntersectionOverUnion(num_classes=num_classes), - } + }, + compute_groups=[["acc", "mean_iou"]], ) self.val_seg_metrics = seg_metrics.clone(prefix="val/") @@ -53,22 +56,32 @@ def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: img, target = batch - pred = self.forward(img) - loss = self.criterion(pred, target) + logits = self.forward(img) + logits = rearrange(logits, "b c h w -> (b h w) c") + target = target.flatten() + valid_mask = target != 255 + loss = self.criterion(logits[valid_mask], target[valid_mask]) self.log("train_loss", loss) return loss def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - img, targets = batch + img, target = batch + # (B, num_classes, H, W) logits = self.forward(img) - self.val_seg_metrics.update(logits, targets) + logits = rearrange(logits, "b c h w -> (b h w) c") + target = target.flatten() + valid_mask = target != 255 + self.val_seg_metrics.update(logits[valid_mask], target[valid_mask]) def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: img, target = batch - pred = self.forward(img) - self.test_seg_metrics.update(pred, target) + logits = self.forward(img) + logits = rearrange(logits, "b c h w -> (b h w) c") + target = target.flatten() + valid_mask = target != 255 + self.test_seg_metrics.update(logits[valid_mask], target[valid_mask]) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_seg_metrics.compute()) From b99d0b8a56d200be5ae03c35257c48dd6275bc42 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 18 Mar 2024 16:39:12 +0100 Subject: [PATCH 059/148] :sparkles: Add Cityscapes support - SegFormer B0 fully running on Cityscapes --- .../cityscapes/configs/segformer.yaml | 26 ++++ .../segmentation/cityscapes/segformer.py | 28 ++++ torch_uncertainty/datamodules/__init__.py | 2 +- .../datamodules/segmentation/__init__.py | 1 + .../datamodules/segmentation/cityscapes.py | 130 ++++++++++++++++++ .../datasets/segmentation/__init__.py | 1 + .../datasets/segmentation/cityscapes.py | 76 ++++++++++ torch_uncertainty/transforms/__init__.py | 1 + torch_uncertainty/transforms/image.py | 82 ++++++++++- 9 files changed, 345 insertions(+), 2 deletions(-) create mode 100644 experiments/segmentation/cityscapes/configs/segformer.yaml create mode 100644 experiments/segmentation/cityscapes/segformer.py create mode 100644 torch_uncertainty/datamodules/segmentation/cityscapes.py create mode 100644 torch_uncertainty/datasets/segmentation/cityscapes.py diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml new file mode 100644 index 00000000..cce57b1f --- /dev/null +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -0,0 +1,26 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + max_steps: 160000 +model: + num_classes: 19 + loss: torch.nn.CrossEntropyLoss + version: std + arch: 0 + num_estimators: 1 +data: + root: ./data + batch_size: 8 + crop_size: 1024 + inference_size: + - 1024 + - 2048 + num_workers: 30 +optimizer: + lr: 6e-5 +lr_scheduler: + step_size: 10000 + gamma: 0.1 diff --git a/experiments/segmentation/cityscapes/segformer.py b/experiments/segmentation/cityscapes/segformer.py new file mode 100644 index 00000000..ecbb12f2 --- /dev/null +++ b/experiments/segmentation/cityscapes/segformer.py @@ -0,0 +1,28 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser +from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 + +from torch_uncertainty.baselines.segmentation import SegFormer +from torch_uncertainty.datamodules import CityscapesDataModule +from torch_uncertainty.utils import TULightningCLI + + +class SegFormerCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.AdamW) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.StepLR) + + +def cli_main() -> SegFormerCLI: + return SegFormerCLI(SegFormer, CityscapesDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index 4c236ef5..24859701 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -4,5 +4,5 @@ from .classification.imagenet import ImageNetDataModule from .classification.mnist import MNISTDataModule from .classification.tiny_imagenet import TinyImageNetDataModule -from .segmentation import CamVidDataModule +from .segmentation import CamVidDataModule, CityscapesDataModule from .uci_regression import UCIDataModule diff --git a/torch_uncertainty/datamodules/segmentation/__init__.py b/torch_uncertainty/datamodules/segmentation/__init__.py index 008b7252..7c0d0a8c 100644 --- a/torch_uncertainty/datamodules/segmentation/__init__.py +++ b/torch_uncertainty/datamodules/segmentation/__init__.py @@ -1,2 +1,3 @@ # ruff: noqa: F401 from .camvid import CamVidDataModule +from .cityscapes import CityscapesDataModule diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py new file mode 100644 index 00000000..2ddde6fe --- /dev/null +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -0,0 +1,130 @@ +import copy +from pathlib import Path + +import torch +from torch.nn.common_types import _size_2_t +from torch.nn.modules.utils import _pair +from torch.utils.data import random_split +from torchvision import datasets, tv_tensors +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets.segmentation import Cityscapes +from torch_uncertainty.transforms import RandomRescale + + +class CityscapesDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + crop_size: _size_2_t = 1024, + inference_size: _size_2_t = (1024, 2048), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + super().__init__( + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.dataset = Cityscapes + self.mode = "fine" + self.crop_size = _pair(crop_size) + self.inference_size = _pair(inference_size) + + self.train_transform = v2.Compose( + [ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=self.crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(size=self.inference_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset(root=self.root, split="train", mode=self.mode) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + full = datasets.wrap_dataset_for_transforms_v2( + self.dataset( + root=self.root, + split="train", + mode=self.mode, + target_type="semantic", + transforms=self.train_transform, + ) + ) + + if self.val_split is not None: + self.train, val = random_split( + full, + [ + 1 - self.val_split, + self.val_split, + ], + ) + # FIXME: memory cost issues might arise here + self.val = copy.deepcopy(val) + self.val.dataset.transforms = self.test_transform + else: + self.train = full + self.val = datasets.wrap_dataset_for_transforms_v2( + self.dataset( + root=self.root, + split="val", + mode=self.mode, + target_type="semantic", + transforms=self.test_transform, + ) + ) + + if stage == "test" or stage is None: + self.test = datasets.wrap_dataset_for_transforms_v2( + self.dataset( + root=self.root, + split="val", + mode=self.mode, + target_type="semantic", + transforms=self.test_transform, + ) + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datasets/segmentation/__init__.py b/torch_uncertainty/datasets/segmentation/__init__.py index 90f7bad4..11d4f9fd 100644 --- a/torch_uncertainty/datasets/segmentation/__init__.py +++ b/torch_uncertainty/datasets/segmentation/__init__.py @@ -1,2 +1,3 @@ # ruff: noqa: F401 from .camvid import CamVid +from .cityscapes import Cityscapes diff --git a/torch_uncertainty/datasets/segmentation/cityscapes.py b/torch_uncertainty/datasets/segmentation/cityscapes.py new file mode 100644 index 00000000..f4835fc2 --- /dev/null +++ b/torch_uncertainty/datasets/segmentation/cityscapes.py @@ -0,0 +1,76 @@ +from typing import Any + +import torch +from einops import rearrange +from PIL import Image +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +from torchvision.datasets import Cityscapes as OriginalCityscapes +from torchvision.transforms.v2 import functional as F + + +class Cityscapes(OriginalCityscapes): + def encode_target(self, target: Image.Image) -> Image.Image: + """Encode target image to tensor. + + Args: + target (Image.Image): Target PIL image. + + Returns: + torch.Tensor: Encoded target. + """ + colored_target = F.pil_to_tensor(target) + colored_target = rearrange(colored_target, "c h w -> h w c") + target = torch.zeros_like(colored_target[..., :1]) + # convert target color to index + for cityscapes_class in self.classes: + target[ + ( + colored_target + == torch.tensor(cityscapes_class.id, dtype=target.dtype) + ).all(dim=-1) + ] = cityscapes_class.train_id + + return F.to_pil_image(rearrange(target, "h w c -> c h w")) + + def __getitem__(self, index: int) -> tuple[Any, Any]: + """Args: + index (int): Index + Returns: + tuple: (image, target) where target is a tuple of all target types + if ``target_type`` is a list with more + than one item. Otherwise, target is a json object if + ``target_type="polygon"``, else the image segmentation. + """ + image = Image.open(self.images[index]).convert("RGB") + + targets: Any = [] + for i, t in enumerate(self.target_type): + if t == "polygon": + target = self._load_json(self.targets[index][i]) + elif t == "semantic": + target = self.encode_target(Image.open(self.targets[index][i])) + else: + target = Image.open(self.targets[index][i]) + + targets.append(target) + + target = tuple(targets) if len(targets) > 1 else targets[0] + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def plot_sample( + self, index: int, ax: _AX_TYPE | None = None + ) -> _PLOT_OUT_TYPE: + """Plot a sample from the dataset. + + Args: + index: The index of the sample to plot. + ax: Optional matplotlib axis to plot on. + + Returns: + The axis on which the sample was plotted. + """ + raise NotImplementedError("This method is not implemented yet.") diff --git a/torch_uncertainty/transforms/__init__.py b/torch_uncertainty/transforms/__init__.py index b2339f3c..d3aae6ec 100644 --- a/torch_uncertainty/transforms/__init__.py +++ b/torch_uncertainty/transforms/__init__.py @@ -8,6 +8,7 @@ Contrast, Equalize, Posterize, + RandomRescale, Rotate, Sharpen, Shear, diff --git a/torch_uncertainty/transforms/image.py b/torch_uncertainty/transforms/image.py index 09c72903..2e8707eb 100644 --- a/torch_uncertainty/transforms/image.py +++ b/torch_uncertainty/transforms/image.py @@ -1,7 +1,11 @@ +from typing import Any + import torch -import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as F from PIL import Image, ImageEnhance from torch import Tensor, nn +from torchvision.transforms.v2 import InterpolationMode, Transform +from torchvision.transforms.v2._utils import query_size class AutoContrast(nn.Module): @@ -242,3 +246,79 @@ def forward( if isinstance(img, Tensor): img: Image.Image = F.to_pil_image(img) return ImageEnhance.Color(img).enhance(level) + + +class RandomRescale(Transform): + """Randomly rescale the input. + + This transformation can be used together with ``RandomCrop`` as data augmentations to train + models on image segmentation task. + + Output spatial size is randomly sampled from the interval ``[min_size, max_size]``: + + .. code-block:: python + + scale = uniform_sample(min_scale, max_scale) + output_width = input_width * scale + output_height = input_height * scale + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + min_scale (int): Minimum scale for random sampling + max_scale (int): Maximum scale for random sampling + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + def __init__( + self, + min_scale: int, + max_scale: int, + interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, + antialias: bool | None = True, + ) -> None: + super().__init__() + self.min_scale = min_scale + self.max_scale = max_scale + self.interpolation = interpolation + self.antialias = antialias + + def _get_params(self, flat_inputs: list[Any]) -> dict[str, Any]: + height, width = query_size(flat_inputs) + scale = torch.rand(1) + scale = self.min_scale + scale * (self.max_scale - self.min_scale) + return {"size": (int(height * scale), int(width * scale))} + + def _transform(self, inpt: Any, params: dict[str, Any]) -> Any: + return self._call_kernel( + F.resize, + inpt, + params["size"], + interpolation=self.interpolation, + antialias=self.antialias, + ) From bd17ce8095b3b528cac6a2411b320182f40323bf Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 18 Mar 2024 16:43:02 +0100 Subject: [PATCH 060/148] :bug: Validation subset use test_transform instead of train_transform --- torch_uncertainty/datamodules/classification/cifar10.py | 8 ++++++-- torch_uncertainty/datamodules/classification/cifar100.py | 8 ++++++-- torch_uncertainty/datamodules/classification/imagenet.py | 5 ++++- torch_uncertainty/datamodules/classification/mnist.py | 8 ++++++-- .../datamodules/classification/tiny_imagenet.py | 8 ++++++-- torch_uncertainty/datamodules/segmentation/camvid.py | 2 +- 6 files changed, 29 insertions(+), 10 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index b6536e90..90afed91 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Literal @@ -26,7 +27,7 @@ def __init__( root: str | Path, batch_size: int, eval_ood: bool = False, - val_split: float = 0.0, + val_split: float | None = None, num_workers: int = 1, cutout: int | None = None, auto_augment: str | None = None, @@ -148,13 +149,16 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, self.val = random_split( + self.train, val = random_split( full, [ 1 - self.val_split, self.val_split, ], ) + # FIXME: memory cost issues might arise here + self.val = copy.deepcopy(val) + self.val.dataset.transform = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index a6b68e56..b2ea294b 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Literal @@ -27,7 +28,7 @@ def __init__( root: str | Path, batch_size: int, eval_ood: bool = False, - val_split: float = 0.0, + val_split: float | None = None, cutout: int | None = None, randaugment: bool = False, auto_augment: str | None = None, @@ -148,13 +149,16 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, self.val = random_split( + self.train, val = random_split( full, [ 1 - self.val_split, self.val_split, ], ) + # FIXME: memory cost issues might arise here + self.val = copy.deepcopy(val) + self.val.dataset.transform = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 20196b1a..aaebeccb 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -202,13 +202,16 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split and isinstance(self.val_split, float): - self.train, self.val = random_split( + self.train, val = random_split( full, [ 1 - self.val_split, self.val_split, ], ) + # FIXME: memory cost issues might arise here + self.val = copy.deepcopy(val) + self.val.dataset.transform = self.test_transform elif isinstance(self.val_split, Path): self.train = Subset(full, self.train_indices) # TODO: improve the performance diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index cd4997c3..0d9fe147 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Literal @@ -24,7 +25,7 @@ def __init__( batch_size: int, eval_ood: bool = False, ood_ds: Literal["fashion", "not"] = "fashion", - val_split: float = 0.0, + val_split: float | None = None, num_workers: int = 1, cutout: int | None = None, test_alt: Literal["c"] | None = None, @@ -112,13 +113,16 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, self.val = random_split( + self.train, val = random_split( full, [ 1 - self.val_split, self.val_split, ], ) + # FIXME: memory cost issues might arise here + self.val = copy.deepcopy(val) + self.val.dataset.transform = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index f99144fb..71334368 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Literal @@ -23,7 +24,7 @@ def __init__( root: str | Path, batch_size: int, eval_ood: bool = False, - val_split: float = 0.0, + val_split: float | None = None, ood_ds: str = "svhn", rand_augment_opt: str | None = None, num_workers: int = 1, @@ -127,13 +128,16 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, self.val = random_split( + self.train, val = random_split( full, [ 1 - self.val_split, self.val_split, ], ) + # FIXME: memory cost issues might arise here + self.val = copy.deepcopy(val) + self.val.dataset.transform = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index 83c36483..cf248669 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -13,7 +13,7 @@ def __init__( self, root: str | Path, batch_size: int, - val_split: float = 0.0, # FIXME: not used for now + val_split: float | None = None, # FIXME: not used for now num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, From ed636dd6b86e60f9d5c618824521155a41d5c8c4 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 18 Mar 2024 15:53:14 +0100 Subject: [PATCH 061/148] :hammer: Rename opt. proc to optim_recipe --- auto_tutorials_source/tutorial_bayesian.py | 2 +- .../tutorial_evidential_classification.py | 2 +- .../tutorial_mc_batch_norm.py | 6 ++--- auto_tutorials_source/tutorial_mc_dropout.py | 6 ++--- docs/source/quickstart.rst | 10 ++++---- experiments/classification/cifar100/resnet.py | 4 ++-- experiments/classification/cifar100/vgg.py | 6 ++--- .../classification/cifar100/wideresnet.py | 6 ++--- .../classification/mnist/bayesian_lenet.py | 4 ++-- experiments/classification/mnist/lenet.py | 4 ++-- .../classification/tiny-imagenet/resnet.py | 6 ++--- .../regression/uci_datasets/mlp-kin8nm.py | 2 +- tests/_dummies/baseline.py | 12 +++++----- tests/baselines/test_batched.py | 2 +- tests/baselines/test_masked.py | 2 +- tests/baselines/test_mc_dropout.py | 2 +- tests/baselines/test_mimo.py | 2 +- tests/routines/test_classification.py | 20 ++++++++-------- tests/routines/test_regression.py | 18 +++++++------- tests/test_cli.py | 24 +++++++++---------- tests/test_optimization_procedures.py | 2 +- .../classification/deep_ensembles.py | 2 +- .../baselines/classification/resnet.py | 2 +- .../baselines/classification/wideresnet.py | 2 +- ...ization_procedures.py => optim_recipes.py} | 18 ++++++-------- torch_uncertainty/routines/classification.py | 8 +++---- torch_uncertainty/routines/regression.py | 10 ++++---- 27 files changed, 88 insertions(+), 96 deletions(-) rename torch_uncertainty/{optimization_procedures.py => optim_recipes.py} (95%) diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index dee6c76b..81f9d06f 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -105,7 +105,7 @@ def optim_lenet(model: nn.Module) -> dict: model=model, num_classes=datamodule.num_classes, loss=loss, - optimization_procedure=optim_lenet, + optim_recipe=optim_lenet, ) # %% diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index 5a551da8..0f8a890d 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -90,7 +90,7 @@ def optim_lenet(model: nn.Module) -> dict: model=model, num_classes=datamodule.num_classes, loss=loss, - optimization_procedure=optim_lenet, + optim_recipe=optim_lenet, ) # %% diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index fec45261..16145683 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -19,7 +19,7 @@ - the mc-batch-norm wrapper: mc_dropout, which lies in torch_uncertainty.models - a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module -- the optimizer wrapper in the torch_uncertainty.optimization_procedures module. +- the optimizer wrapper in the torch_uncertainty.optim_recipes module. We also need import the neural network utils withing `torch.nn`. """ @@ -31,7 +31,7 @@ from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.post_processing.mc_batch_norm import MCBatchNorm from torch_uncertainty.routines import ClassificationRoutine @@ -67,7 +67,7 @@ num_classes=datamodule.num_classes, model=model, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, ) # %% diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index e3bf59c0..49e56b67 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -25,7 +25,7 @@ - the mc-dropout wrapper: mc_dropout, which lies in torch_uncertainty.models - a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module -- the optimizer wrapper in the torch_uncertainty.optimization_procedures module. +- the optimizer wrapper in the torch_uncertainty.optim_recipes module. We also need import the neural network utils within `torch.nn`. """ @@ -38,7 +38,7 @@ from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.models.lenet import lenet from torch_uncertainty.models.mc_dropout import mc_dropout -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine # %% @@ -83,7 +83,7 @@ num_classes=datamodule.num_classes, model=mc_model, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, num_estimators=16, ) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index eef966d5..2d3182ac 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -27,9 +27,9 @@ routine, which takes as arguments: sets with again its arguments and logic. CIFAR-10/100, ImageNet, and ImageNet-200 are available, for instance. * a PyTorch loss such as the torch.nn.CrossEntropyLoss -* a dictionary containing the optimization procedure, namely a scheduler and +* a dictionary containing the optimization recipe, namely a scheduler and an optimizer. Many procedures are available at - `torch_uncertainty/optimization_procedures.py `_ + `torch_uncertainty/optim_recipes.py `_ * the path to the data and logs folder, in the example below, the root of the library * and finally, the name of your model (used for logs) @@ -53,7 +53,7 @@ trains any ResNet architecture on CIFAR10: from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import CIFAR10DataModule - from torch_uncertainty.optimization_procedures import get_procedure + from torch_uncertainty.optim_recipes import get_procedure root = Path(__file__).parent.absolute().parents[1] @@ -70,7 +70,7 @@ trains any ResNet architecture on CIFAR10: num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optimization_procedure=get_procedure( + optim_recipe=get_procedure( f"resnet{args.arch}", "cifar10", args.version ), style="cifar", @@ -86,7 +86,7 @@ Run this model with, for instance: python3 resnet.py --version std --arch 18 --accelerator gpu --device 1 --benchmark True --max_epochs 75 --precision 16 You may replace the architecture (which should be a Lightning Module), the -Datamodule (a Lightning Datamodule), the loss or the optimization procedure to your likings. +Datamodule (a Lightning Datamodule), the loss or the optimization recipe to your likings. Using the PyTorch-based models ------------------------------ diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py index 19d9ea36..54fc0665 100644 --- a/experiments/classification/cifar100/resnet.py +++ b/experiments/classification/cifar100/resnet.py @@ -5,7 +5,7 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.optim_recipes import get_procedure if __name__ == "__main__": args = init_args(ResNet, CIFAR100DataModule) @@ -26,7 +26,7 @@ num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( + optim_recipe=get_procedure( f"resnet{args.arch}", "cifar100", args.version ), style="cifar", diff --git a/experiments/classification/cifar100/vgg.py b/experiments/classification/cifar100/vgg.py index f218f83c..78d18960 100644 --- a/experiments/classification/cifar100/vgg.py +++ b/experiments/classification/cifar100/vgg.py @@ -5,7 +5,7 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import VGG from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.optim_recipes import get_procedure if __name__ == "__main__": args = init_args(VGG, CIFAR100DataModule) @@ -26,9 +26,7 @@ num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"vgg{args.arch}", "cifar100", args.version - ), + optim_recipe=get_procedure(f"vgg{args.arch}", "cifar100", args.version), style="cifar", **vars(args), ) diff --git a/experiments/classification/cifar100/wideresnet.py b/experiments/classification/cifar100/wideresnet.py index 729f07ba..2dec2161 100644 --- a/experiments/classification/cifar100/wideresnet.py +++ b/experiments/classification/cifar100/wideresnet.py @@ -5,7 +5,7 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import WideResNet from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.optim_recipes import get_procedure if __name__ == "__main__": args = init_args(WideResNet, CIFAR100DataModule) @@ -26,9 +26,7 @@ num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - "wideresnet28x10", "cifar100", args.version - ), + optim_recipe=get_procedure("wideresnet28x10", "cifar100", args.version), style="cifar", **vars(args), ) diff --git a/experiments/classification/mnist/bayesian_lenet.py b/experiments/classification/mnist/bayesian_lenet.py index 4789407c..f1030963 100644 --- a/experiments/classification/mnist/bayesian_lenet.py +++ b/experiments/classification/mnist/bayesian_lenet.py @@ -11,7 +11,7 @@ def optim_lenet(model: nn.Module) -> dict: - """Optimization procedure for LeNet. + """Optimization recipe for LeNet. Uses Adam default hyperparameters. @@ -55,7 +55,7 @@ def optim_lenet(model: nn.Module) -> dict: num_classes=dm.num_classes, in_channels=dm.num_channels, loss=loss, - optimization_procedure=optim_lenet, + optim_recipe=optim_lenet, **vars(args), ) diff --git a/experiments/classification/mnist/lenet.py b/experiments/classification/mnist/lenet.py index 0514c892..dc5f9636 100644 --- a/experiments/classification/mnist/lenet.py +++ b/experiments/classification/mnist/lenet.py @@ -9,7 +9,7 @@ def optim_lenet(model: nn.Module) -> dict: - """Optimization procedure for LeNet. + """Optimization recipe for LeNet. Uses Adam default hyperparameters. @@ -45,7 +45,7 @@ def optim_lenet(model: nn.Module) -> dict: num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_lenet, + optim_recipe=optim_lenet, **vars(args), ) diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index 390223eb..c0f381a7 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -5,7 +5,7 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import TinyImageNetDataModule -from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.optim_recipes import get_procedure from torch_uncertainty.utils import csv_writer @@ -50,7 +50,7 @@ def optim_tiny(model: nn.Module) -> dict: num_classes=list_dm[i].dm.num_classes, in_channels=list_dm[i].dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( + optim_recipe=get_procedure( f"resnet{args.arch}", "tiny-imagenet", args.version ), style="cifar", @@ -69,7 +69,7 @@ def optim_tiny(model: nn.Module) -> dict: num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( + optim_recipe=get_procedure( f"resnet{args.arch}", "tiny-imagenet", args.version ), calibration_set=calibration_set, diff --git a/experiments/regression/uci_datasets/mlp-kin8nm.py b/experiments/regression/uci_datasets/mlp-kin8nm.py index d96979a6..ed796f94 100644 --- a/experiments/regression/uci_datasets/mlp-kin8nm.py +++ b/experiments/regression/uci_datasets/mlp-kin8nm.py @@ -40,7 +40,7 @@ def optim_regression( in_features=8, hidden_dims=[100], loss=nn.GaussianNLLLoss, - optimization_procedure=optim_regression, + optim_recipe=optim_regression, dist_estimation=2, **vars(args), ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index b7172881..d8097f1f 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -21,7 +21,7 @@ def __new__( in_channels: int, loss: type[nn.Module], ensemble=False, - optimization_procedure=None, + optim_recipe=None, with_feats: bool = True, with_linear: bool = True, ood_criterion: str = "msp", @@ -43,7 +43,7 @@ def __new__( loss=loss, format_batch_fn=nn.Identity(), log_plots=True, - optimization_procedure=optimization_procedure, + optim_recipe=optim_recipe, num_estimators=1, ood_criterion=ood_criterion, eval_ood=eval_ood, @@ -54,7 +54,7 @@ def __new__( num_classes=num_classes, model=model, loss=loss, - optimization_procedure=optimization_procedure, + optim_recipe=optim_recipe, format_batch_fn=RepeatTarget(2), log_plots=True, num_estimators=2, @@ -72,7 +72,7 @@ def __new__( num_outputs: int, loss: type[nn.Module], baseline_type: str = "single", - optimization_procedure=None, + optim_recipe=None, dist_type: str = "normal", ) -> LightningModule: model = dummy_model( @@ -94,7 +94,7 @@ def __new__( model=model, loss=loss, num_estimators=1, - optimization_procedure=optimization_procedure, + optim_recipe=optim_recipe, ) # baseline_type == "ensemble": model = deep_ensembles( @@ -108,6 +108,6 @@ def __new__( model=model, loss=loss, num_estimators=2, - optimization_procedure=optimization_procedure, + optim_recipe=optim_recipe, format_batch_fn=RepeatTarget(2), ) diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index 12980409..bf8abab7 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -4,7 +4,7 @@ from torch_uncertainty.baselines.classification import ResNet, WideResNet -# from torch_uncertainty.optimization_procedures import ( +# from torch_uncertainty.optim_recipes import ( # optim_cifar10_wideresnet, # optim_cifar100_resnet18, # optim_cifar100_resnet50, diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index 4fc82060..c12a880c 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -5,7 +5,7 @@ from torch_uncertainty.baselines.classification import ResNet, WideResNet -# from torch_uncertainty.optimization_procedures import ( +# from torch_uncertainty.optim_recipes import ( # optim_cifar10_wideresnet, # optim_cifar100_resnet18, # optim_cifar100_resnet50, diff --git a/tests/baselines/test_mc_dropout.py b/tests/baselines/test_mc_dropout.py index 12f0725d..5eeaa8c6 100644 --- a/tests/baselines/test_mc_dropout.py +++ b/tests/baselines/test_mc_dropout.py @@ -4,7 +4,7 @@ from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet -# from torch_uncertainty.optimization_procedures import ( +# from torch_uncertainty.optim_recipes import ( # optim_cifar10_resnet18, # optim_cifar10_wideresnet, # ) diff --git a/tests/baselines/test_mimo.py b/tests/baselines/test_mimo.py index 3adf8bd6..5db246d6 100644 --- a/tests/baselines/test_mimo.py +++ b/tests/baselines/test_mimo.py @@ -4,7 +4,7 @@ from torch_uncertainty.baselines.classification import ResNet, WideResNet -# from torch_uncertainty.optimization_procedures import ( +# from torch_uncertainty.optim_recipes import ( # optim_cifar10_resnet18, # optim_cifar10_resnet50, # optim_cifar10_wideresnet, diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 685b54eb..30df930c 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -9,7 +9,7 @@ DummyClassificationDataModule, dummy_model, ) -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine @@ -29,7 +29,7 @@ def test_one_estimator_binary(self): in_channels=dm.num_channels, num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, ensemble=False, ood_criterion="msp", ) @@ -52,7 +52,7 @@ def test_two_estimators_binary(self): in_channels=dm.num_channels, num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, ensemble=True, ood_criterion="logit", ) @@ -76,7 +76,7 @@ def test_one_estimator_two_classes(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, ensemble=False, ood_criterion="entropy", eval_ood=True, @@ -100,7 +100,7 @@ def test_two_estimators_two_classes(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, ensemble=True, ood_criterion="energy", ) @@ -181,7 +181,7 @@ def test_classification_failures(self): # num_classes=dm.num_classes, # in_channels=dm.num_channels, # loss=DECLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # ensemble=False, # **vars(args), # ) @@ -227,7 +227,7 @@ def test_classification_failures(self): # num_classes=list_dm[i].dm.num_classes, # in_channels=list_dm[i].dm.num_channels, # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # ensemble=False, # calibration_set=dm.get_val_set, # **vars(args), @@ -276,7 +276,7 @@ def test_classification_failures(self): # num_classes=list_dm[i].dm.num_classes, # in_channels=list_dm[i].dm.num_channels, # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # ensemble=False, # calibration_set=dm.get_val_set, # **vars(args), @@ -295,7 +295,7 @@ def test_classification_failures(self): # num_classes=dm.num_classes, # in_channels=dm.num_channels, # loss=nn.BCEWithLogitsLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # ensemble=True, # **vars(args), # ) @@ -316,7 +316,7 @@ def test_classification_failures(self): # num_classes=dm.num_classes, # in_channels=dm.num_channels, # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # ensemble=True, # **vars(args), # ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 1ef3387d..a00c2541 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -6,7 +6,7 @@ from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule from torch_uncertainty.losses import DistributionNLL -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import RegressionRoutine @@ -24,7 +24,7 @@ def test_one_estimator_one_output(self): in_features=dm.in_features, num_outputs=1, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -38,7 +38,7 @@ def test_one_estimator_one_output(self): in_features=dm.in_features, num_outputs=1, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -58,7 +58,7 @@ def test_one_estimator_two_outputs(self): in_features=dm.in_features, num_outputs=2, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="single", dist_type="laplace", ) @@ -71,7 +71,7 @@ def test_one_estimator_two_outputs(self): in_features=dm.in_features, num_outputs=2, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) trainer.fit(model, dm) @@ -89,7 +89,7 @@ def test_two_estimators_one_output(self): in_features=dm.in_features, num_outputs=1, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", dist_type="laplace", ) @@ -102,7 +102,7 @@ def test_two_estimators_one_output(self): in_features=dm.in_features, num_outputs=1, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) trainer.fit(model, dm) @@ -120,7 +120,7 @@ def test_two_estimators_two_outputs(self): in_features=dm.in_features, num_outputs=2, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) trainer.fit(model, dm) @@ -133,7 +133,7 @@ def test_two_estimators_two_outputs(self): in_features=dm.in_features, num_outputs=2, loss=DistributionNLL, - optimization_procedure=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) trainer.fit(model, dm) diff --git a/tests/test_cli.py b/tests/test_cli.py index cd584b1f..c67391cc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,7 +9,7 @@ # from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet # from torch_uncertainty.baselines.regression import MLP # from torch_uncertainty.datamodules import CIFAR10DataModule, UCIDataModule -# from torch_uncertainty.optimization_procedures import ( +# from torch_uncertainty.optim_recipes import ( # optim_cifar10_resnet18, # optim_cifar10_vgg16, # optim_cifar10_wideresnet, @@ -40,7 +40,7 @@ # in_channels=dm.num_channels, # style="cifar", # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # **vars(args), # ) @@ -80,7 +80,7 @@ # in_channels=dm.num_channels, # style="cifar", # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # **vars(args), # ) @@ -101,7 +101,7 @@ # num_classes=dm.num_classes, # in_channels=dm.num_channels, # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_wideresnet, +# optim_recipe=optim_cifar10_wideresnet, # **vars(args), # ) @@ -122,7 +122,7 @@ # num_classes=dm.num_classes, # in_channels=dm.num_channels, # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_vgg16, +# optim_recipe=optim_cifar10_vgg16, # **vars(args), # ) @@ -145,7 +145,7 @@ # hidden_dims=[], # dist_estimation=1, # loss=nn.MSELoss, -# optimization_procedure=optim_regression, +# optim_recipe=optim_regression, # **vars(args), # ) @@ -173,7 +173,7 @@ # hidden_dims=[], # dist_estimation=1, # loss=nn.MSELoss, -# optimization_procedure=optim_regression, +# optim_recipe=optim_regression, # **vars(args), # ) # with pytest.raises(ValueError): @@ -208,7 +208,7 @@ # in_channels=list_dm[i].dm.num_channels, # style="cifar", # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # **vars(args), # ) # for i in range(len(list_dm)) @@ -245,7 +245,7 @@ # in_channels=list_dm[i].dm.num_channels, # style="cifar", # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # **vars(args), # ) # ) @@ -281,7 +281,7 @@ # in_channels=list_dm[i].dm.num_channels, # style="cifar", # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # **vars(args), # ) # ) @@ -317,7 +317,7 @@ # in_channels=list_dm[i].dm.num_channels, # style="cifar", # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # **vars(args), # ) # ) @@ -353,7 +353,7 @@ # in_channels=list_dm[i].dm.num_channels, # style="cifar", # loss=nn.CrossEntropyLoss, -# optimization_procedure=optim_cifar10_resnet18, +# optim_recipe=optim_cifar10_resnet18, # **vars(args), # ) # ) diff --git a/tests/test_optimization_procedures.py b/tests/test_optimization_procedures.py index 8a120057..74522afc 100644 --- a/tests/test_optimization_procedures.py +++ b/tests/test_optimization_procedures.py @@ -4,7 +4,7 @@ from torch_uncertainty.models.resnet import resnet18, resnet34, resnet50 from torch_uncertainty.models.vgg import vgg16 from torch_uncertainty.models.wideresnet import wideresnet28x10 -from torch_uncertainty.optimization_procedures import ( +from torch_uncertainty.optim_recipes import ( get_procedure, optim_regression, ) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index 727cd696..f3dc8efb 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -42,7 +42,7 @@ def __init__( checkpoint_path=ckpt_file, hparams_file=hparams_file, loss=None, - optimization_procedure=None, + optim_recipe=None, ).eval() models.append(trained_model.model) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 21e5c825..1144dfde 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -142,7 +142,7 @@ def __init__( num_classes (int): Number of classes to predict. in_channels (int): Number of input channels. loss (nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to + optim_recipe (Any): optimization recipe, corresponds to what expect the `LightningModule.configure_optimizers() `_ method. diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index c1734a4a..9b46a19d 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -70,7 +70,7 @@ def __init__( num_classes (int): Number of classes to predict. in_channels (int): Number of input channels. loss (nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to + optim_recipe (Any): optimization recipe, corresponds to what expect the `LightningModule.configure_optimizers() `_ method. diff --git a/torch_uncertainty/optimization_procedures.py b/torch_uncertainty/optim_recipes.py similarity index 95% rename from torch_uncertainty/optimization_procedures.py rename to torch_uncertainty/optim_recipes.py index 07d7f21d..9782a5e1 100644 --- a/torch_uncertainty/optimization_procedures.py +++ b/torch_uncertainty/optim_recipes.py @@ -270,7 +270,7 @@ def optim_cifar100_resnet34( def optim_tinyimagenet_resnet34( model: nn.Module, ) -> dict[str, Optimizer | LRScheduler]: - """Optimization procedure from 'The Devil is in the Margin: Margin-based + """Optimization recipe from 'The Devil is in the Margin: Margin-based Label Smoothing for Network Calibration', (CVPR 2022, https://arxiv.org/abs/2111.15430): 'We train for 100 epochs with a learning rate of 0.1 for the first @@ -295,7 +295,7 @@ def optim_tinyimagenet_resnet34( def optim_tinyimagenet_resnet50( model: nn.Module, ) -> dict[str, Optimizer | LRScheduler]: - """Optimization procedure from 'The Devil is in the Margin: Margin-based + """Optimization recipe from 'The Devil is in the Margin: Margin-based Label Smoothing for Network Calibration', (CVPR 2022, https://arxiv.org/abs/2111.15430): 'We train for 100 epochs with a learning rate of 0.1 for the first @@ -332,10 +332,8 @@ def optim_regression( } -def batch_ensemble_wrapper( - model: nn.Module, optimization_procedure: Callable -) -> dict: - procedure = optimization_procedure(model) +def batch_ensemble_wrapper(model: nn.Module, optim_recipe: Callable) -> dict: + procedure = optim_recipe(model) param_optimizer = procedure["optimizer"] scheduler = procedure["lr_scheduler"] @@ -379,7 +377,7 @@ def get_procedure( method: str = "", imagenet_recipe: str | None = None, ) -> Callable: - """Get the optimization procedure for a given architecture and dataset. + """Get the optimization recipe for a given architecture and dataset. Args: arch_name (str): The name of the architecture. @@ -389,7 +387,7 @@ def get_procedure( ImageNet. Defaults to None. Returns: - callable: The optimization procedure. + callable: The optimization recipe. """ if arch_name in ["resnet18", "resnet20"]: if ds_name == "cifar10": @@ -437,8 +435,6 @@ def get_procedure( raise NotImplementedError(f"No recipe for architecture: {arch_name}.") if method == "batched": - procedure = partial( - batch_ensemble_wrapper, optimization_procedure=procedure - ) + procedure = partial(batch_ensemble_wrapper, optim_recipe=procedure) return procedure diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 6544d3a5..4de13148 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -43,7 +43,7 @@ def __init__( loss: type[nn.Module], num_estimators: int = 1, format_batch_fn: nn.Module | None = None, - optimization_procedure=None, + optim_recipe=None, mixtype: str = "erm", mixmode: str = "elem", dist_sim: str = "emb", @@ -70,7 +70,7 @@ def __init__( ensemble. Defaults to 1. format_batch_fn (nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. - optimization_procedure (optional): Training recipe. Defaults to None. + optim_recipe (optional): Training recipe. Defaults to None. mixtype (str, optional): Mixup type. Defaults to ``"erm"``. mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. @@ -164,7 +164,7 @@ def __init__( self.model = model self.loss = loss self.format_batch_fn = format_batch_fn - self.optimization_procedure = optimization_procedure + self.optim_recipe = optim_recipe # metrics if self.binary_cls: @@ -304,7 +304,7 @@ def init_mixup( return nn.Identity() def configure_optimizers(self): - return self.optimization_procedure(self.model) + return self.optim_recipe(self.model) def on_train_start(self) -> None: init_metrics = {k: 0 for k in self.val_cls_metrics} diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index f846228a..86f3ae2b 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -23,7 +23,7 @@ def __init__( loss: type[nn.Module], num_estimators: int = 1, format_batch_fn: nn.Module | None = None, - optimization_procedure=None, + optim_recipe=None, ) -> None: """Regression routine for PyTorch Lightning. @@ -37,7 +37,7 @@ def __init__( ensemble. Defaults to 1. format_batch_fn (nn.Module, optional): The function to format the batch. Defaults to None. - optimization_procedure (optional): The optimization procedure + optim_recipe (optional): The optimization recipe to use. Defaults to None. Warning: @@ -45,7 +45,7 @@ def __init__( distribution _`. Warning: - You must define :attr:`optimization_procedure` if you do not use + You must define :attr:`optim_recipe` if you do not use the CLI. """ super().__init__() @@ -88,10 +88,10 @@ def __init__( if num_outputs == 1: self.one_dim_regression = True - self.optimization_procedure = optimization_procedure + self.optim_recipe = optim_recipe def configure_optimizers(self): - return self.optimization_procedure(self.model) + return self.optim_recipe(self.model) def on_train_start(self) -> None: # hyperparameters for performances From 2472ba1c1b54a4964bec1e63b5113b067e45fed3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 19 Mar 2024 13:54:58 +0100 Subject: [PATCH 062/148] :sparkles: Add Normal Inverse Gamma dist. --- torch_uncertainty/datasets/regression/toy.py | 4 +- torch_uncertainty/layers/distributions.py | 23 +++++ torch_uncertainty/losses.py | 85 +++++++++-------- torch_uncertainty/models/mlp.py | 2 +- torch_uncertainty/utils/distributions.py | 99 +++++++++++++++++++- 5 files changed, 168 insertions(+), 45 deletions(-) diff --git a/torch_uncertainty/datasets/regression/toy.py b/torch_uncertainty/datasets/regression/toy.py index 161ddb8f..d5c3cdfb 100644 --- a/torch_uncertainty/datasets/regression/toy.py +++ b/torch_uncertainty/datasets/regression/toy.py @@ -28,6 +28,6 @@ def __init__( samples = torch.linspace( lower_bound, upper_bound, num_samples - ).unsqueeze(1) + ).unsqueeze(-1) targets = samples**3 + torch.normal(*noise, size=samples.size()) - super().__init__(samples, targets) + super().__init__(samples, targets.squeeze(-1)) diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index f7514c69..cf08278d 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -2,6 +2,8 @@ from torch import Tensor, nn from torch.distributions import Distribution, Laplace, Normal +from torch_uncertainty.utils.distributions import NormalInverseGamma + class AbstractDistLayer(nn.Module): def __init__(self, dim: int) -> None: @@ -54,3 +56,24 @@ def forward(self, x: Tensor) -> Laplace: loc = x[:, : self.dim] scale = F.softplus(x[:, self.dim :]) + self.min_scale return Laplace(loc, scale) + + +class IndptNormalInverseGammaLayer(AbstractDistLayer): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__(dim) + self.eps = eps + + def forward(self, x: Tensor) -> Laplace: + """Forward pass of the independent Laplace distribution layer. + + Args: + x (Tensor): The input tensor of shape (dx2). + + Returns: + Laplace: The independent Laplace distribution. + """ + loc = x[:, : self.dim] + lmbda = F.softplus(x[:, self.dim : 2 * self.dim]) + self.eps + alpha = 1 + F.softplus(x[:, 2 * self.dim : 3 * self.dim]) + self.eps + beta = F.softplus(x[:, 3 * self.dim :]) + self.eps + return NormalInverseGamma(loc, lmbda, alpha, beta) diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index 6e60153a..31b51afe 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -1,8 +1,34 @@ +from typing import Literal + import torch from torch import Tensor, distributions, nn from torch.nn import functional as F -from .layers.bayesian import bayesian_modules +from torch_uncertainty.layers.bayesian import bayesian_modules +from torch_uncertainty.utils.distributions import NormalInverseGamma + + +class DistributionNLL(nn.Module): + def __init__( + self, reduction: Literal["mean", "sum"] | None = "mean" + ) -> None: + """Negative Log-Likelihood loss for a given distribution. + + Args: + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. Defaults to "mean". + + """ + super().__init__() + self.reduction = reduction + + def forward(self, dist: distributions.Distribution, target: Tensor): + loss = -dist.log_prob(target) + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + return loss class KLDiv(nn.Module): @@ -98,12 +124,15 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: return aggregated_elbo / self.num_samples -class NIGLoss(nn.Module): +class DERLoss(DistributionNLL): def __init__( self, reg_weight: float, reduction: str | None = "mean" ) -> None: """The Normal Inverse-Gamma loss. + This loss combines the negative log-likelihood loss of the normal + inverse gamma distribution and a weighted regularization term. + Args: reg_weight (float): The weight of the regularization term. reduction (str, optional): specifies the reduction to apply to the @@ -113,7 +142,11 @@ def __init__( Amini, A., Schwarting, W., Soleimany, A., & Rus, D. (2019). Deep evidential regression. https://arxiv.org/abs/1910.02600. """ - super().__init__() + super().__init__(reduction=None) + + if reduction not in (None, "none", "mean", "sum"): + raise ValueError(f"{reduction} is not a valid value for reduction.") + self.final_reduction = reduction if reg_weight < 0: raise ValueError( @@ -121,49 +154,24 @@ def __init__( f"{reg_weight}." ) self.reg_weight = reg_weight - if reduction not in ("none", "mean", "sum"): - raise ValueError(f"{reduction} is not a valid value for reduction.") - self.reduction = reduction - def _nig_nll( - self, - gamma: Tensor, - v: Tensor, - alpha: Tensor, - beta: Tensor, - targets: Tensor, - ) -> Tensor: - gam = 2 * beta * (1 + v) - return ( - 0.5 * torch.log(torch.pi / v) - - alpha * gam.log() - + (alpha + 0.5) * torch.log(gam + v * (targets - gamma) ** 2) - + torch.lgamma(alpha) - - torch.lgamma(alpha + 0.5) - ) - - def _nig_reg( - self, gamma: Tensor, v: Tensor, alpha: Tensor, targets: Tensor - ) -> Tensor: - return torch.norm(targets - gamma, 1, dim=1, keepdim=True) * ( - 2 * v + alpha + def _reg(self, dist: NormalInverseGamma, targets: Tensor) -> Tensor: + return torch.norm(targets - dist.loc, 1, dim=1, keepdim=True) * ( + 2 * dist.lmbda + dist.alpha ) def forward( self, - gamma: Tensor, - v: Tensor, - alpha: Tensor, - beta: Tensor, + dist: NormalInverseGamma, targets: Tensor, ) -> Tensor: - loss_nll = self._nig_nll(gamma, v, alpha, beta, targets) - loss_reg = self._nig_reg(gamma, v, alpha, targets) + loss_nll = super().forward(dist, targets) + loss_reg = self._reg(dist, targets) loss = loss_nll + self.reg_weight * loss_reg - if self.reduction == "mean": + if self.final_reduction == "mean": return loss.mean() - if self.reduction == "sum": + if self.final_reduction == "sum": return loss.sum() return loss @@ -376,8 +384,3 @@ def forward( elif self.reduction == "sum": loss = loss.sum() return loss - - -class DistributionNLL(nn.Module): - def forward(self, dist: distributions.Distribution, target: Tensor): - return -dist.log_prob(target).mean() diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index c6e82ca3..deac7f23 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -81,7 +81,7 @@ def forward(self, x: Tensor) -> Tensor: for layer in self.layers[:-1]: x = F.dropout(layer(x), p=self.dropout_rate, training=self.training) x = self.activation(x) - return self.layers[-1](x) + return self.final_layer(self.layers[-1](x)) @stochastic_model diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index c87617d2..2a1f3aa6 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -1,6 +1,10 @@ +from numbers import Number + import torch from einops import rearrange -from torch.distributions import Distribution, Laplace, Normal +from torch import Tensor +from torch.distributions import Distribution, Laplace, Normal, constraints +from torch.distributions.utils import broadcast_all def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: @@ -48,6 +52,12 @@ def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: loc = distribution.loc.squeeze(dim) scale = distribution.scale.squeeze(dim) return dist_type(loc=loc, scale=scale) + if isinstance(distribution, NormalInverseGamma): + loc = distribution.loc.squeeze(dim) + lmbda = distribution.lmbda.squeeze(dim) + alpha = distribution.alpha.squeeze(dim) + beta = distribution.beta.squeeze(dim) + return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( f"Squeezing of {dist_type} distributions is not supported." "Raise an issue if needed." @@ -64,7 +74,94 @@ def to_ensemble_dist( distribution.scale, "(n b) c -> b n c", n=num_estimators ) return dist_type(loc=loc, scale=scale) + if isinstance(distribution, NormalInverseGamma): + loc = rearrange(distribution.loc, "(n b) c -> b n c", n=num_estimators) + lmbda = rearrange( + distribution.lmbda, "(n b) c -> b n c", n=num_estimators + ) + alpha = rearrange( + distribution.alpha, "(n b) c -> b n c", n=num_estimators + ) + beta = rearrange( + distribution.beta, "(n b) c -> b n c", n=num_estimators + ) + return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( f"Ensemble distribution of {dist_type} is not supported." "Raise an issue if needed." ) + + +class NormalInverseGamma(Distribution): + arg_constraints = { + "loc": constraints.real, + "lmbda": constraints.positive, + "alpha": constraints.greater_than(1), + "beta": constraints.positive, + } + support = constraints.real + has_rsample = False + + def __init__(self, loc, lmbda, alpha, beta, validate_args=None): + self.loc, self.lmbda, self.alpha, self.beta = broadcast_all( + loc, lmbda, alpha, beta + ) + if ( + isinstance(loc, Number) + and isinstance(lmbda, Number) + and isinstance(alpha, Number) + and isinstance(beta, Number) + ): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + @property + def mean(self): + """Impromper mean of the NormalInverseGamma distribution. + + This value is necessary to perform point-wise predictions in the + regression routine. + """ + return self.loc + + def mode(self): + raise NotImplementedError( + "Mode is not meaningful for the NormalInverseGamma distribution" + ) + + def stddev(self): + raise NotImplementedError( + "Standard deviation is not meaningful for the NormalInverseGamma distribution" + ) + + def variance(self): + raise NotImplementedError( + "Variance is not meaningful for the NormalInverseGamma distribution" + ) + + @property + def mean_loc(self) -> Tensor: + return self.loc + + @property + def mean_variance(self) -> Tensor: + return self.beta / (self.alpha - 1) + + @property + def variance_loc(self) -> Tensor: + return self.beta / (self.alpha - 1) / self.lmbda + + def log_prob(self, value: Tensor) -> Tensor: + if self._validate_args: + self._validate_sample(value) + gam: Tensor = 2 * self.beta * (1 + self.lmbda) + return ( + -0.5 * torch.log(torch.pi / self.lmbda) + + self.alpha * gam.log() + - (self.alpha + 0.5) + * torch.log(gam + self.lmbda * (value - self.loc) ** 2) + - torch.lgamma(self.alpha) + + torch.lgamma(self.alpha + 0.5) + ) From d470ab2d22797b748377243515d5576a77dc37aa Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 19 Mar 2024 13:55:44 +0100 Subject: [PATCH 063/148] :book: Update the DER tutorial --- auto_tutorials_source/tutorial_der_cubic.py | 129 +++++++++----------- 1 file changed, 60 insertions(+), 69 deletions(-) diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 3470689b..45cbe122 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -2,9 +2,14 @@ Deep Evidential Regression on a Toy Example =========================================== -This tutorial aims to provide an introductory overview of Deep Evidential Regression (DER) using a practical example. We demonstrate an application of DER by tackling the toy-problem of fitting :math:`y=x^3` using a Multi-Layer Perceptron (MLP) neural network model. The output layer of the MLP has four outputs, and is trained by minimizing the Normal Inverse-Gamma (NIG) loss function. +This tutorial provides an introduction to probabilistic regression in TorchUncertainty. -DER represents an evidential approach to quantifying uncertainty in neural network regression models. This method involves introducing prior distributions over the parameters of the Gaussian likelihood function. Then, the MLP model estimate the parameters of the evidential distribution. +More specifically, we present Deep Evidential Regression (DER) using a practical example. We demonstrate an application of DER by tackling the toy-problem of fitting :math:`y=x^3` using a Multi-Layer Perceptron (MLP) neural network model. +The output layer of the MLP provides a NormalInverseGamma distribution which is used to optimize the model, trhough its negative log-likelihood. + +DER represents an evidential approach to quantifying epistemic and aleatoric uncertainty in neural network regression models. +This method involves introducing prior distributions over the parameters of the Gaussian likelihood function. +Then, the MLP model estimates the parameters of this evidential distribution. Training a MLP with DER using TorchUncertainty models and PyTorch Lightning --------------------------------------------------------------------------- @@ -14,42 +19,36 @@ 1. Loading the utilities ~~~~~~~~~~~~~~~~~~~~~~~~ -To train a MLP with the NIG loss function using TorchUncertainty, we have to load the following utilities from TorchUncertainty: +To train a MLP with the DER loss function using TorchUncertainty, we have to load the following modules: -- the cli handler: cli_main and argument parser: init_args -- the model: mlp, which lies in the torch_uncertainty.baselines.regression.mlp module. -- the regression training routine in the torch_uncertainty.routines.regression module. -- the evidential objective: the NIGLoss, which lies in the torch_uncertainty.losses file -- a dataset that generates samples from a noisy cubic function: Cubic, which lies in the torch_uncertainty.datasets.regression -""" +- the Trainer from Lightning +- the model: mlp from torch_uncertainty.models.mlp +- the regression training routine from torch_uncertainty.routines +- the evidential objective: the DERLoss from torch_uncertainty.losses. This loss contains the classic NLL loss and a regularization term. +- a dataset that generates samples from a noisy cubic function: Cubic from torch_uncertainty.datasets.regression -from pytorch_lightning import LightningDataModule -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines.regression.mlp import mlp -from torch_uncertainty.datasets.regression.toy import Cubic -from torch_uncertainty.losses import NIGLoss -from torch_uncertainty.routines.regression import RegressionSingle +We also need to define an optimizer using torch.optim, the neural network utils within torch.nn, as well as the partial util to provide +the modified default arguments for the DER loss. +""" # %% -# We also need to define an optimizer using torch.optim as well as the -# neural network utils withing torch.nn, as well as the partial util to provide -# the modified default arguments for the NIG loss. -# -# We also import sys to override the command line arguments. - -import os -import sys from functools import partial -from pathlib import Path import torch +from lightning.pytorch import Trainer +from lightning import LightningDataModule from torch import nn, optim -# %% -# 2. Creating the Optimizer Wrapper -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# We use the Adam optimizer with the default learning rate of 0.001. +from torch_uncertainty.models.mlp import mlp +from torch_uncertainty.datasets.regression.toy import Cubic +from torch_uncertainty.losses import DERLoss +from torch_uncertainty.routines import RegressionRoutine +from torch_uncertainty.layers.distributions import IndptNormalInverseGammaLayer +# %% +# 2. The Optimization Recipe +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# We use the Adam optimizer with a rate of 5e-4. def optim_regression( model: nn.Module, @@ -69,85 +68,77 @@ def optim_regression( # 3. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In the following, we need to define the root of the logs, and to -# fake-parse the arguments needed for using the PyTorch Lightning Trainer. We -# also use the same synthetic regression task example as that used in the -# original DER paper. - -root = Path(os.path.abspath("")) +# In the following, we create a trainer to train the model, the same synthetic regression +# datasets as in the original DER paper and the model, a simple MLP with 2 hidden layers of 64 neurons each. +# Please note that this MLP finishes with a IndptNormalInverseGammaLayer that interpret the outputs of the model +# as the parameters of a Normal Inverse Gamma distribution. -# We mock the arguments for the trainer -sys.argv = ["file.py", "--max_epochs", "50", "--enable_progress_bar", "False"] -args = init_args() - -net_name = "logs/der-mlp-cubic" +trainer = Trainer(accelerator="cpu", max_epochs=50)#, enable_progress_bar=False) # dataset train_ds = Cubic(num_samples=1000) val_ds = Cubic(num_samples=300) -test_ds = train_ds # datamodule datamodule = LightningDataModule.from_datasets( - train_ds, val_dataset=val_ds, test_dataset=test_ds, batch_size=32 + train_ds, val_dataset=val_ds, test_dataset=val_ds, batch_size=32 ) datamodule.training_task = "regression" # model -model = mlp(in_features=1, num_outputs=4, hidden_dims=[64, 64]) +model = mlp( + in_features=1, + num_outputs=4, + hidden_dims=[64, 64], + final_layer=IndptNormalInverseGammaLayer, + final_layer_args={"dim": 1}, +) # %% # 4. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Next, we need to define the loss to be used during training. To do this, we -# redefine the default parameters for the NIG loss using the partial +# set the weight of the regularizer of the DER Loss in advance using the partial # function from functools. After that, we define the training routine using -# the regression training routine from torch_uncertainty.routines.regression. In -# this routine, we provide the model, the NIG loss, and the optimizer, -# along with the dist_estimation parameter, which refers to the number of -# distribution parameters, and all the default arguments. +# the probabilistic regression training routine from torch_uncertainty.routines. In +# this routine, we provide the model, the DER loss, and the optimization recipe. loss = partial( - NIGLoss, + DERLoss, reg_weight=1e-2, ) -baseline = RegressionSingle( +routine = RegressionRoutine( + probabilistic=True, + num_outputs=1, model=model, loss=loss, - optimization_procedure=optim_regression, - dist_estimation=4, - **vars(args), + optim_recipe=optim_regression, ) # %% # 5. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Finally, we train the model using the trainer and the regression routine. We also +# test the model using the same trainer -results = cli_main(baseline, datamodule, root, net_name, args) +trainer.fit(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule) # %% # 6. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ +# We can now test the model by plotting the predictions and the uncertainty estimates. +# In this specific case, we can reproduce the results of the paper. import matplotlib.pyplot as plt -from torch.nn import functional as F with torch.no_grad(): - x = torch.linspace(-7, 7, 1000).unsqueeze(-1) - - logits = model(x) - means, v, alpha, beta = logits.split(1, dim=-1) - - v = F.softplus(v) - alpha = 1 + F.softplus(alpha) - beta = F.softplus(beta) - - vars = torch.sqrt(beta / (v * (alpha - 1))) + x = torch.linspace(-7, 7, 1000) - means.squeeze_(1) - vars.squeeze_(1) - x.squeeze_(1) + dists = model(x.unsqueeze(-1)) + means = dists.loc.squeeze(1) + variances = torch.sqrt(dists.variance_loc).squeeze(1) fig, ax = plt.subplots(1, 1) ax.plot(x, x**3, "--r", label="ground truth", zorder=3) @@ -155,8 +146,8 @@ def optim_regression( for k in torch.linspace(0, 4, 4): ax.fill_between( x, - means - k * vars, - means + k * vars, + means - k * variances, + means + k * variances, linewidth=0, alpha=0.3, edgecolor=None, From 344832a04374fbfbb6ba2eecd66cb6028a13afc2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 19 Mar 2024 14:07:08 +0100 Subject: [PATCH 064/148] :heavy_check_mark: Adapt tests to new implementation --- tests/test_losses.py | 39 ++++++++++++++---------- torch_uncertainty/utils/distributions.py | 2 +- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/tests/test_losses.py b/tests/test_losses.py index e907b7ba..24eeb06a 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -5,7 +5,8 @@ from torch import nn from torch_uncertainty.layers.bayesian import BayesLinear -from torch_uncertainty.losses import BetaNLL, DECLoss, ELBOLoss, NIGLoss +from torch_uncertainty.layers.distributions import NormalInverseGamma +from torch_uncertainty.losses import BetaNLL, DECLoss, DERLoss, ELBOLoss class TestELBOLoss: @@ -44,44 +45,50 @@ def test_no_bayes(self): class TestNIGLoss: - """Testing the NIGLoss class.""" + """Testing the DERLoss class.""" def test_main(self): - loss = NIGLoss(reg_weight=1e-2) - - inputs = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) + loss = DERLoss(reg_weight=1e-2) + layer = NormalInverseGamma + inputs = layer( + torch.ones(1), torch.ones(1), torch.ones(1), torch.ones(1) + ) targets = torch.tensor([[1.0]], dtype=torch.float32) - assert loss(*inputs.split(1, dim=-1), targets) == pytest.approx( - 2 * math.log(2) - ) + assert loss(inputs, targets) == pytest.approx(2 * math.log(2)) - loss = NIGLoss( + loss = DERLoss( reg_weight=1e-2, reduction="sum", ) + inputs = layer( + torch.ones((2, 1)), + torch.ones((2, 1)), + torch.ones((2, 1)), + torch.ones((2, 1)), + ) assert loss( - *inputs.repeat(2, 1).split(1, dim=-1), - targets.repeat(2, 1), + inputs, + targets, ) == pytest.approx(4 * math.log(2)) - loss = NIGLoss( + loss = DERLoss( reg_weight=1e-2, reduction="none", ) assert loss( - *inputs.repeat(2, 1).split(1, dim=-1), - targets.repeat(2, 1), + inputs, + targets, ) == pytest.approx([2 * math.log(2), 2 * math.log(2)]) def test_failures(self): with pytest.raises(ValueError): - NIGLoss(reg_weight=-1) + DERLoss(reg_weight=-1) with pytest.raises(ValueError): - NIGLoss(reg_weight=1.0, reduction="median") + DERLoss(reg_weight=1.0, reduction="median") class TestDECLoss: diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 2a1f3aa6..26422a35 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -96,7 +96,7 @@ class NormalInverseGamma(Distribution): arg_constraints = { "loc": constraints.real, "lmbda": constraints.positive, - "alpha": constraints.greater_than(1), + "alpha": constraints.greater_than_eq(1), "beta": constraints.positive, } support = constraints.real From b2399fb24eb2ab665e5b6456e0efbc7d5aa389e0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 19 Mar 2024 15:13:51 +0100 Subject: [PATCH 065/148] :shirt: Small improvement of the scaling tutorial --- auto_tutorials_source/tutorial_scaler.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index b67f21ca..a4f6fae9 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -3,10 +3,9 @@ ====================================================== In this tutorial, we use *TorchUncertainty* to improve the calibration -of the top-label predictions -and the reliability of the underlying neural network. +of the top-label predictions and the reliability of the underlying neural network. -We also see how to use the datamodules outside any Lightning trainers, +We also see how to use the datamodules outside any Lightning trainers, and how to use TorchUncertainty's models. 1. Loading the Utilities @@ -17,11 +16,12 @@ - torch for its objects - the "calibration error" metric to compute the ECE and evaluate the top-label calibration - the CIFAR-100 datamodule to handle the data -- a ResNet 18 as starting model +- a ResNet 18 as starting model - the temperature scaler to improve the top-label calibration - a utility to download hf models easily -- the calibration plot to visualize the calibration. If you use the classification routine, - the plots will be automatically available in the tensorboard logs. +- the calibration plot to visualize the calibration. + +If you use the classification routine, the plots will be automatically available in the tensorboard logs. """ from torch_uncertainty.datamodules import CIFAR100DataModule @@ -52,7 +52,8 @@ # # To get the dataloader from the datamodule, just call prepare_data, setup, and # extract the first element of the test dataloader list. There are more than one -# element if `:attr:eval_ood` is True. +# element if eval_ood is True: the dataloader of in-distribution data and the dataloader +# of out-of-distribution data. Otherwise, it is a list of 1 element. dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32) dm.prepare_data() @@ -61,7 +62,6 @@ # Get the full test dataloader (unused in this tutorial) dataloader = dm.test_dataloader()[0] - # %% # 4. Iterating on the Dataloader and Computing the ECE # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -93,8 +93,7 @@ ece.update(probs, target) # Compute & print the calibration error -cal = ece.compute() -print(f"ECE before scaling - {cal*100:.3}%.") +print(f"ECE before scaling - {ece.compute():.3%}.") # %% # We also compute and plot the top-label calibration figure. We see that the @@ -133,8 +132,7 @@ probs = logits.softmax(-1) ece.update(probs, target) -cal = ece.compute() -print(f"ECE after scaling - {cal*100:.3}%.") +print(f"ECE after scaling - {ece.compute():.3%}.") # %% # We finally compute and plot the scaled top-label calibration figure. We see From dc86e5a86f29f56156a4464efb24735d36891402 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 20 Mar 2024 12:15:08 +0100 Subject: [PATCH 066/148] :white_check_mark: Small coverage improvement --- tests/_dummies/baseline.py | 25 ++++++++---- tests/routines/test_regression.py | 2 +- tests/test_utils.py | 47 +++++++++++++++++------ torch_uncertainty/layers/distributions.py | 4 +- torch_uncertainty/utils/distributions.py | 14 +++++++ 5 files changed, 70 insertions(+), 22 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index d8097f1f..f51f8c96 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -5,6 +5,7 @@ from torch_uncertainty.layers.distributions import ( IndptLaplaceLayer, + IndptNormalInverseGammaLayer, IndptNormalLayer, ) from torch_uncertainty.models.deep_ensembles import deep_ensembles @@ -75,17 +76,25 @@ def __new__( optim_recipe=None, dist_type: str = "normal", ) -> LightningModule: + if probabilistic: + if dist_type == "normal": + last_layer = IndptNormalLayer(num_outputs) + num_classes = num_outputs * 2 + elif dist_type == "laplace": + last_layer = IndptLaplaceLayer(num_outputs) + num_classes = num_outputs * 2 + else: # dist_type == "nig" + last_layer = IndptNormalInverseGammaLayer(num_outputs) + num_classes = num_outputs * 4 + else: + last_layer = nn.Identity() + num_classes = num_outputs + model = dummy_model( in_channels=in_features, - num_classes=num_outputs * 2 if probabilistic else num_outputs, + num_classes=num_classes, num_estimators=1, - last_layer=( - IndptNormalLayer(num_outputs) - if dist_type == "normal" - else IndptLaplaceLayer(num_outputs) - ) - if probabilistic - else nn.Identity(), + last_layer=last_layer, ) if baseline_type == "single": return RegressionRoutine( diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index a00c2541..4de12166 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -91,7 +91,7 @@ def test_two_estimators_one_output(self): loss=DistributionNLL, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", - dist_type="laplace", + dist_type="nig", ) trainer.fit(model, dm) trainer.test(model, dm) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8fc5b131..2d8c0e1b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,35 +1,60 @@ from pathlib import Path import pytest +import torch +from torch.distributions import Laplace, Normal -from torch_uncertainty import utils +from torch_uncertainty.utils import distributions, get_version, hub class TestUtils: """Testing utils methods.""" - def test_getversion_log_success(self): - utils.get_version("tests/testlog", version=42) - utils.get_version(Path("tests/testlog"), version=42) + def test_get_version_log_success(self): + get_version("tests/testlog", version=42) + get_version(Path("tests/testlog"), version=42) - utils.get_version("tests/testlog", version=42, checkpoint=45) + get_version("tests/testlog", version=42, checkpoint=45) def test_getversion_log_failure(self): with pytest.raises(Exception): - utils.get_version("tests/testlog", version=52) + get_version("tests/testlog", version=52) class TestHub: """Testing hub methods.""" def test_hub_exists(self): - utils.hub.load_hf("test") - utils.hub.load_hf("test", version=1) - utils.hub.load_hf("test", version=2) + hub.load_hf("test") + hub.load_hf("test", version=1) + hub.load_hf("test", version=2) def test_hub_notexists(self): with pytest.raises(Exception): - utils.hub.load_hf("tests") + hub.load_hf("tests") with pytest.raises(ValueError): - utils.hub.load_hf("test", version=42) + hub.load_hf("test", version=42) + + +class TestDistributions: + """Testing distributions methods.""" + + def test_nig(self): + dist = distributions.NormalInverseGamma( + torch.tensor(0.0), + torch.tensor(1.1), + torch.tensor(1.1), + torch.tensor(1.1), + ) + _ = dist.mean, dist.mean_loc, dist.mean_variance, dist.variance_loc + + def test_errors(self): + with pytest.raises(ValueError): + distributions.cat_dist( + [ + Normal(torch.tensor([0.0]), torch.tensor([1.0])), + Laplace(torch.tensor([0.0]), torch.tensor([1.0])), + ], + dim=0, + ) diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index cf08278d..70667299 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -17,7 +17,7 @@ def forward(self, x: Tensor) -> Distribution: class IndptNormalLayer(AbstractDistLayer): - def __init__(self, dim: int, min_scale: float = 1e-3) -> None: + def __init__(self, dim: int, min_scale: float = 1e-6) -> None: super().__init__(dim) if min_scale <= 0: raise ValueError(f"min_scale must be positive, got {min_scale}.") @@ -38,7 +38,7 @@ def forward(self, x: Tensor) -> Normal: class IndptLaplaceLayer(AbstractDistLayer): - def __init__(self, dim: int, min_scale: float = 1e-3) -> None: + def __init__(self, dim: int, min_scale: float = 1e-6) -> None: super().__init__(dim) if min_scale <= 0: raise ValueError(f"min_scale must be positive, got {min_scale}.") diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 26422a35..fd6fd39c 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -31,6 +31,20 @@ def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: [distribution.scale for distribution in distributions], dim=dim ) return dist_type(loc=locs, scale=scales) + if isinstance(distributions[0], NormalInverseGamma): + locs = torch.cat( + [distribution.loc for distribution in distributions], dim=dim + ) + lmbdas = torch.cat( + [distribution.lmbda for distribution in distributions], dim=dim + ) + alphas = torch.cat( + [distribution.alpha for distribution in distributions], dim=dim + ) + betas = torch.cat( + [distribution.beta for distribution in distributions], dim=dim + ) + return dist_type(loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas) raise NotImplementedError( f"Concatenation of {dist_type} distributions is not supported." "Raise an issue if needed." From bf220ccb83094836acdeeae04d3895c7808a4b57 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 20 Mar 2024 12:28:13 +0100 Subject: [PATCH 067/148] :hammer: Add 'Baseline' to the baseline class names --- .../tutorial_mc_batch_norm.py | 3 +-- auto_tutorials_source/tutorial_mc_dropout.py | 3 +-- .../classification/cifar10/deep_ensembles.py | 6 ++--- experiments/classification/cifar10/resnet.py | 4 +-- experiments/classification/cifar10/vgg.py | 4 +-- .../classification/cifar10/wideresnet.py | 4 +-- .../classification/cifar100/deep_ensembles.py | 6 ++--- experiments/classification/cifar100/resnet.py | 6 ++--- experiments/classification/cifar100/vgg.py | 6 ++--- .../classification/cifar100/wideresnet.py | 6 ++--- .../classification/tiny-imagenet/resnet.py | 8 +++--- .../regression/uci_datasets/deep_ensemble.py | 6 ++--- tests/baselines/test_batched.py | 17 ++++++------- tests/baselines/test_deep_ensembles.py | 4 +-- tests/baselines/test_masked.py | 21 +++++++--------- tests/baselines/test_mc_dropout.py | 25 +++++++++---------- tests/baselines/test_mimo.py | 11 +++++--- tests/baselines/test_packed.py | 18 +++++++------ tests/baselines/test_standard.py | 24 ++++++++++-------- torch_uncertainty/baselines/__init__.py | 6 ----- .../baselines/classification/__init__.py | 6 ++--- .../classification/deep_ensembles.py | 10 ++++---- .../baselines/classification/resnet.py | 2 +- .../baselines/classification/vgg.py | 2 +- .../baselines/classification/wideresnet.py | 2 +- 25 files changed, 103 insertions(+), 107 deletions(-) diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 16145683..1919b306 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -17,9 +17,8 @@ - the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule - the model: LeNet, which lies in torch_uncertainty.models - the mc-batch-norm wrapper: mc_dropout, which lies in torch_uncertainty.models -- a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module -- the optimizer wrapper in the torch_uncertainty.optim_recipes module. +- an optimization recipe in the torch_uncertainty.optim_recipes module. We also need import the neural network utils withing `torch.nn`. """ diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 49e56b67..4506ce7d 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -23,9 +23,8 @@ - the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule - the model: LeNet, which lies in torch_uncertainty.models - the mc-dropout wrapper: mc_dropout, which lies in torch_uncertainty.models -- a resnet baseline to get the command line arguments: ResNet, which lies in torch_uncertainty.baselines - the classification training routine in the torch_uncertainty.training.classification module -- the optimizer wrapper in the torch_uncertainty.optim_recipes module. +- an optimization recipe in the torch_uncertainty.optim_recipes module. We also need import the neural network utils within `torch.nn`. """ diff --git a/experiments/classification/cifar10/deep_ensembles.py b/experiments/classification/cifar10/deep_ensembles.py index f497092c..d7316811 100644 --- a/experiments/classification/cifar10/deep_ensembles.py +++ b/experiments/classification/cifar10/deep_ensembles.py @@ -1,11 +1,11 @@ from pathlib import Path from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines import DeepEnsemblesBaseline from torch_uncertainty.datamodules import CIFAR10DataModule if __name__ == "__main__": - args = init_args(DeepEnsembles, CIFAR10DataModule) + args = init_args(DeepEnsemblesBaseline, CIFAR10DataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -19,7 +19,7 @@ # model args.task = "classification" - model = DeepEnsembles( + model = DeepEnsemblesBaseline( **vars(args), num_classes=dm.num_classes, in_channels=dm.num_channels, diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 5610bcbe..66fc6cc4 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -2,7 +2,7 @@ from lightning.pytorch.cli import LightningArgumentParser from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 -from torch_uncertainty.baselines.classification import ResNet +from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.utils import TULightningCLI @@ -14,7 +14,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def cli_main() -> ResNetCLI: - return ResNetCLI(ResNet, CIFAR10DataModule) + return ResNetCLI(ResNetBaseline, CIFAR10DataModule) if __name__ == "__main__": diff --git a/experiments/classification/cifar10/vgg.py b/experiments/classification/cifar10/vgg.py index 90af58c4..393ecbc0 100644 --- a/experiments/classification/cifar10/vgg.py +++ b/experiments/classification/cifar10/vgg.py @@ -2,7 +2,7 @@ from lightning.pytorch.cli import LightningArgumentParser from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 -from torch_uncertainty.baselines.classification import VGG +from torch_uncertainty.baselines.classification import VGGBaseline from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.utils import TULightningCLI @@ -14,7 +14,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def cli_main() -> ResNetCLI: - return ResNetCLI(VGG, CIFAR10DataModule) + return ResNetCLI(VGGBaseline, CIFAR10DataModule) if __name__ == "__main__": diff --git a/experiments/classification/cifar10/wideresnet.py b/experiments/classification/cifar10/wideresnet.py index fdb47fa1..e30c4a8f 100644 --- a/experiments/classification/cifar10/wideresnet.py +++ b/experiments/classification/cifar10/wideresnet.py @@ -2,7 +2,7 @@ from lightning.pytorch.cli import LightningArgumentParser from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 -from torch_uncertainty.baselines.classification import WideResNet +from torch_uncertainty.baselines.classification import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.utils import TULightningCLI @@ -14,7 +14,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def cli_main() -> ResNetCLI: - return ResNetCLI(WideResNet, CIFAR10DataModule) + return ResNetCLI(WideResNetBaseline, CIFAR10DataModule) if __name__ == "__main__": diff --git a/experiments/classification/cifar100/deep_ensembles.py b/experiments/classification/cifar100/deep_ensembles.py index 69a419a8..3a1ed65f 100644 --- a/experiments/classification/cifar100/deep_ensembles.py +++ b/experiments/classification/cifar100/deep_ensembles.py @@ -1,11 +1,11 @@ from pathlib import Path from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines import DeepEnsemblesBaseline from torch_uncertainty.datamodules import CIFAR100DataModule if __name__ == "__main__": - args = init_args(DeepEnsembles, CIFAR100DataModule) + args = init_args(DeepEnsemblesBaseline, CIFAR100DataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -19,7 +19,7 @@ # model args.task = "classification" - model = DeepEnsembles( + model = DeepEnsemblesBaseline( **vars(args), num_classes=dm.num_classes, in_channels=dm.num_channels, diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py index 54fc0665..32926e7c 100644 --- a/experiments/classification/cifar100/resnet.py +++ b/experiments/classification/cifar100/resnet.py @@ -3,12 +3,12 @@ from torch import nn from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import ResNet +from torch_uncertainty.baselines import ResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.optim_recipes import get_procedure if __name__ == "__main__": - args = init_args(ResNet, CIFAR100DataModule) + args = init_args(ResNetBaseline, CIFAR100DataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -22,7 +22,7 @@ dm = CIFAR100DataModule(**vars(args)) # model - model = ResNet( + model = ResNetBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, diff --git a/experiments/classification/cifar100/vgg.py b/experiments/classification/cifar100/vgg.py index 78d18960..dafdabc9 100644 --- a/experiments/classification/cifar100/vgg.py +++ b/experiments/classification/cifar100/vgg.py @@ -3,12 +3,12 @@ from torch import nn from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import VGG +from torch_uncertainty.baselines import VGGBaseline from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.optim_recipes import get_procedure if __name__ == "__main__": - args = init_args(VGG, CIFAR100DataModule) + args = init_args(VGGBaseline, CIFAR100DataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -22,7 +22,7 @@ dm = CIFAR100DataModule(**vars(args)) # model - model = VGG( + model = VGGBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, diff --git a/experiments/classification/cifar100/wideresnet.py b/experiments/classification/cifar100/wideresnet.py index 2dec2161..f2114908 100644 --- a/experiments/classification/cifar100/wideresnet.py +++ b/experiments/classification/cifar100/wideresnet.py @@ -3,12 +3,12 @@ from torch import nn from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import WideResNet +from torch_uncertainty.baselines import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.optim_recipes import get_procedure if __name__ == "__main__": - args = init_args(WideResNet, CIFAR100DataModule) + args = init_args(WideResNetBaseline, CIFAR100DataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -22,7 +22,7 @@ dm = CIFAR100DataModule(**vars(args)) # model - model = WideResNet( + model = WideResNetBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index c0f381a7..bd84ef6b 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -3,7 +3,7 @@ from torch import nn, optim from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import ResNet +from torch_uncertainty.baselines import ResNetBaseline from torch_uncertainty.datamodules import TinyImageNetDataModule from torch_uncertainty.optim_recipes import get_procedure from torch_uncertainty.utils import csv_writer @@ -22,7 +22,7 @@ def optim_tiny(model: nn.Module) -> dict: if __name__ == "__main__": - args = init_args(ResNet, TinyImageNetDataModule) + args = init_args(ResNetBaseline, TinyImageNetDataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -46,7 +46,7 @@ def optim_tiny(model: nn.Module) -> dict: if args.use_cv: list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) list_model = [ - ResNet( + ResNetBaseline( num_classes=list_dm[i].dm.num_classes, in_channels=list_dm[i].dm.num_channels, loss=nn.CrossEntropyLoss, @@ -65,7 +65,7 @@ def optim_tiny(model: nn.Module) -> dict: ) else: # model - model = ResNet( + model = ResNetBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, diff --git a/experiments/regression/uci_datasets/deep_ensemble.py b/experiments/regression/uci_datasets/deep_ensemble.py index ef7217f3..6628e6e6 100644 --- a/experiments/regression/uci_datasets/deep_ensemble.py +++ b/experiments/regression/uci_datasets/deep_ensemble.py @@ -1,11 +1,11 @@ from pathlib import Path from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import DeepEnsembles +from torch_uncertainty.baselines import DeepEnsemblesBaseline from torch_uncertainty.datamodules import UCIDataModule if __name__ == "__main__": - args = init_args(DeepEnsembles, UCIDataModule) + args = init_args(DeepEnsemblesBaseline, UCIDataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -19,7 +19,7 @@ # model args.task = "regression" - model = DeepEnsembles( + model = DeepEnsemblesBaseline( **vars(args), ) diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index bf8abab7..da3e5130 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -2,20 +2,17 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines.classification import ResNet, WideResNet - -# from torch_uncertainty.optim_recipes import ( -# optim_cifar10_wideresnet, -# optim_cifar100_resnet18, -# optim_cifar100_resnet50, -# ) +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + WideResNetBaseline, +) class TestBatchedBaseline: """Testing the BatchedResNet baseline class.""" def test_batched_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -32,7 +29,7 @@ def test_batched_18(self): _ = net(torch.rand(1, 3, 32, 32)) def test_batched_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -53,7 +50,7 @@ class TestBatchedWideBaseline: """Testing the BatchedWideResNet baseline class.""" def test_batched(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index e559fe58..fbb7a512 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -1,7 +1,7 @@ import pytest from torch_uncertainty.baselines.classification.deep_ensembles import ( - DeepEnsembles, + DeepEnsemblesBaseline, ) @@ -10,7 +10,7 @@ class TestDeepEnsembles: def test_failure(self): with pytest.raises(ValueError): - DeepEnsembles( + DeepEnsemblesBaseline( log_path=".", checkpoint_ids=[], backbone="resnet", diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index c12a880c..9a51094c 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -3,20 +3,17 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines.classification import ResNet, WideResNet - -# from torch_uncertainty.optim_recipes import ( -# optim_cifar10_wideresnet, -# optim_cifar100_resnet18, -# optim_cifar100_resnet50, -# ) +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + WideResNetBaseline, +) class TestMaskedBaseline: """Testing the MaskedResNet baseline class.""" def test_masked_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -34,7 +31,7 @@ def test_masked_18(self): _ = net(torch.rand(1, 3, 32, 32)) def test_masked_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -53,7 +50,7 @@ def test_masked_50(self): def test_masked_scale_lt_1(self): with pytest.raises(Exception): - _ = ResNet( + _ = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -67,7 +64,7 @@ def test_masked_scale_lt_1(self): def test_masked_groups_lt_1(self): with pytest.raises(Exception): - _ = ResNet( + _ = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -84,7 +81,7 @@ class TestMaskedWideBaseline: """Testing the MaskedWideResNet baseline class.""" def test_masked(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, diff --git a/tests/baselines/test_mc_dropout.py b/tests/baselines/test_mc_dropout.py index 5eeaa8c6..b1c3995a 100644 --- a/tests/baselines/test_mc_dropout.py +++ b/tests/baselines/test_mc_dropout.py @@ -2,19 +2,18 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet - -# from torch_uncertainty.optim_recipes import ( -# optim_cifar10_resnet18, -# optim_cifar10_wideresnet, -# ) +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + VGGBaseline, + WideResNetBaseline, +) class TestStandardBaseline: - """Testing the ResNet baseline class.""" + """Testing the ResNetBaseline baseline class.""" def test_standard(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -32,10 +31,10 @@ def test_standard(self): class TestStandardWideBaseline: - """Testing the WideResNet baseline class.""" + """Testing the WideResNetBaseline baseline class.""" def test_standard(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -52,10 +51,10 @@ def test_standard(self): class TestStandardVGGBaseline: - """Testing the VGG baseline class.""" + """Testing the VGGBaseline baseline class.""" def test_standard(self): - net = VGG( + net = VGGBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -71,7 +70,7 @@ def test_standard(self): _ = net.criterion net(torch.rand(1, 3, 32, 32)) - net = VGG( + net = VGGBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, diff --git a/tests/baselines/test_mimo.py b/tests/baselines/test_mimo.py index 5db246d6..b3d48c28 100644 --- a/tests/baselines/test_mimo.py +++ b/tests/baselines/test_mimo.py @@ -2,7 +2,10 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines.classification import ResNet, WideResNet +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + WideResNetBaseline, +) # from torch_uncertainty.optim_recipes import ( # optim_cifar10_resnet18, @@ -15,7 +18,7 @@ class TestMIMOBaseline: """Testing the MIMOResNet baseline class.""" def test_mimo_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -34,7 +37,7 @@ def test_mimo_50(self): _ = net(torch.rand(1, 3, 32, 32)) def test_mimo_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -57,7 +60,7 @@ class TestMIMOWideBaseline: """Testing the PackedWideResNet baseline class.""" def test_mimo(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index 250e38cc..746df90d 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -3,7 +3,11 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + VGGBaseline, + WideResNetBaseline, +) from torch_uncertainty.baselines.regression import MLP @@ -11,7 +15,7 @@ class TestPackedBaseline: """Testing the PackedResNet baseline class.""" def test_packed_50(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -30,7 +34,7 @@ def test_packed_50(self): _ = net(torch.rand(1, 3, 32, 32)) def test_packed_18(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -50,7 +54,7 @@ def test_packed_18(self): def test_packed_exception(self): with pytest.raises(Exception): - _ = ResNet( + _ = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -64,7 +68,7 @@ def test_packed_exception(self): ) with pytest.raises(Exception): - _ = ResNet( + _ = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -82,7 +86,7 @@ class TestPackedWideBaseline: """Testing the PackedWideResNet baseline class.""" def test_packed(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -104,7 +108,7 @@ class TestPackedVGGBaseline: """Testing the PackedWideResNet baseline class.""" def test_packed(self): - net = VGG( + net = VGGBaseline( num_classes=10, in_channels=3, arch=13, diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index c52f438b..bccc60c9 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -5,7 +5,11 @@ from torch import nn from torchinfo import summary -from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet +from torch_uncertainty.baselines.classification import ( + ResNetBaseline, + VGGBaseline, + WideResNetBaseline, +) from torch_uncertainty.baselines.regression import MLP from torch_uncertainty.baselines.utils.parser_addons import ( add_mlp_specific_args, @@ -13,10 +17,10 @@ class TestStandardBaseline: - """Testing the ResNet baseline class.""" + """Testing the ResNetBaseline baseline class.""" def test_standard(self): - net = ResNet( + net = ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -32,7 +36,7 @@ def test_standard(self): def test_errors(self): with pytest.raises(ValueError): - ResNet( + ResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -44,10 +48,10 @@ def test_errors(self): class TestStandardWideBaseline: - """Testing the WideResNet baseline class.""" + """Testing the WideResNetBaseline baseline class.""" def test_standard(self): - net = WideResNet( + net = WideResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -62,7 +66,7 @@ def test_standard(self): def test_errors(self): with pytest.raises(ValueError): - WideResNet( + WideResNetBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -73,10 +77,10 @@ def test_errors(self): class TestStandardVGGBaseline: - """Testing the VGG baseline class.""" + """Testing the VGGBaseline baseline class.""" def test_standard(self): - net = VGG( + net = VGGBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, @@ -91,7 +95,7 @@ def test_standard(self): def test_errors(self): with pytest.raises(ValueError): - VGG( + VGGBaseline( num_classes=10, in_channels=3, loss=nn.CrossEntropyLoss, diff --git a/torch_uncertainty/baselines/__init__.py b/torch_uncertainty/baselines/__init__.py index c44ec481..e69de29b 100644 --- a/torch_uncertainty/baselines/__init__.py +++ b/torch_uncertainty/baselines/__init__.py @@ -1,6 +0,0 @@ -# ruff: noqa: F401 -# from .classification import ResNet -# from .classification.vgg import VGG -# from .classification import WideResNet - -# from .deep_ensembles import DeepEnsembles diff --git a/torch_uncertainty/baselines/classification/__init__.py b/torch_uncertainty/baselines/classification/__init__.py index 1326c2e3..e080ee4e 100644 --- a/torch_uncertainty/baselines/classification/__init__.py +++ b/torch_uncertainty/baselines/classification/__init__.py @@ -1,4 +1,4 @@ # ruff: noqa: F401 -from .resnet import ResNet -from .vgg import VGG -from .wideresnet import WideResNet +from .resnet import ResNetBaseline +from .vgg import VGGBaseline +from .wideresnet import WideResNetBaseline diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index f3dc8efb..2b1e3ae2 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -5,14 +5,14 @@ from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.utils import get_version -from . import VGG, ResNet, WideResNet +from . import ResNetBaseline, VGGBaseline, WideResNetBaseline -class DeepEnsembles(ClassificationRoutine): +class DeepEnsemblesBaseline(ClassificationRoutine): backbones = { - "resnet": ResNet, - "vgg": VGG, - "wideresnet": WideResNet, + "resnet": ResNetBaseline, + "vgg": VGGBaseline, + "wideresnet": WideResNetBaseline, } def __init__( diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 1144dfde..b34ea48a 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -39,7 +39,7 @@ from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget -class ResNet(ClassificationRoutine): +class ResNetBaseline(ClassificationRoutine): single = ["std"] ensemble = ["packed", "batched", "masked", "mc-dropout", "mimo"] versions = { diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 39f048f3..f3cd194a 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -17,7 +17,7 @@ from torch_uncertainty.transforms import RepeatTarget -class VGG(ClassificationRoutine): +class VGGBaseline(ClassificationRoutine): single = ["std"] ensemble = ["mc-dropout", "packed"] versions = { diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 9b46a19d..d324fba2 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -16,7 +16,7 @@ from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget -class WideResNet(ClassificationRoutine): +class WideResNetBaseline(ClassificationRoutine): single = ["std"] ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] versions = { From fb892dc1aa5c91c75f63ce34ec18d39d1c2e5766 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 20 Mar 2024 14:39:42 +0100 Subject: [PATCH 068/148] :sparkles: Various NLL improvements - :fire: Delete Gaussian NLL metric - :white_check_mark: Improve coverage - :hammer: Rename NegativeLog-Likelihood to CategoricalNLL --- docs/source/api.rst | 4 +- tests/metrics/test_nll.py | 62 ++++++++--------- tests/models/test_deep_ensembles.py | 8 ++- tests/routines/test_regression.py | 18 ++--- tests/test_losses.py | 18 ++++- tests/test_utils.py | 6 ++ torch_uncertainty/losses.py | 4 +- torch_uncertainty/metrics/__init__.py | 2 +- torch_uncertainty/metrics/nll.py | 70 +++----------------- torch_uncertainty/routines/classification.py | 4 +- torch_uncertainty/utils/distributions.py | 2 +- 11 files changed, 83 insertions(+), 115 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 6ae95cf7..8efe44e1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -197,8 +197,8 @@ Metrics Disagreement Entropy MutualInformation - NegativeLogLikelihood - GaussianNegativeLogLikelihood + CategoricalNLL + DistributionNLL FPR95 Losses diff --git a/tests/metrics/test_nll.py b/tests/metrics/test_nll.py index 2ec65a6b..dc157047 100644 --- a/tests/metrics/test_nll.py +++ b/tests/metrics/test_nll.py @@ -1,61 +1,51 @@ import pytest import torch +from torch.distributions import Normal -from torch_uncertainty.metrics import ( - GaussianNegativeLogLikelihood, - NegativeLogLikelihood, -) +from torch_uncertainty.metrics import CategoricalNLL, DistributionNLL -@pytest.fixture() -def probs_zero() -> torch.Tensor: - return torch.as_tensor([[1, 0.0], [0.0, 1.0]]) +class TestCategoricalNegativeLogLikelihood: + """Testing the CategoricalNLL metric class.""" + def test_compute_zero(self) -> None: + probs = torch.as_tensor([[1, 0.0], [0.0, 1.0]]) + targets = torch.as_tensor([0, 1]) -@pytest.fixture() -def targets_zero() -> torch.Tensor: - return torch.as_tensor([0, 1]) - - -class TestNegativeLogLikelihood: - """Testing the NegativeLogLikelihood metric class.""" - - def test_compute_zero( - self, probs_zero: torch.Tensor, targets_zero: torch.Tensor - ) -> None: - metric = NegativeLogLikelihood() - metric.update(probs_zero, targets_zero) + metric = CategoricalNLL() + metric.update(probs, targets) res = metric.compute() assert res == 0 - metric = NegativeLogLikelihood(reduction="none") - metric.update(probs_zero, targets_zero) + metric = CategoricalNLL(reduction="none") + metric.update(probs, targets) res_sum = metric.compute() assert torch.all(res_sum == torch.zeros(2)) def test_bad_argument(self) -> None: with pytest.raises(Exception): - _ = NegativeLogLikelihood(reduction="geometric_mean") + _ = CategoricalNLL(reduction="geometric_mean") -class TestGaussianNegativeLogLikelihood: - """Testing the NegativeLogLikelihood metric class.""" +class TestDistributionNLL: + """Testing the TestDistributionNLL metric class.""" def test_compute_zero(self) -> None: - metric = GaussianNegativeLogLikelihood() + metric = DistributionNLL(reduction="mean") means = torch.as_tensor([1, 10]).float() - variances = torch.as_tensor([1, 2]).float() + stds = torch.as_tensor([1, 2]).float() targets = torch.as_tensor([1, 10]).float() - metric.update(means, targets, variances) + dist = Normal(means, stds) + metric.update(dist, targets) res_mean = metric.compute() - assert res_mean == torch.log(variances).mean() / 2 + assert res_mean == torch.mean(torch.log(2 * torch.pi * (stds**2)) / 2) - metric = GaussianNegativeLogLikelihood(reduction="sum") - metric.update(means, targets, variances) + metric = DistributionNLL(reduction="sum") + metric.update(dist, targets) res_sum = metric.compute() - assert res_sum == torch.log(variances).sum() / 2 + assert res_sum == torch.log(2 * torch.pi * (stds**2)).sum() / 2 - metric = GaussianNegativeLogLikelihood(reduction="none") - metric.update(means, targets, variances) - res_sum = metric.compute() - assert torch.all(res_sum == torch.log(variances) / 2) + metric = DistributionNLL(reduction="none") + metric.update(dist, targets) + res_all = metric.compute() + assert torch.all(res_all == torch.log(2 * torch.pi * (stds**2)) / 2) diff --git a/tests/models/test_deep_ensembles.py b/tests/models/test_deep_ensembles.py index 7d791414..330098b2 100644 --- a/tests/models/test_deep_ensembles.py +++ b/tests/models/test_deep_ensembles.py @@ -31,7 +31,7 @@ def test_list_singleton(self): with pytest.raises(ValueError): deep_ensembles([model_1], num_estimators=1) - def test_model_and_no_num_estimator(self): + def test_errors(self): model_1 = dummy_model(1, 10, 1) with pytest.raises(ValueError): deep_ensembles(model_1, num_estimators=None) @@ -41,3 +41,9 @@ def test_model_and_no_num_estimator(self): with pytest.raises(ValueError): deep_ensembles(model_1, num_estimators=1) + + with pytest.raises(ValueError): + deep_ensembles(model_1, num_estimators=1, task="regression") + + with pytest.raises(ValueError): + deep_ensembles(model_1, num_estimators=1, task="other") diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 4de12166..5526cfc7 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -5,7 +5,7 @@ from torch import nn from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule -from torch_uncertainty.losses import DistributionNLL +from torch_uncertainty.losses import DistributionNLLLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import RegressionRoutine @@ -23,7 +23,7 @@ def test_one_estimator_one_output(self): probabilistic=True, in_features=dm.in_features, num_outputs=1, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -37,7 +37,7 @@ def test_one_estimator_one_output(self): probabilistic=False, in_features=dm.in_features, num_outputs=1, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -57,7 +57,7 @@ def test_one_estimator_two_outputs(self): probabilistic=True, in_features=dm.in_features, num_outputs=2, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", dist_type="laplace", @@ -70,7 +70,7 @@ def test_one_estimator_two_outputs(self): probabilistic=False, in_features=dm.in_features, num_outputs=2, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -88,7 +88,7 @@ def test_two_estimators_one_output(self): probabilistic=True, in_features=dm.in_features, num_outputs=1, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", dist_type="nig", @@ -101,7 +101,7 @@ def test_two_estimators_one_output(self): probabilistic=False, in_features=dm.in_features, num_outputs=1, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) @@ -119,7 +119,7 @@ def test_two_estimators_two_outputs(self): probabilistic=True, in_features=dm.in_features, num_outputs=2, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) @@ -132,7 +132,7 @@ def test_two_estimators_two_outputs(self): probabilistic=False, in_features=dm.in_features, num_outputs=2, - loss=DistributionNLL, + loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) diff --git a/tests/test_losses.py b/tests/test_losses.py index 24eeb06a..ff671938 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -3,10 +3,26 @@ import pytest import torch from torch import nn +from torch.distributions import Normal from torch_uncertainty.layers.bayesian import BayesLinear from torch_uncertainty.layers.distributions import NormalInverseGamma -from torch_uncertainty.losses import BetaNLL, DECLoss, DERLoss, ELBOLoss +from torch_uncertainty.losses import ( + BetaNLL, + DECLoss, + DERLoss, + DistributionNLLLoss, + ELBOLoss, +) + + +class TestDistributionNLL: + """Testing the DistributionNLLLoss class.""" + + def test_sum(self): + loss = DistributionNLLLoss(reduction="sum") + dist = Normal(0, 1) + loss(dist, torch.tensor([0.0])) class TestELBOLoss: diff --git a/tests/test_utils.py b/tests/test_utils.py index 2d8c0e1b..728f288d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -41,6 +41,12 @@ class TestDistributions: """Testing distributions methods.""" def test_nig(self): + dist = distributions.NormalInverseGamma( + 0.0, + 1.1, + 1.1, + 1.1, + ) dist = distributions.NormalInverseGamma( torch.tensor(0.0), torch.tensor(1.1), diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index 31b51afe..a2f7e863 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -8,7 +8,7 @@ from torch_uncertainty.utils.distributions import NormalInverseGamma -class DistributionNLL(nn.Module): +class DistributionNLLLoss(nn.Module): def __init__( self, reduction: Literal["mean", "sum"] | None = "mean" ) -> None: @@ -124,7 +124,7 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: return aggregated_elbo / self.num_samples -class DERLoss(DistributionNLL): +class DERLoss(DistributionNLLLoss): def __init__( self, reg_weight: float, reduction: str | None = "mean" ) -> None: diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 3d132fbe..0f89fb96 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -7,6 +7,6 @@ from .grouping_loss import GroupingLoss from .mean_iou import MeanIntersectionOverUnion from .mutual_information import MutualInformation -from .nll import GaussianNegativeLogLikelihood, NegativeLogLikelihood +from .nll import CategoricalNLL, DistributionNLL from .sparsification import AUSE from .variation_ratio import VariationRatio diff --git a/torch_uncertainty/metrics/nll.py b/torch_uncertainty/metrics/nll.py index 14dac386..d433f43e 100644 --- a/torch_uncertainty/metrics/nll.py +++ b/torch_uncertainty/metrics/nll.py @@ -2,11 +2,12 @@ import torch import torch.nn.functional as F +from torch import Tensor, distributions from torchmetrics import Metric from torchmetrics.utilities.data import dim_zero_cat -class NegativeLogLikelihood(Metric): +class CategoricalNLL(Metric): is_differentiabled = False higher_is_better = False full_state_update = False @@ -66,12 +67,12 @@ def __init__( self.add_state("values", default=[], dist_reduce_fx="cat") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def update(self, probs: torch.Tensor, target: torch.Tensor) -> None: + def update(self, probs: Tensor, target: Tensor) -> None: """Update state with prediction probabilities and targets. Args: - probs (torch.Tensor): Probabilities from the model. - target (torch.Tensor): Ground truth labels. + probs (Tensor): Probabilities from the model. + target (Tensor): Ground truth labels. """ if self.reduction is None or self.reduction == "none": self.values.append( @@ -81,7 +82,7 @@ def update(self, probs: torch.Tensor, target: torch.Tensor) -> None: self.values += F.nll_loss(torch.log(probs), target, reduction="sum") self.total += target.size(0) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """Computes NLL based on inputs passed in to ``update`` previously.""" values = dim_zero_cat(self.values) @@ -93,64 +94,13 @@ def compute(self) -> torch.Tensor: return values -class GaussianNegativeLogLikelihood(NegativeLogLikelihood): - """The Gaussian Negative Log Likelihood Metric. - - Args: - reduction (str, optional): Determines how to reduce over the - :math:`B`/batch dimension: - - - ``'mean'`` [default]: Averages score across samples - - ``'sum'``: Sum score across samples - - ``'none'`` or ``None``: Returns score per sample - - kwargs: Additional keyword arguments, see `Advanced metric settings - `_. - - Inputs: - - :attr:`mean`: :math:`(B, D)` - - :attr:`target`: :math:`(B, D)` - - :attr:`var`: :math:`(B, D)` - - where :math:`B` is the batch size and :math:`D` is the number of - dimensions. :math:`D` is optional. - - Raises: - ValueError: - If :attr:`reduction` is not one of ``'mean'``, ``'sum'``, - ``'none'`` or ``None``. - """ - - def update( - self, mean: torch.Tensor, target: torch.Tensor, var: torch.Tensor - ) -> None: - """Update state with prediction mean, targets, and prediction varoance. - - Args: - mean (torch.Tensor): Probabilities from the model. - target (torch.Tensor): Ground truth labels. - var (torch.Tensor): Predicted variance from the model. - """ - if self.reduction is None or self.reduction == "none": - self.values.append( - F.gaussian_nll_loss(mean, target, var, reduction="none") - ) - else: - self.values += F.gaussian_nll_loss( - mean, target, var, reduction="sum" - ) - self.total += target.size(0) - - -class DistributionNLL(NegativeLogLikelihood): - def update( - self, dists: torch.distributions.Distribution, target: torch.Tensor - ) -> None: +class DistributionNLL(CategoricalNLL): + def update(self, dists: distributions.Distribution, target: Tensor) -> None: """Update state with the predicted distributions and the targets. Args: dists (torch.distributions.Distribution): Predicted distributions. - target (torch.Tensor): Ground truth labels. + target (Tensor): Ground truth labels. """ if self.reduction is None or self.reduction == "none": self.values.append(-dists.log_prob(target)) @@ -158,7 +108,7 @@ def update( self.values += -dists.log_prob(target).sum() self.total += target.size(0) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """Computes NLL based on inputs passed in to ``update`` previously.""" values = dim_zero_cat(self.values) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 4de13148..373fb0db 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -22,11 +22,11 @@ CE, FPR95, BrierScore, + CategoricalNLL, Disagreement, Entropy, GroupingLoss, MutualInformation, - NegativeLogLikelihood, VariationRatio, ) from torch_uncertainty.plotting_utils import plot_hist @@ -179,7 +179,7 @@ def __init__( else: cls_metrics = MetricCollection( { - "nll": NegativeLogLikelihood(), + "nll": CategoricalNLL(), "acc": Accuracy( task="multiclass", num_classes=self.num_classes ), diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index fd6fd39c..62a66bdd 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -168,7 +168,7 @@ def variance_loc(self) -> Tensor: return self.beta / (self.alpha - 1) / self.lmbda def log_prob(self, value: Tensor) -> Tensor: - if self._validate_args: + if self._validate_args: # coverage: ignore self._validate_sample(value) gam: Tensor = 2 * self.beta * (1 + self.lmbda) return ( From d81591e9b9403216a0f4669b3a0a41111416b9f8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 20 Mar 2024 15:50:18 +0100 Subject: [PATCH 069/148] :fire: Remove parser_addons --- torch_uncertainty/baselines/utils/__init__.py | 0 .../baselines/utils/parser_addons.py | 139 ------------------ 2 files changed, 139 deletions(-) delete mode 100644 torch_uncertainty/baselines/utils/__init__.py delete mode 100644 torch_uncertainty/baselines/utils/parser_addons.py diff --git a/torch_uncertainty/baselines/utils/__init__.py b/torch_uncertainty/baselines/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/torch_uncertainty/baselines/utils/parser_addons.py b/torch_uncertainty/baselines/utils/parser_addons.py deleted file mode 100644 index 763aa908..00000000 --- a/torch_uncertainty/baselines/utils/parser_addons.py +++ /dev/null @@ -1,139 +0,0 @@ -from argparse import ArgumentParser - - -def add_resnet_specific_args(parser: ArgumentParser) -> ArgumentParser: - """Add ResNet specific arguments to parser. - - Args: - parser (ArgumentParser): Argument parser. - - Adds the following arguments: - --arch (int): Architecture of ResNet. Choose among: [18, 34, 50, 101, 152] - --dropout_rate (float): Dropout rate. - --groups (int): Number of groups. - """ - # style_choices = ["cifar", "imagenet", "robust"] - archs = [18, 20, 34, 50, 101, 152] - parser.add_argument( - "--arch", - type=int, - choices=archs, - default=18, - help=f"Architecture of ResNet. Choose among: {archs}", - ) - parser.add_argument( - "--dropout_rate", - type=float, - default=0.0, - help="Dropout rate", - ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups", - ) - return parser - - -def add_vgg_specific_args(parser: ArgumentParser) -> ArgumentParser: - # style_choices = ["cifar", "imagenet", "robust"] - archs = [11, 13, 16, 19] - parser.add_argument( - "--arch", - type=int, - choices=archs, - default=11, - help=f"Architecture of VGG. Choose among: {archs}", - ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups", - ) - parser.add_argument( - "--dropout_rate", - type=float, - default=0.1, - help="Dropout rate", - ) - return parser - - -def add_wideresnet_specific_args(parser: ArgumentParser) -> ArgumentParser: - # style_choices = ["cifar", "imagenet"] - parser.add_argument( - "--dropout_rate", - type=float, - default=0.3, - help="Dropout rate", - ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups", - ) - return parser - - -def add_mlp_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--dropout_rate", - type=float, - default=0.1, - help="Dropout rate", - ) - return parser - - -def add_packed_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--alpha", - type=int, - default=None, - help="Alpha for Packed-Ensembles", - ) - parser.add_argument( - "--gamma", - type=int, - default=1, - help="Gamma for Packed-Ensembles", - ) - return parser - - -def add_masked_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--scale", - type=float, - default=None, - help="Scale for Masksembles", - ) - return parser - - -def add_mimo_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--rho", - type=float, - default=0.0, - help="Rho for MIMO", - ) - parser.add_argument( - "--batch_repeat", - type=int, - default=1, - help="Batch repeat for MIMO", - ) - return parser - - -def add_mc_dropout_specific_args(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "--last_layer_dropout", - action="store_true", - help="Whether to apply dropout to the last layer only", - ) - return parser From 275cde3e89079c3f3a0d50fd2b6ead88459fc5ca Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 20 Mar 2024 15:50:48 +0100 Subject: [PATCH 070/148] :sparkles: Add reset_parameters to deep_ensembles and improve cov. --- tests/baselines/test_standard.py | 8 -------- tests/metrics/test_nll.py | 5 +++++ tests/models/test_deep_ensembles.py | 8 ++++---- tests/models/test_mlps.py | 15 +++++++++++---- torch_uncertainty/models/deep_ensembles.py | 10 ++++++++++ 5 files changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index bccc60c9..ce5add04 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -1,5 +1,3 @@ -from argparse import ArgumentParser - import pytest import torch from torch import nn @@ -11,9 +9,6 @@ WideResNetBaseline, ) from torch_uncertainty.baselines.regression import MLP -from torch_uncertainty.baselines.utils.parser_addons import ( - add_mlp_specific_args, -) class TestStandardBaseline: @@ -121,9 +116,6 @@ def test_standard(self): _ = net.criterion _ = net(torch.rand(1, 3)) - parser = ArgumentParser() - add_mlp_specific_args(parser) - def test_errors(self): with pytest.raises(ValueError): MLP( diff --git a/tests/metrics/test_nll.py b/tests/metrics/test_nll.py index dc157047..bfd60bc9 100644 --- a/tests/metrics/test_nll.py +++ b/tests/metrics/test_nll.py @@ -22,6 +22,11 @@ def test_compute_zero(self) -> None: res_sum = metric.compute() assert torch.all(res_sum == torch.zeros(2)) + metric = CategoricalNLL(reduction="sum") + metric.update(probs, targets) + res_sum = metric.compute() + assert torch.all(res_sum == torch.zeros(1)) + def test_bad_argument(self) -> None: with pytest.raises(Exception): _ = CategoricalNLL(reduction="geometric_mean") diff --git a/tests/models/test_deep_ensembles.py b/tests/models/test_deep_ensembles.py index 330098b2..2c3df7c2 100644 --- a/tests/models/test_deep_ensembles.py +++ b/tests/models/test_deep_ensembles.py @@ -25,8 +25,8 @@ def test_list_and_num_estimators(self): def test_list_singleton(self): model_1 = dummy_model(1, 10, 1) - deep_ensembles([model_1], num_estimators=2) - deep_ensembles(model_1, num_estimators=2) + deep_ensembles([model_1], num_estimators=2, reset_model_parameters=True) + deep_ensembles(model_1, num_estimators=2, reset_model_parameters=False) with pytest.raises(ValueError): deep_ensembles([model_1], num_estimators=1) @@ -43,7 +43,7 @@ def test_errors(self): deep_ensembles(model_1, num_estimators=1) with pytest.raises(ValueError): - deep_ensembles(model_1, num_estimators=1, task="regression") + deep_ensembles(model_1, num_estimators=2, task="regression") with pytest.raises(ValueError): - deep_ensembles(model_1, num_estimators=1, task="other") + deep_ensembles(model_1, num_estimators=2, task="other") diff --git a/tests/models/test_mlps.py b/tests/models/test_mlps.py index afb56d8d..93aaabc0 100644 --- a/tests/models/test_mlps.py +++ b/tests/models/test_mlps.py @@ -1,11 +1,18 @@ -from torch_uncertainty.models.mlp import bayesian_mlp, packed_mlp +from torch_uncertainty.layers.distributions import IndptNormalLayer +from torch_uncertainty.models.mlp import bayesian_mlp, mlp, packed_mlp class TestMLPModel: """Testing the mlp models.""" - def test_packed(self): + def test_mlps(self): + mlp( + 1, + 1, + hidden_dims=[1, 1, 1], + final_layer=IndptNormalLayer, + final_layer_args={"dim": 1}, + ) + mlp(1, 1, hidden_dims=[]) packed_mlp(1, 1, hidden_dims=[]) - - def test_bayesian(self): bayesian_mlp(1, 1, hidden_dims=[1, 1, 1]) diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index c5b80eac..66435947 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -63,6 +63,7 @@ def deep_ensembles( num_estimators: int | None = None, task: Literal["classification", "regression"] = "classification", probabilistic=None, + reset_model_parameters: bool = False, ) -> nn.Module: """Build a Deep Ensembles out of the original models. @@ -71,6 +72,8 @@ def deep_ensembles( num_estimators (int | None): The number of estimators in the ensemble. task (Literal["classification", "regression"]): The model task. probabilistic (bool): Whether the regression model is probabilistic. + reset_model_parameters (bool): Whether to reset the model parameters + when :attr:models is a module or a list of length 1. Returns: nn.Module: The ensembled model. @@ -106,6 +109,13 @@ def deep_ensembles( models = models[0] models = [copy.deepcopy(models) for _ in range(num_estimators)] + + if reset_model_parameters: + for model in models: + for layer in model.children(): + if hasattr(layer, "reset_parameters"): + layer.reset_parameters() + elif ( isinstance(models, list) and len(models) > 1 From f09153a05f6fd5b2e36a712cad676bdf902031f8 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 02:16:02 +0100 Subject: [PATCH 071/148] :hammer: Update Cityscapes ``__get_item__()`` to use ``tv_tensor`` --- .../datamodules/segmentation/cityscapes.py | 38 ++++++++----------- .../datasets/segmentation/camvid.py | 2 +- .../datasets/segmentation/cityscapes.py | 11 ++++-- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 2ddde6fe..6a937827 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -5,7 +5,7 @@ from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair from torch.utils.data import random_split -from torchvision import datasets, tv_tensors +from torchvision import tv_tensors from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import AbstractDataModule @@ -82,14 +82,12 @@ def prepare_data(self) -> None: # coverage: ignore def setup(self, stage: str | None = None) -> None: if stage == "fit" or stage is None: - full = datasets.wrap_dataset_for_transforms_v2( - self.dataset( - root=self.root, - split="train", - mode=self.mode, - target_type="semantic", - transforms=self.train_transform, - ) + full = self.dataset( + root=self.root, + split="train", + mode=self.mode, + target_type="semantic", + transforms=self.train_transform, ) if self.val_split is not None: @@ -105,25 +103,21 @@ def setup(self, stage: str | None = None) -> None: self.val.dataset.transforms = self.test_transform else: self.train = full - self.val = datasets.wrap_dataset_for_transforms_v2( - self.dataset( - root=self.root, - split="val", - mode=self.mode, - target_type="semantic", - transforms=self.test_transform, - ) - ) - - if stage == "test" or stage is None: - self.test = datasets.wrap_dataset_for_transforms_v2( - self.dataset( + self.val = self.dataset( root=self.root, split="val", mode=self.mode, target_type="semantic", transforms=self.test_transform, ) + + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="val", + mode=self.mode, + target_type="semantic", + transforms=self.test_transform, ) if stage not in ["fit", "test", None]: diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index e017fd1d..13d94fee 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -154,7 +154,7 @@ def encode_target(self, target: Image.Image) -> torch.Tensor: ).all(dim=-1) ] = index - return rearrange(target, "h w c -> c h w").squeeze(0) + return rearrange(target, "h w c -> c h w") def decode_target(self, target: torch.Tensor) -> Image.Image: """Decode target tensor to image. diff --git a/torch_uncertainty/datasets/segmentation/cityscapes.py b/torch_uncertainty/datasets/segmentation/cityscapes.py index f4835fc2..234a6ee5 100644 --- a/torch_uncertainty/datasets/segmentation/cityscapes.py +++ b/torch_uncertainty/datasets/segmentation/cityscapes.py @@ -4,6 +4,7 @@ from einops import rearrange from PIL import Image from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +from torchvision import tv_tensors from torchvision.datasets import Cityscapes as OriginalCityscapes from torchvision.transforms.v2 import functional as F @@ -33,7 +34,9 @@ def encode_target(self, target: Image.Image) -> Image.Image: return F.to_pil_image(rearrange(target, "h w c -> c h w")) def __getitem__(self, index: int) -> tuple[Any, Any]: - """Args: + """Get the sample at the given index. + + Args: index (int): Index Returns: tuple: (image, target) where target is a tuple of all target types @@ -41,14 +44,16 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: than one item. Otherwise, target is a json object if ``target_type="polygon"``, else the image segmentation. """ - image = Image.open(self.images[index]).convert("RGB") + image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) targets: Any = [] for i, t in enumerate(self.target_type): if t == "polygon": target = self._load_json(self.targets[index][i]) elif t == "semantic": - target = self.encode_target(Image.open(self.targets[index][i])) + target = tv_tensors.Mask( + self.encode_target(Image.open(self.targets[index][i])) + ) else: target = Image.open(self.targets[index][i]) From ce1e711f4f7d155b78ae4f09d48a65538b618e3d Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 02:22:09 +0100 Subject: [PATCH 072/148] :sparkles: Rework MUAD dataset and enable Segformer training on it --- .../segmentation/muad/configs/segformer.yaml | 26 ++++ experiments/segmentation/muad/segformer.py | 28 ++++ .../datamodules/segmentation/__init__.py | 1 + .../datamodules/segmentation/muad.py | 125 ++++++++++++++++++ torch_uncertainty/datasets/__init__.py | 1 + torch_uncertainty/datasets/muad.py | 104 ++++++++++----- .../models/segmentation/segformer/std.py | 6 - torch_uncertainty/routines/__init__.py | 1 + 8 files changed, 252 insertions(+), 40 deletions(-) create mode 100644 experiments/segmentation/muad/configs/segformer.yaml create mode 100644 experiments/segmentation/muad/segformer.py create mode 100644 torch_uncertainty/datamodules/segmentation/muad.py diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml new file mode 100644 index 00000000..cce57b1f --- /dev/null +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -0,0 +1,26 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + max_steps: 160000 +model: + num_classes: 19 + loss: torch.nn.CrossEntropyLoss + version: std + arch: 0 + num_estimators: 1 +data: + root: ./data + batch_size: 8 + crop_size: 1024 + inference_size: + - 1024 + - 2048 + num_workers: 30 +optimizer: + lr: 6e-5 +lr_scheduler: + step_size: 10000 + gamma: 0.1 diff --git a/experiments/segmentation/muad/segformer.py b/experiments/segmentation/muad/segformer.py new file mode 100644 index 00000000..5fc741ba --- /dev/null +++ b/experiments/segmentation/muad/segformer.py @@ -0,0 +1,28 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser +from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 + +from torch_uncertainty.baselines.segmentation import SegFormer +from torch_uncertainty.datamodules.segmentation import MUADDataModule +from torch_uncertainty.utils import TULightningCLI + + +class SegFormerCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.AdamW) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.StepLR) + + +def cli_main() -> SegFormerCLI: + return SegFormerCLI(SegFormer, MUADDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/torch_uncertainty/datamodules/segmentation/__init__.py b/torch_uncertainty/datamodules/segmentation/__init__.py index 7c0d0a8c..b4f55984 100644 --- a/torch_uncertainty/datamodules/segmentation/__init__.py +++ b/torch_uncertainty/datamodules/segmentation/__init__.py @@ -1,3 +1,4 @@ # ruff: noqa: F401 from .camvid import CamVidDataModule from .cityscapes import CityscapesDataModule +from .muad import MUADDataModule diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py new file mode 100644 index 00000000..0b567a1e --- /dev/null +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -0,0 +1,125 @@ +import copy +from pathlib import Path + +import torch +from torch.nn.common_types import _size_2_t +from torch.nn.modules.utils import _pair +from torch.utils.data import random_split +from torchvision import tv_tensors +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets import MUAD +from torch_uncertainty.transforms import RandomRescale + + +class MUADDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + crop_size: _size_2_t = 1024, + inference_size: _size_2_t = (1024, 2048), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + super().__init__( + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.dataset = MUAD + self.crop_size = _pair(crop_size) + self.inference_size = _pair(inference_size) + + self.train_transform = v2.Compose( + [ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=self.crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(size=self.inference_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset( + root=self.root, split="train", target_type="semantic", download=True + ) + self.dataset( + root=self.root, split="val", target_type="semantic", download=True + ) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + full = self.dataset( + root=self.root, + split="train", + target_type="semantic", + transforms=self.train_transform, + ) + + if self.val_split is not None: + self.train, val = random_split( + full, + [ + 1 - self.val_split, + self.val_split, + ], + ) + # FIXME: memory cost issues might arise here + self.val = copy.deepcopy(val) + self.val.dataset.transforms = self.test_transform + else: + self.train = full + self.val = self.dataset( + root=self.root, + split="val", + target_type="semantic", + transforms=self.test_transform, + ) + + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="val", + target_type="semantic", + transforms=self.test_transform, + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index 5aa7ef67..732334a0 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,3 +1,4 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset from .frost import FrostImages +from .muad import MUAD diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index b7babdb6..05d79143 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -1,18 +1,20 @@ import json from collections.abc import Callable from pathlib import Path -from typing import Literal +from typing import Any, Literal -import cv2 as cv -import matplotlib.pyplot as plt import numpy as np +import torch +from einops import rearrange from PIL import Image +from torchvision import tv_tensors from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, download_and_extract_archive, download_url, ) +from torchvision.transforms.v2 import functional as F class MUAD(VisionDataset): @@ -34,7 +36,8 @@ def __init__( self, root: str | Path, split: Literal["train", "val", "train_depth", "val_depth"], - transform: Callable | None = None, + target_type: Literal["semantic", "depth"] = "semantic", + transforms: Callable | None = None, download: bool = False, ) -> None: """The MUAD Dataset. @@ -44,7 +47,9 @@ def __init__( and 'leftLabel' are located. split (str, optional): The image split to use, 'train', 'val', 'train_depth' or 'val_depth'. - transform (callable, optional): A function/transform that takes in + target_type (str, optional): The type of target to use, 'semantic' + or 'depth'. + transforms (callable, optional): A function/transform that takes in a tuple of PIL images and returns a transformed version. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already @@ -63,23 +68,30 @@ def __init__( ) super().__init__( root=Path(root) / "MUAD", - transform=transform, + transforms=transforms, ) - if split not in ["train", "val", "train_depth", "val_depth"]: + if split not in ["train", "val"]: raise ValueError( - "split must be one of ['train', 'val', 'train_depth', " - f"'val_depth']. Got {split}." + f"split must be one of ['train', 'val']. Got {split}." ) self.split = split + self.target_type = target_type split_path = self.root / (split + ".zip") - if not check_integrity(split_path, self.zip_md5[split]) and download: - self._download() + if (not check_integrity(split_path, self.zip_md5[split])) and download: + self._download(split=self.split) + + if ( + self.target_type == "depth" + and not check_integrity(split_path, self.zip_md5[split + "_depth"]) + and download + ): + self._download(split=f"{split}_depth") # Load classes metadata cls_path = self.root / "classes.json" - if not check_integrity(cls_path, self.classes_md5) and download: + if (not check_integrity(cls_path, self.classes_md5)) and download: download_url( self.classes_url, self.root, @@ -100,32 +112,50 @@ def __init__( self._make_dataset(self.root / split) + def encode_target(self, smnt: Image.Image) -> Image.Image: + """Encode target image to tensor. + + Args: + smnt (Image.Image): Target PIL image. + + Returns: + torch.Tensor: Encoded target. + """ + smnt = F.pil_to_tensor(smnt) + smnt = rearrange(smnt, "c h w -> h w c") + target = torch.zeros_like(smnt[..., :1]) + # convert target color to index + for muad_class in self.classes: + target[ + ( + smnt == torch.tensor(muad_class["id"], dtype=target.dtype) + ).all(dim=-1) + ] = muad_class["train_id"] + + return F.to_pil_image(rearrange(target, "h w c -> c h w")) + def decode_target(self, target: Image.Image) -> np.ndarray: target[target == 255] = 19 return self.train_id_to_color[target] - def __getitem__(self, index: int) -> tuple[Image.Image, Image.Image]: - """Get the image and its segmentation target.""" - img_path = self.samples[index] - seg_path = self.targets[index] - - image = cv.imread(img_path) - image = cv.cvtColor(image, cv.COLOR_BGR2RGB) + def __getitem__(self, index: int) -> tuple[Any, Any]: + """Get the sample at the given index. - segm = plt.imread(seg_path) * 255.0 - target = np.zeros((segm.shape[0], segm.shape[1])) + 255.0 + Args: + index (int): Index - for c in self.classes: - upper = np.array(c["train_id"]) - mask = cv.inRange(segm, upper, upper) - target[mask == 255] = c["train_id"] - target = target.astype(np.uint8) - target = Image.fromarray(target) + Returns: + tuple: (image, target) where target is either a segmentation mask + or a depth map. + """ + image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) + target = tv_tensors.Mask( + self.encode_target(Image.open(self.targets[index])) + ) - image = Image.fromarray(image) + if self.transforms is not None: + image, target = self.transforms(image, target) - if self.transform: - image, target = self.transform(image, target) return image, target def __len__(self) -> int: @@ -144,11 +174,17 @@ def _make_dataset(self, path: Path) -> None: "if you need it." ) self.samples = list((path / "leftImg8bit/").glob("**/*")) - self.targets = list((path / "leftLabel/").glob("**/*")) + if self.target_type == "semantic": + self.targets = list((path / "leftLabel/").glob("**/*")) + else: + raise NotImplementedError( + "Depth regression mode is not implemented yet. Raise an issue " + "if you need it." + ) - def _download(self): + def _download(self, split: str): """Download and extract the chosen split of the dataset.""" - split_url = self.base_url + self.split + ".zip" + split_url = self.base_url + split + ".zip" download_and_extract_archive( - split_url, self.root, md5=self.zip_md5[self.split] + split_url, self.root, md5=self.zip_md5[split] ) diff --git a/torch_uncertainty/models/segmentation/segformer/std.py b/torch_uncertainty/models/segmentation/segformer/std.py index 1b735760..a723de64 100644 --- a/torch_uncertainty/models/segmentation/segformer/std.py +++ b/torch_uncertainty/models/segmentation/segformer/std.py @@ -872,9 +872,3 @@ def segformer_b5(num_classes: int): dropout_ratio=0.1, mit=MitB5, ) - - -if __name__ == "__main__": - x = torch.randn((1, 3, 224, 224)) - model = segformer_b0() - print(model(x).size()) # torch.Size([1, 19, 56, 56]) diff --git a/torch_uncertainty/routines/__init__.py b/torch_uncertainty/routines/__init__.py index 2513e15e..41b7ea80 100644 --- a/torch_uncertainty/routines/__init__.py +++ b/torch_uncertainty/routines/__init__.py @@ -1,3 +1,4 @@ # ruff: noqa: F401 from .classification import ClassificationRoutine from .regression import RegressionRoutine +from .segmentation import SegmentationRoutine From 1c1e67f56adde19943160639111e3ce0e61007f0 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 02:24:28 +0100 Subject: [PATCH 073/148] :hammer: Update Segmentation routine to handle format_batch_fn and ensembles --- torch_uncertainty/routines/segmentation.py | 27 +++++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index d2dd126d..5a22d8a5 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -13,7 +13,8 @@ def __init__( num_classes: int, model: nn.Module, loss: type[nn.Module], - num_estimators: int, + num_estimators: int = 1, + optim_recipe=None, format_batch_fn: nn.Module | None = None, ) -> None: super().__init__() @@ -24,6 +25,9 @@ def __init__( self.num_classes = num_classes self.model = model self.loss = loss + self.num_estimators = num_estimators + self.format_batch_fn = format_batch_fn + self.optim_recipe = optim_recipe self.metric_to_monitor = "val/mean_iou" @@ -39,6 +43,9 @@ def __init__( self.val_seg_metrics = seg_metrics.clone(prefix="val/") self.test_seg_metrics = seg_metrics.clone(prefix="test/") + def configure_optimizers(self): + return self.optim_recipe(self.model) + @property def criterion(self) -> nn.Module: return self.loss() @@ -56,6 +63,7 @@ def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: img, target = batch + img, target = self.format_batch_fn((img, target)) logits = self.forward(img) logits = rearrange(logits, "b c h w -> (b h w) c") target = target.flatten() @@ -68,20 +76,27 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: img, target = batch - # (B, num_classes, H, W) logits = self.forward(img) - logits = rearrange(logits, "b c h w -> (b h w) c") + logits = rearrange( + logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + ) + probs_per_est = logits.softmax(dim=-1) + probs = probs_per_est.mean(dim=1) target = target.flatten() valid_mask = target != 255 - self.val_seg_metrics.update(logits[valid_mask], target[valid_mask]) + self.val_seg_metrics.update(probs[valid_mask], target[valid_mask]) def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: img, target = batch logits = self.forward(img) - logits = rearrange(logits, "b c h w -> (b h w) c") + logits = rearrange( + logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + ) + probs_per_est = logits.softmax(dim=-1) + probs = probs_per_est.mean(dim=1) target = target.flatten() valid_mask = target != 255 - self.test_seg_metrics.update(logits[valid_mask], target[valid_mask]) + self.test_seg_metrics.update(probs[valid_mask], target[valid_mask]) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_seg_metrics.compute()) From 874bdb6fe42fed1b8846c814f32de42297399b02 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 02:26:24 +0100 Subject: [PATCH 074/148] :white_check_mark: Add tests for segmentation --- tests/_dummies/baseline.py | 43 ++++++++- tests/_dummies/datamodule.py | 96 ++++++++++++++++++- tests/_dummies/dataset.py | 56 +++++++++++ tests/_dummies/model.py | 55 +++++++++++ tests/baselines/test_standard.py | 26 +++++ .../classification/test_cifar10_datamodule.py | 2 - tests/datamodules/segmentation/__init__.py | 0 .../{ => segmentation}/test_camvid.py | 18 ++-- .../segmentation/test_cityscapes.py | 31 ++++++ tests/datamodules/segmentation/test_muad.py | 31 ++++++ .../datasets/segmentation/test_cityscapes.py | 11 +++ tests/datasets/test_muad.py | 11 +++ tests/models/test_segformer.py | 25 +++++ 13 files changed, 392 insertions(+), 13 deletions(-) create mode 100644 tests/datamodules/segmentation/__init__.py rename tests/datamodules/{ => segmentation}/test_camvid.py (61%) create mode 100644 tests/datamodules/segmentation/test_cityscapes.py create mode 100644 tests/datamodules/segmentation/test_muad.py create mode 100644 tests/datasets/segmentation/test_cityscapes.py create mode 100644 tests/datasets/test_muad.py create mode 100644 tests/models/test_segformer.py diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index f51f8c96..2668e93b 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -9,10 +9,14 @@ IndptNormalLayer, ) from torch_uncertainty.models.deep_ensembles import deep_ensembles -from torch_uncertainty.routines import ClassificationRoutine, RegressionRoutine +from torch_uncertainty.routines import ( + ClassificationRoutine, + RegressionRoutine, + SegmentationRoutine, +) from torch_uncertainty.transforms import RepeatTarget -from .model import dummy_model +from .model import dummy_model, dummy_segmentation_model class DummyClassificationBaseline: @@ -120,3 +124,38 @@ def __new__( optim_recipe=optim_recipe, format_batch_fn=RepeatTarget(2), ) + + +class DummySegmentationBaseline: + def __new__( + cls, + in_channels: int, + num_classes: int, + loss: type[nn.Module], + ensemble: bool = False, + ) -> LightningModule: + model = dummy_segmentation_model( + in_channels=in_channels, + num_classes=num_classes, + num_estimators=1 + int(ensemble), + ) + + if not ensemble: + return SegmentationRoutine( + num_classes=num_classes, + model=model, + loss=loss, + format_batch_fn=nn.Identity(), + optim_recipe=None, + num_estimators=1, + ) + + # ensemble + return SegmentationRoutine( + num_classes=num_classes, + model=model, + loss=loss, + format_batch_fn=RepeatTarget(2), + optim_recipe=None, + num_estimators=2, + ) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 0b81ccc0..fb05740c 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -1,13 +1,19 @@ from pathlib import Path import numpy as np -import torchvision.transforms as T +import torch +import torchvision.transforms.v2 as T from numpy.typing import ArrayLike from torch.utils.data import DataLoader +from torchvision import tv_tensors from torch_uncertainty.datamodules.abstract import AbstractDataModule -from .dataset import DummyClassificationDataset, DummyRegressionDataset +from .dataset import ( + DummyClassificationDataset, + DummyRegressionDataset, + DummySegmentationDataset, +) class DummyClassificationDataModule(AbstractDataModule): @@ -151,3 +157,89 @@ def setup(self, stage: str | None = None) -> None: def test_dataloader(self) -> DataLoader | list[DataLoader]: return [self._data_loader(self.test)] + + +class DummySegmentationDataModule(AbstractDataModule): + num_channels = 3 + training_task = "segmentation" + + def __init__( + self, + root: str | Path, + batch_size: int, + num_classes: int = 2, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + num_images: int = 2, + ) -> None: + super().__init__( + root=root, + val_split=None, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.num_classes = num_classes + self.num_images = num_images + + self.dataset = DummySegmentationDataset + + self.train_transform = T.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ) + self.test_transform = T.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ) + + def prepare_data(self) -> None: + pass + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + self.train = self.dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transforms=self.train_transform, + num_images=self.num_images, + ) + self.val = self.dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transforms=self.test_transform, + num_images=self.num_images, + ) + elif stage == "test": + self.test = self.dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transforms=self.test_transform, + num_images=self.num_images, + ) + + def test_dataloader(self) -> DataLoader | list[DataLoader]: + return [self._data_loader(self.test)] + + def _get_train_data(self) -> ArrayLike: + return self.train.data + + def _get_train_targets(self) -> ArrayLike: + return np.array(self.train.targets) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 59183811..4aa858ef 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -6,6 +6,7 @@ import torch from PIL import Image from torch.utils.data import Dataset +from torchvision import tv_tensors class DummyClassificationDataset(Dataset): @@ -156,3 +157,58 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: def __len__(self) -> int: return len(self.data) + + +class DummySegmentationDataset(Dataset): + def __init__( + self, + root: Path, + split: str = "train", + transforms: Callable[..., Any] | None = None, + num_channels: int = 3, + image_size: int = 4, + num_classes: int = 10, + num_images: int = 2, + **kwargs: Any, + ) -> None: + super().__init__() + + self.root = root + self.split = split + self.transforms = transforms + + self.data: Any = [] + self.targets = [] + + if num_channels == 1: + img_shape = (num_images, image_size, image_size) + else: + img_shape = (num_images, num_channels, image_size, image_size) + + smnt_shape = (num_images, 1, image_size, image_size) + + self.data = np.random.randint( + low=0, + high=255, + size=img_shape, + dtype=np.uint8, + ) + + self.targets = np.random.randint( + low=0, + high=num_classes, + size=smnt_shape, + dtype=np.uint8, + ) + + def __getitem__(self, index: int) -> tuple[Any, Any]: + img = tv_tensors.Image(self.data[index]) + target = tv_tensors.Mask(self.targets[index]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self) -> int: + return len(self.data) diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 09a27806..9e1552e6 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -52,6 +52,35 @@ def feats_forward(self, x: Tensor) -> Tensor: return self.forward(x) +class _DummySegmentation(nn.Module): + def __init__( + self, + in_channels: int, + num_classes: int, + dropout_rate: float, + num_estimators: int, + ) -> None: + super().__init__() + self.dropout_rate = dropout_rate + + self.conv = nn.Conv2d( + in_channels, num_classes, kernel_size=3, padding=1 + ) + self.dropout = nn.Dropout(p=dropout_rate) + + self.num_estimators = num_estimators + + def forward(self, x: Tensor) -> Tensor: + return self.dropout( + self.conv( + torch.ones( + (x.shape[0] * self.num_estimators, 1, 32, 32), + dtype=torch.float32, + ) + ) + ) + + def dummy_model( in_channels: int, num_classes: int, @@ -95,3 +124,29 @@ def dummy_model( with_linear=with_linear, last_layer=last_layer, ) + + +def dummy_segmentation_model( + in_channels: int, + num_classes: int, + dropout_rate: float = 0.0, + num_estimators: int = 1, +) -> nn.Module: + """Dummy segmentation model for testing purposes. + + Args: + in_channels (int): Number of input channels. + num_classes (int): Number of output classes. + dropout_rate (float, optional): Dropout rate. Defaults to 0.0. + num_estimators (int, optional): Number of estimators in the ensemble. + Defaults to 1. + + Returns: + nn.Module: Dummy segmentation model. + """ + return _DummySegmentation( + in_channels=in_channels, + num_classes=num_classes, + dropout_rate=dropout_rate, + num_estimators=num_estimators, + ) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index ce5add04..ce3225b3 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -9,6 +9,7 @@ WideResNetBaseline, ) from torch_uncertainty.baselines.regression import MLP +from torch_uncertainty.baselines.segmentation import SegFormer class TestStandardBaseline: @@ -125,3 +126,28 @@ def test_errors(self): version="test", hidden_dims=[1], ) + + +class TestStandardSegFormerBaseline: + """Testing the SegFormer baseline class.""" + + def test_standard(self): + net = SegFormer( + num_classes=10, + loss=nn.CrossEntropyLoss, + version="std", + arch=0, + ) + summary(net) + + _ = net.criterion + _ = net(torch.rand(1, 3, 32, 32)) + + def test_errors(self): + with pytest.raises(ValueError): + SegFormer( + num_classes=10, + loss=nn.CrossEntropyLoss, + version="test", + arch=0, + ) diff --git a/tests/datamodules/classification/test_cifar10_datamodule.py b/tests/datamodules/classification/test_cifar10_datamodule.py index 3bc5931f..405dc6b9 100644 --- a/tests/datamodules/classification/test_cifar10_datamodule.py +++ b/tests/datamodules/classification/test_cifar10_datamodule.py @@ -24,7 +24,6 @@ def test_cifar10_main(self): dm.prepare_data() dm.setup() - dm.setup("test") with pytest.raises(ValueError): dm.setup("xxx") @@ -66,7 +65,6 @@ def test_cifar10_main(self): dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.setup() - dm.setup("test") dm.train_dataloader() # args.cutout = 8 diff --git a/tests/datamodules/segmentation/__init__.py b/tests/datamodules/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/test_camvid.py b/tests/datamodules/segmentation/test_camvid.py similarity index 61% rename from tests/datamodules/test_camvid.py rename to tests/datamodules/segmentation/test_camvid.py index e9cf05ff..9ccf4d0c 100644 --- a/tests/datamodules/test_camvid.py +++ b/tests/datamodules/segmentation/test_camvid.py @@ -1,6 +1,6 @@ import pytest -from tests._dummies.dataset import DummyClassificationDataset +from tests._dummies.dataset import DummySegmentationDataset from torch_uncertainty.datamodules.segmentation import CamVidDataModule from torch_uncertainty.datasets.segmentation import CamVid @@ -9,19 +9,23 @@ class TestCamVidDataModule: """Testing the CamVidDataModule datamodule.""" def test_camvid_main(self): - # parser = ArgumentParser() - # parser = CIFAR10DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 16 dm = CamVidDataModule(root="./data/", batch_size=128) assert dm.dataset == CamVid - dm.dataset = DummyClassificationDataset + dm.dataset = DummySegmentationDataset dm.prepare_data() dm.setup() - dm.setup("test") with pytest.raises(ValueError): dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() diff --git a/tests/datamodules/segmentation/test_cityscapes.py b/tests/datamodules/segmentation/test_cityscapes.py new file mode 100644 index 00000000..3dcd74ec --- /dev/null +++ b/tests/datamodules/segmentation/test_cityscapes.py @@ -0,0 +1,31 @@ +import pytest + +from tests._dummies.dataset import DummySegmentationDataset +from torch_uncertainty.datamodules.segmentation import CityscapesDataModule +from torch_uncertainty.datasets.segmentation import Cityscapes + + +class TestCityscapesDataModule: + """Testing the CityscapesDataModule datamodule.""" + + def test_camvid_main(self): + dm = CityscapesDataModule(root="./data/", batch_size=128) + + assert dm.dataset == Cityscapes + + dm.dataset = DummySegmentationDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() diff --git a/tests/datamodules/segmentation/test_muad.py b/tests/datamodules/segmentation/test_muad.py new file mode 100644 index 00000000..314b172c --- /dev/null +++ b/tests/datamodules/segmentation/test_muad.py @@ -0,0 +1,31 @@ +import pytest + +from tests._dummies.dataset import DummySegmentationDataset +from torch_uncertainty.datamodules.segmentation import MUADDataModule +from torch_uncertainty.datasets import MUAD + + +class TestMUADDataModule: + """Testing the MUADDataModule datamodule.""" + + def test_camvid_main(self): + dm = MUADDataModule(root="./data/", batch_size=128) + + assert dm.dataset == MUAD + + dm.dataset = DummySegmentationDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() diff --git a/tests/datasets/segmentation/test_cityscapes.py b/tests/datasets/segmentation/test_cityscapes.py new file mode 100644 index 00000000..c9b9e6f5 --- /dev/null +++ b/tests/datasets/segmentation/test_cityscapes.py @@ -0,0 +1,11 @@ +import pytest + +from torch_uncertainty.datasets.segmentation import Cityscapes + + +class TestCityscapes: + """Testing the Cityscapes dataset class.""" + + def test_nodataset(self): + with pytest.raises(RuntimeError): + _ = Cityscapes("./.data") diff --git a/tests/datasets/test_muad.py b/tests/datasets/test_muad.py new file mode 100644 index 00000000..3a431f3f --- /dev/null +++ b/tests/datasets/test_muad.py @@ -0,0 +1,11 @@ +import pytest + +from torch_uncertainty.datasets import MUAD + + +class TestMUAD: + """Testing the MUAD dataset class.""" + + def test_nodataset(self): + with pytest.raises(FileNotFoundError): + _ = MUAD("./.data", split="train") diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py new file mode 100644 index 00000000..3e518129 --- /dev/null +++ b/tests/models/test_segformer.py @@ -0,0 +1,25 @@ +import torch + +from torch_uncertainty.models.segmentation.segformer import ( + segformer_b0, + segformer_b1, + segformer_b2, + segformer_b3, + segformer_b4, + segformer_b5, +) + + +class TestSegformer: + """Testing the Segformer class.""" + + def test_main(self): + segformer_b1(10) + segformer_b2(10) + segformer_b3(10) + segformer_b4(10) + segformer_b5(10) + + model = segformer_b0(10) + with torch.no_grad(): + model(torch.randn(1, 3, 32, 32)) From 2aaf0cea6464a5b9bc1836ca2d28b6dfa86a2f16 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 10:34:37 +0100 Subject: [PATCH 075/148] :fire: Remove useless code from __init__.py --- torch_uncertainty/__init__.py | 252 ---------------------------------- 1 file changed, 252 deletions(-) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 867ac13b..e69de29b 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -1,252 +0,0 @@ -# ruff: noqa: F401 -# from argparse import ArgumentParser, Namespace -# from collections import defaultdict -# from pathlib import Path -# from typing import Any - -# import numpy as np -# import pytorch_lightning as pl -# import torch -# from pytorch_lightning.callbacks import ( -# EarlyStopping, -# LearningRateMonitor, -# ModelCheckpoint, -# ) -# from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -# from torchinfo import summary - -# from .datamodules.abstract import AbstractDataModule -# from .utils import get_version - - -# def init_args( -# network: Any = None, -# datamodule: type[pl.LightningDataModule] | None = None, -# ) -> Namespace: -# parser = ArgumentParser("torch-uncertainty") -# parser.add_argument( -# "--seed", -# type=int, -# default=None, -# help="Random seed to make the training deterministic.", -# ) -# parser.add_argument( -# "--test", -# type=int, -# default=None, -# help="Run in test mode. Set to the checkpoint version number to test.", -# ) -# parser.add_argument( -# "--ckpt", type=int, default=None, help="The number of the checkpoint" -# ) -# parser.add_argument( -# "--summary", -# dest="summary", -# action="store_true", -# help="Print model summary", -# ) -# parser.add_argument("--log_graph", dest="log_graph", action="store_true") -# parser.add_argument( -# "--channels_last", -# action="store_true", -# help="Use channels last memory format", -# ) -# parser.add_argument( -# "--enable_resume", -# action="store_true", -# help="Allow resuming the training (save optimizer's states)", -# ) -# parser.add_argument( -# "--exp_dir", -# type=str, -# default="logs/", -# help="Directory to store experiment files", -# ) -# parser.add_argument( -# "--exp_name", -# type=str, -# default="", -# help="Name of the experiment folder", -# ) -# parser.add_argument( -# "--opt_temp_scaling", -# action="store_true", -# default=False, -# help="Compute optimal temperature on the test set", -# ) -# parser.add_argument( -# "--val_temp_scaling", -# action="store_true", -# default=False, -# help="Compute temperature on the validation set", -# ) -# parser = pl.Trainer.add_argparse_args(parser) -# if network is not None: -# parser = network.add_model_specific_args(parser) - -# if datamodule is not None: -# parser = datamodule.add_argparse_args(parser) - -# return parser.parse_args() - - -# def cli_main( -# network: pl.LightningModule | list[pl.LightningModule], -# datamodule: AbstractDataModule | list[AbstractDataModule], -# root: Path | str, -# net_name: str, -# args: Namespace, -# ) -> list[dict]: -# if isinstance(root, str): -# root = Path(root) - -# if isinstance(datamodule, list): -# training_task = datamodule[0].dm.training_task -# else: -# training_task = datamodule.training_task -# if training_task == "classification": -# monitor = "cls_val/acc" -# mode = "max" -# elif training_task == "regression": -# monitor = "reg_val/mse" -# mode = "min" -# else: -# raise ValueError("Unknown problem type.") - -# if args.test is None and args.max_epochs is None: -# print( -# "Setting max_epochs to 1 for testing purposes. Set max_epochs" -# " manually to train the model." -# ) -# args.max_epochs = 1 - -# if isinstance(args.seed, int): -# pl.seed_everything(args.seed, workers=True) - -# if args.channels_last: -# if isinstance(network, list): -# for i in range(len(network)): -# network[i] = network[i].to(memory_format=torch.channels_last) -# else: -# network = network.to(memory_format=torch.channels_last) - -# if hasattr(args, "use_cv") and args.use_cv: -# test_values = [] -# for i in range(len(datamodule)): -# print( -# f"Starting fold {i} out of {args.train_over} of a {args.n_splits}-fold CV." -# ) - -# # logger -# tb_logger = TensorBoardLogger( -# str(root), -# name=net_name, -# default_hp_metric=False, -# log_graph=args.log_graph, -# version=f"fold_{i}", -# ) - -# # callbacks -# save_checkpoints = ModelCheckpoint( -# dirpath=tb_logger.log_dir, -# monitor=monitor, -# mode=mode, -# save_last=True, -# save_weights_only=not args.enable_resume, -# ) - -# # Select the best model, monitor the lr and stop if NaN -# callbacks = [ -# save_checkpoints, -# LearningRateMonitor(logging_interval="step"), -# EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), -# ] - -# trainer = pl.Trainer.from_argparse_args( -# args, -# callbacks=callbacks, -# logger=tb_logger, -# deterministic=(args.seed is not None), -# inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), -# ) -# if args.summary: -# summary( -# network[i], -# input_size=list(datamodule[i].dm.input_shape).insert(0, 1), -# ) -# test_values.append({}) -# else: -# trainer.fit(network[i], datamodule[i]) -# test_values.append( -# trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] -# ) - -# all_test_values = defaultdict(list) -# for test_value in test_values: -# for key in test_value: -# all_test_values[key].append(test_value[key]) - -# avg_test_values = {} -# for key in all_test_values: -# avg_test_values[key] = np.mean(all_test_values[key]) - -# return [avg_test_values] - -# # logger -# tb_logger = TensorBoardLogger( -# str(root), -# name=net_name, -# default_hp_metric=False, -# log_graph=args.log_graph, -# version=args.test, -# ) - -# # callbacks -# save_checkpoints = ModelCheckpoint( -# monitor=monitor, -# mode=mode, -# save_last=True, -# save_weights_only=not args.enable_resume, -# ) - -# # Select the best model, monitor the lr and stop if NaN -# callbacks = [ -# save_checkpoints, -# LearningRateMonitor(logging_interval="step"), -# EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), -# ] - -# trainer -# trainer = pl.Trainer.from_argparse_args( -# args, -# callbacks=callbacks, -# logger=tb_logger, -# deterministic=(args.seed is not None), -# inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), -# ) -# if args.summary: -# summary( -# network, -# input_size=list(datamodule.input_shape).insert(0, 1), -# ) -# test_values = [{}] -# elif args.test is not None: -# if args.test >= 0: -# ckpt_file, _ = get_version( -# root=(root / net_name), -# version=args.test, -# checkpoint=args.ckpt, -# ) -# test_values = trainer.test( -# network, datamodule=datamodule, ckpt_path=str(ckpt_file) -# ) -# else: -# test_values = trainer.test(network, datamodule=datamodule) -# else: -# # training and testing -# trainer.fit(network, datamodule) -# if not args.fast_dev_run: -# test_values = trainer.test(datamodule=datamodule, ckpt_path="best") -# else: -# test_values = [{}] -# return test_values From cefff22e32effc59a709d6fc9555ec1edff458c8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 10:36:09 +0100 Subject: [PATCH 076/148] :hammer: Rename layers & add docstrings --- auto_tutorials_source/tutorial_der_cubic.py | 6 +- tests/_dummies/baseline.py | 12 +-- tests/layers/test_distributions.py | 10 +-- tests/layers/test_filter_response_norm.py | 4 +- tests/models/test_mlps.py | 4 +- torch_uncertainty/layers/distributions.py | 76 +++++++++++++------ .../layers/filter_response_norm.py | 8 +- torch_uncertainty/layers/packed.py | 9 ++- torch_uncertainty/losses.py | 22 ++++-- torch_uncertainty/optim_recipes.py | 6 +- 10 files changed, 99 insertions(+), 58 deletions(-) diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 45cbe122..8d00cde5 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -43,7 +43,7 @@ from torch_uncertainty.datasets.regression.toy import Cubic from torch_uncertainty.losses import DERLoss from torch_uncertainty.routines import RegressionRoutine -from torch_uncertainty.layers.distributions import IndptNormalInverseGammaLayer +from torch_uncertainty.layers.distributions import NormalInverseGammaLayer # %% # 2. The Optimization Recipe @@ -70,7 +70,7 @@ def optim_regression( # # In the following, we create a trainer to train the model, the same synthetic regression # datasets as in the original DER paper and the model, a simple MLP with 2 hidden layers of 64 neurons each. -# Please note that this MLP finishes with a IndptNormalInverseGammaLayer that interpret the outputs of the model +# Please note that this MLP finishes with a NormalInverseGammaLayer that interpret the outputs of the model # as the parameters of a Normal Inverse Gamma distribution. trainer = Trainer(accelerator="cpu", max_epochs=50)#, enable_progress_bar=False) @@ -90,7 +90,7 @@ def optim_regression( in_features=1, num_outputs=4, hidden_dims=[64, 64], - final_layer=IndptNormalInverseGammaLayer, + final_layer=NormalInverseGammaLayer, final_layer_args={"dim": 1}, ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 2668e93b..e1a45ca1 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -4,9 +4,9 @@ from torch import nn from torch_uncertainty.layers.distributions import ( - IndptLaplaceLayer, - IndptNormalInverseGammaLayer, - IndptNormalLayer, + LaplaceLayer, + NormalInverseGammaLayer, + NormalLayer, ) from torch_uncertainty.models.deep_ensembles import deep_ensembles from torch_uncertainty.routines import ( @@ -82,13 +82,13 @@ def __new__( ) -> LightningModule: if probabilistic: if dist_type == "normal": - last_layer = IndptNormalLayer(num_outputs) + last_layer = NormalLayer(num_outputs) num_classes = num_outputs * 2 elif dist_type == "laplace": - last_layer = IndptLaplaceLayer(num_outputs) + last_layer = LaplaceLayer(num_outputs) num_classes = num_outputs * 2 else: # dist_type == "nig" - last_layer = IndptNormalInverseGammaLayer(num_outputs) + last_layer = NormalInverseGammaLayer(num_outputs) num_classes = num_outputs * 4 else: last_layer = nn.Identity() diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py index 46e5e0b4..b1fbf4bd 100644 --- a/tests/layers/test_distributions.py +++ b/tests/layers/test_distributions.py @@ -1,16 +1,16 @@ import pytest from torch_uncertainty.layers.distributions import ( - IndptLaplaceLayer, - IndptNormalLayer, + LaplaceLayer, + NormalLayer, ) class TestDistributions: def test_errors(self): with pytest.raises(ValueError): - IndptNormalLayer(-1, 1) + NormalLayer(-1, 1) with pytest.raises(ValueError): - IndptNormalLayer(1, -1) + NormalLayer(1, -1) with pytest.raises(ValueError): - IndptLaplaceLayer(1, -1) + LaplaceLayer(1, -1) diff --git a/tests/layers/test_filter_response_norm.py b/tests/layers/test_filter_response_norm.py index bde2d534..e1f58eb1 100644 --- a/tests/layers/test_filter_response_norm.py +++ b/tests/layers/test_filter_response_norm.py @@ -5,7 +5,7 @@ FilterResponseNorm1d, FilterResponseNorm2d, FilterResponseNorm3d, - FilterResponseNormNd, + _FilterResponseNormNd, ) from torch_uncertainty.layers.mc_batch_norm import ( MCBatchNorm1d, @@ -27,7 +27,7 @@ def test_main(self): def test_errors(self): """Test errors.""" with pytest.raises(ValueError): - FilterResponseNormNd(-1, 1) + _FilterResponseNormNd(-1, 1) with pytest.raises(ValueError): FilterResponseNorm2d(0) with pytest.raises(ValueError): diff --git a/tests/models/test_mlps.py b/tests/models/test_mlps.py index 93aaabc0..2e7a72e8 100644 --- a/tests/models/test_mlps.py +++ b/tests/models/test_mlps.py @@ -1,4 +1,4 @@ -from torch_uncertainty.layers.distributions import IndptNormalLayer +from torch_uncertainty.layers.distributions import NormalLayer from torch_uncertainty.models.mlp import bayesian_mlp, mlp, packed_mlp @@ -10,7 +10,7 @@ def test_mlps(self): 1, 1, hidden_dims=[1, 1, 1], - final_layer=IndptNormalLayer, + final_layer=NormalLayer, final_layer_args={"dim": 1}, ) mlp(1, 1, hidden_dims=[]) diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index 70667299..108cc8c4 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -5,7 +5,7 @@ from torch_uncertainty.utils.distributions import NormalInverseGamma -class AbstractDistLayer(nn.Module): +class _AbstractDist(nn.Module): def __init__(self, dim: int) -> None: super().__init__() if dim < 1: @@ -16,61 +16,89 @@ def forward(self, x: Tensor) -> Distribution: raise NotImplementedError -class IndptNormalLayer(AbstractDistLayer): - def __init__(self, dim: int, min_scale: float = 1e-6) -> None: +class NormalLayer(_AbstractDist): + """Normal distribution layer. + + Converts model outputs to Independent Normal distributions. + + Args: + dim (int): The number of independent dimensions for each prediction. + eps (float): The minimal value of the :attr:`scale` parameter. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__(dim) - if min_scale <= 0: - raise ValueError(f"min_scale must be positive, got {min_scale}.") - self.min_scale = min_scale + if eps <= 0: + raise ValueError(f"eps must be positive, got {eps}.") + self.eps = eps def forward(self, x: Tensor) -> Normal: - """Forward pass of the independent normal distribution layer. + r"""Forward pass of the Normal distribution layer. Args: - x (Tensor): The input tensor of shape (dx2). + x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`2). Returns: - Normal: The independent normal distribution. + Normal: The output normal distribution. """ loc = x[:, : self.dim] - scale = F.softplus(x[:, self.dim :]) + self.min_scale + scale = F.softplus(x[:, self.dim :]) + self.eps return Normal(loc, scale) -class IndptLaplaceLayer(AbstractDistLayer): - def __init__(self, dim: int, min_scale: float = 1e-6) -> None: +class LaplaceLayer(_AbstractDist): + """Laplace distribution layer. + + Converts model outputs to Independent Laplace distributions. + + Args: + dim (int): The number of independent dimensions for each prediction. + eps (float): The minimal value of the :attr:`scale` parameter. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__(dim) - if min_scale <= 0: - raise ValueError(f"min_scale must be positive, got {min_scale}.") - self.min_scale = min_scale + if eps <= 0: + raise ValueError(f"eps must be positive, got {eps}.") + self.eps = eps def forward(self, x: Tensor) -> Laplace: - """Forward pass of the independent Laplace distribution layer. + r"""Forward pass of the Laplace distribution layer. Args: - x (Tensor): The input tensor of shape (dx2). + x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`2). Returns: - Laplace: The independent Laplace distribution. + Laplace: The output Laplace distribution. """ loc = x[:, : self.dim] - scale = F.softplus(x[:, self.dim :]) + self.min_scale + scale = F.softplus(x[:, self.dim :]) + self.eps return Laplace(loc, scale) -class IndptNormalInverseGammaLayer(AbstractDistLayer): +class NormalInverseGammaLayer(_AbstractDist): + """Normal-Inverse-Gamma distribution layer. + + Converts model outputs to Independent Normal-Inverse-Gamma distributions. + + Args: + dim (int): The number of independent dimensions for each prediction. + eps (float): The minimal values of the :attr:`lmbda`, :attr:`alpha`-1 + and :attr:`beta` parameters. + """ + def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__(dim) self.eps = eps - def forward(self, x: Tensor) -> Laplace: - """Forward pass of the independent Laplace distribution layer. + def forward(self, x: Tensor) -> NormalInverseGamma: + r"""Forward pass of the NormalInverseGamma distribution layer. Args: - x (Tensor): The input tensor of shape (dx2). + x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`4). Returns: - Laplace: The independent Laplace distribution. + NormalInverseGamma: The output NormalInverseGamma distribution. """ loc = x[:, : self.dim] lmbda = F.softplus(x[:, self.dim : 2 * self.dim]) + self.eps diff --git a/torch_uncertainty/layers/filter_response_norm.py b/torch_uncertainty/layers/filter_response_norm.py index ade98582..8c9f3aee 100644 --- a/torch_uncertainty/layers/filter_response_norm.py +++ b/torch_uncertainty/layers/filter_response_norm.py @@ -2,7 +2,7 @@ from torch import Tensor, nn -class FilterResponseNormNd(nn.Module): +class _FilterResponseNormNd(nn.Module): def __init__( self, dimension: int, @@ -49,7 +49,7 @@ def forward(self, x: Tensor) -> Tensor: return torch.max(y, self.tau) -class FilterResponseNorm1d(FilterResponseNormNd): +class FilterResponseNorm1d(_FilterResponseNormNd): def __init__( self, num_channels: int, eps: float = 1e-6, device=None, dtype=None ) -> None: @@ -70,7 +70,7 @@ def __init__( ) -class FilterResponseNorm2d(FilterResponseNormNd): +class FilterResponseNorm2d(_FilterResponseNormNd): def __init__( self, num_channels: int, eps: float = 1e-6, device=None, dtype=None ) -> None: @@ -91,7 +91,7 @@ def __init__( ) -class FilterResponseNorm3d(FilterResponseNormNd): +class FilterResponseNorm3d(_FilterResponseNormNd): def __init__( self, num_channels: int, eps: float = 1e-6, device=None, dtype=None ) -> None: diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 38bcb7db..336c5576 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -8,6 +8,13 @@ def check_packed_parameters_consistency( alpha: float, num_estimators: int, gamma: int ) -> None: + """Check the consistency of the parameters of the Packed-Ensembles layers. + + Args: + alpha (float): The width multiplier of the layer. + num_estimators (int): The number of estimators in the ensemble. + gamma (int): The number of groups in the ensemble. + """ if alpha is None: raise ValueError("You must specify the value of the arg. `alpha`") @@ -92,7 +99,7 @@ def __init__( this constraint. Note: - The input should be of size (`batch_size`, :attr:`in_features`, 1, + The input should be of shape (`batch_size`, :attr:`in_features`, 1, 1). The (often) necessary rearrange operation is executed by default. """ diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index a2f7e863..eeb6bd32 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -1,7 +1,8 @@ from typing import Literal import torch -from torch import Tensor, distributions, nn +from torch import Tensor, nn +from torch.distributions import Distribution from torch.nn import functional as F from torch_uncertainty.layers.bayesian import bayesian_modules @@ -12,18 +13,23 @@ class DistributionNLLLoss(nn.Module): def __init__( self, reduction: Literal["mean", "sum"] | None = "mean" ) -> None: - """Negative Log-Likelihood loss for a given distribution. + """Negative Log-Likelihood loss using given distributions as inputs. Args: reduction (str, optional): specifies the reduction to apply to the output:``'none'`` | ``'mean'`` | ``'sum'``. Defaults to "mean". - """ super().__init__() self.reduction = reduction - def forward(self, dist: distributions.Distribution, target: Tensor): - loss = -dist.log_prob(target) + def forward(self, dist: Distribution, targets: Tensor) -> Tensor: + """Compute the NLL of the targets given predicted distributions. + + Args: + dist (Distribution): The predicted distributions + targets (Tensor): The target values + """ + loss = -dist.log_prob(targets) if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": @@ -106,11 +112,11 @@ def __init__( self.num_samples = num_samples def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: - """Gather the kl divergence from the bayesian modules and aggregate + """Gather the KL divergence from the bayesian modules and aggregate the ELBO loss for a given network. Args: - inputs (Tensor): The *inputs* of the Bayesian Neural Network + inputs (Tensor): The inputs of the Bayesian Neural Network targets (Tensor): The target values Returns: @@ -128,7 +134,7 @@ class DERLoss(DistributionNLLLoss): def __init__( self, reg_weight: float, reduction: str | None = "mean" ) -> None: - """The Normal Inverse-Gamma loss. + """The Deep Evidential loss. This loss combines the negative log-likelihood loss of the normal inverse gamma distribution and a weighted regularization term. diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index 9782a5e1..4c8be004 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -8,19 +8,19 @@ __all__ = [ "optim_cifar10_resnet18", + "optim_cifar10_resnet34", "optim_cifar10_resnet50", "optim_cifar10_wideresnet", "optim_cifar10_vgg16", "optim_cifar100_resnet18", + "optim_cifar100_resnet34", "optim_cifar100_resnet50", "optim_cifar100_vgg16", "optim_imagenet_resnet50", "optim_imagenet_resnet50_a3", - "optim_regression", - "optim_cifar10_resnet34", - "optim_cifar100_resnet34", "optim_tinyimagenet_resnet34", "optim_tinyimagenet_resnet50", + "optim_regression", ] From deb1cea7c4cbc8ec8cba18e4d381e3d1ba9bb30a Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 11:25:27 +0100 Subject: [PATCH 077/148] :fire: Remove useless import --- experiments/classification/cifar10/resnet.py | 1 - experiments/classification/cifar10/vgg.py | 1 - .../classification/cifar10/wideresnet.py | 1 - experiments/segmentation/camvid/segformer.py | 1 - .../segmentation/cityscapes/segformer.py | 1 - experiments/segmentation/muad/segformer.py | 1 - tests/_dummies/baseline.py | 20 +++++++++---------- 7 files changed, 10 insertions(+), 16 deletions(-) diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 66fc6cc4..6deddd4c 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -1,6 +1,5 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule diff --git a/experiments/classification/cifar10/vgg.py b/experiments/classification/cifar10/vgg.py index 393ecbc0..a4614e3a 100644 --- a/experiments/classification/cifar10/vgg.py +++ b/experiments/classification/cifar10/vgg.py @@ -1,6 +1,5 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 from torch_uncertainty.baselines.classification import VGGBaseline from torch_uncertainty.datamodules import CIFAR10DataModule diff --git a/experiments/classification/cifar10/wideresnet.py b/experiments/classification/cifar10/wideresnet.py index e30c4a8f..03870002 100644 --- a/experiments/classification/cifar10/wideresnet.py +++ b/experiments/classification/cifar10/wideresnet.py @@ -1,6 +1,5 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 from torch_uncertainty.baselines.classification import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule diff --git a/experiments/segmentation/camvid/segformer.py b/experiments/segmentation/camvid/segformer.py index 90369f08..a3e376a7 100644 --- a/experiments/segmentation/camvid/segformer.py +++ b/experiments/segmentation/camvid/segformer.py @@ -1,6 +1,5 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 from torch_uncertainty.baselines.segmentation import SegFormer from torch_uncertainty.datamodules import CamVidDataModule diff --git a/experiments/segmentation/cityscapes/segformer.py b/experiments/segmentation/cityscapes/segformer.py index ecbb12f2..7472e33c 100644 --- a/experiments/segmentation/cityscapes/segformer.py +++ b/experiments/segmentation/cityscapes/segformer.py @@ -1,6 +1,5 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 from torch_uncertainty.baselines.segmentation import SegFormer from torch_uncertainty.datamodules import CityscapesDataModule diff --git a/experiments/segmentation/muad/segformer.py b/experiments/segmentation/muad/segformer.py index 5fc741ba..1ee20710 100644 --- a/experiments/segmentation/muad/segformer.py +++ b/experiments/segmentation/muad/segformer.py @@ -1,6 +1,5 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from lightning.pytorch.loggers import TensorBoardLogger # noqa: F401 from torch_uncertainty.baselines.segmentation import SegFormer from torch_uncertainty.datamodules.segmentation import MUADDataModule diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index e1a45ca1..f72d39b3 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -74,7 +74,7 @@ def __new__( cls, probabilistic: bool, in_features: int, - num_outputs: int, + output_dim: int, loss: type[nn.Module], baseline_type: str = "single", optim_recipe=None, @@ -82,17 +82,17 @@ def __new__( ) -> LightningModule: if probabilistic: if dist_type == "normal": - last_layer = NormalLayer(num_outputs) - num_classes = num_outputs * 2 + last_layer = NormalLayer(output_dim) + num_classes = output_dim * 2 elif dist_type == "laplace": - last_layer = LaplaceLayer(num_outputs) - num_classes = num_outputs * 2 + last_layer = LaplaceLayer(output_dim) + num_classes = output_dim * 2 else: # dist_type == "nig" - last_layer = NormalInverseGammaLayer(num_outputs) - num_classes = num_outputs * 4 + last_layer = NormalInverseGammaLayer(output_dim) + num_classes = output_dim * 4 else: last_layer = nn.Identity() - num_classes = num_outputs + num_classes = output_dim model = dummy_model( in_channels=in_features, @@ -103,7 +103,7 @@ def __new__( if baseline_type == "single": return RegressionRoutine( probabilistic=probabilistic, - num_outputs=num_outputs, + output_dim=output_dim, model=model, loss=loss, num_estimators=1, @@ -117,7 +117,7 @@ def __new__( ) return RegressionRoutine( probabilistic=probabilistic, - num_outputs=num_outputs, + output_dim=output_dim, model=model, loss=loss, num_estimators=2, From 4b1d1c22815dd8146c8e4bca6eeb780e24384fd2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 11:28:23 +0100 Subject: [PATCH 078/148] :hammer: Add Baseline to MLP baseline & rename parameter --- docs/source/api.rst | 19 ++++++++-- tests/baselines/test_packed.py | 6 +-- tests/baselines/test_standard.py | 10 ++--- tests/routines/test_regression.py | 16 ++++---- .../baselines/regression/__init__.py | 2 +- torch_uncertainty/baselines/regression/mlp.py | 37 ++++++++++++++++--- torch_uncertainty/routines/regression.py | 14 +++---- 7 files changed, 70 insertions(+), 34 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 8efe44e1..690e9cc9 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -18,9 +18,9 @@ Classification :nosignatures: :template: class.rst - ResNet - VGG - WideResNet + ResNetBaseline + VGGBaseline + WideResNetBaseline .. currentmodule:: torch_uncertainty.baselines.regression @@ -32,7 +32,18 @@ Regression :nosignatures: :template: class.rst - MLP + MLPBaseline + +Segmentation +^^^^^^^^^^^^ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + SegformerBaseline + .. Models .. ------ diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index 746df90d..6a57325c 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -8,7 +8,7 @@ VGGBaseline, WideResNetBaseline, ) -from torch_uncertainty.baselines.regression import MLP +from torch_uncertainty.baselines.regression import MLPBaseline class TestPackedBaseline: @@ -130,9 +130,9 @@ class TestPackedMLPBaseline: """Testing the Packed MLP baseline class.""" def test_packed(self): - net = MLP( + net = MLPBaseline( in_features=3, - num_outputs=10, + output_dim=10, loss=nn.MSELoss, version="packed", hidden_dims=[1], diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index ce3225b3..addad5a4 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -8,7 +8,7 @@ VGGBaseline, WideResNetBaseline, ) -from torch_uncertainty.baselines.regression import MLP +from torch_uncertainty.baselines.regression import MLPBaseline from torch_uncertainty.baselines.segmentation import SegFormer @@ -105,9 +105,9 @@ class TestStandardMLPBaseline: """Testing the MLP baseline class.""" def test_standard(self): - net = MLP( + net = MLPBaseline( in_features=3, - num_outputs=10, + output_dim=10, loss=nn.MSELoss, version="std", hidden_dims=[1], @@ -119,9 +119,9 @@ def test_standard(self): def test_errors(self): with pytest.raises(ValueError): - MLP( + MLPBaseline( in_features=3, - num_outputs=10, + output_dim=10, loss=nn.MSELoss, version="test", hidden_dims=[1], diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 5526cfc7..119f9cf6 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -22,7 +22,7 @@ def test_one_estimator_one_output(self): model = DummyRegressionBaseline( probabilistic=True, in_features=dm.in_features, - num_outputs=1, + output_dim=1, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", @@ -36,7 +36,7 @@ def test_one_estimator_one_output(self): model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, - num_outputs=1, + output_dim=1, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", @@ -56,7 +56,7 @@ def test_one_estimator_two_outputs(self): model = DummyRegressionBaseline( probabilistic=True, in_features=dm.in_features, - num_outputs=2, + output_dim=2, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", @@ -69,7 +69,7 @@ def test_one_estimator_two_outputs(self): model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, - num_outputs=2, + output_dim=2, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="single", @@ -87,7 +87,7 @@ def test_two_estimators_one_output(self): model = DummyRegressionBaseline( probabilistic=True, in_features=dm.in_features, - num_outputs=1, + output_dim=1, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", @@ -100,7 +100,7 @@ def test_two_estimators_one_output(self): model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, - num_outputs=1, + output_dim=1, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", @@ -118,7 +118,7 @@ def test_two_estimators_two_outputs(self): model = DummyRegressionBaseline( probabilistic=True, in_features=dm.in_features, - num_outputs=2, + output_dim=2, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", @@ -131,7 +131,7 @@ def test_two_estimators_two_outputs(self): model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, - num_outputs=2, + output_dim=2, loss=DistributionNLLLoss, optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", diff --git a/torch_uncertainty/baselines/regression/__init__.py b/torch_uncertainty/baselines/regression/__init__.py index 9320254f..b4a1391a 100644 --- a/torch_uncertainty/baselines/regression/__init__.py +++ b/torch_uncertainty/baselines/regression/__init__.py @@ -1,2 +1,2 @@ # ruff: noqa: F401 -from .mlp import MLP +from .mlp import MLPBaseline diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 1dc2ba80..5c523072 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -2,6 +2,11 @@ from torch import nn +from torch_uncertainty.layers.distributions import ( + LaplaceLayer, + NormalInverseGammaLayer, + NormalLayer, +) from torch_uncertainty.models.mlp import mlp, packed_mlp from torch_uncertainty.routines.regression import ( RegressionRoutine, @@ -9,14 +14,14 @@ from torch_uncertainty.transforms.batch import RepeatTarget -class MLP(RegressionRoutine): +class MLPBaseline(RegressionRoutine): single = ["std"] ensemble = ["packed"] versions = {"std": mlp, "packed": packed_mlp} def __init__( self, - num_outputs: int, + output_dim: int, in_features: int, loss: type[nn.Module], version: Literal["std", "packed"], @@ -25,15 +30,37 @@ def __init__( dropout_rate: float = 0.0, alpha: float | None = None, gamma: int = 1, + distribution: Literal["normal", "laplace", "nig"] | None = None, ) -> None: r"""MLP baseline for regression providing support for various versions.""" + probabilistic = True params = { "dropout_rate": dropout_rate, "in_features": in_features, - "num_outputs": num_outputs, + "num_outputs": output_dim, "hidden_dims": hidden_dims, } + if distribution == "normal": + final_layer = NormalLayer + final_layer_args = {"dim": output_dim} + params["num_outputs"] *= 2 + elif distribution == "laplace": + final_layer = LaplaceLayer + final_layer_args = {"dim": output_dim} + params["num_outputs"] *= 2 + elif distribution == "nig": + final_layer = NormalInverseGammaLayer + final_layer_args = {"dim": output_dim} + params["num_outputs"] *= 4 + elif distribution is None: + probabilistic = False + final_layer = nn.Identity + final_layer_args = {} + + params["final_layer"] = final_layer + params["final_layer_args"] = final_layer_args + format_batch_fn = nn.Identity() if version not in self.versions: @@ -51,8 +78,8 @@ def __init__( # version in self.versions: super().__init__( - probabilistic=False, - num_outputs=num_outputs, + probabilistic=probabilistic, + output_dim=output_dim, model=model, loss=loss, num_estimators=num_estimators, diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 86f3ae2b..6b11b555 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -18,7 +18,7 @@ class RegressionRoutine(LightningModule): def __init__( self, probabilistic: bool, - num_outputs: int, + output_dim: int, model: nn.Module, loss: type[nn.Module], num_estimators: int = 1, @@ -30,7 +30,7 @@ def __init__( Args: probabilistic (bool): Whether the model is probabilistic, i.e., outputs a PyTorch distribution. - num_outputs (int): The number of outputs of the model. + output_dim (int): The number of outputs of the model. model (nn.Module): The model to train. loss (type[nn.Module]): The loss function to use. num_estimators (int, optional): The number of estimators for the @@ -78,14 +78,12 @@ def __init__( ) self.num_estimators = num_estimators - if num_outputs < 1: - raise ValueError( - f"num_outputs must be positive, got {num_outputs}." - ) - self.num_outputs = num_outputs + if output_dim < 1: + raise ValueError(f"output_dim must be positive, got {output_dim}.") + self.output_dim = output_dim self.one_dim_regression = False - if num_outputs == 1: + if output_dim == 1: self.one_dim_regression = True self.optim_recipe = optim_recipe From 1af083c9b3d3f85e000ac79e2d2bd09f0ba9c062 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 11:32:39 +0100 Subject: [PATCH 079/148] :wrapped_gift: Add MLP regression configs --- .../configs/gaussian_mlp_kin8nm.yaml | 43 +++++++++++++++++ .../configs/laplace_mlp_kin8nm.yaml | 43 +++++++++++++++++ .../uci_datasets/configs/pw_mlp_kin8nm.yaml | 42 ++++++++++++++++ .../regression/uci_datasets/mlp-kin8nm.py | 48 ------------------- experiments/regression/uci_datasets/mlp.py | 26 ++++++++++ 5 files changed, 154 insertions(+), 48 deletions(-) create mode 100644 experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml create mode 100644 experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml create mode 100644 experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml delete mode 100644 experiments/regression/uci_datasets/mlp-kin8nm.py create mode 100644 experiments/regression/uci_datasets/mlp.py diff --git a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml new file mode 100644 index 00000000..8fd87576 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/gaussian_mlp_kin8nm + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: reg_val/nll + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: reg_val/nll + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 100 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: normal +data: + root: ./data + batch_size: 128 + dataset_name: kin8nm +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml new file mode 100644 index 00000000..3c42f176 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml @@ -0,0 +1,43 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/gaussian_mlp_kin8nm + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: reg_val/nll + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: reg_val/nll + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 100 + loss: torch_uncertainty.losses.DistributionNLLLoss + version: std + distribution: laplace +data: + root: ./data + batch_size: 128 + dataset_name: kin8nm +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml new file mode 100644 index 00000000..e1a4cda3 --- /dev/null +++ b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/pw_mlp_kin8nm + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: reg_val/mse + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: reg_val/mse + patience: 1000 + check_finite: true +model: + output_dim: 1 + in_features: 8 + hidden_dims: + - 100 + loss: torch.nn.MSELoss + version: std +data: + root: ./data + batch_size: 128 + dataset_name: kin8nm +optimizer: + lr: 5e-3 + weight_decay: 0 diff --git a/experiments/regression/uci_datasets/mlp-kin8nm.py b/experiments/regression/uci_datasets/mlp-kin8nm.py deleted file mode 100644 index ed796f94..00000000 --- a/experiments/regression/uci_datasets/mlp-kin8nm.py +++ /dev/null @@ -1,48 +0,0 @@ -from pathlib import Path - -from torch import nn, optim - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines.regression.mlp import MLP -from torch_uncertainty.datamodules import UCIDataModule - - -def optim_regression( - model: nn.Module, - learning_rate: float = 5e-3, -) -> dict: - optimizer = optim.Adam( - model.parameters(), - lr=learning_rate, - weight_decay=0, - ) - return { - "optimizer": optimizer, - } - - -if __name__ == "__main__": - args = init_args(MLP, UCIDataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - - net_name = "mlp-kin8nm" - - # datamodule - args.root = str(root / "data") - dm = UCIDataModule(dataset_name="kin8nm", **vars(args)) - - # model - model = MLP( - num_outputs=2, - in_features=8, - hidden_dims=[100], - loss=nn.GaussianNLLLoss, - optim_recipe=optim_regression, - dist_estimation=2, - **vars(args), - ) - - cli_main(model, dm, root, net_name, args) diff --git a/experiments/regression/uci_datasets/mlp.py b/experiments/regression/uci_datasets/mlp.py new file mode 100644 index 00000000..a0605472 --- /dev/null +++ b/experiments/regression/uci_datasets/mlp.py @@ -0,0 +1,26 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.regression import MLPBaseline +from torch_uncertainty.datamodules import UCIDataModule +from torch_uncertainty.utils import TULightningCLI + + +class MLPCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.Adam) + + +def cli_main() -> MLPCLI: + return MLPCLI(MLPBaseline, UCIDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") From 90049a2e323b0e85f81c21322a5998c58d0790df Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 11:43:21 +0100 Subject: [PATCH 080/148] :hammer: Handle target resizing within Segmentation Routine --- .../baselines/segmentation/segformer.py | 43 +------------------ torch_uncertainty/routines/segmentation.py | 10 +++++ 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index fe32a034..517fbc54 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -1,8 +1,6 @@ from typing import Literal -from einops import rearrange -from torch import Tensor, nn -from torchvision.transforms.v2 import functional as F +from torch import nn from torch_uncertainty.models.segmentation.segformer import ( segformer_b0, @@ -79,42 +77,3 @@ def __init__( format_batch_fn=format_batch_fn, ) self.save_hyperparameters() - - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> Tensor: - img, target = batch - logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST - ) - logits = rearrange(logits, "b c h w -> (b h w) c") - target = target.flatten() - valid_mask = target != 255 - loss = self.criterion(logits[valid_mask], target[valid_mask]) - self.log("train_loss", loss) - return loss - - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> None: - img, target = batch - logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST - ) - logits = rearrange(logits, "b c h w -> (b h w) c") - target = target.flatten() - valid_mask = target != 255 - self.val_seg_metrics.update(logits[valid_mask], target[valid_mask]) - - def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: - img, target = batch - logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST - ) - logits = rearrange(logits, "b c h w -> (b h w) c") - target = target.flatten() - valid_mask = target != 255 - self.test_seg_metrics.update(logits[valid_mask], target[valid_mask]) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 5a22d8a5..3b6124a4 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -3,6 +3,7 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, nn from torchmetrics import Accuracy, MetricCollection +from torchvision.transforms.v2 import functional as F from torch_uncertainty.metrics import MeanIntersectionOverUnion @@ -65,6 +66,9 @@ def training_step( img, target = batch img, target = self.format_batch_fn((img, target)) logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) logits = rearrange(logits, "b c h w -> (b h w) c") target = target.flatten() valid_mask = target != 255 @@ -77,6 +81,9 @@ def validation_step( ) -> None: img, target = batch logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) logits = rearrange( logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators ) @@ -89,6 +96,9 @@ def validation_step( def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: img, target = batch logits = self.forward(img) + target = F.resize( + target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) logits = rearrange( logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators ) From 5bd811db0b4b648a92369e1d257de00fea357929 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 11:39:55 +0100 Subject: [PATCH 081/148] :wrapped_gift: Update CIFAR-100 configs --- .../cifar100/configs/resnet18/batched.yaml | 2 +- .../cifar100/configs/resnet18/masked.yaml | 4 +- .../cifar100/configs/resnet18/mimo.yaml | 4 +- .../cifar100/configs/resnet18/packed.yaml | 4 +- .../cifar100/configs/resnet18/standard.yaml | 4 +- experiments/classification/cifar100/resnet.py | 49 ++++++++----------- experiments/classification/cifar100/vgg.py | 47 ++++++++---------- .../classification/cifar100/wideresnet.py | 47 ++++++++---------- 8 files changed, 69 insertions(+), 92 deletions(-) diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index 4410892d..0c14021f 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -46,4 +46,4 @@ lr_scheduler: milestones: - 25 - 50 - gamma: 0.1 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index 3bd98005..c1d6f261 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -41,10 +41,10 @@ data: optimizer: lr: 0.05 momentum: 0.9 - weight_decay: 5e-4 + weight_decay: 1e-4 nesterov: true lr_scheduler: milestones: - 25 - 50 - gamma: 0.1 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index ee3efcb9..13fc9d20 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -41,10 +41,10 @@ data: optimizer: lr: 0.05 momentum: 0.9 - weight_decay: 5e-4 + weight_decay: 1e-4 nesterov: true lr_scheduler: milestones: - 25 - 50 - gamma: 0.1 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 2a0d7c47..1dc954d1 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -42,10 +42,10 @@ data: optimizer: lr: 0.05 momentum: 0.9 - weight_decay: 5e-4 + weight_decay: 1e-4 nesterov: true lr_scheduler: milestones: - 25 - 50 - gamma: 0.1 + gamma: 0.2 diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index 235a6382..bcb1c7e9 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -39,10 +39,10 @@ data: optimizer: lr: 0.05 momentum: 0.9 - weight_decay: 5e-4 + weight_decay: 1e-4 nesterov: true lr_scheduler: milestones: - 25 - 50 - gamma: 0.1 + gamma: 0.2 diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py index 32926e7c..0c3a0068 100644 --- a/experiments/classification/cifar100/resnet.py +++ b/experiments/classification/cifar100/resnet.py @@ -1,36 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import ResNetBaseline +from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optim_recipes import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(ResNetBaseline, CIFAR100DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-resnet{args.arch}-cifar100" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR100DataModule(**vars(args)) - # model - model = ResNetBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optim_recipe=get_procedure( - f"resnet{args.arch}", "cifar100", args.version - ), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, CIFAR100DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar100/vgg.py b/experiments/classification/cifar100/vgg.py index dafdabc9..1936f809 100644 --- a/experiments/classification/cifar100/vgg.py +++ b/experiments/classification/cifar100/vgg.py @@ -1,34 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import VGGBaseline +from torch_uncertainty.baselines.classification import VGGBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optim_recipes import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(VGGBaseline, CIFAR100DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-vgg{args.arch}-cifar100" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.Adam) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR100DataModule(**vars(args)) - # model - model = VGGBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optim_recipe=get_procedure(f"vgg{args.arch}", "cifar100", args.version), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(VGGBaseline, CIFAR100DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/classification/cifar100/wideresnet.py b/experiments/classification/cifar100/wideresnet.py index f2114908..49b9a227 100644 --- a/experiments/classification/cifar100/wideresnet.py +++ b/experiments/classification/cifar100/wideresnet.py @@ -1,34 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import WideResNetBaseline +from torch_uncertainty.baselines.classification import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optim_recipes import get_procedure +from torch_uncertainty.utils import TULightningCLI -if __name__ == "__main__": - args = init_args(WideResNetBaseline, CIFAR100DataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - if args.exp_name == "": - args.exp_name = f"{args.version}-wideresnet28x10-cifar100" +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - # datamodule - args.root = str(root / "data") - dm = CIFAR100DataModule(**vars(args)) - # model - model = WideResNetBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optim_recipe=get_procedure("wideresnet28x10", "cifar100", args.version), - style="cifar", - **vars(args), - ) +def cli_main() -> ResNetCLI: + return ResNetCLI(WideResNetBaseline, CIFAR100DataModule) - cli_main(model, dm, args.exp_dir, args.exp_name, args) + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") From a9d7d88f81c7b4f1e28d5ca02f84a8646ef516aa Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 13:38:58 +0100 Subject: [PATCH 082/148] :books: Improve docs & misc. --- auto_tutorials_source/tutorial_der_cubic.py | 2 +- auto_tutorials_source/tutorial_mc_dropout.py | 2 +- docs/source/api.rst | 209 ++++++++---------- docs/source/contributing.rst | 2 - docs/source/installation.rst | 8 +- docs/source/introduction_uncertainty.rst | 1 - ...on_procedures.py => test_optim_recipes.py} | 5 - .../baselines/classification/resnet.py | 2 +- .../baselines/classification/vgg.py | 2 +- .../baselines/classification/wideresnet.py | 2 +- .../baselines/segmentation/segformer.py | 24 +- torch_uncertainty/models/__init__.py | 1 + torch_uncertainty/models/deep_ensembles.py | 4 +- .../models/segmentation/segformer/std.py | 12 +- torch_uncertainty/optim_recipes.py | 16 -- 15 files changed, 128 insertions(+), 164 deletions(-) rename tests/{test_optimization_procedures.py => test_optim_recipes.py} (95%) diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 8d00cde5..941ef998 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -110,7 +110,7 @@ def optim_regression( routine = RegressionRoutine( probabilistic=True, - num_outputs=1, + output_dim=1, model=model, loss=loss, optim_recipe=optim_regression, diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 4506ce7d..29a3c4c7 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -36,7 +36,7 @@ from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.models.mc_dropout import mc_dropout +from torch_uncertainty.models import mc_dropout from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine diff --git a/docs/source/api.rst b/docs/source/api.rst index 690e9cc9..39099970 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1,14 +1,16 @@ -API reference +API Reference ============= .. currentmodule:: torch_uncertainty -Baselines ---------- +Routines +-------- -This API provides lightning-based models that can be easily trained and evaluated. +The routine are the main building blocks of the library. They define the framework +in which the models are trained and evaluated. They allow for easy computation of different +metrics crucial for uncertainty estimation in different contexts, namely classification, regression and segmentation. -.. currentmodule:: torch_uncertainty.baselines.classification +.. currentmodule:: torch_uncertainty.routines Classification ^^^^^^^^^^^^^^ @@ -18,11 +20,7 @@ Classification :nosignatures: :template: class.rst - ResNetBaseline - VGGBaseline - WideResNetBaseline - -.. currentmodule:: torch_uncertainty.baselines.regression + ClassificationRoutine Regression ^^^^^^^^^^ @@ -32,7 +30,7 @@ Regression :nosignatures: :template: class.rst - MLPBaseline + RegressionRoutine Segmentation ^^^^^^^^^^^^ @@ -42,119 +40,52 @@ Segmentation :nosignatures: :template: class.rst - SegformerBaseline - - -.. Models -.. ------ - -.. This section encapsulates the model backbones currently supported by the library. - -.. ResNet -.. ^^^^^^ - -.. .. currentmodule:: torch_uncertainty.models.resnet - -.. Concerning ResNet backbones, we provide building functions for ResNet18, ResNet34, -.. ResNet50, ResNet101 and, ResNet152 (from `Deep Residual Learning for Image Recognition -.. `_, CVPR 2016). - -.. Standard -.. ~~~~~~~ - -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: - -.. resnet18 -.. resnet34 -.. resnet50 -.. resnet101 -.. resnet152 - -.. Packed-Ensembles -.. ~~~~~~~~~~~~~~~~ - -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: - -.. packed_resnet18 -.. packed_resnet34 -.. packed_resnet50 -.. packed_resnet101 -.. packed_resnet152 - -.. Masksembles -.. ~~~~~~~~~~~ - -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: - -.. masked_resnet18 -.. masked_resnet34 -.. masked_resnet50 -.. masked_resnet101 -.. masked_resnet152 - -.. BatchEnsemble -.. ~~~~~~~~~~~~~ + SegmentationRoutine -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: - -.. batched_resnet18 -.. batched_resnet34 -.. batched_resnet50 -.. batched_resnet101 -.. batched_resnet152 - -.. Wide-ResNet -.. ^^^^^^^^^^^ - -.. .. currentmodule:: torch_uncertainty.models.wideresnet +Baselines +--------- -.. Concerning Wide-ResNet backbones, we provide building functions for Wide-ResNet28x10 -.. (from `Wide Residual Networks `_, British -.. Machine Vision Conference 2016). +TorchUncertainty provide lightning-based models that can be easily trained and evaluated. +These models inherit from the routines and are specifically designed to benchmark +different methods in similar settings, here with constant architectures. -.. Standard -.. ~~~~~~~ +.. currentmodule:: torch_uncertainty.baselines.classification -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: +Classification +^^^^^^^^^^^^^^ -.. wideresnet28x10 +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst -.. Packed-Ensembles -.. ~~~~~~~~~~~~~~~~ + ResNetBaseline + VGGBaseline + WideResNetBaseline -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: +.. currentmodule:: torch_uncertainty.baselines.regression -.. packed_wideresnet28x10 +Regression +^^^^^^^^^^ -.. Masksembles -.. ~~~~~~~~~~~ +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: + MLPBaseline -.. masked_wideresnet28x10 +.. currentmodule:: torch_uncertainty.baselines.segmentation -.. BatchEnsemble -.. ~~~~~~~~~~~~~ +Segmentation +^^^^^^^^^^^^ -.. .. autosummary:: -.. :toctree: generated/ -.. :nosignatures: +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst -.. batched_wideresnet28x10 + SegFormerBaseline Layers ------ @@ -192,6 +123,30 @@ Bayesian layers BayesConv2d BayesConv3d +Models +------ + +.. currentmodule:: torch_uncertainty.models + +Deep Ensembles +^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + deep_ensembles + +Monte Carlo Dropout + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + mc_dropout + Metrics ------- @@ -222,10 +177,10 @@ Losses :nosignatures: :template: class.rst + DistributionNLLLoss KLDiv ELBOLoss BetaNLL - NIGLoss DECLoss Post-Processing Methods @@ -241,12 +196,24 @@ Post-Processing Methods TemperatureScaler VectorScaler MatrixScaler + MCBatchNorm Datamodules ----------- +.. currentmodule:: torch_uncertainty.datamodules.abstract + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + .. currentmodule:: torch_uncertainty.datamodules +Classification +^^^^^^^^^^^^^^ + .. autosummary:: :toctree: generated/ :nosignatures: @@ -257,4 +224,24 @@ Datamodules MNISTDataModule TinyImageNetDataModule ImageNetDataModule + +Regression +^^^^^^^^^^ +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + UCIDataModule + +Segmentation +^^^^^^^^^^^^ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + CamVidDataModule + CityscapesDataModule + diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 4a822d4c..466479b6 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -4,7 +4,6 @@ Contributing .. role:: bash(code) :language: bash - TorchUncertainty is in early development stage. We are looking for contributors to help us build a comprehensive library for uncertainty quantification in PyTorch. @@ -54,7 +53,6 @@ Then navigate to :bash:`./docs` and build the documentation with: make html - Optionally, specify :bash:`html-noplot` instead of :bash:`html` to avoid running the tutorials. Guidelines diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d2153e31..a05ae32e 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -5,7 +5,7 @@ Installation :language: bash -You can install the package from PyPI or from source. Choose the latter if you +You can install the package either from PyPI or from source. Choose the latter if you want to access the files included the `experiments `_ folder or if you want to contribute to the project. @@ -59,11 +59,11 @@ Options You can install the package with the following options: * dev: includes all the dependencies for the development of the package -including ruff and the pre-commits hooks. + including ruff and the pre-commits hooks. * docs: includes all the dependencies for the documentation of the package -based on sphinx + based on sphinx * image: includes all the dependencies for the image processing module -including opencv and scikit-image + including opencv and scikit-image * tabular: includes pandas * all: includes all the aforementioned dependencies diff --git a/docs/source/introduction_uncertainty.rst b/docs/source/introduction_uncertainty.rst index 6e2d226a..4e92c4fe 100644 --- a/docs/source/introduction_uncertainty.rst +++ b/docs/source/introduction_uncertainty.rst @@ -22,7 +22,6 @@ it may not be a good idea to trust these predictions. Let's see why in more deta The Overconfidence of Neural Networks ------------------------------------- - References ---------- diff --git a/tests/test_optimization_procedures.py b/tests/test_optim_recipes.py similarity index 95% rename from tests/test_optimization_procedures.py rename to tests/test_optim_recipes.py index 74522afc..48fcd06f 100644 --- a/tests/test_optimization_procedures.py +++ b/tests/test_optim_recipes.py @@ -6,7 +6,6 @@ from torch_uncertainty.models.wideresnet import wideresnet28x10 from torch_uncertainty.optim_recipes import ( get_procedure, - optim_regression, ) @@ -72,10 +71,6 @@ def test_optim_imagenet_resnet50(self): model = resnet50(in_channels=3, num_classes=1000) procedure(model) - def test_optim_regression(self): - model = resnet18(in_channels=3, num_classes=1) - optim_regression(model) - def test_optim_unknown(self): with pytest.raises(NotImplementedError): _ = get_procedure("unknown", "cifar100") diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index b34ea48a..4dbfa466 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -2,7 +2,7 @@ from torch import nn -from torch_uncertainty.models.mc_dropout import mc_dropout +from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.resnet import ( batched_resnet18, batched_resnet20, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index f3cd194a..1942211c 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -2,7 +2,7 @@ from torch import nn -from torch_uncertainty.models.mc_dropout import mc_dropout +from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.vgg import ( packed_vgg11, packed_vgg13, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index d324fba2..9c3e108b 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -2,7 +2,7 @@ from torch import nn -from torch_uncertainty.models.mc_dropout import mc_dropout +from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.wideresnet import ( batched_wideresnet28x10, masked_wideresnet28x10, diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index d330ea1d..07ae5043 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -3,12 +3,12 @@ from torch import nn from torch_uncertainty.models.segmentation.segformer import ( - segformer_b0, - segformer_b1, - segformer_b2, - segformer_b3, - segformer_b4, - segformer_b5, + seg_former_b0, + seg_former_b1, + seg_former_b2, + seg_former_b3, + seg_former_b4, + seg_former_b5, ) from torch_uncertainty.routines.segmentation import SegmentationRoutine @@ -17,12 +17,12 @@ class SegFormerBaseline(SegmentationRoutine): single = ["std"] versions = { "std": [ - segformer_b0, - segformer_b1, - segformer_b2, - segformer_b3, - segformer_b4, - segformer_b5, + seg_former_b0, + seg_former_b1, + seg_former_b2, + seg_former_b3, + seg_former_b4, + seg_former_b5, ] } archs = [0, 1, 2, 3, 4, 5] diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index afc7480f..08dfc824 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -1,2 +1,3 @@ # ruff: noqa: F401 from .deep_ensembles import deep_ensembles +from .mc_dropout import mc_dropout diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index 66435947..8a1895c2 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -64,7 +64,7 @@ def deep_ensembles( task: Literal["classification", "regression"] = "classification", probabilistic=None, reset_model_parameters: bool = False, -) -> nn.Module: +) -> _DeepEnsembles: """Build a Deep Ensembles out of the original models. Args: @@ -76,7 +76,7 @@ def deep_ensembles( when :attr:models is a module or a list of length 1. Returns: - nn.Module: The ensembled model. + _DeepEnsembles: The ensembled model. Raises: ValueError: If :attr:num_estimators is not specified and :attr:models diff --git a/torch_uncertainty/models/segmentation/segformer/std.py b/torch_uncertainty/models/segmentation/segformer/std.py index a723de64..49bd41d9 100644 --- a/torch_uncertainty/models/segmentation/segformer/std.py +++ b/torch_uncertainty/models/segmentation/segformer/std.py @@ -808,7 +808,7 @@ def forward(self, x): return self.head(features) -def segformer_b0(num_classes: int): +def seg_former_b0(num_classes: int): return _SegFormer( in_channels=[32, 64, 160, 256], feature_strides=[4, 8, 16, 32], @@ -819,7 +819,7 @@ def segformer_b0(num_classes: int): ) -def segformer_b1(num_classes: int): +def seg_former_b1(num_classes: int): return _SegFormer( in_channels=[64, 128, 320, 512], feature_strides=[4, 8, 16, 32], @@ -830,7 +830,7 @@ def segformer_b1(num_classes: int): ) -def segformer_b2(num_classes: int): +def seg_former_b2(num_classes: int): return _SegFormer( in_channels=[64, 128, 320, 512], feature_strides=[4, 8, 16, 32], @@ -841,7 +841,7 @@ def segformer_b2(num_classes: int): ) -def segformer_b3(num_classes: int): +def seg_former_b3(num_classes: int): return _SegFormer( in_channels=[64, 128, 320, 512], feature_strides=[4, 8, 16, 32], @@ -852,7 +852,7 @@ def segformer_b3(num_classes: int): ) -def segformer_b4(num_classes: int): +def seg_former_b4(num_classes: int): return _SegFormer( in_channels=[64, 128, 320, 512], feature_strides=[4, 8, 16, 32], @@ -863,7 +863,7 @@ def segformer_b4(num_classes: int): ) -def segformer_b5(num_classes: int): +def seg_former_b5(num_classes: int): return _SegFormer( in_channels=[64, 128, 320, 512], feature_strides=[4, 8, 16, 32], diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index 4c8be004..7ee39586 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -20,7 +20,6 @@ "optim_imagenet_resnet50_a3", "optim_tinyimagenet_resnet34", "optim_tinyimagenet_resnet50", - "optim_regression", ] @@ -317,21 +316,6 @@ def optim_tinyimagenet_resnet50( return {"optimizer": optimizer, "lr_scheduler": scheduler} -def optim_regression( - model: nn.Module, - learning_rate: float = 1e-2, -) -> dict: - optimizer = optim.SGD( - model.parameters(), - lr=learning_rate, - weight_decay=0, - ) - return { - "optimizer": optimizer, - "monitor": "reg/val_nll", - } - - def batch_ensemble_wrapper(model: nn.Module, optim_recipe: Callable) -> dict: procedure = optim_recipe(model) param_optimizer = procedure["optimizer"] From e0c9b535814d8a35b4fd308bc9e81a7c0f84af1f Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 14:11:52 +0100 Subject: [PATCH 083/148] :hammer: SegFormer baseline renaming into SegFormerBaseline --- experiments/segmentation/camvid/segformer.py | 4 ++-- experiments/segmentation/cityscapes/segformer.py | 4 ++-- experiments/segmentation/muad/segformer.py | 4 ++-- torch_uncertainty/baselines/segmentation/__init__.py | 2 +- torch_uncertainty/baselines/segmentation/segformer.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/experiments/segmentation/camvid/segformer.py b/experiments/segmentation/camvid/segformer.py index a3e376a7..a42756c3 100644 --- a/experiments/segmentation/camvid/segformer.py +++ b/experiments/segmentation/camvid/segformer.py @@ -1,7 +1,7 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from torch_uncertainty.baselines.segmentation import SegFormer +from torch_uncertainty.baselines.segmentation import SegFormerBaseline from torch_uncertainty.datamodules import CamVidDataModule from torch_uncertainty.utils import TULightningCLI @@ -13,7 +13,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def cli_main() -> SegFormerCLI: - return SegFormerCLI(SegFormer, CamVidDataModule) + return SegFormerCLI(SegFormerBaseline, CamVidDataModule) if __name__ == "__main__": diff --git a/experiments/segmentation/cityscapes/segformer.py b/experiments/segmentation/cityscapes/segformer.py index 7472e33c..7dab5755 100644 --- a/experiments/segmentation/cityscapes/segformer.py +++ b/experiments/segmentation/cityscapes/segformer.py @@ -1,7 +1,7 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from torch_uncertainty.baselines.segmentation import SegFormer +from torch_uncertainty.baselines.segmentation import SegFormerBaseline from torch_uncertainty.datamodules import CityscapesDataModule from torch_uncertainty.utils import TULightningCLI @@ -13,7 +13,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def cli_main() -> SegFormerCLI: - return SegFormerCLI(SegFormer, CityscapesDataModule) + return SegFormerCLI(SegFormerBaseline, CityscapesDataModule) if __name__ == "__main__": diff --git a/experiments/segmentation/muad/segformer.py b/experiments/segmentation/muad/segformer.py index 1ee20710..67ad9564 100644 --- a/experiments/segmentation/muad/segformer.py +++ b/experiments/segmentation/muad/segformer.py @@ -1,7 +1,7 @@ import torch from lightning.pytorch.cli import LightningArgumentParser -from torch_uncertainty.baselines.segmentation import SegFormer +from torch_uncertainty.baselines.segmentation import SegFormerBaseline from torch_uncertainty.datamodules.segmentation import MUADDataModule from torch_uncertainty.utils import TULightningCLI @@ -13,7 +13,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def cli_main() -> SegFormerCLI: - return SegFormerCLI(SegFormer, MUADDataModule) + return SegFormerCLI(SegFormerBaseline, MUADDataModule) if __name__ == "__main__": diff --git a/torch_uncertainty/baselines/segmentation/__init__.py b/torch_uncertainty/baselines/segmentation/__init__.py index d9e05601..fe2488e4 100644 --- a/torch_uncertainty/baselines/segmentation/__init__.py +++ b/torch_uncertainty/baselines/segmentation/__init__.py @@ -1,2 +1,2 @@ # ruff: noqa: F401 -from .segformer import SegFormer +from .segformer import SegFormerBaseline diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 517fbc54..d330ea1d 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -13,7 +13,7 @@ from torch_uncertainty.routines.segmentation import SegmentationRoutine -class SegFormer(SegmentationRoutine): +class SegFormerBaseline(SegmentationRoutine): single = ["std"] versions = { "std": [ From 69b05898075c297c0471f1da53ddf33bfb6585fb Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 14:31:25 +0100 Subject: [PATCH 084/148] :white_mark_check: Improve coverage --- tests/_dummies/__init__.py | 18 +- tests/_dummies/baseline.py | 30 +- tests/_dummies/datamodule.py | 3 + tests/_dummies/model.py | 30 +- tests/baselines/test_standard.py | 6 +- tests/models/test_deep_ensembles.py | 12 +- tests/models/test_mc_dropout.py | 10 +- tests/routines/test_classification.py | 211 +--------- tests/routines/test_segmentation.py | 62 +++ tests/test_cli.py | 397 ++----------------- tests/test_utils.py | 13 +- torch_uncertainty/datasets/muad.py | 16 +- torch_uncertainty/models/deep_ensembles.py | 6 +- torch_uncertainty/routines/classification.py | 6 + torch_uncertainty/routines/segmentation.py | 13 +- torch_uncertainty/utils/misc.py | 3 +- 16 files changed, 214 insertions(+), 622 deletions(-) create mode 100644 tests/routines/test_segmentation.py diff --git a/tests/_dummies/__init__.py b/tests/_dummies/__init__.py index db6a25de..ac5d4d0d 100644 --- a/tests/_dummies/__init__.py +++ b/tests/_dummies/__init__.py @@ -1,6 +1,18 @@ # ruff: noqa: F401 -from .baseline import DummyClassificationBaseline, DummyRegressionBaseline -from .datamodule import DummyClassificationDataModule, DummyRegressionDataModule -from .dataset import DummyClassificationDataset, DummyRegressionDataset +from .baseline import ( + DummyClassificationBaseline, + DummyRegressionBaseline, + DummySegmentationBaseline, +) +from .datamodule import ( + DummyClassificationDataModule, + DummyRegressionDataModule, + DummySegmentationDataModule, +) +from .dataset import ( + DummyClassificationDataset, + DummyRegressionDataset, + DummySegmentationDataset, +) from .model import dummy_model from .transform import DummyTransform diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index f72d39b3..07f48f3b 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -25,7 +25,7 @@ def __new__( num_classes: int, in_channels: int, loss: type[nn.Module], - ensemble=False, + baseline_type: str = "single", optim_recipe=None, with_feats: bool = True, with_linear: bool = True, @@ -36,12 +36,11 @@ def __new__( model = dummy_model( in_channels=in_channels, num_classes=num_classes, - num_estimators=1 + int(ensemble), with_feats=with_feats, with_linear=with_linear, ) - if not ensemble: + if baseline_type == "single": return ClassificationRoutine( num_classes=num_classes, model=model, @@ -54,7 +53,11 @@ def __new__( eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, ) - # ensemble + # baseline_type == "ensemble": + model = deep_ensembles( + [model, copy.deepcopy(model)], + task="classification", + ) return ClassificationRoutine( num_classes=num_classes, model=model, @@ -97,7 +100,6 @@ def __new__( model = dummy_model( in_channels=in_features, num_classes=num_classes, - num_estimators=1, last_layer=last_layer, ) if baseline_type == "single": @@ -131,31 +133,37 @@ def __new__( cls, in_channels: int, num_classes: int, + image_size: int, loss: type[nn.Module], - ensemble: bool = False, + baseline_type: bool = False, + optim_recipe=None, ) -> LightningModule: model = dummy_segmentation_model( in_channels=in_channels, num_classes=num_classes, - num_estimators=1 + int(ensemble), + image_size=image_size, ) - if not ensemble: + if baseline_type == "single": return SegmentationRoutine( num_classes=num_classes, model=model, loss=loss, format_batch_fn=nn.Identity(), - optim_recipe=None, num_estimators=1, + optim_recipe=optim_recipe, ) - # ensemble + # baseline_type == "ensemble": + model = deep_ensembles( + [model, copy.deepcopy(model)], + task="segmentation", + ) return SegmentationRoutine( num_classes=num_classes, model=model, loss=loss, format_batch_fn=RepeatTarget(2), - optim_recipe=None, num_estimators=2, + optim_recipe=optim_recipe, ) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index fb05740c..9cf0ab77 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -169,6 +169,7 @@ def __init__( batch_size: int, num_classes: int = 2, num_workers: int = 1, + image_size: int = 4, pin_memory: bool = True, persistent_workers: bool = True, num_images: int = 2, @@ -183,7 +184,9 @@ def __init__( ) self.num_classes = num_classes + self.num_channels = 3 self.num_images = num_images + self.image_size = image_size self.dataset = DummySegmentationDataset diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 9e1552e6..17682920 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -11,7 +11,6 @@ def __init__( self, in_channels: int, num_classes: int, - num_estimators: int, dropout_rate: float, with_linear: bool, last_layer: nn.Module, @@ -32,14 +31,12 @@ def __init__( self.last_layer = last_layer self.dropout = nn.Dropout(p=dropout_rate) - self.num_estimators = num_estimators - def forward(self, x: Tensor) -> Tensor: return self.last_layer( self.dropout( self.linear( torch.ones( - (x.shape[0] * self.num_estimators, 1), + (x.shape[0], 1), dtype=torch.float32, ) ) @@ -58,23 +55,28 @@ def __init__( in_channels: int, num_classes: int, dropout_rate: float, - num_estimators: int, + image_size: int, ) -> None: super().__init__() self.dropout_rate = dropout_rate - + self.in_channels = in_channels + self.num_classes = num_classes + self.image_size = image_size self.conv = nn.Conv2d( in_channels, num_classes, kernel_size=3, padding=1 ) self.dropout = nn.Dropout(p=dropout_rate) - self.num_estimators = num_estimators - def forward(self, x: Tensor) -> Tensor: return self.dropout( self.conv( torch.ones( - (x.shape[0] * self.num_estimators, 1, 32, 32), + ( + x.shape[0], + self.in_channels, + self.image_size, + self.image_size, + ), dtype=torch.float32, ) ) @@ -84,7 +86,6 @@ def forward(self, x: Tensor) -> Tensor: def dummy_model( in_channels: int, num_classes: int, - num_estimators: int, dropout_rate: float = 0.0, with_feats: bool = True, with_linear: bool = True, @@ -111,7 +112,6 @@ def dummy_model( return _DummyWithFeats( in_channels=in_channels, num_classes=num_classes, - num_estimators=num_estimators, dropout_rate=dropout_rate, with_linear=with_linear, last_layer=last_layer, @@ -119,7 +119,6 @@ def dummy_model( return _Dummy( in_channels=in_channels, num_classes=num_classes, - num_estimators=num_estimators, dropout_rate=dropout_rate, with_linear=with_linear, last_layer=last_layer, @@ -129,17 +128,16 @@ def dummy_model( def dummy_segmentation_model( in_channels: int, num_classes: int, + image_size: int, dropout_rate: float = 0.0, - num_estimators: int = 1, ) -> nn.Module: """Dummy segmentation model for testing purposes. Args: in_channels (int): Number of input channels. num_classes (int): Number of output classes. + image_size (int): Size of the input image. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. - num_estimators (int, optional): Number of estimators in the ensemble. - Defaults to 1. Returns: nn.Module: Dummy segmentation model. @@ -148,5 +146,5 @@ def dummy_segmentation_model( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, - num_estimators=num_estimators, + image_size=image_size, ) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index addad5a4..7912954f 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -9,7 +9,7 @@ WideResNetBaseline, ) from torch_uncertainty.baselines.regression import MLPBaseline -from torch_uncertainty.baselines.segmentation import SegFormer +from torch_uncertainty.baselines.segmentation import SegFormerBaseline class TestStandardBaseline: @@ -132,7 +132,7 @@ class TestStandardSegFormerBaseline: """Testing the SegFormer baseline class.""" def test_standard(self): - net = SegFormer( + net = SegFormerBaseline( num_classes=10, loss=nn.CrossEntropyLoss, version="std", @@ -145,7 +145,7 @@ def test_standard(self): def test_errors(self): with pytest.raises(ValueError): - SegFormer( + SegFormerBaseline( num_classes=10, loss=nn.CrossEntropyLoss, version="test", diff --git a/tests/models/test_deep_ensembles.py b/tests/models/test_deep_ensembles.py index 2c3df7c2..98e45c08 100644 --- a/tests/models/test_deep_ensembles.py +++ b/tests/models/test_deep_ensembles.py @@ -9,21 +9,21 @@ class TestDeepEnsemblesModel: """Testing the deep_ensembles function.""" def test_main(self): - model_1 = dummy_model(1, 10, 1) - model_2 = dummy_model(1, 10, 1) + model_1 = dummy_model(1, 10) + model_2 = dummy_model(1, 10) de = deep_ensembles([model_1, model_2]) # Check B N C assert de(torch.randn(3, 4, 4)).shape == (6, 10) def test_list_and_num_estimators(self): - model_1 = dummy_model(1, 10, 1) - model_2 = dummy_model(1, 10, 1) + model_1 = dummy_model(1, 10) + model_2 = dummy_model(1, 10) with pytest.raises(ValueError): deep_ensembles([model_1, model_2], num_estimators=2) def test_list_singleton(self): - model_1 = dummy_model(1, 10, 1) + model_1 = dummy_model(1, 10) deep_ensembles([model_1], num_estimators=2, reset_model_parameters=True) deep_ensembles(model_1, num_estimators=2, reset_model_parameters=False) @@ -32,7 +32,7 @@ def test_list_singleton(self): deep_ensembles([model_1], num_estimators=1) def test_errors(self): - model_1 = dummy_model(1, 10, 1) + model_1 = dummy_model(1, 10) with pytest.raises(ValueError): deep_ensembles(model_1, num_estimators=None) diff --git a/tests/models/test_mc_dropout.py b/tests/models/test_mc_dropout.py index 59f084af..b0cd9327 100644 --- a/tests/models/test_mc_dropout.py +++ b/tests/models/test_mc_dropout.py @@ -9,7 +9,7 @@ class TestMCDropout: """Testing the MC Dropout class.""" def test_mc_dropout_train(self): - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) dropout_model = mc_dropout(model, num_estimators=5) dropout_model.train() assert dropout_model.training @@ -21,14 +21,14 @@ def test_mc_dropout_train(self): dropout_model(torch.rand(1, 10)) def test_mc_dropout_eval(self): - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) dropout_model = mc_dropout(model, num_estimators=5) dropout_model.eval() assert not dropout_model.training dropout_model(torch.rand(1, 10)) def test_mc_dropout_errors(self): - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) with pytest.raises(ValueError): _MCDropout(model=model, num_estimators=-1, last_layer=True) @@ -47,10 +47,10 @@ def test_mc_dropout_errors(self): with pytest.raises(ValueError): dropout_model = mc_dropout(model, 5) - model = dummy_model(10, 5, 1, 0.1) + model = dummy_model(10, 5, 0.1) with pytest.raises(ValueError): dropout_model = mc_dropout(model, None) - model = dummy_model(10, 5, 1, dropout_rate=0) + model = dummy_model(10, 5, dropout_rate=0) with pytest.raises(ValueError): dropout_model = mc_dropout(model, None) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 30df930c..2e5e1c78 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -13,8 +13,8 @@ from torch_uncertainty.routines import ClassificationRoutine -class TestClassificationSingle: - """Testing the classification routine with a single model.""" +class TestClassification: + """Testing the classification routine.""" def test_one_estimator_binary(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) @@ -30,7 +30,7 @@ def test_one_estimator_binary(self): num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss, optim_recipe=optim_cifar10_resnet18, - ensemble=False, + baseline_type="single", ood_criterion="msp", ) @@ -53,7 +53,7 @@ def test_two_estimators_binary(self): num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss, optim_recipe=optim_cifar10_resnet18, - ensemble=True, + baseline_type="single", ood_criterion="logit", ) @@ -77,7 +77,7 @@ def test_one_estimator_two_classes(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, optim_recipe=optim_cifar10_resnet18, - ensemble=False, + baseline_type="single", ood_criterion="entropy", eval_ood=True, ) @@ -101,7 +101,7 @@ def test_two_estimators_two_classes(self): in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, optim_recipe=optim_cifar10_resnet18, - ensemble=True, + baseline_type="ensemble", ood_criterion="energy", ) @@ -114,12 +114,14 @@ def test_classification_failures(self): # num_estimators with pytest.raises(ValueError): ClassificationRoutine(10, nn.Module(), None, num_estimators=-1) + # num_classes + with pytest.raises(ValueError): + ClassificationRoutine(0, nn.Module(), None) # single & MI with pytest.raises(ValueError): ClassificationRoutine( 10, nn.Module(), None, num_estimators=1, ood_criterion="mi" ) - with pytest.raises(ValueError): ClassificationRoutine(10, nn.Module(), None, ood_criterion="other") @@ -136,201 +138,10 @@ def test_classification_failures(self): 10, nn.Module(), None, 2, eval_grouping_loss=True ) - model = dummy_model(1, 1, 1, 0, with_feats=False, with_linear=True) + model = dummy_model(1, 1, 0, with_feats=False, with_linear=True) with pytest.raises(ValueError): ClassificationRoutine(10, model, None, eval_grouping_loss=True) - model = dummy_model(1, 1, 1, 0, with_feats=True, with_linear=False) + model = dummy_model(1, 1, 0, with_feats=True, with_linear=False) with pytest.raises(ValueError): ClassificationRoutine(10, model, None, eval_grouping_loss=True) - - -# from functools import partial -# from pathlib import Path - -# import pytest -# from cli_test_helpers import ArgvContext -# from torch import nn - -# from tests._dummies import ( -# DummyClassificationBaseline, -# DummyClassificationDataModule, -# DummyClassificationDataset, -# dummy_model, -# ) -# from torch_uncertainty import cli_main, init_args -# from torch_uncertainty.losses import DECLoss, ELBOLoss - -# with ArgvContext( -# "file.py", -# "--eval-ood", -# "--entropy", -# "--cutmix_alpha", -# "0.5", -# "--mixtype", -# "timm", -# ): -# args = init_args( -# DummyClassificationBaseline, DummyClassificationDataModule -# ) - -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(**vars(args)) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=DECLoss, -# optim_recipe=optim_cifar10_resnet18, -# ensemble=False, -# **vars(args), -# ) -# with pytest.raises(NotImplementedError): -# cli_main(model, dm, root, "logs/dummy", args) - -# def test_cli_main_dummy_mixup_ts_cv(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext( -# "file.py", -# "--mixtype", -# "kernel_warping", -# "--mixup_alpha", -# "1.", -# "--dist_sim", -# "inp", -# "--val_temp_scaling", -# "--use_cv", -# ): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(num_classes=10, **vars(args)) -# dm.dataset = ( -# lambda root, -# num_channels, -# num_classes, -# image_size, -# transform, -# num_images: DummyClassificationDataset( -# root, -# num_channels=num_channels, -# num_classes=num_classes, -# image_size=image_size, -# transform=transform, -# num_images=20, -# ) -# ) - -# list_dm = dm.make_cross_val_splits(2, 1) -# list_model = [ -# DummyClassificationBaseline( -# num_classes=list_dm[i].dm.num_classes, -# in_channels=list_dm[i].dm.num_channels, -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# ensemble=False, -# calibration_set=dm.get_val_set, -# **vars(args), -# ) -# for i in range(len(list_dm)) -# ] - -# cli_main(list_model, list_dm, root, "logs/dummy", args) - -# with ArgvContext( -# "file.py", -# "--mixtype", -# "kernel_warping", -# "--mixup_alpha", -# "1.", -# "--dist_sim", -# "emb", -# "--val_temp_scaling", -# "--use_cv", -# ): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(num_classes=10, **vars(args)) -# dm.dataset = ( -# lambda root, -# num_channels, -# num_classes, -# image_size, -# transform, -# num_images: DummyClassificationDataset( -# root, -# num_channels=num_channels, -# num_classes=num_classes, -# image_size=image_size, -# transform=transform, -# num_images=20, -# ) -# ) - -# list_dm = dm.make_cross_val_splits(2, 1) -# list_model = [] -# for i in range(len(list_dm)): -# list_model.append( -# DummyClassificationBaseline( -# num_classes=list_dm[i].dm.num_classes, -# in_channels=list_dm[i].dm.num_channels, -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# ensemble=False, -# calibration_set=dm.get_val_set, -# **vars(args), -# ) -# ) - -# cli_main(list_model, list_dm, root, "logs/dummy", args) -# with ArgvContext("file.py", "--mutual_information"): -# args = init_args(DummyClassificationBaseline, DummyClassificationDataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(num_classes=1, **vars(args)) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=nn.BCEWithLogitsLoss, -# optim_recipe=optim_cifar10_resnet18, -# ensemble=True, -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy", args) - - -# with ArgvContext("file.py", "--eval-ood", "--variation_ratio"): -# args = init_args( -# DummyClassificationBaseline, DummyClassificationDataModule -# ) - -# # datamodule -# args.root = str(root / "data") -# dm = DummyClassificationDataModule(**vars(args)) - -# model = DummyClassificationBaseline( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# ensemble=True, -# **vars(args), -# ) - -# cli_main(model, dm, root, "logs/dummy", args) - -# def test_classification_failures(self): -# with pytest.raises(ValueError): -# ClassificationRoutine( -# 10, -# nn.Module(), -# None, -# None, -# 2, -# use_entropy=True, -# use_variation_ratio=True, -# ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py new file mode 100644 index 00000000..965693b3 --- /dev/null +++ b/tests/routines/test_segmentation.py @@ -0,0 +1,62 @@ +from pathlib import Path + +import pytest +from lightning.pytorch import Trainer +from torch import nn + +from tests._dummies import ( + DummySegmentationBaseline, + DummySegmentationDataModule, +) +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.routines import SegmentationRoutine + + +class TestSegmentation: + def test_one_estimator_two_classes(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) + + model = DummySegmentationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + image_size=dm.image_size, + loss=nn.CrossEntropyLoss, + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_two_classes(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) + + model = DummySegmentationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + image_size=dm.image_size, + loss=nn.CrossEntropyLoss, + baseline_type="ensemble", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_segmentation_failures(self): + with pytest.raises(ValueError): + SegmentationRoutine( + 2, nn.Identity(), nn.CrossEntropyLoss(), num_estimators=0 + ) + with pytest.raises(ValueError): + SegmentationRoutine(1, nn.Identity(), nn.CrossEntropyLoss()) diff --git a/tests/test_cli.py b/tests/test_cli.py index c67391cc..6612f5ec 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,365 +1,32 @@ -# import sys -# from pathlib import Path - -# import pytest -# from cli_test_helpers import ArgvContext -# from torch import nn - -# from torch_uncertainty import cli_main, init_args -# from torch_uncertainty.baselines.classification import VGG, ResNet, WideResNet -# from torch_uncertainty.baselines.regression import MLP -# from torch_uncertainty.datamodules import CIFAR10DataModule, UCIDataModule -# from torch_uncertainty.optim_recipes import ( -# optim_cifar10_resnet18, -# optim_cifar10_vgg16, -# optim_cifar10_wideresnet, -# optim_regression, -# ) -# from torch_uncertainty.utils.misc import csv_writer - -# from ._dummies.dataset import DummyClassificationDataset - - -# class TestCLI: -# """Testing the CLI function.""" - -# def test_cli_main_resnet(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(ResNet, CIFAR10DataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = CIFAR10DataModule(**vars(args)) - -# # Simulate that summary is True & the only argument -# args.summary = True - -# model = ResNet( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# style="cifar", -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# **vars(args), -# ) - -# results = cli_main(model, dm, root, "std", args) -# results_path = root / "tests" / "logs" -# if not results_path.exists(): -# results_path.mkdir(parents=True) -# for dict_result in results: -# csv_writer( -# results_path / "results.csv", -# dict_result, -# ) -# # Test if file already exists -# for dict_result in results: -# csv_writer( -# results_path / "results.csv", -# dict_result, -# ) - -# def test_cli_main_other_arguments(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext( -# "file.py", "--seed", "42", "--max_epochs", "1", "--channels_last" -# ): -# print(sys.orig_argv, sys.argv) -# args = init_args(ResNet, CIFAR10DataModule) - -# # datamodule -# args.root = root / "data" -# dm = CIFAR10DataModule(**vars(args)) - -# # Simulate that summary is True & the only argument -# args.summary = True - -# model = ResNet( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# style="cifar", -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# **vars(args), -# ) - -# cli_main(model, dm, root, "std", args) - -# def test_cli_main_wideresnet(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(WideResNet, CIFAR10DataModule) - -# # datamodule -# args.root = root / "data" -# dm = CIFAR10DataModule(**vars(args)) - -# args.summary = True - -# model = WideResNet( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_wideresnet, -# **vars(args), -# ) - -# cli_main(model, dm, root, "std", args) - -# def test_cli_main_vgg(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(VGG, CIFAR10DataModule) - -# # datamodule -# args.root = root / "data" -# dm = CIFAR10DataModule(**vars(args)) - -# args.summary = True - -# model = VGG( -# num_classes=dm.num_classes, -# in_channels=dm.num_channels, -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_vgg16, -# **vars(args), -# ) - -# cli_main(model, dm, root, "std", args) - -# def test_cli_main_mlp(self): -# root = str(Path(__file__).parent.absolute().parents[0]) -# with ArgvContext("file.py"): -# args = init_args(MLP, UCIDataModule) - -# # datamodule -# args.root = root + "/data" -# dm = UCIDataModule(dataset_name="kin8nm", input_shape=(1, 5), **vars(args)) - -# args.summary = True - -# model = MLP( -# num_outputs=1, -# in_features=5, -# hidden_dims=[], -# dist_estimation=1, -# loss=nn.MSELoss, -# optim_recipe=optim_regression, -# **vars(args), -# ) - -# cli_main(model, dm, root, "std", args) - -# args.test = True -# cli_main(model, dm, root, "std", args) - -# def test_cli_other_training_task(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py"): -# args = init_args(MLP, UCIDataModule) - -# # datamodule -# args.root = root / "data" -# dm = UCIDataModule(dataset_name="kin8nm", input_shape=(1, 5), **vars(args)) - -# dm.training_task = "time-series-regression" - -# args.summary = True - -# model = MLP( -# num_outputs=1, -# in_features=5, -# hidden_dims=[], -# dist_estimation=1, -# loss=nn.MSELoss, -# optim_recipe=optim_regression, -# **vars(args), -# ) -# with pytest.raises(ValueError): -# cli_main(model, dm, root, "std", args) - -# def test_cli_cv_ts(self): -# root = Path(__file__).parent.absolute().parents[0] -# with ArgvContext("file.py", "--use_cv", "--channels_last"): -# args = init_args(ResNet, CIFAR10DataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = CIFAR10DataModule(**vars(args)) - -# # Simulate that summary is True & the only argument -# args.summary = True - -# dm.dataset = ( -# lambda root, train, download, transform: DummyClassificationDataset( -# root, -# train=train, -# download=download, -# transform=transform, -# num_images=20, -# ) -# ) - -# list_dm = dm.make_cross_val_splits(2, 1) -# list_model = [ -# ResNet( -# num_classes=list_dm[i].dm.num_classes, -# in_channels=list_dm[i].dm.num_channels, -# style="cifar", -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# **vars(args), -# ) -# for i in range(len(list_dm)) -# ] - -# cli_main(list_model, list_dm, root, "std", args) - -# with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup"): -# args = init_args(ResNet, CIFAR10DataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = CIFAR10DataModule(**vars(args)) - -# # Simulate that summary is True & the only argument -# args.summary = True - -# dm.dataset = ( -# lambda root, train, download, transform: DummyClassificationDataset( -# root, -# train=train, -# download=download, -# transform=transform, -# num_images=20, -# ) -# ) - -# list_dm = dm.make_cross_val_splits(2, 1) -# list_model = [] -# for i in range(len(list_dm)): -# list_model.append( -# ResNet( -# num_classes=list_dm[i].dm.num_classes, -# in_channels=list_dm[i].dm.num_channels, -# style="cifar", -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# **vars(args), -# ) -# ) - -# cli_main(list_model, list_dm, root, "std", args) - -# with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup_io"): -# args = init_args(ResNet, CIFAR10DataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = CIFAR10DataModule(**vars(args)) - -# # Simulate that summary is True & the only argument -# args.summary = True - -# dm.dataset = ( -# lambda root, train, download, transform: DummyClassificationDataset( -# root, -# train=train, -# download=download, -# transform=transform, -# num_images=20, -# ) -# ) - -# list_dm = dm.make_cross_val_splits(2, 1) -# list_model = [] -# for i in range(len(list_dm)): -# list_model.append( -# ResNet( -# num_classes=list_dm[i].dm.num_classes, -# in_channels=list_dm[i].dm.num_channels, -# style="cifar", -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# **vars(args), -# ) -# ) - -# cli_main(list_model, list_dm, root, "std", args) - -# with ArgvContext("file.py", "--use_cv", "--mixtype", "regmixup"): -# args = init_args(ResNet, CIFAR10DataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = CIFAR10DataModule(**vars(args)) - -# # Simulate that summary is True & the only argument -# args.summary = True - -# dm.dataset = ( -# lambda root, train, download, transform: DummyClassificationDataset( -# root, -# train=train, -# download=download, -# transform=transform, -# num_images=20, -# ) -# ) - -# list_dm = dm.make_cross_val_splits(2, 1) -# list_model = [] -# for i in range(len(list_dm)): -# list_model.append( -# ResNet( -# num_classes=list_dm[i].dm.num_classes, -# in_channels=list_dm[i].dm.num_channels, -# style="cifar", -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# **vars(args), -# ) -# ) - -# cli_main(list_model, list_dm, root, "std", args) - -# with ArgvContext("file.py", "--use_cv", "--mixtype", "kernel_warping"): -# args = init_args(ResNet, CIFAR10DataModule) - -# # datamodule -# args.root = str(root / "data") -# dm = CIFAR10DataModule(**vars(args)) - -# # Simulate that summary is True & the only argument -# args.summary = True - -# dm.dataset = ( -# lambda root, train, download, transform: DummyClassificationDataset( -# root, -# train=train, -# download=download, -# transform=transform, -# num_images=20, -# ) -# ) - -# list_dm = dm.make_cross_val_splits(2, 1) -# list_model = [] -# for i in range(len(list_dm)): -# list_model.append( -# ResNet( -# num_classes=list_dm[i].dm.num_classes, -# in_channels=list_dm[i].dm.num_channels, -# style="cifar", -# loss=nn.CrossEntropyLoss, -# optim_recipe=optim_cifar10_resnet18, -# **vars(args), -# ) -# ) - -# cli_main(list_model, list_dm, root, "std", args) - -# def test_init_args_void(self): -# with ArgvContext("file.py"): -# init_args() +from cli_test_helpers import ArgvContext + +from torch_uncertainty.baselines.classification import ResNetBaseline +from torch_uncertainty.datamodules import CIFAR10DataModule +from torch_uncertainty.utils.cli import TULightningCLI, TUSaveConfigCallback + + +class TestCLI: + """Testing torch-uncertainty CLI.""" + + def test_cli_init(self): + """Test CLI initialization.""" + with ArgvContext( + "file.py", + "--model.in_channels", + "3", + "--model.num_classes", + "10", + "--model.version", + "std", + "--model.arch", + "18", + "--model.loss", + "torch.nn.CrossEntropyLoss", + "--data.root", + "./data", + "--data.batch_size", + "32", + ): + cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) + assert cli.eval_after_fit_default is False + assert cli.save_config_callback == TUSaveConfigCallback diff --git a/tests/test_utils.py b/tests/test_utils.py index 728f288d..d7dba834 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,7 @@ import torch from torch.distributions import Laplace, Normal -from torch_uncertainty.utils import distributions, get_version, hub +from torch_uncertainty.utils import csv_writer, distributions, get_version, hub class TestUtils: @@ -37,6 +37,17 @@ def test_hub_notexists(self): hub.load_hf("test", version=42) +class TestMisc: + """Testing misc methods.""" + + def test_csv_writer(self): + root = Path(__file__).parent.resolve() + csv_writer(root / "logs" / "results.csv", {"a": 1.0, "b": 2.0}) + csv_writer( + root / "logs" / "results.csv", {"a": 1.0, "b": 2.0, "c": 3.0} + ) + + class TestDistributions: """Testing distributions methods.""" diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 05d79143..0f5f2ed6 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -112,27 +112,27 @@ def __init__( self._make_dataset(self.root / split) - def encode_target(self, smnt: Image.Image) -> Image.Image: + def encode_target(self, target: Image.Image) -> Image.Image: """Encode target image to tensor. Args: - smnt (Image.Image): Target PIL image. + target (Image.Image): Target PIL image. Returns: torch.Tensor: Encoded target. """ - smnt = F.pil_to_tensor(smnt) - smnt = rearrange(smnt, "c h w -> h w c") - target = torch.zeros_like(smnt[..., :1]) + target = F.pil_to_tensor(target) + target = rearrange(target, "c h w -> h w c") + out = torch.zeros_like(target[..., :1]) # convert target color to index for muad_class in self.classes: - target[ + out[ ( - smnt == torch.tensor(muad_class["id"], dtype=target.dtype) + target == torch.tensor(muad_class["id"], dtype=target.dtype) ).all(dim=-1) ] = muad_class["train_id"] - return F.to_pil_image(rearrange(target, "h w c -> c h w")) + return F.to_pil_image(rearrange(out, "h w c -> c h w")) def decode_target(self, target: Image.Image) -> np.ndarray: target[target == 255] = 19 diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index 66435947..71dd3f1c 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -61,7 +61,9 @@ def forward(self, x: torch.Tensor) -> Distribution: def deep_ensembles( models: list[nn.Module] | nn.Module, num_estimators: int | None = None, - task: Literal["classification", "regression"] = "classification", + task: Literal[ + "classification", "regression", "segmentation" + ] = "classification", probabilistic=None, reset_model_parameters: bool = False, ) -> nn.Module: @@ -125,7 +127,7 @@ def deep_ensembles( "num_estimators must be None if you provided a non-singleton list." ) - if task == "classification": + if task in ("classification", "segmentation"): return _DeepEnsembles(models=models) if task == "regression": if probabilistic is None: diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 373fb0db..ba0cebcb 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -137,6 +137,12 @@ def __init__( "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." ) + if num_classes < 1: + raise ValueError( + "The number of classes must be a positive integer >= 1." + f"Got {num_classes}." + ) + if eval_grouping_loss and not hasattr(model, "feats_forward"): raise ValueError( "Your model must have a `feats_forward` method to compute the " diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 3b6124a4..26e93ca8 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -20,16 +20,27 @@ def __init__( ) -> None: super().__init__() + if num_estimators < 1: + raise ValueError( + f"num_estimators must be positive, got {num_estimators}." + ) + + if num_classes < 2: + raise ValueError( + f"num_classes must be at least 2, got {num_classes}." + ) + if format_batch_fn is None: format_batch_fn = nn.Identity() self.num_classes = num_classes self.model = model self.loss = loss - self.num_estimators = num_estimators self.format_batch_fn = format_batch_fn self.optim_recipe = optim_recipe + self.num_estimators = num_estimators + self.metric_to_monitor = "val/mean_iou" # metrics diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index 33a43c92..7052e243 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -15,7 +15,8 @@ def csv_writer(path: Path, dic: dict) -> None: else: append_mode = False rw_mode = "w" - + print(f"Writing to {path}") + print(f"Append mode: {append_mode}") # Write dic with path.open(rw_mode) as csvfile: writer = csv.writer(csvfile, delimiter=",") From 68644af8b3188a4f5ec4116a57ad0a54a70d88b5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 13:54:00 +0100 Subject: [PATCH 085/148] :bug: Fix tests --- .gitignore | 3 ++- tests/models/test_segformer.py | 24 ++++++++++++------------ tests/test_cli.py | 14 +++++++------- tests/test_utils.py | 4 ++-- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 8e9d6307..4aa0c186 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,8 @@ docs/*/auto_tutorials/ *.pth *.ckpt *.out -sg_execution_times.rst +docs/source/sg_execution_times.rst +test**/*.csv # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 3e518129..f69439d8 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -1,12 +1,12 @@ import torch from torch_uncertainty.models.segmentation.segformer import ( - segformer_b0, - segformer_b1, - segformer_b2, - segformer_b3, - segformer_b4, - segformer_b5, + seg_former_b0, + seg_former_b1, + seg_former_b2, + seg_former_b3, + seg_former_b4, + seg_former_b5, ) @@ -14,12 +14,12 @@ class TestSegformer: """Testing the Segformer class.""" def test_main(self): - segformer_b1(10) - segformer_b2(10) - segformer_b3(10) - segformer_b4(10) - segformer_b5(10) + seg_former_b1(10) + seg_former_b2(10) + seg_former_b3(10) + seg_former_b4(10) + seg_former_b5(10) - model = segformer_b0(10) + model = seg_former_b0(10) with torch.no_grad(): model(torch.randn(1, 3, 32, 32)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 6612f5ec..424cdc5c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,4 @@ -from cli_test_helpers import ArgvContext +import sys from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule @@ -10,7 +10,7 @@ class TestCLI: def test_cli_init(self): """Test CLI initialization.""" - with ArgvContext( + sys.argv = [ "file.py", "--model.in_channels", "3", @@ -25,8 +25,8 @@ def test_cli_init(self): "--data.root", "./data", "--data.batch_size", - "32", - ): - cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) - assert cli.eval_after_fit_default is False - assert cli.save_config_callback == TUSaveConfigCallback + "4", + ] + cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) + assert cli.eval_after_fit_default is False + assert cli.save_config_callback == TUSaveConfigCallback diff --git a/tests/test_utils.py b/tests/test_utils.py index d7dba834..bebafb22 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -42,9 +42,9 @@ class TestMisc: def test_csv_writer(self): root = Path(__file__).parent.resolve() - csv_writer(root / "logs" / "results.csv", {"a": 1.0, "b": 2.0}) + csv_writer(root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0}) csv_writer( - root / "logs" / "results.csv", {"a": 1.0, "b": 2.0, "c": 3.0} + root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0, "c": 3.0} ) From 429aabdb41d2cb1c04cfdb98cb086bae87117c5c Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 14:08:12 +0100 Subject: [PATCH 086/148] :white_check_mark: Improve baseline cov. --- tests/baselines/test_standard.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index 7912954f..457b5e8e 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -117,6 +117,16 @@ def test_standard(self): _ = net.criterion _ = net(torch.rand(1, 3)) + for distribution in ["normal", "laplace", "nig"]: + MLPBaseline( + in_features=3, + output_dim=10, + loss=nn.MSELoss, + version="std", + hidden_dims=[1], + distribution=distribution, + ) + def test_errors(self): with pytest.raises(ValueError): MLPBaseline( From ae59a500d172d48034f9e56de848d43c5f4bf643 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 14:15:07 +0100 Subject: [PATCH 087/148] :hammer: Move plot utils --- torch_uncertainty/plotting_utils.py | 43 ------------------- torch_uncertainty/routines/classification.py | 3 +- torch_uncertainty/utils/__init__.py | 2 +- torch_uncertainty/utils/misc.py | 44 ++++++++++++++++++++ 4 files changed, 46 insertions(+), 46 deletions(-) delete mode 100644 torch_uncertainty/plotting_utils.py diff --git a/torch_uncertainty/plotting_utils.py b/torch_uncertainty/plotting_utils.py deleted file mode 100644 index c9bdb337..00000000 --- a/torch_uncertainty/plotting_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -import matplotlib.pyplot as plt -import torch -from matplotlib.axes import Axes -from matplotlib.figure import Figure - - -def plot_hist( - conf: list[torch.Tensor], - bins: int = 20, - title: str = "Histogram with 'auto' bins", - dpi: int = 60, -) -> tuple[Figure, Axes]: - """Plot a confidence histogram. - - Args: - conf (Any): The confidence values. - bins (int, optional): The number of bins. Defaults to 20. - title (str, optional): The title of the plot. Defaults to "Histogram - with 'auto' bins". - dpi (int, optional): The dpi of the plot. Defaults to 60. - - Returns: - Tuple[Figure, Axes]: The figure and axes of the plot. - """ - plt.rc("axes", axisbelow=True) - fig, ax = plt.subplots(1, figsize=(7, 5), dpi=dpi) - for i in [1, 0]: - ax.hist( - conf[i], - bins=bins, - density=True, - label=["In-distribution", "Out-of-Distribution"][i], - alpha=0.4, - linewidth=1, - edgecolor=["#0d559f", "#d45f00"][i], - color=["#1f77b4", "#ff7f0e"][i], - ) - - ax.set_title(title) - plt.grid(True, linestyle="--", alpha=0.7, zorder=0) - plt.legend() - fig.tight_layout() - return fig, ax diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index ba0cebcb..c088ca40 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -29,10 +29,9 @@ MutualInformation, VariationRatio, ) -from torch_uncertainty.plotting_utils import plot_hist from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup -from torch_uncertainty.utils import csv_writer +from torch_uncertainty.utils import csv_writer, plot_hist class ClassificationRoutine(LightningModule): diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index af7813b4..b5d0b1d4 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -2,4 +2,4 @@ from .checkpoints import get_version from .cli import TULightningCLI from .hub import load_hf -from .misc import csv_writer +from .misc import csv_writer, plot_hist diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index 7052e243..5b87647a 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -1,6 +1,11 @@ import csv from pathlib import Path +import matplotlib.pyplot as plt +import torch +from matplotlib.axes import Axes +from matplotlib.figure import Figure + def csv_writer(path: Path, dic: dict) -> None: """Write a dictionary to a csv file. @@ -24,3 +29,42 @@ def csv_writer(path: Path, dic: dict) -> None: if append_mode is False: writer.writerow(dic.keys()) writer.writerow([f"{elem:.4f}" for elem in dic.values()]) + + +def plot_hist( + conf: list[torch.Tensor], + bins: int = 20, + title: str = "Histogram with 'auto' bins", + dpi: int = 60, +) -> tuple[Figure, Axes]: + """Plot a confidence histogram. + + Args: + conf (Any): The confidence values. + bins (int, optional): The number of bins. Defaults to 20. + title (str, optional): The title of the plot. Defaults to "Histogram + with 'auto' bins". + dpi (int, optional): The dpi of the plot. Defaults to 60. + + Returns: + Tuple[Figure, Axes]: The figure and axes of the plot. + """ + plt.rc("axes", axisbelow=True) + fig, ax = plt.subplots(1, figsize=(7, 5), dpi=dpi) + for i in [1, 0]: + ax.hist( + conf[i], + bins=bins, + density=True, + label=["In-distribution", "Out-of-Distribution"][i], + alpha=0.4, + linewidth=1, + edgecolor=["#0d559f", "#d45f00"][i], + color=["#1f77b4", "#ff7f0e"][i], + ) + + ax.set_title(title) + plt.grid(True, linestyle="--", alpha=0.7, zorder=0) + plt.legend() + fig.tight_layout() + return fig, ax From 0676e49d40b9f66f66b17bb7a2da2d9a9a07182a Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 15:10:06 +0100 Subject: [PATCH 088/148] :white_check_mark: Improve cli coverage --- tests/_dummies/baseline.py | 4 ++-- tests/test_cli.py | 1 + torch_uncertainty/utils/cli.py | 14 ++++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 07f48f3b..bee2b77c 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -135,7 +135,7 @@ def __new__( num_classes: int, image_size: int, loss: type[nn.Module], - baseline_type: bool = False, + baseline_type: str = "single", optim_recipe=None, ) -> LightningModule: model = dummy_segmentation_model( @@ -149,7 +149,7 @@ def __new__( num_classes=num_classes, model=model, loss=loss, - format_batch_fn=nn.Identity(), + format_batch_fn=None, num_estimators=1, optim_recipe=optim_recipe, ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 424cdc5c..4d570858 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -30,3 +30,4 @@ def test_cli_init(self): cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) assert cli.eval_after_fit_default is False assert cli.save_config_callback == TUSaveConfigCallback + cli.instantiate_trainer() diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 695b3b32..e04a5ec0 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -61,12 +61,14 @@ def setup( class TULightningCLI(LightningCLI): def __init__( self, - model_class: type[LightningModule] - | Callable[..., LightningModule] - | None = None, - datamodule_class: type[LightningDataModule] - | Callable[..., LightningDataModule] - | None = None, + model_class: ( + type[LightningModule] | Callable[..., LightningModule] | None + ) = None, + datamodule_class: ( + type[LightningDataModule] + | Callable[..., LightningDataModule] + | None + ) = None, save_config_callback: type[SaveConfigCallback] | None = TUSaveConfigCallback, save_config_kwargs: dict[str, Any] | None = None, From cf2ca5486fbe98fb7216d017276c4fc3c33cd4dc Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 15:44:24 +0100 Subject: [PATCH 089/148] :hammer: Loss rework to be an instance Co-authored-by: Olivier Laurent --- auto_tutorials_source/tutorial_bayesian.py | 7 ++--- .../tutorial_evidential_classification.py | 13 ++++----- .../tutorial_mc_batch_norm.py | 10 ++++--- auto_tutorials_source/tutorial_mc_dropout.py | 2 +- .../cifar10/configs/resnet.yaml | 3 +- .../cifar10/configs/resnet18/batched.yaml | 3 +- .../cifar10/configs/resnet18/masked.yaml | 3 +- .../cifar10/configs/resnet18/mimo.yaml | 3 +- .../cifar10/configs/resnet18/packed.yaml | 3 +- .../cifar10/configs/resnet18/standard.yaml | 3 +- .../cifar10/configs/resnet50/batched.yaml | 3 +- .../cifar10/configs/resnet50/masked.yaml | 3 +- .../cifar10/configs/resnet50/mimo.yaml | 3 +- .../cifar10/configs/resnet50/packed.yaml | 3 +- .../cifar10/configs/resnet50/standard.yaml | 3 +- .../cifar10/configs/wideresnet28x10.yaml | 3 +- .../configs/wideresnet28x10/batched.yaml | 3 +- .../configs/wideresnet28x10/masked.yaml | 3 +- .../cifar10/configs/wideresnet28x10/mimo.yaml | 3 +- .../configs/wideresnet28x10/packed.yaml | 3 +- .../configs/wideresnet28x10/standard.yaml | 3 +- .../cifar100/configs/resnet.yaml | 3 +- .../cifar100/configs/resnet18/batched.yaml | 3 +- .../cifar100/configs/resnet18/masked.yaml | 3 +- .../cifar100/configs/resnet18/mimo.yaml | 3 +- .../cifar100/configs/resnet18/packed.yaml | 3 +- .../cifar100/configs/resnet18/standard.yaml | 3 +- .../cifar100/configs/resnet50/batched.yaml | 3 +- .../cifar100/configs/resnet50/masked.yaml | 3 +- .../cifar100/configs/resnet50/mimo.yaml | 3 +- .../cifar100/configs/resnet50/packed.yaml | 3 +- .../cifar100/configs/resnet50/standard.yaml | 3 +- experiments/classification/mnist/lenet.py | 2 +- .../classification/tiny-imagenet/resnet.py | 4 +-- .../configs/gaussian_mlp_kin8nm.yaml | 3 +- .../configs/laplace_mlp_kin8nm.yaml | 3 +- .../uci_datasets/configs/pw_mlp_kin8nm.yaml | 3 +- .../camvid/configs/segformer.yaml | 3 +- .../cityscapes/configs/segformer.yaml | 3 +- .../segmentation/muad/configs/segformer.yaml | 3 +- tests/baselines/test_batched.py | 6 ++-- tests/baselines/test_masked.py | 10 +++---- tests/baselines/test_mc_dropout.py | 8 +++--- tests/baselines/test_mimo.py | 6 ++-- tests/baselines/test_packed.py | 19 +++++-------- tests/baselines/test_standard.py | 27 ++++++++---------- tests/routines/test_classification.py | 8 +++--- tests/routines/test_regression.py | 16 +++++------ tests/routines/test_segmentation.py | 4 +-- tests/test_cli.py | 2 +- .../baselines/classification/resnet.py | 2 +- .../baselines/classification/vgg.py | 2 +- .../baselines/classification/wideresnet.py | 2 +- torch_uncertainty/baselines/regression/mlp.py | 2 +- .../baselines/segmentation/segformer.py | 2 +- torch_uncertainty/losses.py | 12 ++++++-- torch_uncertainty/routines/classification.py | 28 +++++++------------ torch_uncertainty/routines/regression.py | 11 ++------ torch_uncertainty/routines/segmentation.py | 6 +--- 59 files changed, 161 insertions(+), 152 deletions(-) diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 81f9d06f..28f523d6 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -37,8 +37,8 @@ neural network utils from torch.nn, as well as the partial util to provide the modified default arguments for the ELBO loss. """ + # %% -from functools import partial from pathlib import Path from lightning.pytorch import Trainer @@ -76,7 +76,7 @@ def optim_lenet(model: nn.Module) -> dict: # datamodule root = Path("") / "data" -datamodule = MNISTDataModule(root = root, batch_size=128, eval_ood=False) +datamodule = MNISTDataModule(root=root, batch_size=128, eval_ood=False) # model model = bayesian_lenet(datamodule.num_channels, datamodule.num_classes) @@ -93,8 +93,7 @@ def optim_lenet(model: nn.Module) -> dict: # from torch_uncertainty.classification. We provide the model, the ELBO # loss and the optimizer to the routine. -loss = partial( - ELBOLoss, +loss = ELBOLoss( model=model, criterion=nn.CrossEntropyLoss(), kl_weight=1 / 50000, diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index 0f8a890d..d1aa9a3f 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -39,6 +39,7 @@ from torch_uncertainty.models.lenet import lenet from torch_uncertainty.routines import ClassificationRoutine + # %% # 2. Creating the Optimizer Wrapper # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -46,11 +47,10 @@ # with the default learning rate of 0.001 and a step scheduler. def optim_lenet(model: nn.Module) -> dict: optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005) - exp_lr_scheduler = optim.lr_scheduler.StepLR( - optimizer, step_size=7, gamma=0.1 - ) + exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) return {"optimizer": optimizer, "lr_scheduler": exp_lr_scheduler} + # %% # 3. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -81,10 +81,7 @@ def optim_lenet(model: nn.Module) -> dict: # In this routine, we provide the model, the DEC loss, the optimizer, # and all the default arguments. -loss = partial( - DECLoss, - reg_weight=1e-2, -) +loss = DECLoss(reg_weight=1e-2) routine = ClassificationRoutine( model=model, @@ -125,7 +122,7 @@ def rotated_mnist(angle: int) -> None: """ rotated_images = F.rotate(images, angle) # print rotated images - plt.axis('off') + plt.axis("off") imshow(torchvision.utils.make_grid(rotated_images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 1919b306..273c7ba8 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -51,7 +51,7 @@ model = lenet( in_channels=datamodule.num_channels, num_classes=datamodule.num_classes, - norm = nn.BatchNorm2d, + norm=nn.BatchNorm2d, ) # %% @@ -59,13 +59,13 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # This is a classification problem, and we use CrossEntropyLoss as likelihood. # We define the training routine using the classification training routine from -# torch_uncertainty.training.classification. We provide the number of classes, +# torch_uncertainty.training.classification. We provide the number of classes, # and the optimization recipe. routine = ClassificationRoutine( num_classes=datamodule.num_classes, model=model, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18, ) @@ -87,7 +87,9 @@ # The authors suggest 32 as a good value for ``mc_batch_size`` but we use 4 here # to highlight the effect of stochasticity on the predictions. -routine.model = MCBatchNorm(routine.model, num_estimators=8, convert=True, mc_batch_size=4) +routine.model = MCBatchNorm( + routine.model, num_estimators=8, convert=True, mc_batch_size=4 +) routine.model.fit(datamodule.train) routine.eval() diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 29a3c4c7..75f8aefb 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -81,7 +81,7 @@ routine = ClassificationRoutine( num_classes=datamodule.num_classes, model=mc_model, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18, num_estimators=16, diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index fb396273..2ba51027 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -27,7 +27,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss style: cifar data: root: ./data diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index 9596dc65..59369531 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: batched arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index 958c8c25..79aa2fe7 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: masked arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index c642e877..d73cb421 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: mimo arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index c6c4ecd4..e920b354 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: packed arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index d6ab70f9..0184abf1 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: std arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index 396a268f..1352eb88 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: batched arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index 195c8338..dea33597 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: masked arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index 939f2897..d9575e9f 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: mimo arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index ac99f2f0..aa4d4e76 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: packed arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index 02743adb..f24e039e 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: std arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index 8d88ad09..82e91c72 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss style: cifar data: root: ./data diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index eeca402d..74e806db 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: batched style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 74a0b950..437b9243 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: masked style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index 782c1202..45bc95cd 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: mimo style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index e3af37c5..2fec727e 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: packed style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index 875ec995..15c3a848 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: std style: cifar data: diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index fb396273..2ba51027 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -27,7 +27,8 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss style: cifar data: root: ./data diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index 0c14021f..8c8c0d77 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: batched arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index c1d6f261..e184c07d 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: masked arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index 13fc9d20..983dec22 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: mimo arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 1dc954d1..099c93b7 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: packed arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index bcb1c7e9..8de85cc4 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: standard arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index 38331ee5..752158e2 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: batched arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index 09168144..306dc50c 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: masked arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index 387e6dd9..a24e9693 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: mimo arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index 3813fba7..4c5384fc 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: packed arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index 7f236d4f..5ae26e1f 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -29,7 +29,8 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: standard arch: 50 style: cifar diff --git a/experiments/classification/mnist/lenet.py b/experiments/classification/mnist/lenet.py index dc5f9636..450f72c2 100644 --- a/experiments/classification/mnist/lenet.py +++ b/experiments/classification/mnist/lenet.py @@ -44,7 +44,7 @@ def optim_lenet(model: nn.Module) -> dict: model=model, num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), optim_recipe=optim_lenet, **vars(args), ) diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index bd84ef6b..77be476f 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -49,7 +49,7 @@ def optim_tiny(model: nn.Module) -> dict: ResNetBaseline( num_classes=list_dm[i].dm.num_classes, in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), optim_recipe=get_procedure( f"resnet{args.arch}", "tiny-imagenet", args.version ), @@ -68,7 +68,7 @@ def optim_tiny(model: nn.Module) -> dict: model = ResNetBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), optim_recipe=get_procedure( f"resnet{args.arch}", "tiny-imagenet", args.version ), diff --git a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml index 8fd87576..225eb6ee 100644 --- a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml @@ -31,7 +31,8 @@ model: in_features: 8 hidden_dims: - 100 - loss: torch_uncertainty.losses.DistributionNLLLoss + loss: + class_path: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal data: diff --git a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml index 3c42f176..4f4d5345 100644 --- a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml @@ -31,7 +31,8 @@ model: in_features: 8 hidden_dims: - 100 - loss: torch_uncertainty.losses.DistributionNLLLoss + loss: + class_path: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace data: diff --git a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml index e1a4cda3..1cb32c36 100644 --- a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml @@ -31,7 +31,8 @@ model: in_features: 8 hidden_dims: - 100 - loss: torch.nn.MSELoss + loss: + class_path: torch.nn.MSELoss version: std data: root: ./data diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index 9146c141..9ec8fca7 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -6,7 +6,8 @@ trainer: devices: 1 model: num_classes: 12 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: std arch: 0 num_estimators: 1 diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index cce57b1f..366450a8 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -7,7 +7,8 @@ trainer: max_steps: 160000 model: num_classes: 19 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: std arch: 0 num_estimators: 1 diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml index cce57b1f..366450a8 100644 --- a/experiments/segmentation/muad/configs/segformer.yaml +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -7,7 +7,8 @@ trainer: max_steps: 160000 model: num_classes: 19 - loss: torch.nn.CrossEntropyLoss + loss: + class_path: torch.nn.CrossEntropyLoss version: std arch: 0 num_estimators: 1 diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index da3e5130..36c9915e 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -15,7 +15,7 @@ def test_batched_18(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="batched", arch=18, style="cifar", @@ -32,7 +32,7 @@ def test_batched_50(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="batched", arch=50, style="imagenet", @@ -53,7 +53,7 @@ def test_batched(self): net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="batched", style="cifar", num_estimators=4, diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index 9a51094c..5cefa5f5 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -16,7 +16,7 @@ def test_masked_18(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="masked", arch=18, style="cifar", @@ -34,7 +34,7 @@ def test_masked_50(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="masked", arch=50, style="imagenet", @@ -53,7 +53,7 @@ def test_masked_scale_lt_1(self): _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="masked", arch=18, style="cifar", @@ -67,7 +67,7 @@ def test_masked_groups_lt_1(self): _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="masked", arch=18, style="cifar", @@ -84,7 +84,7 @@ def test_masked(self): net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="masked", style="cifar", num_estimators=4, diff --git a/tests/baselines/test_mc_dropout.py b/tests/baselines/test_mc_dropout.py index b1c3995a..06006c2d 100644 --- a/tests/baselines/test_mc_dropout.py +++ b/tests/baselines/test_mc_dropout.py @@ -16,7 +16,7 @@ def test_standard(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="mc-dropout", dropout_rate=0.1, num_estimators=4, @@ -37,7 +37,7 @@ def test_standard(self): net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="mc-dropout", dropout_rate=0.1, num_estimators=4, @@ -57,7 +57,7 @@ def test_standard(self): net = VGGBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="mc-dropout", dropout_rate=0.1, num_estimators=4, @@ -73,7 +73,7 @@ def test_standard(self): net = VGGBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="mc-dropout", num_estimators=4, arch=11, diff --git a/tests/baselines/test_mimo.py b/tests/baselines/test_mimo.py index b3d48c28..5e191128 100644 --- a/tests/baselines/test_mimo.py +++ b/tests/baselines/test_mimo.py @@ -21,7 +21,7 @@ def test_mimo_50(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="mimo", arch=50, style="cifar", @@ -40,7 +40,7 @@ def test_mimo_18(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="mimo", arch=18, style="imagenet", @@ -63,7 +63,7 @@ def test_mimo(self): net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="mimo", style="cifar", num_estimators=4, diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index 6a57325c..ee772623 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -18,7 +18,7 @@ def test_packed_50(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="packed", arch=50, style="cifar", @@ -30,14 +30,13 @@ def test_packed_50(self): summary(net) - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_packed_18(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="packed", arch=18, style="imagenet", @@ -49,7 +48,6 @@ def test_packed_18(self): summary(net) - _ = net.criterion _ = net(torch.rand(1, 3, 40, 40)) def test_packed_exception(self): @@ -57,7 +55,7 @@ def test_packed_exception(self): _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="packed", arch=50, style="cifar", @@ -71,7 +69,7 @@ def test_packed_exception(self): _ = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="packed", arch=50, style="cifar", @@ -89,7 +87,7 @@ def test_packed(self): net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="packed", style="cifar", num_estimators=4, @@ -100,7 +98,6 @@ def test_packed(self): summary(net) - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) @@ -112,7 +109,7 @@ def test_packed(self): num_classes=10, in_channels=3, arch=13, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="packed", num_estimators=4, alpha=2, @@ -122,7 +119,6 @@ def test_packed(self): summary(net) - _ = net.criterion _ = net(torch.rand(2, 3, 32, 32)) @@ -133,7 +129,7 @@ def test_packed(self): net = MLPBaseline( in_features=3, output_dim=10, - loss=nn.MSELoss, + loss=nn.MSELoss(), version="packed", hidden_dims=[1], num_estimators=2, @@ -142,5 +138,4 @@ def test_packed(self): ) summary(net) - _ = net.criterion _ = net(torch.rand(1, 3)) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index 457b5e8e..0f8ccffa 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -19,7 +19,7 @@ def test_standard(self): net = ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="std", arch=18, style="cifar", @@ -27,7 +27,6 @@ def test_standard(self): ) summary(net) - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): @@ -35,7 +34,7 @@ def test_errors(self): ResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="test", arch=18, style="cifar", @@ -50,14 +49,13 @@ def test_standard(self): net = WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="std", style="cifar", groups=1, ) summary(net) - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): @@ -65,7 +63,7 @@ def test_errors(self): WideResNetBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="test", style="cifar", groups=1, @@ -79,14 +77,13 @@ def test_standard(self): net = VGGBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="std", arch=11, groups=1, ) summary(net) - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): @@ -94,7 +91,7 @@ def test_errors(self): VGGBaseline( num_classes=10, in_channels=3, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="test", arch=11, groups=1, @@ -108,20 +105,19 @@ def test_standard(self): net = MLPBaseline( in_features=3, output_dim=10, - loss=nn.MSELoss, + loss=nn.MSELoss(), version="std", hidden_dims=[1], ) summary(net) - _ = net.criterion _ = net(torch.rand(1, 3)) for distribution in ["normal", "laplace", "nig"]: MLPBaseline( in_features=3, output_dim=10, - loss=nn.MSELoss, + loss=nn.MSELoss(), version="std", hidden_dims=[1], distribution=distribution, @@ -132,7 +128,7 @@ def test_errors(self): MLPBaseline( in_features=3, output_dim=10, - loss=nn.MSELoss, + loss=nn.MSELoss(), version="test", hidden_dims=[1], ) @@ -144,20 +140,19 @@ class TestStandardSegFormerBaseline: def test_standard(self): net = SegFormerBaseline( num_classes=10, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="std", arch=0, ) summary(net) - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): with pytest.raises(ValueError): SegFormerBaseline( num_classes=10, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), version="test", arch=0, ) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 2e5e1c78..d6c6349b 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -28,7 +28,7 @@ def test_one_estimator_binary(self): model = DummyClassificationBaseline( in_channels=dm.num_channels, num_classes=dm.num_classes, - loss=nn.BCEWithLogitsLoss, + loss=nn.BCEWithLogitsLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="msp", @@ -51,7 +51,7 @@ def test_two_estimators_binary(self): model = DummyClassificationBaseline( in_channels=dm.num_channels, num_classes=dm.num_classes, - loss=nn.BCEWithLogitsLoss, + loss=nn.BCEWithLogitsLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="logit", @@ -75,7 +75,7 @@ def test_one_estimator_two_classes(self): model = DummyClassificationBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", @@ -99,7 +99,7 @@ def test_two_estimators_two_classes(self): model = DummyClassificationBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ood_criterion="energy", diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 119f9cf6..c22c799d 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -23,7 +23,7 @@ def test_one_estimator_one_output(self): probabilistic=True, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -37,7 +37,7 @@ def test_one_estimator_one_output(self): probabilistic=False, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -57,7 +57,7 @@ def test_one_estimator_two_outputs(self): probabilistic=True, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", dist_type="laplace", @@ -70,7 +70,7 @@ def test_one_estimator_two_outputs(self): probabilistic=False, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) @@ -88,7 +88,7 @@ def test_two_estimators_one_output(self): probabilistic=True, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", dist_type="nig", @@ -101,7 +101,7 @@ def test_two_estimators_one_output(self): probabilistic=False, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) @@ -119,7 +119,7 @@ def test_two_estimators_two_outputs(self): probabilistic=True, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) @@ -132,7 +132,7 @@ def test_two_estimators_two_outputs(self): probabilistic=False, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss, + loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 965693b3..8c7a715c 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -23,7 +23,7 @@ def test_one_estimator_two_classes(self): in_channels=dm.num_channels, num_classes=dm.num_classes, image_size=dm.image_size, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), baseline_type="single", optim_recipe=optim_cifar10_resnet18, ) @@ -43,7 +43,7 @@ def test_two_estimators_two_classes(self): in_channels=dm.num_channels, num_classes=dm.num_classes, image_size=dm.image_size, - loss=nn.CrossEntropyLoss, + loss=nn.CrossEntropyLoss(), baseline_type="ensemble", optim_recipe=optim_cifar10_resnet18, ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4d570858..8da93a96 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -20,7 +20,7 @@ def test_cli_init(self): "std", "--model.arch", "18", - "--model.loss", + "--model.loss.class_path", "torch.nn.CrossEntropyLoss", "--data.root", "./data", diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 4dbfa466..03871bc9 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -98,7 +98,7 @@ def __init__( self, num_classes: int, in_channels: int, - loss: type[nn.Module], + loss: nn.Module, version: Literal[ "std", "mc-dropout", diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 1942211c..5ef60bd3 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -36,7 +36,7 @@ def __init__( self, num_classes: int, in_channels: int, - loss: type[nn.Module], + loss: nn.Module, version: Literal["std", "mc-dropout", "packed"], arch: int, style: str = "imagenet", diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 9c3e108b..b160c8ff 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -32,7 +32,7 @@ def __init__( self, num_classes: int, in_channels: int, - loss: type[nn.Module], + loss: nn.Module, version: Literal[ "std", "mc-dropout", "packed", "batched", "masked", "mimo" ], diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 5c523072..44bc787e 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -23,7 +23,7 @@ def __init__( self, output_dim: int, in_features: int, - loss: type[nn.Module], + loss: nn.Module, version: Literal["std", "packed"], hidden_dims: list[int], num_estimators: int | None = 1, diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 07ae5043..55ed92a5 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -30,7 +30,7 @@ class SegFormerBaseline(SegmentationRoutine): def __init__( self, num_classes: int, - loss: type[nn.Module], + loss: nn.Module, version: Literal["std"], arch: int, num_estimators: int = 1, diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index eeb6bd32..8fd4b10f 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -66,7 +66,7 @@ def _kl_div(self) -> Tensor: class ELBOLoss(nn.Module): def __init__( self, - model: nn.Module, + model: nn.Module | None, criterion: nn.Module, kl_weight: float, num_samples: int, @@ -83,8 +83,9 @@ def __init__( num_samples (int): The number of samples to use for the ELBO loss """ super().__init__() - self.model = model - self._kl_div = KLDiv(model) + + if model is not None: + self.set_model(model) if isinstance(criterion, type): raise TypeError( @@ -129,6 +130,11 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: aggregated_elbo += self.kl_weight * self._kl_div() return aggregated_elbo / self.num_samples + def set_model(self, model: nn.Module) -> None: + self.model = model + if model is not None: + self._kl_div = KLDiv(model) + class DERLoss(DistributionNLLLoss): def __init__( diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index c088ca40..2cfd6294 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,5 +1,4 @@ from collections.abc import Callable -from functools import partial from pathlib import Path from typing import Literal @@ -39,7 +38,7 @@ def __init__( self, num_classes: int, model: nn.Module, - loss: type[nn.Module], + loss: nn.Module, num_estimators: int = 1, format_batch_fn: nn.Module | None = None, optim_recipe=None, @@ -237,15 +236,10 @@ def __init__( prefix="gpl/test_" ) - # Handle ELBO special cases - self.is_elbo = ( - isinstance(self.loss, partial) and self.loss.func == ELBOLoss - ) - - # Deep Evidential Classification - self.is_dec = self.loss == DECLoss or ( - isinstance(self.loss, partial) and self.loss.func == DECLoss - ) + self.is_elbo = isinstance(self.loss, ELBOLoss) + if self.is_elbo: + self.loss.set_model(self.model) + self.is_dec = isinstance(self.loss, DECLoss) # metrics for ensembles only if self.num_estimators > 1: @@ -349,9 +343,7 @@ def on_test_start(self) -> None: @property def criterion(self) -> nn.Module: - if self.is_elbo: - self.loss = partial(self.loss, model=self.model) - return self.loss() + return self.loss def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: """Forward pass of the model. @@ -394,18 +386,18 @@ def training_step( inputs, targets = self.format_batch_fn(batch) if self.is_elbo: - loss = self.criterion(inputs, targets) + loss = self.loss(inputs, targets) else: logits = self.forward(inputs) # BCEWithLogitsLoss expects float targets - if self.binary_cls and self.loss == nn.BCEWithLogitsLoss: + if self.binary_cls and isinstance(self.loss, nn.BCEWithLogitsLoss): logits = logits.squeeze(-1) targets = targets.float() if not self.is_dec: - loss = self.criterion(logits, targets) + loss = self.loss(logits, targets) else: - loss = self.criterion(logits, targets, self.current_epoch) + loss = self.loss(logits, targets, self.current_epoch) self.log("train_loss", loss) return loss diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 6b11b555..657859c3 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -82,9 +82,7 @@ def __init__( raise ValueError(f"output_dim must be positive, got {output_dim}.") self.output_dim = output_dim - self.one_dim_regression = False - if output_dim == 1: - self.one_dim_regression = True + self.one_dim_regression = output_dim == 1 self.optim_recipe = optim_recipe @@ -127,11 +125,6 @@ def forward(self, inputs: Tensor) -> Tensor: pred = pred.squeeze(-1) return pred - @property - def criterion(self) -> nn.Module: - """The loss function of the routine.""" - return self.loss() - def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: @@ -141,7 +134,7 @@ def training_step( if self.one_dim_regression: targets = targets.unsqueeze(-1) - loss = self.criterion(dists, targets) + loss = self.loss(dists, targets) self.log("train_loss", loss) return loss diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 26e93ca8..2ed053d7 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -58,10 +58,6 @@ def __init__( def configure_optimizers(self): return self.optim_recipe(self.model) - @property - def criterion(self) -> nn.Module: - return self.loss() - def forward(self, img: Tensor) -> Tensor: return self.model(img) @@ -83,7 +79,7 @@ def training_step( logits = rearrange(logits, "b c h w -> (b h w) c") target = target.flatten() valid_mask = target != 255 - loss = self.criterion(logits[valid_mask], target[valid_mask]) + loss = self.loss(logits[valid_mask], target[valid_mask]) self.log("train_loss", loss) return loss From 00795649d16cff73294732bc15ffea91287b1425 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 15:52:33 +0100 Subject: [PATCH 090/148] :white_check_mark: Ignore unreachable code in tests --- torch_uncertainty/baselines/classification/deep_ensembles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index 2b1e3ae2..fd6bc6c1 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -59,4 +59,4 @@ def __init__( log_plots=log_plots, calibration_set=calibration_set, ) - self.save_hyperparameters() + self.save_hyperparameters() # coverage: ignore From 16e37ba4a4cf3eaf3b3063ad7b2f0a449dfb8ea9 Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 16:32:27 +0100 Subject: [PATCH 091/148] :white_check_mark: Improve DataModule coverage - Add create_train_val_split() --- .../classification/test_mnist_datamodule.py | 3 +-- tests/datamodules/segmentation/test_camvid.py | 6 ++++++ .../segmentation/test_cityscapes.py | 6 ++++++ tests/test_utils.py | 12 ++++++++++- tests/transforms/test_image.py | 20 +++++++++++++++++++ .../datamodules/classification/cifar10.py | 16 ++++++--------- .../datamodules/classification/cifar100.py | 15 +++++--------- .../datamodules/classification/imagenet.py | 14 +++++-------- .../datamodules/classification/mnist.py | 15 +++++--------- .../classification/tiny_imagenet.py | 15 +++++--------- .../datamodules/segmentation/cityscapes.py | 14 ++++--------- .../datamodules/segmentation/muad.py | 14 ++++--------- torch_uncertainty/metrics/mean_iou.py | 18 ----------------- torch_uncertainty/utils/__init__.py | 2 +- torch_uncertainty/utils/misc.py | 16 +++++++++++++-- 15 files changed, 93 insertions(+), 93 deletions(-) diff --git a/tests/datamodules/classification/test_mnist_datamodule.py b/tests/datamodules/classification/test_mnist_datamodule.py index 234eb613..0da67431 100644 --- a/tests/datamodules/classification/test_mnist_datamodule.py +++ b/tests/datamodules/classification/test_mnist_datamodule.py @@ -31,7 +31,6 @@ def test_mnist_cutout(self): dm.prepare_data() dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() @@ -43,5 +42,5 @@ def test_mnist_cutout(self): dm.eval_ood = True dm.val_split = 0.1 dm.prepare_data() - dm.setup("test") + dm.setup() dm.test_dataloader() diff --git a/tests/datamodules/segmentation/test_camvid.py b/tests/datamodules/segmentation/test_camvid.py index 9ccf4d0c..23b1d4d3 100644 --- a/tests/datamodules/segmentation/test_camvid.py +++ b/tests/datamodules/segmentation/test_camvid.py @@ -29,3 +29,9 @@ def test_camvid_main(self): dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() diff --git a/tests/datamodules/segmentation/test_cityscapes.py b/tests/datamodules/segmentation/test_cityscapes.py index 3dcd74ec..25b0bbd1 100644 --- a/tests/datamodules/segmentation/test_cityscapes.py +++ b/tests/datamodules/segmentation/test_cityscapes.py @@ -29,3 +29,9 @@ def test_camvid_main(self): dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() diff --git a/tests/test_utils.py b/tests/test_utils.py index bebafb22..fa89df36 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,13 @@ import torch from torch.distributions import Laplace, Normal -from torch_uncertainty.utils import csv_writer, distributions, get_version, hub +from torch_uncertainty.utils import ( + csv_writer, + distributions, + get_version, + hub, + plot_hist, +) class TestUtils: @@ -47,6 +53,10 @@ def test_csv_writer(self): root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0, "c": 3.0} ) + def test_plot_hist(self): + conf = [torch.rand(20), torch.rand(20)] + plot_hist(conf, bins=10, title="test") + class TestDistributions: """Testing distributions methods.""" diff --git a/tests/transforms/test_image.py b/tests/transforms/test_image.py index 56eeb560..872707d9 100644 --- a/tests/transforms/test_image.py +++ b/tests/transforms/test_image.py @@ -2,6 +2,7 @@ import pytest import torch from PIL import Image +from torchvision import tv_tensors from torch_uncertainty.transforms import ( AutoContrast, @@ -11,6 +12,7 @@ Equalize, MIMOBatchFormat, Posterize, + RandomRescale, RepeatTarget, Rotate, Sharpen, @@ -26,6 +28,16 @@ def img_input() -> torch.Tensor: return Image.fromarray(imarray.astype("uint8")).convert("RGB") +@pytest.fixture() +def tv_tensors_input() -> tuple[torch.Tensor, torch.Tensor]: + imarray1 = np.random.rand(3, 28, 28) * 255 + imarray2 = np.random.rand(1, 28, 28) * 255 + return ( + tv_tensors.Image(imarray1.astype("uint8")), + tv_tensors.Mask(imarray2.astype("uint8")), + ) + + @pytest.fixture() def batch_input() -> tuple[torch.Tensor, torch.Tensor]: imgs = torch.rand(2, 3, 28, 28) @@ -210,3 +222,11 @@ def test_failures(self): with pytest.raises(ValueError): _ = MIMOBatchFormat(1, 0, 0) + + +class TestRandomRescale: + """Testing the RandomRescale transform.""" + + def test_tv_tensors(self, tv_tensors_input): + aug = RandomRescale(0.5, 2.0) + _ = aug(tv_tensors_input) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 90afed91..45452115 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -1,4 +1,3 @@ -import copy from pathlib import Path from typing import Literal @@ -7,13 +6,14 @@ from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10, SVHN from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H from torch_uncertainty.transforms import Cutout +from torch_uncertainty.utils import create_train_val_split class CIFAR10DataModule(AbstractDataModule): @@ -149,16 +149,12 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) - # FIXME: memory cost issues might arise here - self.val = copy.deepcopy(val) - self.val.dataset.transform = self.test_transform + else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index b2ea294b..bc5a3691 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -1,4 +1,3 @@ -import copy from pathlib import Path from typing import Literal @@ -8,13 +7,14 @@ from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C from torch_uncertainty.transforms import Cutout +from torch_uncertainty.utils import create_train_val_split class CIFAR100DataModule(AbstractDataModule): @@ -149,16 +149,11 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) - # FIXME: memory cost issues might arise here - self.val = copy.deepcopy(val) - self.val.dataset.transform = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index aaebeccb..8f89a23b 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -7,7 +7,7 @@ from timm.data.auto_augment import rand_augment_transform from timm.data.mixup import Mixup from torch import nn -from torch.utils.data import DataLoader, Subset, random_split +from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist from torch_uncertainty.datamodules.abstract import AbstractDataModule @@ -17,6 +17,7 @@ ImageNetR, OpenImageO, ) +from torch_uncertainty.utils.misc import create_train_val_split class ImageNetDataModule(AbstractDataModule): @@ -202,16 +203,11 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split and isinstance(self.val_split, float): - self.train, val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) - # FIXME: memory cost issues might arise here - self.val = copy.deepcopy(val) - self.val.dataset.transform = self.test_transform elif isinstance(self.val_split, Path): self.train = Subset(full, self.train_indices) # TODO: improve the performance diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 0d9fe147..77a6f4f5 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -1,15 +1,15 @@ -import copy from pathlib import Path from typing import Literal import torchvision.transforms as T from torch import nn -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST from torch_uncertainty.transforms import Cutout +from torch_uncertainty.utils import create_train_val_split class MNISTDataModule(AbstractDataModule): @@ -113,16 +113,11 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) - # FIXME: memory cost issues might arise here - self.val = copy.deepcopy(val) - self.val.dataset.transform = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 71334368..5430264d 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -1,4 +1,3 @@ -import copy from pathlib import Path from typing import Literal @@ -7,11 +6,12 @@ from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import ConcatDataset, DataLoader, random_split +from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet +from torch_uncertainty.utils import create_train_val_split class TinyImageNetDataModule(AbstractDataModule): @@ -128,16 +128,11 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.train_transform, ) if self.val_split: - self.train, val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) - # FIXME: memory cost issues might arise here - self.val = copy.deepcopy(val) - self.val.dataset.transform = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 6a937827..57e0db0d 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -1,16 +1,15 @@ -import copy from pathlib import Path import torch from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair -from torch.utils.data import random_split from torchvision import tv_tensors from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.segmentation import Cityscapes from torch_uncertainty.transforms import RandomRescale +from torch_uncertainty.utils.misc import create_train_val_split class CityscapesDataModule(AbstractDataModule): @@ -91,16 +90,11 @@ def setup(self, stage: str | None = None) -> None: ) if self.val_split is not None: - self.train, val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) - # FIXME: memory cost issues might arise here - self.val = copy.deepcopy(val) - self.val.dataset.transforms = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 0b567a1e..acf0535c 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -1,16 +1,15 @@ -import copy from pathlib import Path import torch from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair -from torch.utils.data import random_split from torchvision import tv_tensors from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets import MUAD from torch_uncertainty.transforms import RandomRescale +from torch_uncertainty.utils.misc import create_train_val_split class MUADDataModule(AbstractDataModule): @@ -94,16 +93,11 @@ def setup(self, stage: str | None = None) -> None: ) if self.val_split is not None: - self.train, val = random_split( + self.train, self.val = create_train_val_split( full, - [ - 1 - self.val_split, - self.val_split, - ], + self.val_split, + self.test_transform, ) - # FIXME: memory cost issues might arise here - self.val = copy.deepcopy(val) - self.val.dataset.transforms = self.test_transform else: self.train = full self.val = self.dataset( diff --git a/torch_uncertainty/metrics/mean_iou.py b/torch_uncertainty/metrics/mean_iou.py index 1a3dec82..95c5b8a0 100644 --- a/torch_uncertainty/metrics/mean_iou.py +++ b/torch_uncertainty/metrics/mean_iou.py @@ -1,4 +1,3 @@ -from einops import rearrange from torch import Tensor from torchmetrics.classification.stat_scores import MulticlassStatScores from torchmetrics.utilities.compute import _safe_divide @@ -11,23 +10,6 @@ class MeanIntersectionOverUnion(MulticlassStatScores): higher_is_better: bool = True full_state_update: bool = False - def update(self, preds: Tensor, target: Tensor) -> None: - """Update state with predictions and targets. - - Args: - preds (Tensor): prediction images of shape :math:`(B, H, W)` or - :math:`(B, C, H, W)`. - target (Tensor): target images of shape :math:`(B, H, W)`. - """ - if preds.ndim == 3: - preds = preds.flatten() - if preds.ndim == 4: - preds = rearrange(preds, "b c h w -> (b h w) c") - - target = target.flatten() - - super().update(preds, target) - def compute(self) -> Tensor: """Compute the Means Intersection over Union (MIoU) based on saved inputs.""" tp, fp, _, fn = self._final_state() diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index b5d0b1d4..ad7a12db 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -2,4 +2,4 @@ from .checkpoints import get_version from .cli import TULightningCLI from .hub import load_hf -from .misc import csv_writer, plot_hist +from .misc import create_train_val_split, csv_writer, plot_hist diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index 5b87647a..9a134ccf 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -1,10 +1,13 @@ +import copy import csv +from collections.abc import Callable from pathlib import Path import matplotlib.pyplot as plt import torch from matplotlib.axes import Axes from matplotlib.figure import Figure +from torch.utils.data import Dataset, random_split def csv_writer(path: Path, dic: dict) -> None: @@ -20,8 +23,6 @@ def csv_writer(path: Path, dic: dict) -> None: else: append_mode = False rw_mode = "w" - print(f"Writing to {path}") - print(f"Append mode: {append_mode}") # Write dic with path.open(rw_mode) as csvfile: writer = csv.writer(csvfile, delimiter=",") @@ -68,3 +69,14 @@ def plot_hist( plt.legend() fig.tight_layout() return fig, ax + + +def create_train_val_split( + dataset: Dataset, + val_split_rate: float, + val_transforms: Callable | None = None, +): + train, val = random_split(dataset, [1 - val_split_rate, val_split_rate]) + val = copy.deepcopy(val) + val.dataset.transform = val_transforms + return train, val From e6d8baf8068660489d382b63ac6074005067ab4e Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 17:00:04 +0100 Subject: [PATCH 092/148] :white_check_mark: Improve cli coverage hopefully --- tests/test_cli.py | 6 +++++- torch_uncertainty/utils/cli.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 8da93a96..f696fcbf 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -26,8 +26,12 @@ def test_cli_init(self): "./data", "--data.batch_size", "4", + "--trainer.callbacks+=ModelCheckpoint", + "--trainer.callbacks.monitor=cls_val/acc", + "--trainer.callbacks.mode=max", ] cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) assert cli.eval_after_fit_default is False assert cli.save_config_callback == TUSaveConfigCallback - cli.instantiate_trainer() + assert isinstance(cli.trainer.callbacks[0], TUSaveConfigCallback) + cli.trainer.callbacks[0].setup(cli.trainer, cli.model, stage="fit") diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index e04a5ec0..e2b50156 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -117,6 +117,7 @@ def __init__( run, auto_configure_optimizers, ) + print("IN TU CLI INIT") def add_default_arguments_to_parser( self, parser: LightningArgumentParser From 1bc4f765d668eabd59a3229ebcbd3aa43d3ec9ed Mon Sep 17 00:00:00 2001 From: alafage Date: Thu, 21 Mar 2024 17:08:32 +0100 Subject: [PATCH 093/148] :white_check_mark: Slightly improves cli coverage --- tests/test_cli.py | 2 ++ torch_uncertainty/utils/cli.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index f696fcbf..edce26d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -35,3 +35,5 @@ def test_cli_init(self): assert cli.save_config_callback == TUSaveConfigCallback assert isinstance(cli.trainer.callbacks[0], TUSaveConfigCallback) cli.trainer.callbacks[0].setup(cli.trainer, cli.model, stage="fit") + cli.trainer.callbacks[0].already_saved = True + cli.trainer.callbacks[0].setup(cli.trainer, cli.model, stage="fit") diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index e2b50156..d1871b91 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -34,7 +34,7 @@ def setup( ) # broadcast whether to fail to all ranks file_exists = trainer.strategy.broadcast(file_exists) - if file_exists: + if file_exists: # coverage: ignore raise RuntimeError( f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" " results of a previous run. You can delete the previous config file," From 5d77f3de30a9f927dab8a341b1b18eaa7f9bea4a Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 16:24:35 +0100 Subject: [PATCH 094/148] :wrench: Try fixing codecov issues --- codecov.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codecov.yml b/codecov.yml index 3c1d530a..c5a92954 100644 --- a/codecov.yml +++ b/codecov.yml @@ -6,3 +6,6 @@ coverage: patch: default: target: 95% + +codecov: + disable_default_path_fixes: true From 1d2dc94976ee0c2f3ac242ba177a093477bbdbed Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 21 Mar 2024 16:38:54 +0100 Subject: [PATCH 095/148] :white_check_mark: Complete losses cov. --- tests/test_losses.py | 14 ++++++-------- torch_uncertainty/losses.py | 3 +-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/test_losses.py b/tests/test_losses.py index ff671938..d88ba2a1 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -31,9 +31,14 @@ class TestELBOLoss: def test_main(self): model = BayesLinear(1, 1) criterion = nn.BCEWithLogitsLoss() - loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) + loss(model(torch.randn(1, 1)), torch.randn(1, 1)) + + model = nn.Linear(1, 1) + criterion = nn.BCEWithLogitsLoss() + ELBOLoss(None, criterion, kl_weight=1e-5, num_samples=1) + loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) loss(model(torch.randn(1, 1)), torch.randn(1, 1)) def test_failures(self): @@ -52,13 +57,6 @@ def test_failures(self): with pytest.raises(TypeError): ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1.5) - def test_no_bayes(self): - model = nn.Linear(1, 1) - criterion = nn.BCEWithLogitsLoss() - - loss = ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1) - loss(model(torch.randn(1, 1)), torch.randn(1, 1)) - class TestNIGLoss: """Testing the DERLoss class.""" diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index 8fd4b10f..bd7319c9 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -84,8 +84,7 @@ def __init__( """ super().__init__() - if model is not None: - self.set_model(model) + self.set_model(model) if isinstance(criterion, type): raise TypeError( From fe4903e3f8c8a4be65ba4cc07730380102e79ab4 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 00:23:22 +0100 Subject: [PATCH 096/148] :bug: Fix TemperatureScaler in ClassificationRoutine - Improve ClassificationRoutine coverage --- tests/_dummies/baseline.py | 3 ++ tests/routines/test_classification.py | 31 ++++++++++++++++++- .../baselines/classification/resnet.py | 2 +- .../baselines/classification/vgg.py | 2 +- .../baselines/classification/wideresnet.py | 3 +- torch_uncertainty/baselines/regression/mlp.py | 2 +- .../baselines/segmentation/segformer.py | 2 +- torch_uncertainty/routines/classification.py | 17 +++++----- 8 files changed, 45 insertions(+), 17 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index bee2b77c..07f9ce79 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -32,6 +32,7 @@ def __new__( ood_criterion: str = "msp", eval_ood: bool = False, eval_grouping_loss: bool = False, + calibrate: bool = False, ) -> LightningModule: model = dummy_model( in_channels=in_channels, @@ -52,6 +53,7 @@ def __new__( ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, + calibration_set="val" if calibrate else None, ) # baseline_type == "ensemble": model = deep_ensembles( @@ -69,6 +71,7 @@ def __new__( ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, + calibration_set="val" if calibrate else None, ) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index d6c6349b..ecc9a27d 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -87,7 +87,34 @@ def test_one_estimator_two_classes(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) - def test_two_estimators_two_classes(self): + def test_one_estimator_two_classes_calibrated_with_ood(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True, logger=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + # eval_grouping_loss=True, + calibrate=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_two_classes_with_ood(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( @@ -95,6 +122,7 @@ def test_two_estimators_two_classes(self): batch_size=16, num_classes=2, num_images=100, + eval_ood=True, ) model = DummyClassificationBaseline( num_classes=dm.num_classes, @@ -103,6 +131,7 @@ def test_two_estimators_two_classes(self): optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ood_criterion="energy", + eval_ood=True, ) trainer.fit(model, dm) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 03871bc9..989bb9d8 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -302,4 +302,4 @@ def __init__( save_in_csv=save_in_csv, calibration_set=calibration_set, ) - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 5ef60bd3..f4b3d573 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -200,4 +200,4 @@ def __init__( save_in_csv=save_in_csv, calibration_set=calibration_set, ) - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index b160c8ff..0f00b855 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -214,5 +214,4 @@ def __init__( save_in_csv=save_in_csv, calibration_set=calibration_set, ) - - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 44bc787e..93ec8b87 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -85,4 +85,4 @@ def __init__( num_estimators=num_estimators, format_batch_fn=format_batch_fn, ) - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 55ed92a5..49be67d7 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -76,4 +76,4 @@ def __init__( num_estimators=num_estimators, format_batch_fn=format_batch_fn, ) - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 2cfd6294..17ca1ae5 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from einops import rearrange from lightning.pytorch import LightningModule -from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers import Logger from lightning.pytorch.utilities.types import STEP_OUTPUT from timm.data import Mixup as timm_Mixup from torch import Tensor, nn @@ -325,19 +325,16 @@ def on_test_start(self) -> None: if self.calibration_set == "val" else self.trainer.datamodule.test_dataloader().dataset ) - self.scaler = TemperatureScaler(device=self.device).fit( - model=self.model, calibration_set=dataset - ) + with torch.inference_mode(False): + self.scaler = TemperatureScaler( + model=self.model, device=self.device + ).fit(calibration_set=dataset) self.cal_model = torch.nn.Sequential(self.model, self.scaler) else: self.scaler = None self.cal_model = None - if ( - self.eval_ood - and self.log_plots - and isinstance(self.logger, TensorBoardLogger) - ): + if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = [] self.ood_logit_storage = [] @@ -574,7 +571,7 @@ def on_test_epoch_end(self) -> None: result_dict.update(tmp_metrics) self.test_ood_ens_metrics.reset() - if isinstance(self.logger, TensorBoardLogger) and self.log_plots: + if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] ) From d4dbc40978b0d1745c76741383d15ee3a2297ec0 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 00:48:00 +0100 Subject: [PATCH 097/148] :bug: Fix not setted up save_feats in validation_step ClassificationRoutine --- tests/_dummies/model.py | 5 ++++- torch_uncertainty/routines/classification.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 17682920..e0d3232c 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -46,7 +46,10 @@ def forward(self, x: Tensor) -> Tensor: class _DummyWithFeats(_Dummy): def feats_forward(self, x: Tensor) -> Tensor: - return self.forward(x) + return torch.ones( + (x.shape[0], 1), + dtype=torch.float32, + ) class _DummySegmentation(nn.Module): diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 17ca1ae5..07214155 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -403,7 +403,9 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch - logits = self.forward(inputs) # (m*b, c) + logits = self.forward( + inputs, save_feats=self.eval_grouping_loss + ) # (m*b, c) logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) if self.binary_cls: From f56e69f806688c867b2343f0bca097367a85f707 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 00:52:38 +0100 Subject: [PATCH 098/148] :white_check_mark: Improve MUADDataModule coverage --- tests/datamodules/segmentation/test_muad.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/datamodules/segmentation/test_muad.py b/tests/datamodules/segmentation/test_muad.py index 314b172c..862206f0 100644 --- a/tests/datamodules/segmentation/test_muad.py +++ b/tests/datamodules/segmentation/test_muad.py @@ -29,3 +29,9 @@ def test_camvid_main(self): dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() From f235438f040f9f6140ffdc5f85e678f3752768a2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 22 Mar 2024 10:46:25 +0100 Subject: [PATCH 099/148] :hammer: Rework routines & adapt tests and tutorials --- auto_tutorials_source/tutorial_bayesian.py | 2 +- auto_tutorials_source/tutorial_der_cubic.py | 2 +- .../tutorial_evidential_classification.py | 2 +- .../tutorial_mc_batch_norm.py | 2 +- auto_tutorials_source/tutorial_mc_dropout.py | 2 +- tests/_dummies/baseline.py | 12 +- tests/baselines/test_batched.py | 6 - tests/baselines/test_masked.py | 9 +- tests/baselines/test_mc_dropout.py | 6 - tests/baselines/test_mimo.py | 6 - tests/baselines/test_packed.py | 4 - tests/baselines/test_standard.py | 5 - tests/routines/test_classification.py | 42 ++++-- tests/routines/test_segmentation.py | 9 +- torch_uncertainty/routines/classification.py | 137 ++++++++++-------- torch_uncertainty/routines/regression.py | 52 +++---- torch_uncertainty/routines/segmentation.py | 57 +++++--- 17 files changed, 191 insertions(+), 164 deletions(-) diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 28f523d6..38b3689a 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -104,7 +104,7 @@ def optim_lenet(model: nn.Module) -> dict: model=model, num_classes=datamodule.num_classes, loss=loss, - optim_recipe=optim_lenet, + optim_recipe=optim_lenet(model), ) # %% diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 941ef998..f293a65b 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -113,7 +113,7 @@ def optim_regression( output_dim=1, model=model, loss=loss, - optim_recipe=optim_regression, + optim_recipe=optim_regression(model), ) # %% diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index d1aa9a3f..cef81083 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -87,7 +87,7 @@ def optim_lenet(model: nn.Module) -> dict: model=model, num_classes=datamodule.num_classes, loss=loss, - optim_recipe=optim_lenet, + optim_recipe=optim_lenet(model), ) # %% diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 273c7ba8..0da3b4b2 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -66,7 +66,7 @@ num_classes=datamodule.num_classes, model=model, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18(model), ) # %% diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 75f8aefb..7600dec4 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -82,7 +82,7 @@ num_classes=datamodule.num_classes, model=mc_model, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, + optim_recipe=optim_cifar10_resnet18(mc_model), num_estimators=16, ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 07f9ce79..c043db6e 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -48,7 +48,7 @@ def __new__( loss=loss, format_batch_fn=nn.Identity(), log_plots=True, - optim_recipe=optim_recipe, + optim_recipe=optim_recipe(model), num_estimators=1, ood_criterion=ood_criterion, eval_ood=eval_ood, @@ -64,7 +64,7 @@ def __new__( num_classes=num_classes, model=model, loss=loss, - optim_recipe=optim_recipe, + optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), log_plots=True, num_estimators=2, @@ -112,7 +112,7 @@ def __new__( model=model, loss=loss, num_estimators=1, - optim_recipe=optim_recipe, + optim_recipe=optim_recipe(model), ) # baseline_type == "ensemble": model = deep_ensembles( @@ -126,7 +126,7 @@ def __new__( model=model, loss=loss, num_estimators=2, - optim_recipe=optim_recipe, + optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), ) @@ -154,7 +154,7 @@ def __new__( loss=loss, format_batch_fn=None, num_estimators=1, - optim_recipe=optim_recipe, + optim_recipe=optim_recipe(model), ) # baseline_type == "ensemble": @@ -168,5 +168,5 @@ def __new__( loss=loss, format_batch_fn=RepeatTarget(2), num_estimators=2, - optim_recipe=optim_recipe, + optim_recipe=optim_recipe(model), ) diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index 36c9915e..ef208523 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -24,8 +24,6 @@ def test_batched_18(self): ) summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_batched_50(self): @@ -41,8 +39,6 @@ def test_batched_50(self): ) summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 40, 40)) @@ -61,6 +57,4 @@ def test_batched(self): ) summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index 5cefa5f5..6989230a 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -26,8 +26,6 @@ def test_masked_18(self): ) summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_masked_50(self): @@ -44,11 +42,9 @@ def test_masked_50(self): ) summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 40, 40)) - def test_masked_scale_lt_1(self): + def test_masked_errors(self): with pytest.raises(Exception): _ = ResNetBaseline( num_classes=10, @@ -62,7 +58,6 @@ def test_masked_scale_lt_1(self): groups=1, ) - def test_masked_groups_lt_1(self): with pytest.raises(Exception): _ = ResNetBaseline( num_classes=10, @@ -93,6 +88,4 @@ def test_masked(self): ) summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_mc_dropout.py b/tests/baselines/test_mc_dropout.py index 06006c2d..dca61c3c 100644 --- a/tests/baselines/test_mc_dropout.py +++ b/tests/baselines/test_mc_dropout.py @@ -25,8 +25,6 @@ def test_standard(self): groups=1, ) summary(net) - - _ = net.criterion net(torch.rand(1, 3, 32, 32)) @@ -45,8 +43,6 @@ def test_standard(self): groups=1, ) summary(net) - - _ = net.criterion net(torch.rand(1, 3, 32, 32)) @@ -66,8 +62,6 @@ def test_standard(self): last_layer_dropout=True, ) summary(net) - - _ = net.criterion net(torch.rand(1, 3, 32, 32)) net = VGGBaseline( diff --git a/tests/baselines/test_mimo.py b/tests/baselines/test_mimo.py index 5e191128..4b7e6231 100644 --- a/tests/baselines/test_mimo.py +++ b/tests/baselines/test_mimo.py @@ -32,8 +32,6 @@ def test_mimo_50(self): ).eval() summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) def test_mimo_18(self): @@ -51,8 +49,6 @@ def test_mimo_18(self): ).eval() summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 40, 40)) @@ -73,6 +69,4 @@ def test_mimo(self): ).eval() summary(net) - - _ = net.criterion _ = net(torch.rand(1, 3, 32, 32)) diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index ee772623..34ed4a6f 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -47,7 +47,6 @@ def test_packed_18(self): ) summary(net) - _ = net(torch.rand(1, 3, 40, 40)) def test_packed_exception(self): @@ -97,7 +96,6 @@ def test_packed(self): ) summary(net) - _ = net(torch.rand(1, 3, 32, 32)) @@ -118,7 +116,6 @@ def test_packed(self): ) summary(net) - _ = net(torch.rand(2, 3, 32, 32)) @@ -137,5 +134,4 @@ def test_packed(self): gamma=1, ) summary(net) - _ = net(torch.rand(1, 3)) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index 0f8ccffa..77cb948f 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -26,7 +26,6 @@ def test_standard(self): groups=1, ) summary(net) - _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): @@ -55,7 +54,6 @@ def test_standard(self): groups=1, ) summary(net) - _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): @@ -83,7 +81,6 @@ def test_standard(self): groups=1, ) summary(net) - _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): @@ -110,7 +107,6 @@ def test_standard(self): hidden_dims=[1], ) summary(net) - _ = net(torch.rand(1, 3)) for distribution in ["normal", "laplace", "nig"]: @@ -145,7 +141,6 @@ def test_standard(self): arch=0, ) summary(net) - _ = net(torch.rand(1, 3, 32, 32)) def test_errors(self): diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index ecc9a27d..31acb368 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -142,35 +142,59 @@ def test_two_estimators_two_classes_with_ood(self): def test_classification_failures(self): # num_estimators with pytest.raises(ValueError): - ClassificationRoutine(10, nn.Module(), None, num_estimators=-1) + ClassificationRoutine( + num_classes=10, model=nn.Module(), loss=None, num_estimators=-1 + ) # num_classes with pytest.raises(ValueError): - ClassificationRoutine(0, nn.Module(), None) + ClassificationRoutine(num_classes=0, model=nn.Module(), loss=None) # single & MI with pytest.raises(ValueError): ClassificationRoutine( - 10, nn.Module(), None, num_estimators=1, ood_criterion="mi" + num_classes=10, + model=nn.Module(), + loss=None, + num_estimators=1, + ood_criterion="mi", ) with pytest.raises(ValueError): - ClassificationRoutine(10, nn.Module(), None, ood_criterion="other") + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + ood_criterion="other", + ) with pytest.raises(ValueError): - ClassificationRoutine(10, nn.Module(), None, cutmix_alpha=-1) + ClassificationRoutine( + num_classes=10, model=nn.Module(), loss=None, cutmix_alpha=-1 + ) with pytest.raises(ValueError): ClassificationRoutine( - 10, nn.Module(), None, eval_grouping_loss=True + num_classes=10, + model=nn.Module(), + loss=None, + eval_grouping_loss=True, ) with pytest.raises(NotImplementedError): ClassificationRoutine( - 10, nn.Module(), None, 2, eval_grouping_loss=True + num_classes=10, + model=nn.Module(), + loss=None, + num_estimators=2, + eval_grouping_loss=True, ) model = dummy_model(1, 1, 0, with_feats=False, with_linear=True) with pytest.raises(ValueError): - ClassificationRoutine(10, model, None, eval_grouping_loss=True) + ClassificationRoutine( + num_classes=10, model=model, loss=None, eval_grouping_loss=True + ) model = dummy_model(1, 1, 0, with_feats=True, with_linear=False) with pytest.raises(ValueError): - ClassificationRoutine(10, model, None, eval_grouping_loss=True) + ClassificationRoutine( + num_classes=10, model=model, loss=None, eval_grouping_loss=True + ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 8c7a715c..1801a7c1 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -56,7 +56,12 @@ def test_two_estimators_two_classes(self): def test_segmentation_failures(self): with pytest.raises(ValueError): SegmentationRoutine( - 2, nn.Identity(), nn.CrossEntropyLoss(), num_estimators=0 + model=nn.Identity(), + num_classes=2, + loss=nn.CrossEntropyLoss(), + num_estimators=0, ) with pytest.raises(ValueError): - SegmentationRoutine(1, nn.Identity(), nn.CrossEntropyLoss()) + SegmentationRoutine( + model=nn.Identity(), num_classes=1, loss=nn.CrossEntropyLoss() + ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 07214155..b2047295 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -10,6 +10,7 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT from timm.data import Mixup as timm_Mixup from torch import Tensor, nn +from torch.optim import Optimizer from torchmetrics import Accuracy, MetricCollection from torchmetrics.classification import ( BinaryAUROC, @@ -36,12 +37,12 @@ class ClassificationRoutine(LightningModule): def __init__( self, - num_classes: int, model: nn.Module, + num_classes: int, loss: nn.Module, num_estimators: int = 1, format_batch_fn: nn.Module | None = None, - optim_recipe=None, + optim_recipe: dict | Optimizer | None = None, mixtype: str = "erm", mixmode: str = "elem", dist_sim: str = "emb", @@ -58,17 +59,18 @@ def __init__( save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, ) -> None: - """Classification routine. + """Classification routine for Lightning. Args: - num_classes (int): Number of classes. model (nn.Module): Model to train. - loss (type[nn.Module]): Loss function. + num_classes (int): Number of classes. + loss (type[nn.Module]): Loss function to optimize the :attr:`model`. num_estimators (int, optional): Number of estimators for the - ensemble. Defaults to 1. + ensemble. Defaults to 1 (single model). format_batch_fn (nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. - optim_recipe (optional): Training recipe. Defaults to None. + optim_recipe (dict | Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. mixtype (str, optional): Mixup type. Defaults to ``"erm"``. mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. @@ -99,62 +101,23 @@ def __init__( ValueError: _description_ ValueError: _description_ ValueError: _description_ + + Warning: + You must define :attr:`optim_recipe` if you do not use + the CLI. """ super().__init__() + _classification_routine_checks( + model=model, + num_classes=num_classes, + num_estimators=num_estimators, + ood_criterion=ood_criterion, + eval_grouping_loss=eval_grouping_loss, + ) if format_batch_fn is None: format_batch_fn = nn.Identity() - if not isinstance(num_estimators, int) or num_estimators < 1: - raise ValueError( - "The number of estimators must be a positive integer >= 1." - f"Got {num_estimators}." - ) - - if ood_criterion not in [ - "msp", - "logit", - "energy", - "entropy", - "mi", - "vr", - ]: - raise ValueError( - "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," - f" 'mi' or 'vr'. Got {ood_criterion}." - ) - - if num_estimators == 1 and ood_criterion in ["mi", "vr"]: - raise ValueError( - "You cannot use mutual information or variation ratio with a single" - " model." - ) - - if num_estimators != 1 and eval_grouping_loss: - raise NotImplementedError( - "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." - ) - - if num_classes < 1: - raise ValueError( - "The number of classes must be a positive integer >= 1." - f"Got {num_classes}." - ) - - if eval_grouping_loss and not hasattr(model, "feats_forward"): - raise ValueError( - "Your model must have a `feats_forward` method to compute the " - "grouping loss." - ) - - if eval_grouping_loss and not ( - hasattr(model, "classification_head") or hasattr(model, "linear") - ): - raise ValueError( - "Your model must have a `classification_head` or `linear` " - "attribute to compute the grouping loss." - ) - self.num_classes = num_classes self.num_estimators = num_estimators self.eval_ood = eval_ood @@ -303,7 +266,7 @@ def init_mixup( return nn.Identity() def configure_optimizers(self): - return self.optim_recipe(self.model) + return self.optim_recipe def on_train_start(self) -> None: init_metrics = {k: 0 for k in self.val_cls_metrics} @@ -338,10 +301,6 @@ def on_test_start(self) -> None: self.id_logit_storage = [] self.ood_logit_storage = [] - @property - def criterion(self) -> nn.Module: - return self.loss - def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: """Forward pass of the model. @@ -616,3 +575,57 @@ def save_results_to_csv(self, results: dict[str, float]) -> None: Path(self.logger.log_dir) / "results.csv", results, ) + + +def _classification_routine_checks( + model, num_classes, num_estimators, ood_criterion, eval_grouping_loss +): + if not isinstance(num_estimators, int) or num_estimators < 1: + raise ValueError( + "The number of estimators must be a positive integer >= 1." + f"Got {num_estimators}." + ) + + if ood_criterion not in [ + "msp", + "logit", + "energy", + "entropy", + "mi", + "vr", + ]: + raise ValueError( + "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," + f" 'mi' or 'vr'. Got {ood_criterion}." + ) + + if num_estimators == 1 and ood_criterion in ["mi", "vr"]: + raise ValueError( + "You cannot use mutual information or variation ratio with a single" + " model." + ) + + if num_estimators != 1 and eval_grouping_loss: + raise NotImplementedError( + "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." + ) + + if num_classes < 1: + raise ValueError( + "The number of classes must be a positive integer >= 1." + f"Got {num_classes}." + ) + + if eval_grouping_loss and not hasattr(model, "feats_forward"): + raise ValueError( + "Your model must have a `feats_forward` method to compute the " + "grouping loss." + ) + + if eval_grouping_loss and not ( + hasattr(model, "classification_head") or hasattr(model, "linear") + ): + raise ValueError( + "Your model must have a `classification_head` or `linear` " + "attribute to compute the grouping loss." + ) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 657859c3..1cefa4f7 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -8,6 +8,7 @@ Independent, MixtureSameFamily, ) +from torch.optim import Optimizer from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torch_uncertainty.metrics.nll import DistributionNLL @@ -17,28 +18,28 @@ class RegressionRoutine(LightningModule): def __init__( self, - probabilistic: bool, - output_dim: int, model: nn.Module, + output_dim: int, + probabilistic: bool, loss: type[nn.Module], num_estimators: int = 1, + optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, - optim_recipe=None, ) -> None: - """Regression routine for PyTorch Lightning. + """Regression routine for Lightning. Args: + model (nn.Module): Model to train. probabilistic (bool): Whether the model is probabilistic, i.e., outputs a PyTorch distribution. - output_dim (int): The number of outputs of the model. - model (nn.Module): The model to train. - loss (type[nn.Module]): The loss function to use. + output_dim (int): Number of outputs of the model. + loss (type[nn.Module]): Loss function to optimize the :attr:`model`. num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to 1. + ensemble. Defaults to 1 (single model). + optim_recipe (dict | Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. format_batch_fn (nn.Module, optional): The function to format the batch. Defaults to None. - optim_recipe (optional): The optimization recipe - to use. Defaults to None. Warning: If :attr:`probabilistic` is True, the model must output a `PyTorch @@ -49,14 +50,18 @@ def __init__( the CLI. """ super().__init__() + _regression_routine_checks(num_estimators, output_dim) - self.probabilistic = probabilistic self.model = model + self.probabilistic = probabilistic + self.output_dim = output_dim self.loss = loss + self.num_estimators = num_estimators if format_batch_fn is None: format_batch_fn = nn.Identity() + self.optim_recipe = optim_recipe self.format_batch_fn = format_batch_fn reg_metrics = MetricCollection( @@ -72,25 +77,12 @@ def __init__( self.val_metrics = reg_metrics.clone(prefix="reg_val/") self.test_metrics = reg_metrics.clone(prefix="reg_test/") - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - self.num_estimators = num_estimators - - if output_dim < 1: - raise ValueError(f"output_dim must be positive, got {output_dim}.") - self.output_dim = output_dim - self.one_dim_regression = output_dim == 1 - self.optim_recipe = optim_recipe - def configure_optimizers(self): - return self.optim_recipe(self.model) + return self.optim_recipe def on_train_start(self) -> None: - # hyperparameters for performances init_metrics = {k: 0 for k in self.val_metrics} init_metrics.update({k: 0 for k in self.test_metrics}) @@ -210,3 +202,13 @@ def on_test_epoch_end(self) -> None: self.test_metrics.compute(), ) self.test_metrics.reset() + + +def _regression_routine_checks(num_estimators, output_dim): + if num_estimators < 1: + raise ValueError( + f"num_estimators must be positive, got {num_estimators}." + ) + + if output_dim < 1: + raise ValueError(f"output_dim must be positive, got {output_dim}.") diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 2ed053d7..d9e28065 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -2,6 +2,7 @@ from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, nn +from torch.optim import Optimizer from torchmetrics import Accuracy, MetricCollection from torchvision.transforms.v2 import functional as F @@ -11,37 +12,43 @@ class SegmentationRoutine(LightningModule): def __init__( self, - num_classes: int, model: nn.Module, + num_classes: int, loss: type[nn.Module], num_estimators: int = 1, - optim_recipe=None, + optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, ) -> None: + """Segmentation routine for Lightning. + + Args: + model (nn.Module): Model to train. + num_classes (int): Number of classes in the segmentation task. + loss (type[nn.Module]): Loss function to optimize the :attr:`model`. + num_estimators (int, optional): The number of estimators for the + ensemble. Defaults to 1 (single model). + optim_recipe (dict | Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. + format_batch_fn (nn.Module, optional): The function to format the + batch. Defaults to None. + + Warning: + You must define :attr:`optim_recipe` if you do not use + the CLI. + """ super().__init__() + _segmentation_routine_checks(num_estimators, num_classes) - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - - if num_classes < 2: - raise ValueError( - f"num_classes must be at least 2, got {num_classes}." - ) + self.model = model + self.num_classes = num_classes + self.loss = loss + self.num_estimators = num_estimators if format_batch_fn is None: format_batch_fn = nn.Identity() - self.num_classes = num_classes - self.model = model - self.loss = loss - self.format_batch_fn = format_batch_fn self.optim_recipe = optim_recipe - - self.num_estimators = num_estimators - - self.metric_to_monitor = "val/mean_iou" + self.format_batch_fn = format_batch_fn # metrics seg_metrics = MetricCollection( @@ -56,7 +63,7 @@ def __init__( self.test_seg_metrics = seg_metrics.clone(prefix="test/") def configure_optimizers(self): - return self.optim_recipe(self.model) + return self.optim_recipe def forward(self, img: Tensor) -> Tensor: return self.model(img) @@ -122,3 +129,13 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self.log_dict(self.test_seg_metrics.compute()) self.test_seg_metrics.reset() + + +def _segmentation_routine_checks(num_estimators, num_classes): + if num_estimators < 1: + raise ValueError( + f"num_estimators must be positive, got {num_estimators}." + ) + + if num_classes < 2: + raise ValueError(f"num_classes must be at least 2, got {num_classes}.") From c0e37355be61ca81bbe2b9c64e33d7e17115e9fd Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 10:51:44 +0100 Subject: [PATCH 100/148] :white_check_mark: Cover grouping loss in ClassificationRoutine --- tests/routines/test_classification.py | 4 ++-- torch_uncertainty/utils/cli.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index ecc9a27d..19dd6b2c 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -92,7 +92,7 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): dm = DummyClassificationDataModule( root=Path(), - batch_size=16, + batch_size=19, # lower than 19 it doesn't work :'( num_classes=2, num_images=100, eval_ood=True, @@ -105,7 +105,7 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): baseline_type="single", ood_criterion="entropy", eval_ood=True, - # eval_grouping_loss=True, + eval_grouping_loss=True, calibrate=True, ) diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index d1871b91..899bcee9 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -117,7 +117,6 @@ def __init__( run, auto_configure_optimizers, ) - print("IN TU CLI INIT") def add_default_arguments_to_parser( self, parser: LightningArgumentParser From a9a3c9cd6dd2c0c3da3763b9d89488b86046e47e Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 22 Mar 2024 11:01:17 +0100 Subject: [PATCH 101/148] :white_check_mark: Slightly improve coverage --- .../test_abstract_datamodule.py | 3 +++ tests/metrics/test_grouping_loss.py | 13 ++++++++----- torch_uncertainty/baselines/regression/mlp.py | 2 +- torch_uncertainty/datamodules/abstract.py | 2 ++ torch_uncertainty/metrics/grouping_loss.py | 6 +----- 5 files changed, 15 insertions(+), 11 deletions(-) rename tests/datamodules/{classification => }/test_abstract_datamodule.py (95%) diff --git a/tests/datamodules/classification/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py similarity index 95% rename from tests/datamodules/classification/test_abstract_datamodule.py rename to tests/datamodules/test_abstract_datamodule.py index c756d7be..7b0f5e66 100644 --- a/tests/datamodules/classification/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -58,3 +58,6 @@ def test_errors(self): cv_dm.setup() cv_dm._get_train_data() cv_dm._get_train_targets() + + with pytest.raises(ValueError): + cv_dm.setup("other") diff --git a/tests/metrics/test_grouping_loss.py b/tests/metrics/test_grouping_loss.py index fe1c4376..7d2fec31 100644 --- a/tests/metrics/test_grouping_loss.py +++ b/tests/metrics/test_grouping_loss.py @@ -4,15 +4,18 @@ from torch_uncertainty.metrics import GroupingLoss -@pytest.fixture() -def disagreement_probas_3() -> torch.Tensor: - return torch.as_tensor([[[0.0, 1.0], [0.0, 1.0], [1.0, 0.0]]]) - - class TestGroupingLoss: """Testing the GroupingLoss metric class.""" def test_compute(self): + metric = GroupingLoss() + metric.update( + torch.rand(100), + (torch.rand(100) > 0.3).long(), + torch.rand((100, 10)), + ) + metric.compute() + metric = GroupingLoss() metric.update( torch.ones((100, 4, 10)) / 10, diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 93ec8b87..02e3c658 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -53,7 +53,7 @@ def __init__( final_layer = NormalInverseGammaLayer final_layer_args = {"dim": output_dim} params["num_outputs"] *= 4 - elif distribution is None: + else: # distribution is None: probabilistic = False final_layer = nn.Identity final_layer_args = {} diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 9925f6e2..1da19ced 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -180,6 +180,8 @@ def setup(self, stage: str | None = None) -> None: self.val = self.dm.val elif stage == "test": self.test = self.val + else: + raise ValueError(f"Stage {stage} not supported.") def _data_loader(self, dataset: Dataset, idx: ArrayLike) -> DataLoader: return DataLoader( diff --git a/torch_uncertainty/metrics/grouping_loss.py b/torch_uncertainty/metrics/grouping_loss.py index cfa23367..062bff45 100644 --- a/torch_uncertainty/metrics/grouping_loss.py +++ b/torch_uncertainty/metrics/grouping_loss.py @@ -33,15 +33,11 @@ def __init__( Inputs: - :attr:`probs`: :math:`(B, C)` or :math:`(B, N, C)` - :attr:`target`: :math:`(B)` or :math:`(B, C)` + - :attr:`features`: :math:`(B, F)` or :math:`(B, N, F)` where :math:`B` is the batch size, :math:`C` is the number of classes and :math:`N` is the number of estimators. - Note: - If :attr:`probs` is a 3d tensor, then the metric computes the mean of - the Brier score over the estimators ie. :math:`t = \frac{1}{N} - \sum_{i=0}^{N-1} BrierScore(probs[:,i,:], target)`. - Warning: Make sure that the probabilities in :attr:`probs` are normalized to sum to one. From 4f2d6f713434cf8b1eac87a845a33ac600765d2a Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 22 Mar 2024 11:20:16 +0100 Subject: [PATCH 102/148] :white_check_mark: Improve dm coverage --- .../classification/test_cifar100_datamodule.py | 1 + .../classification/test_imagenet_datamodule.py | 13 +++++++++++-- .../classification/test_mnist_datamodule.py | 2 ++ .../classification/test_tiny_imagenet_datamodule.py | 9 +++++++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/datamodules/classification/test_cifar100_datamodule.py b/tests/datamodules/classification/test_cifar100_datamodule.py index e032d332..487e7214 100644 --- a/tests/datamodules/classification/test_cifar100_datamodule.py +++ b/tests/datamodules/classification/test_cifar100_datamodule.py @@ -35,6 +35,7 @@ def test_cifar100(self): root="./data/", batch_size=128, cutout=0, test_alt="c" ) dm.dataset = DummyClassificationDataset + dm.setup("test") with pytest.raises(ValueError): dm.setup() diff --git a/tests/datamodules/classification/test_imagenet_datamodule.py b/tests/datamodules/classification/test_imagenet_datamodule.py index 4689c2d9..80b73aff 100644 --- a/tests/datamodules/classification/test_imagenet_datamodule.py +++ b/tests/datamodules/classification/test_imagenet_datamodule.py @@ -12,9 +12,7 @@ class TestImageNetDataModule: def test_imagenet(self): dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=0.1) - assert dm.dataset == ImageNet - dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.prepare_data() @@ -44,6 +42,13 @@ def test_imagenet(self): dm.setup("test") dm.test_dataloader() + ImageNetDataModule( + root="./data/", + batch_size=128, + val_split=path, + rand_augment_opt="rand-m9-n1", + ) + with pytest.raises(ValueError): dm.setup("other") @@ -92,3 +97,7 @@ def test_imagenet(self): with pytest.raises(FileNotFoundError): dm._verify_splits(split="test") + + with pytest.raises(FileNotFoundError): + dm.root = Path("./tests/testlog") + dm._verify_splits(split="test") diff --git a/tests/datamodules/classification/test_mnist_datamodule.py b/tests/datamodules/classification/test_mnist_datamodule.py index 0da67431..1707409a 100644 --- a/tests/datamodules/classification/test_mnist_datamodule.py +++ b/tests/datamodules/classification/test_mnist_datamodule.py @@ -26,6 +26,8 @@ def test_mnist_cutout(self): with pytest.raises(ValueError): MNISTDataModule(root="./data/", batch_size=128, ood_ds="other") + MNISTDataModule(root="./data/", batch_size=128, test_alt="c") + dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset diff --git a/tests/datamodules/classification/test_tiny_imagenet_datamodule.py b/tests/datamodules/classification/test_tiny_imagenet_datamodule.py index ca519347..5885fdb3 100644 --- a/tests/datamodules/classification/test_tiny_imagenet_datamodule.py +++ b/tests/datamodules/classification/test_tiny_imagenet_datamodule.py @@ -48,6 +48,15 @@ def test_tiny_imagenet(self): dm.setup("test") dm.test_dataloader() + dm = TinyImageNetDataModule( + root="./data/", batch_size=128, ood_ds="svhn" + ) + dm.dataset = DummyClassificationDataset + dm.ood_dataset = DummyClassificationDataset + dm.eval_ood = True + dm.prepare_data() + dm.setup("test") + def test_tiny_imagenet_cv(self): dm = TinyImageNetDataModule(root="./data/", batch_size=128) dm.dataset = lambda root, split, transform: DummyClassificationDataset( From 49d3eb3c0871f4b7862285b58fb9a56d2b995773 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 22 Mar 2024 12:01:41 +0100 Subject: [PATCH 103/148] :ok_hand: Solve some comments --- auto_tutorials_source/tutorial_bayesian.py | 2 +- experiments/classification/cifar10/readme.md | 2 - .../classification/mnist/bayesian_lenet.py | 2 +- tests/baselines/test_mimo.py | 6 -- .../baselines/segmentation/segformer.py | 3 +- torch_uncertainty/losses.py | 63 +++++++++++-------- torch_uncertainty/metrics/nll.py | 2 +- torch_uncertainty/routines/classification.py | 9 +-- 8 files changed, 43 insertions(+), 46 deletions(-) diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 38b3689a..04f1202e 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -95,7 +95,7 @@ def optim_lenet(model: nn.Module) -> dict: loss = ELBOLoss( model=model, - criterion=nn.CrossEntropyLoss(), + inner_loss=nn.CrossEntropyLoss(), kl_weight=1 / 50000, num_samples=3, ) diff --git a/experiments/classification/cifar10/readme.md b/experiments/classification/cifar10/readme.md index 286b73a1..6fdfb043 100644 --- a/experiments/classification/cifar10/readme.md +++ b/experiments/classification/cifar10/readme.md @@ -27,8 +27,6 @@ python resnet.py fit --config configs/resnet50/packed.yaml python resnet.py fit --config configs/resnet.yaml --model.arch 101 --model.version mimo --model.num_estimators 4 --model.rho 1.0 ``` - - ## Available configurations: ### ResNet diff --git a/experiments/classification/mnist/bayesian_lenet.py b/experiments/classification/mnist/bayesian_lenet.py index f1030963..05a7c17e 100644 --- a/experiments/classification/mnist/bayesian_lenet.py +++ b/experiments/classification/mnist/bayesian_lenet.py @@ -45,7 +45,7 @@ def optim_lenet(model: nn.Module) -> dict: # hyperparameters are from blitz. loss = partial( ELBOLoss, - criterion=nn.CrossEntropyLoss(), + inner_loss=nn.CrossEntropyLoss(), kl_weight=1 / 50000, num_samples=3, ) diff --git a/tests/baselines/test_mimo.py b/tests/baselines/test_mimo.py index 4b7e6231..18c83a08 100644 --- a/tests/baselines/test_mimo.py +++ b/tests/baselines/test_mimo.py @@ -7,12 +7,6 @@ WideResNetBaseline, ) -# from torch_uncertainty.optim_recipes import ( -# optim_cifar10_resnet18, -# optim_cifar10_resnet50, -# optim_cifar10_wideresnet, -# ) - class TestMIMOBaseline: """Testing the MIMOResNet baseline class.""" diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 49be67d7..1e8185a1 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -56,7 +56,8 @@ def __init__( - ``4``: SegFormer-B4 - ``5``: SegFormer-B5 - num_estimators (int, optional): _description_. Defaults to 1. + num_estimators (int, optional): Number of estimators in the + ensemble. Defaults to 1 (single model). """ params = { "num_classes": num_classes, diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index bd7319c9..ed552924 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -67,48 +67,31 @@ class ELBOLoss(nn.Module): def __init__( self, model: nn.Module | None, - criterion: nn.Module, + inner_loss: nn.Module, kl_weight: float, num_samples: int, ) -> None: """The Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks. ELBO loss for Bayesian Neural Networks. Use this loss function with the - objective that you seek to minimize as :attr:`criterion`. + objective that you seek to minimize as :attr:`inner_loss`. Args: model (nn.Module): The Bayesian Neural Network to compute the loss for - criterion (nn.Module): The loss function to use during training + inner_loss (nn.Module): The loss function to use during training kl_weight (float): The weight of the KL divergence term num_samples (int): The number of samples to use for the ELBO loss + + Note: + Set the model to None if you use the ELBOLoss within + the ClassificationRoutine. It will get filled automatically. """ super().__init__() - + _elbo_loss_checks(inner_loss, kl_weight, num_samples) self.set_model(model) - if isinstance(criterion, type): - raise TypeError( - "The criterion should be an instance of a class." - f"Got {criterion}." - ) - self.criterion = criterion - - if kl_weight < 0: - raise ValueError( - f"The KL weight should be non-negative. Got {kl_weight}." - ) + self.inner_loss = inner_loss self.kl_weight = kl_weight - - if num_samples < 1: - raise ValueError( - "The number of samples should not be lower than 1." - f"Got {num_samples}." - ) - if not isinstance(num_samples, int): - raise TypeError( - "The number of samples should be an integer. " - f"Got {type(num_samples)}." - ) self.num_samples = num_samples def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: @@ -125,7 +108,7 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: aggregated_elbo = torch.zeros(1, device=inputs.device) for _ in range(self.num_samples): logits = self.model(inputs) - aggregated_elbo += self.criterion(logits, targets) + aggregated_elbo += self.inner_loss(logits, targets) aggregated_elbo += self.kl_weight * self._kl_div() return aggregated_elbo / self.num_samples @@ -135,6 +118,32 @@ def set_model(self, model: nn.Module) -> None: self._kl_div = KLDiv(model) +def _elbo_loss_checks( + inner_loss: nn.Module, kl_weight: float, num_samples: int +): + if isinstance(inner_loss, type): + raise TypeError( + "The inner_loss should be an instance of a class." + f"Got {inner_loss}." + ) + + if kl_weight < 0: + raise ValueError( + f"The KL weight should be non-negative. Got {kl_weight}." + ) + + if num_samples < 1: + raise ValueError( + "The number of samples should not be lower than 1." + f"Got {num_samples}." + ) + if not isinstance(num_samples, int): + raise TypeError( + "The number of samples should be an integer. " + f"Got {type(num_samples)}." + ) + + class DERLoss(DistributionNLLLoss): def __init__( self, reg_weight: float, reduction: str | None = "mean" diff --git a/torch_uncertainty/metrics/nll.py b/torch_uncertainty/metrics/nll.py index d433f43e..98df27e8 100644 --- a/torch_uncertainty/metrics/nll.py +++ b/torch_uncertainty/metrics/nll.py @@ -8,7 +8,7 @@ class CategoricalNLL(Metric): - is_differentiabled = False + is_differentiable = False higher_is_better = False full_state_update = False diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index b2047295..05fcda0f 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -92,16 +92,11 @@ def __init__( the ensemble and vr is the variation ratio of the ensemble. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. - save_in_csv(bool, optional): __TODO__ + save_in_csv(bool, optional): Save the results in csv. Defaults to + ``False``. calibration_set (Callable, optional): Function to get the calibration set. Defaults to ``None``. - Raises: - ValueError: _description_ - ValueError: _description_ - ValueError: _description_ - ValueError: _description_ - Warning: You must define :attr:`optim_recipe` if you do not use the CLI. From fe0b08569932ba311dddbe35103549f2a55b02e1 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 12:37:15 +0100 Subject: [PATCH 104/148] :white_check_mark: Slightly improve ClassificationRoutine coverage --- tests/_dummies/baseline.py | 3 ++ tests/routines/test_classification.py | 40 +++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index c043db6e..8fe08bca 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -33,6 +33,7 @@ def __new__( eval_ood: bool = False, eval_grouping_loss: bool = False, calibrate: bool = False, + save_in_csv: bool = False, ) -> LightningModule: model = dummy_model( in_channels=in_channels, @@ -54,6 +55,7 @@ def __new__( eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, calibration_set="val" if calibrate else None, + save_in_csv=save_in_csv, ) # baseline_type == "ensemble": model = deep_ensembles( @@ -72,6 +74,7 @@ def __new__( eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, calibration_set="val" if calibrate else None, + save_in_csv=save_in_csv, ) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 26661ddc..4f65e800 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -9,6 +9,7 @@ DummyClassificationDataModule, dummy_model, ) +from torch_uncertainty.losses import ELBOLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine @@ -103,7 +104,7 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", - ood_criterion="entropy", + ood_criterion="energy", eval_ood=True, eval_grouping_loss=True, calibrate=True, @@ -130,8 +131,43 @@ def test_two_estimators_two_classes_with_ood(self): loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", - ood_criterion="energy", + ood_criterion="mi", + eval_ood=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimator_two_classes(self): + trainer = Trainer( + accelerator="cpu", + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + enable_checkpointing=False, + ) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=ELBOLoss( + None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4 + ), + optim_recipe=optim_cifar10_resnet18, + baseline_type="ensemble", + ood_criterion="vr", eval_ood=True, + save_in_csv=True, ) trainer.fit(model, dm) From 4c8be66828593e173a294975ef71c49a190bfddd Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 13:13:26 +0100 Subject: [PATCH 105/148] :bug: Fix mixup execution and improve coverage --- tests/_dummies/baseline.py | 14 ++ tests/routines/test_classification.py | 168 ++++++++++++++++++- torch_uncertainty/layers/__init__.py | 1 + torch_uncertainty/layers/modules.py | 11 ++ torch_uncertainty/routines/classification.py | 9 +- 5 files changed, 197 insertions(+), 6 deletions(-) create mode 100644 torch_uncertainty/layers/modules.py diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 8fe08bca..b650f180 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -34,6 +34,13 @@ def __new__( eval_grouping_loss: bool = False, calibrate: bool = False, save_in_csv: bool = False, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, ) -> LightningModule: model = dummy_model( in_channels=in_channels, @@ -51,6 +58,13 @@ def __new__( log_plots=True, optim_recipe=optim_recipe(model), num_estimators=1, + mixtype=mixtype, + mixmode=mixmode, + dist_sim=dist_sim, + kernel_tau_max=kernel_tau_max, + kernel_tau_std=kernel_tau_std, + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 4f65e800..9204dcff 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -88,6 +88,170 @@ def test_one_estimator_two_classes(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + def test_one_estimator_two_classes_timm(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="timm", + mixup_alpha=1.0, + cutmix_alpha=0.5, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_mixup(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="mixup", + mixup_alpha=1.0, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_mixup_io(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="mixup_io", + mixup_alpha=1.0, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_regmixup(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="regmixup", + mixup_alpha=1.0, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_kernel_warping_emb(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="kernel_warping", + mixup_alpha=0.5, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_one_estimator_two_classes_kernel_warping_inp(self): + trainer = Trainer(accelerator="cpu", fast_dev_run=True) + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=16, + num_classes=2, + num_images=100, + eval_ood=True, + ) + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_cifar10_resnet18, + baseline_type="single", + ood_criterion="entropy", + eval_ood=True, + mixtype="kernel_warping", + mixmode="inp", + mixup_alpha=0.5, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + def test_one_estimator_two_classes_calibrated_with_ood(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True, logger=True) @@ -115,7 +279,7 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) - def test_two_estimators_two_classes_with_ood(self): + def test_two_estimators_two_classes_mi(self): trainer = Trainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( @@ -140,7 +304,7 @@ def test_two_estimators_two_classes_with_ood(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) - def test_two_estimator_two_classes(self): + def test_two_estimator_two_classes_elbo_vr_logs(self): trainer = Trainer( accelerator="cpu", max_epochs=1, diff --git a/torch_uncertainty/layers/__init__.py b/torch_uncertainty/layers/__init__.py index 64fd39fa..f91746bd 100644 --- a/torch_uncertainty/layers/__init__.py +++ b/torch_uncertainty/layers/__init__.py @@ -2,4 +2,5 @@ from .batch_ensemble import BatchConv2d, BatchLinear from .bayesian import BayesConv1d, BayesConv2d, BayesConv3d, BayesLinear from .masksembles import MaskedConv2d, MaskedLinear +from .modules import Identity from .packed import PackedConv1d, PackedConv2d, PackedConv3d, PackedLinear diff --git a/torch_uncertainty/layers/modules.py b/torch_uncertainty/layers/modules.py new file mode 100644 index 00000000..c2e9a6e3 --- /dev/null +++ b/torch_uncertainty/layers/modules.py @@ -0,0 +1,11 @@ +from typing import Any + +from torch import nn + + +class Identity(nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, *args) -> Any: + return args diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 05fcda0f..5ff3a8ae 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -17,6 +17,7 @@ BinaryAveragePrecision, ) +from torch_uncertainty.layers import Identity from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.metrics import ( CE, @@ -258,7 +259,7 @@ def init_mixup( tau_max=kernel_tau_max, tau_std=kernel_tau_std, ) - return nn.Identity() + return Identity() def configure_optimizers(self): return self.optim_recipe @@ -328,11 +329,11 @@ def training_step( with torch.no_grad(): feats = self.model.feats_forward(batch[0]).detach() - batch = self.mixup(batch, feats) + batch = self.mixup(*batch, feats) elif self.dist_sim == "inp": - batch = self.mixup(batch, batch[0]) + batch = self.mixup(*batch, batch[0]) else: - batch = self.mixup(batch) + batch = self.mixup(*batch) inputs, targets = self.format_batch_fn(batch) From 398554789871e5ed64166032198bca3318e4050d Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 13:23:19 +0100 Subject: [PATCH 106/148] :bug: Fix wrong testing function --- tests/routines/test_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 9204dcff..05cc2f84 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -243,7 +243,7 @@ def test_one_estimator_two_classes_kernel_warping_inp(self): ood_criterion="entropy", eval_ood=True, mixtype="kernel_warping", - mixmode="inp", + dist_sim="inp", mixup_alpha=0.5, ) From 603c464ae3fa3f14817bd7052502776f992d87d0 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 22 Mar 2024 13:41:29 +0100 Subject: [PATCH 107/148] :white_check_mark: Slightly improve coverage and remove dead code in segformer std --- tests/routines/test_classification.py | 4 +- .../models/segmentation/segformer/std.py | 67 ------------------- 2 files changed, 2 insertions(+), 69 deletions(-) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 05cc2f84..dea1ee61 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -9,7 +9,7 @@ DummyClassificationDataModule, dummy_model, ) -from torch_uncertainty.losses import ELBOLoss +from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine @@ -292,7 +292,7 @@ def test_two_estimators_two_classes_mi(self): model = DummyClassificationBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss(), + loss=DECLoss(1, 1e-2), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ood_criterion="mi", diff --git a/torch_uncertainty/models/segmentation/segformer/std.py b/torch_uncertainty/models/segmentation/segformer/std.py index 49bd41d9..83c86e02 100644 --- a/torch_uncertainty/models/segmentation/segformer/std.py +++ b/torch_uncertainty/models/segmentation/segformer/std.py @@ -267,24 +267,6 @@ def forward(self, x): return x, h, w -class SegFormerSegmentationHead(nn.Module): - def __init__(self, channels: int, num_classes: int, num_features: int = 4): - super().__init__() - self.fuse = nn.Sequential( - nn.Conv2d( - channels * num_features, channels, kernel_size=1, bias=False - ), - nn.ReLU(), - nn.BatchNorm2d(channels), - ) - self.predict = nn.Conv2d(channels, num_classes, kernel_size=1) - - def forward(self, features): - x = torch.cat(features, dim=1) - x = self.fuse(x) - return self.predict(x) - - class MixVisionTransformer(nn.Module): def __init__( self, @@ -432,10 +414,6 @@ def __init__( ) self.norm4 = norm_layer(embed_dims[3]) - # classification head - # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 - # else nn.Identity() - self.apply(self._init_weights) def _init_weights(self, m): @@ -453,51 +431,6 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def reset_drop_path(self, drop_path_rate): - dpr = [ - x.item() - for x in torch.linspace(0, drop_path_rate, sum(self.depths)) - ] - cur = 0 - for i in range(self.depths[0]): - self.block1[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[0] - for i in range(self.depths[1]): - self.block2[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[1] - for i in range(self.depths[2]): - self.block3[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[2] - for i in range(self.depths[3]): - self.block4[i].drop_path.drop_prob = dpr[cur + i] - - def freeze_patch_emb(self): - self.patch_embed1.requires_grad = False - - @torch.jit.ignore - def no_weight_decay(self): - return { - "pos_embed1", - "pos_embed2", - "pos_embed3", - "pos_embed4", - "cls_token", - } # has pos_embed may be better - - def get_classifier(self): - return self.head - - def reset_classifier(self, num_classes, global_pool=""): - self.num_classes = num_classes - self.head = ( - nn.Linear(self.embed_dim, num_classes) - if num_classes > 0 - else nn.Identity() - ) - def forward_features(self, x): b = x.shape[0] outs = [] From d9e261e4ed5db25725094d62f2900373cad30f6e Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 22 Mar 2024 13:43:04 +0100 Subject: [PATCH 108/148] :bug: Fix tutorials --- auto_tutorials_source/tutorial_der_cubic.py | 21 +++++++------------ .../tutorial_evidential_classification.py | 4 ++-- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index f293a65b..b6bc91e1 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -27,13 +27,9 @@ - the evidential objective: the DERLoss from torch_uncertainty.losses. This loss contains the classic NLL loss and a regularization term. - a dataset that generates samples from a noisy cubic function: Cubic from torch_uncertainty.datasets.regression -We also need to define an optimizer using torch.optim, the neural network utils within torch.nn, as well as the partial util to provide -the modified default arguments for the DER loss. +We also need to define an optimizer using torch.optim and the neural network utils within torch.nn. """ - # %% -from functools import partial - import torch from lightning.pytorch import Trainer from lightning import LightningDataModule @@ -98,15 +94,12 @@ def optim_regression( # 4. The Loss and the Training Routine # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Next, we need to define the loss to be used during training. To do this, we -# set the weight of the regularizer of the DER Loss in advance using the partial -# function from functools. After that, we define the training routine using -# the probabilistic regression training routine from torch_uncertainty.routines. In -# this routine, we provide the model, the DER loss, and the optimization recipe. - -loss = partial( - DERLoss, - reg_weight=1e-2, -) +# set the weight of the regularizer of the DER Loss. After that, we define the +# training routine using the probabilistic regression training routine from +# torch_uncertainty.routines. In this routine, we provide the model, the DER +# loss, and the optimization recipe. + +loss = DERLoss(reg_weight=1e-2) routine = RegressionRoutine( probabilistic=True, diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index cef81083..1b780361 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -58,8 +58,8 @@ def optim_lenet(model: nn.Module) -> dict: # In the following, we need to define the root of the logs, and to # fake-parse the arguments needed for using the PyTorch Lightning Trainer. We # also use the same MNIST classification example as that used in the -# original DEC paper. We only train for 5 epochs for the sake of time. -trainer = Trainer(accelerator="cpu", max_epochs=5, enable_progress_bar=False) +# original DEC paper. We only train for 3 epochs for the sake of time. +trainer = Trainer(accelerator="cpu", max_epochs=3, enable_progress_bar=False) # datamodule root = Path() / "data" From ee075848b358f10a9b19519f3521b6f76310c0ed Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 22 Mar 2024 14:02:13 +0100 Subject: [PATCH 109/148] :book: Improve Readme --- README.md | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 45472d62..a49c775c 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ [![Discord Badge](https://dcbadge.vercel.app/api/server/HMCawt5MJu?compact=true&style=flat)](https://discord.gg/HMCawt5MJu) -_TorchUncertainty_ is a package designed to help you leverage uncertainty quantification techniques and make your deep neural networks more reliable. It aims at being collaborative and including as many methods as possible, so reach out to add yours! +_TorchUncertainty_ is a package designed to help you leverage [uncertainty quantification techniques](https://github.com/ENSTA-U2IS-AI/awesome-uncertainty-deeplearning) and make your deep neural networks more reliable. It aims at being collaborative and including as many methods as possible, so reach out to add yours! :construction: _TorchUncertainty_ is in early development :construction: - expect changes, but reach out and contribute if you are interested in the project! **Please raise an issue if you have any bugs or difficulties and join the [discord server](https://discord.gg/HMCawt5MJu).** @@ -19,24 +19,24 @@ _TorchUncertainty_ is a package designed to help you leverage uncertainty quanti This package provides a multi-level API, including: +- easy-to-use ⚡️ lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. - ready-to-train baselines on research datasets, such as ImageNet and CIFAR -- deep learning baselines available for training on your datasets - [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (work in progress 🚧). -- layers available for use in your networks -- scikit-learn style post-processing methods such as Temperature Scaling +- **layers**, **models**, **metrics**, & **losses** available for use in your networks +- scikit-learn style post-processing methods such as Temperature Scaling. -See the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. +Have a look at the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. -## Installation +## ⚙️ Installation -Install the desired PyTorch version in your environment. +TorchUncertainty requires Python 3.10 or greater. Install the desired PyTorch version in your environment. Then, install the package from PyPI: ```sh pip install torch-uncertainty ``` -If you aim to contribute, have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). +The installation procedure for contributors is different: have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). ## Getting Started and Documentation @@ -46,6 +46,8 @@ A quickstart is available at [torch-uncertainty.github.io/quickstart](https://to ## Implemented methods +TorchUncertainty currently supports **Classification**, **probabilistic** and pointwise **Regression** and **Segmentation**. + ### Baselines To date, the following deep learning baselines have been implemented: @@ -75,7 +77,7 @@ To date, the following post-processing methods have been implemented: ## Tutorials -We provide the following tutorials in our documentation: +The following tutorials willWe provide the following tutorials in our documentation: - [From a Standard Classifier to a Packed-Ensemble](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - [Training a Bayesian Neural Network in 3 minutes](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) @@ -84,10 +86,6 @@ We provide the following tutorials in our documentation: - [Training a LeNet with Monte-Carlo Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) - [Training a LeNet with Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) -## Awesome Uncertainty repositories - -You may find a lot of papers about modern uncertainty estimation techniques on the [Awesome Uncertainty in Deep Learning](https://github.com/ENSTA-U2IS-AI/awesome-uncertainty-deeplearning). - ## Other References This package also contains the official implementation of Packed-Ensembles. From d337c1b581624c979c5fecaa5494a5b54b47baea Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 23 Mar 2024 10:50:49 +0100 Subject: [PATCH 110/148] :shirt: Small fixes --- torch_uncertainty/models/deep_ensembles.py | 2 -- torch_uncertainty/models/mc_dropout.py | 7 ++++--- torch_uncertainty/routines/classification.py | 8 ++++++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index e970bd15..a76beb99 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -15,7 +15,6 @@ def __init__( ) -> None: """Create a classification deep ensembles from a list of models.""" super().__init__() - self.models = nn.ModuleList(models) self.num_estimators = len(models) @@ -41,7 +40,6 @@ def __init__( ) -> None: """Create a regression deep ensembles from a list of models.""" super().__init__(models) - self.probabilistic = probabilistic def forward(self, x: torch.Tensor) -> Distribution: diff --git a/torch_uncertainty/models/mc_dropout.py b/torch_uncertainty/models/mc_dropout.py index d3ac77af..355fe43e 100644 --- a/torch_uncertainty/models/mc_dropout.py +++ b/torch_uncertainty/models/mc_dropout.py @@ -13,7 +13,8 @@ def __init__( last_layer (bool): whether to apply dropout to the last layer only. Warning: - The underlying models must have a `dropout_rate` attribute. + The underlying models must have a non-zero :attr:`dropout_rate` + attribute. Warning: For the `last-layer` option to work properly, the model must @@ -70,7 +71,7 @@ def train(self, mode: bool = True) -> nn.Module: def forward( self, x: Tensor, - ) -> tuple[Tensor, Tensor]: + ) -> Tensor: if not self.training: x = x.repeat(self.num_estimators, 1, 1, 1) return self.model(x) @@ -85,7 +86,7 @@ def mc_dropout( model (nn.Module): model to wrap num_estimators (int): number of estimators to use last_layer (bool, optional): whether to apply dropout to the last - layer. Defaults to False. + layer only. Defaults to False. """ return _MCDropout( model=model, num_estimators=num_estimators, last_layer=last_layer diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 5ff3a8ae..88969d7b 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -210,7 +210,7 @@ def __init__( } ) - self.test_id_ens_metrics = ens_metrics.clone(prefix="ood/ens_") + self.test_id_ens_metrics = ens_metrics.clone(prefix="cls_test/ens_") if self.eval_ood: self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") @@ -574,7 +574,11 @@ def save_results_to_csv(self, results: dict[str, float]) -> None: def _classification_routine_checks( - model, num_classes, num_estimators, ood_criterion, eval_grouping_loss + model: nn.Module, + num_classes: int, + num_estimators: int, + ood_criterion: str, + eval_grouping_loss: bool, ): if not isinstance(num_estimators, int) or num_estimators < 1: raise ValueError( From 7978844c9f7428812402371976589aa54996dd9c Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 23 Mar 2024 11:22:39 +0100 Subject: [PATCH 111/148] :bug: Fix calibration in cls routine --- torch_uncertainty/routines/classification.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 88969d7b..eb467aa2 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -101,6 +101,10 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the CLI. + + Warning: + You must provide a datamodule to the trainer or use the CLI for if + :attr:`calibration_set` is not ``None``. """ super().__init__() _classification_routine_checks( @@ -288,7 +292,7 @@ def on_test_start(self) -> None: self.scaler = TemperatureScaler( model=self.model, device=self.device ).fit(calibration_set=dataset) - self.cal_model = torch.nn.Sequential(self.model, self.scaler) + self.cal_model = self.scaler else: self.scaler = None self.cal_model = None From 7541a646659cc250cd435e6fc51ff8dfab6c2459 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 23 Mar 2024 11:31:18 +0100 Subject: [PATCH 112/148] :shirt: Small calibration improvements --- auto_tutorials_source/tutorial_scaler.py | 6 +++++- torch_uncertainty/routines/classification.py | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index a4f6fae9..2d927b10 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -5,7 +5,11 @@ In this tutorial, we use *TorchUncertainty* to improve the calibration of the top-label predictions and the reliability of the underlying neural network. -We also see how to use the datamodules outside any Lightning trainers, +This tutorial provides extensive details on how to use the TemperatureScaler +class, however, this is done automatically in the classification routine when setting +the `calibration_set` to val or test. + +Through this tutorial, we also see how to use the datamodules outside any Lightning trainers, and how to use TorchUncertainty's models. 1. Loading the Utilities diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index eb467aa2..77e281f8 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -160,7 +160,7 @@ def __init__( self.test_cls_metrics = cls_metrics.clone(prefix="cls_test/") if self.calibration_set is not None: - self.ts_cls_metrics = cls_metrics.clone(prefix="ts_") + self.ts_cls_metrics = cls_metrics.clone(prefix="cls_test/ts_") self.test_entropy_id = Entropy() @@ -289,10 +289,9 @@ def on_test_start(self) -> None: else self.trainer.datamodule.test_dataloader().dataset ) with torch.inference_mode(False): - self.scaler = TemperatureScaler( + self.cal_model = TemperatureScaler( model=self.model, device=self.device ).fit(calibration_set=dataset) - self.cal_model = self.scaler else: self.scaler = None self.cal_model = None From c729298a4261067fbd6d0f9268950402109ea8e9 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 23 Mar 2024 11:33:36 +0100 Subject: [PATCH 113/148] :bug: Fix small bug --- torch_uncertainty/routines/classification.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 77e281f8..61e76f8a 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -293,7 +293,6 @@ def on_test_start(self) -> None: model=self.model, device=self.device ).fit(calibration_set=dataset) else: - self.scaler = None self.cal_model = None if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): @@ -426,7 +425,6 @@ def test_step( if ( self.num_estimators == 1 and self.calibration_set is not None - and self.scaler is not None and self.cal_model is not None ): cal_logits = self.cal_model(inputs) @@ -497,7 +495,6 @@ def on_test_epoch_end(self) -> None: if ( self.num_estimators == 1 and self.calibration_set is not None - and self.scaler is not None and self.cal_model is not None ): tmp_metrics = self.ts_cls_metrics.compute() From 54269de7667959235174a4402e28e222bcbf79fe Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 23 Mar 2024 12:01:39 +0100 Subject: [PATCH 114/148] :disappointed: Rollback sphinx to 5.x --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8baf0296..bbc8e223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dev = [ "pre-commit-hooks", ] docs = [ - "sphinx<7", + "sphinx<6", "tu_sphinx_theme", "sphinx-copybutton", "sphinx-gallery", From 988791c91fbcc528ab07a3e821ccdd55ed6cd08e Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 23 Mar 2024 16:51:54 +0100 Subject: [PATCH 115/148] :wrench: Set ruff version to avoid potential CI issues --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bbc8e223..d9a521d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ [project.optional-dependencies] dev = [ - "ruff", + "ruff==0.3.4", "pytest-cov", "pre-commit", "pre-commit-hooks", From f75e7b5de113f613cebe403a1ced03dd4475d6a2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 24 Mar 2024 11:07:49 +0100 Subject: [PATCH 116/148] :shirt: Improve overall code quality --- .github/workflows/run-tests.yml | 4 ++-- docs/source/conf.py | 4 ---- .../classification/tiny-imagenet/resnet.py | 1 - pyproject.toml | 8 ++++++-- tests/_dummies/dataset.py | 3 --- tests/_dummies/model.py | 1 + tests/baselines/test_masked.py | 4 ++-- tests/baselines/test_packed.py | 4 ++-- .../classification/test_cifar10_datamodule.py | 6 ------ tests/layers/test_mask.py | 2 +- tests/metrics/test_brier_score.py | 6 ++++-- tests/metrics/test_entropy.py | 2 +- tests/metrics/test_mutual_information.py | 2 +- tests/metrics/test_nll.py | 2 +- tests/metrics/test_variation_ratio.py | 2 +- tests/test_utils.py | 5 +++-- .../baselines/classification/vgg.py | 1 + .../baselines/classification/wideresnet.py | 1 - .../datamodules/uci_regression.py | 10 +--------- .../datasets/classification/cifar/cifar_c.py | 4 ++-- .../datasets/classification/cifar/cifar_h.py | 4 +--- .../datasets/classification/cifar/cifar_n.py | 8 ++------ .../datasets/classification/imagenet/base.py | 1 + .../classification/imagenet/tiny_imagenet.py | 8 ++++---- .../datasets/classification/mnist_c.py | 4 +--- .../datasets/classification/openimage_o.py | 2 +- torch_uncertainty/datasets/muad.py | 2 +- .../datasets/regression/uci_regression.py | 6 ++---- .../datasets/segmentation/camvid.py | 6 +----- torch_uncertainty/layers/batch_ensemble.py | 12 +----------- .../layers/bayesian/bayes_conv.py | 4 ++-- .../layers/bayesian/bayes_linear.py | 2 +- torch_uncertainty/layers/mc_batch_norm.py | 6 +++--- torch_uncertainty/layers/modules.py | 1 + torch_uncertainty/losses.py | 4 ++-- torch_uncertainty/metrics/fpr95.py | 1 - torch_uncertainty/metrics/grouping_loss.py | 4 +++- torch_uncertainty/metrics/sparsification.py | 2 +- torch_uncertainty/models/deep_ensembles.py | 8 +++----- torch_uncertainty/models/lenet.py | 2 +- torch_uncertainty/models/mlp.py | 2 +- torch_uncertainty/models/resnet/batched.py | 2 +- torch_uncertainty/models/resnet/masked.py | 8 ++++---- torch_uncertainty/models/resnet/std.py | 1 + .../models/segmentation/segformer/std.py | 19 +++++++++---------- torch_uncertainty/models/wideresnet/std.py | 2 +- torch_uncertainty/optim_recipes.py | 2 +- .../post_processing/calibration/scaler.py | 2 +- .../post_processing/mc_batch_norm.py | 2 +- torch_uncertainty/routines/classification.py | 8 ++++---- torch_uncertainty/routines/regression.py | 8 ++++---- torch_uncertainty/routines/segmentation.py | 8 ++++---- torch_uncertainty/transforms/corruptions.py | 12 ++++++------ torch_uncertainty/transforms/mixup.py | 1 + torch_uncertainty/utils/distributions.py | 17 ++++++++++++----- torch_uncertainty/utils/misc.py | 2 +- 56 files changed, 113 insertions(+), 142 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index b3872575..308a3f84 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -70,8 +70,8 @@ jobs: - name: Check style & format if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m ruff check torch_uncertainty tests --no-fix - python3 -m ruff format torch_uncertainty tests --check + python3 -m ruff check torch_uncertainty --no-fix + python3 -m ruff format torch_uncertainty --check - name: Test with pytest and compute coverage if: steps.changed-files-specific.outputs.only_changed != 'true' diff --git a/docs/source/conf.py b/docs/source/conf.py index 4dc73558..3a03317b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,7 +86,3 @@ html_static_path = ["_static"] html_style = "css/custom.css" -# html_css_files = [ -# 'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css', -# 'css/custom.css' -# ] diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index 77be476f..e003ae84 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -28,7 +28,6 @@ def optim_tiny(model: nn.Module) -> dict: else: root = Path(args.root) - # net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet" if args.exp_name == "": args.exp_name = f"{args.version}-resnet{args.arch}-tinyimagenet" diff --git a/pyproject.toml b/pyproject.toml index d9a521d7..4d003ce5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,12 @@ line-length = 80 target-version = "py310" lint.extend-select = [ "A", + "ARG", "B", "C4", "D", + "ERA", + "F", "G", "I", "ISC", @@ -88,19 +91,20 @@ lint.extend-select = [ "PIE", "PTH", "PYI", + "Q", "RET", "RUF", "RSE", "S", "SIM", - "UP", "TCH", "TID", "TRY", + "UP", "YTT", ] lint.ignore = [ - "B017", + "ARG002", "D100", "D101", "D102", diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 4aa858ef..6e78b423 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -38,7 +38,6 @@ def __init__( image_size: int = 4, num_classes: int = 10, num_images: int = 2, - **kwargs: Any, ) -> None: self.root = root self.train = train # training set or test set @@ -112,7 +111,6 @@ def __init__( in_features: int = 3, out_features: int = 10, num_samples: int = 2, - **kwargs: Any, ) -> None: self.root = root self.train = train # training set or test set @@ -169,7 +167,6 @@ def __init__( image_size: int = 4, num_classes: int = 10, num_images: int = 2, - **kwargs: Any, ) -> None: super().__init__() diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index e0d3232c..2e29e2b5 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -16,6 +16,7 @@ def __init__( last_layer: nn.Module, ) -> None: super().__init__() + self.in_channels = in_channels self.dropout_rate = dropout_rate if with_linear: diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index 6989230a..3fd48ebf 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -45,7 +45,7 @@ def test_masked_50(self): _ = net(torch.rand(1, 3, 40, 40)) def test_masked_errors(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = ResNetBaseline( num_classes=10, in_channels=3, @@ -58,7 +58,7 @@ def test_masked_errors(self): groups=1, ) - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = ResNetBaseline( num_classes=10, in_channels=3, diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index 34ed4a6f..c8331119 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -50,7 +50,7 @@ def test_packed_18(self): _ = net(torch.rand(1, 3, 40, 40)) def test_packed_exception(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = ResNetBaseline( num_classes=10, in_channels=3, @@ -64,7 +64,7 @@ def test_packed_exception(self): groups=1, ) - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = ResNetBaseline( num_classes=10, in_channels=3, diff --git a/tests/datamodules/classification/test_cifar10_datamodule.py b/tests/datamodules/classification/test_cifar10_datamodule.py index 405dc6b9..df12f214 100644 --- a/tests/datamodules/classification/test_cifar10_datamodule.py +++ b/tests/datamodules/classification/test_cifar10_datamodule.py @@ -10,10 +10,6 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" def test_cifar10_main(self): - # parser = ArgumentParser() - # parser = CIFAR10DataModule.add_argparse_args(parser) - - # Simulate that cutout is set to 16 dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR10 @@ -67,8 +63,6 @@ def test_cifar10_main(self): dm.setup() dm.train_dataloader() - # args.cutout = 8 - # args.auto_augment = "rand-m9-n2-mstd0.5" with pytest.raises(ValueError): dm = CIFAR10DataModule( root="./data/", diff --git a/tests/layers/test_mask.py b/tests/layers/test_mask.py index 972d3f7f..bf8e2c2d 100644 --- a/tests/layers/test_mask.py +++ b/tests/layers/test_mask.py @@ -34,7 +34,7 @@ def test_linear_one_estimator(self, feat_input_odd: torch.Tensor): def test_linear_two_estimators_odd(self, feat_input_odd: torch.Tensor): layer = MaskedLinear(10, 2, num_estimators=2, scale=2) - with pytest.raises(Exception): + with pytest.raises(RuntimeError): _ = layer(feat_input_odd) def test_linear_two_estimators_even(self, feat_input_even: torch.Tensor): diff --git a/tests/metrics/test_brier_score.py b/tests/metrics/test_brier_score.py index 559109e3..b0d0b03f 100644 --- a/tests/metrics/test_brier_score.py +++ b/tests/metrics/test_brier_score.py @@ -193,10 +193,12 @@ def test_compute_3d_to_2d( assert metric.compute() == 0.5 def test_bad_input(self) -> None: - with pytest.raises(Exception): + with pytest.raises(ValueError): metric = BrierScore(num_classes=2, reduction="none") metric.update(torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2)) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises( + ValueError, match="Expected argument `reduction` to be one of" + ): _ = BrierScore(num_classes=2, reduction="geometric_mean") diff --git a/tests/metrics/test_entropy.py b/tests/metrics/test_entropy.py index 0239e13c..558184c4 100644 --- a/tests/metrics/test_entropy.py +++ b/tests/metrics/test_entropy.py @@ -83,5 +83,5 @@ def test_compute_3d_to_2d(self, vec3d: torch.Tensor): assert res == math.log(2) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = Entropy("geometric_mean") diff --git a/tests/metrics/test_mutual_information.py b/tests/metrics/test_mutual_information.py index 14acf90f..99dde31d 100644 --- a/tests/metrics/test_mutual_information.py +++ b/tests/metrics/test_mutual_information.py @@ -55,5 +55,5 @@ def test_compute_mixed( assert res[1] == pytest.approx(math.log(2), 1e-5) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = MutualInformation("geometric_mean") diff --git a/tests/metrics/test_nll.py b/tests/metrics/test_nll.py index bfd60bc9..316d6c93 100644 --- a/tests/metrics/test_nll.py +++ b/tests/metrics/test_nll.py @@ -28,7 +28,7 @@ def test_compute_zero(self) -> None: assert torch.all(res_sum == torch.zeros(1)) def test_bad_argument(self) -> None: - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = CategoricalNLL(reduction="geometric_mean") diff --git a/tests/metrics/test_variation_ratio.py b/tests/metrics/test_variation_ratio.py index 263b25f9..10936f2d 100644 --- a/tests/metrics/test_variation_ratio.py +++ b/tests/metrics/test_variation_ratio.py @@ -52,5 +52,5 @@ def test_compute_disagreement( assert res == pytest.approx(0.8, 1e-6) def test_bad_argument(self): - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = VariationRatio(reduction="geometric_mean") diff --git a/tests/test_utils.py b/tests/test_utils.py index fa89df36..0ce5c482 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import pytest import torch +from huggingface_hub.utils._errors import RepositoryNotFoundError from torch.distributions import Laplace, Normal from torch_uncertainty.utils import ( @@ -23,7 +24,7 @@ def test_get_version_log_success(self): get_version("tests/testlog", version=42, checkpoint=45) def test_getversion_log_failure(self): - with pytest.raises(Exception): + with pytest.raises(FileNotFoundError): get_version("tests/testlog", version=52) @@ -36,7 +37,7 @@ def test_hub_exists(self): hub.load_hf("test", version=2) def test_hub_notexists(self): - with pytest.raises(Exception): + with pytest.raises(RepositoryNotFoundError): hub.load_hf("tests") with pytest.raises(ValueError): diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index f4b3d573..9c429ea1 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -199,5 +199,6 @@ def __init__( log_plots=log_plots, save_in_csv=save_in_csv, calibration_set=calibration_set, + eval_grouping_loss=eval_grouping_loss, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 0f00b855..ffda0d48 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -61,7 +61,6 @@ def __init__( calibration_set: Literal["val", "test"] | None = None, eval_ood: bool = False, eval_grouping_loss: bool = False, - # pretrained: bool = False, ) -> None: r"""Wide-ResNet28x10 backbone baseline for classification providing support for various versions. diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 0d8f96dc..66571959 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -65,6 +65,7 @@ def prepare_data(self) -> None: """Download the dataset.""" self.dataset(root=self.root, download=True) + # ruff: noqa: ARG002 def setup(self, stage: str | None = None) -> None: """Split the datasets into train, val, and test.""" full = self.dataset( @@ -82,12 +83,3 @@ def setup(self, stage: str | None = None) -> None: ) if self.val_split == 0: self.val = self.test - - # Change by default test_dataloader -> List[DataLoader] - # def test_dataloader(self) -> DataLoader: - # """Get the test dataloader for UCI Regression. - - # Return: - # DataLoader: UCI Regression test dataloader. - # """ - # return self._data_loader(self.test) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index c98563e8..10f9f230 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -1,8 +1,8 @@ from collections.abc import Callable from pathlib import Path -from typing import Any import numpy as np +from torch import Tensor from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, @@ -169,7 +169,7 @@ def __len__(self) -> int: """The number of samples in the dataset.""" return self.labels.shape[0] - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index: int) -> tuple[np.ndarray | Tensor, int]: """Get the samples and targets of the dataset. Args: diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index 4c0ae7f6..168f8571 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -70,9 +70,7 @@ def __init__( def _check_specific_integrity(self) -> bool: filename, md5 = self.h_test_list fpath = self.root / filename - if not check_integrity(fpath, md5): - return False - return True + return check_integrity(fpath, md5) def download_h(self) -> None: download_url( diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_n.py b/torch_uncertainty/datasets/classification/cifar/cifar_n.py index 56a74704..069a081a 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_n.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_n.py @@ -70,9 +70,7 @@ def __init__( def _check_specific_integrity(self) -> bool: filename, md5 = self.n_test_list fpath = self.root / filename - if not check_integrity(fpath, md5): - return False - return True + return check_integrity(fpath, md5) def download_n(self) -> None: download_and_extract_archive( @@ -124,9 +122,7 @@ def __init__( def _check_specific_integrity(self) -> bool: filename, md5 = self.n_test_list fpath = self.root / filename - if not check_integrity(fpath, md5): - return False - return True + return check_integrity(fpath, md5) def download_n(self) -> None: download_and_extract_archive( diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index a33ae285..891bfb9a 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -49,6 +49,7 @@ def __init__( self.download() self.root = Path(root) + self.split = split if not self._check_integrity(): raise RuntimeError( diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py index 5dd08ffe..553fbd1b 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py @@ -51,7 +51,7 @@ def make_dataset(self) -> None: self.samples = samples self.label_data = torch.as_tensor(labels).long() - def _add_channels(self, img): + def _add_channels(self, img: np.ndarray) -> np.ndarray: while len(img.shape) < 3: # third axis is the channels img = np.expand_dims(img, axis=-1) while (img.shape[-1]) < 3: @@ -78,7 +78,7 @@ def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: return sample, target - def _make_paths(self): + def _make_paths(self) -> list[tuple[Path, int]]: self.ids = [] with self.wnids_path.open() as idf: for nid in idf: @@ -103,7 +103,7 @@ def _make_paths(self): label_id = self.ids.index(nid) with anno_path.open() as annof: for line in annof: - fname, x0, y0, x1, y1 = line.split() + fname, _, _, _, _ = line.split() fname = imgs_path / fname paths.append((fname, label_id)) @@ -111,7 +111,7 @@ def _make_paths(self): val_path = self.root / "val" with (val_path / "val_annotations.txt").open() as valf: for line in valf: - fname, nid, x0, y0, x1, y1 = line.split() + fname, nid, _, _, _, _ = line.split() fname = val_path / "images" / fname label_id = self.ids.index(nid) paths.append((fname, label_id)) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index 3a72e404..65febcf9 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -163,9 +163,7 @@ def __getitem__(self, index: int) -> Any: def _check_integrity(self) -> bool: """Check the integrity of the dataset.""" fpath = self.root / self.filename - if not check_integrity(fpath, self.zip_md5): - return False - return True + return check_integrity(fpath, self.zip_md5) def download(self) -> None: """Download the dataset.""" diff --git a/torch_uncertainty/datasets/classification/openimage_o.py b/torch_uncertainty/datasets/classification/openimage_o.py index 4ec40f50..2cc9104a 100644 --- a/torch_uncertainty/datasets/classification/openimage_o.py +++ b/torch_uncertainty/datasets/classification/openimage_o.py @@ -46,6 +46,7 @@ def __init__( Wang H., et al. In CVPR 2022. """ self.root = Path(root) + self.split = split self.transform = transform self.target_transform = target_transform @@ -78,4 +79,3 @@ def download(self) -> None: filename=self.filename, md5=self.md5sum, ) - print(f"Downloaded {self.filename} to {self.root}") diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 0f5f2ed6..477d9116 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -182,7 +182,7 @@ def _make_dataset(self, path: Path) -> None: "if you need it." ) - def _download(self, split: str): + def _download(self, split: str) -> None: """Download and extract the chosen split of the dataset.""" split_url = self.base_url + split + ".zip" download_and_extract_archive( diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index ab3fd160..59722abc 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -174,11 +174,11 @@ def _check_integrity(self) -> bool: self.md5, ) - def _standardize(self): + def _standardize(self) -> None: self.data = (self.data - self.data_mean) / self.data_std self.targets = (self.targets - self.target_mean) / self.target_std - def _compute_statistics(self): + def _compute_statistics(self) -> None: self.data_mean = self.data.mean(axis=0) self.data_std = self.data.std(axis=0) self.data_std[self.data_std == 0] = 1 @@ -253,8 +253,6 @@ def _make_dataset(self) -> None: ) # convert Ex to 10^x and remove second target array = df.apply(pd.to_numeric, errors="coerce").to_numpy()[:, :-1] - # elif self.dataset_name == "power-plant": - # array = pd.read_excel(path / "Folds5x2_pp.xlsx").to_numpy() elif self.dataset_name == "protein": array = pd.read_csv( path / "CASP.csv", diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 13d94fee..5a25c821 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -129,7 +129,6 @@ def __init__( ] ) - # self.transforms = transforms self.split = split if split is not None else "all" def encode_target(self, target: Image.Image) -> torch.Tensor: @@ -215,10 +214,7 @@ def _check_integrity(self) -> bool: ): return False - if not (Path(self.root) / "camvid" / "splits.json").exists(): - return False - - return True + return (Path(self.root) / "camvid" / "splits.json").exists() def download(self) -> None: """Download the CamVid data if it doesn't exist already.""" diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index 10b83a62..b4169ae5 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -110,10 +110,6 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see - # https://github.com/pytorch/pytorch/issues/57109 - # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) if self.bias is not None: @@ -335,12 +331,6 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(k), 1/sqrt(k)), where - # k = weight.size(1) * prod(*kernel_size) - # For more details see: - # https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 - # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) if self.bias is not None: @@ -408,7 +398,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: bias if bias is not None else 0 ) - def extra_repr(self): + def extra_repr(self) -> str: s = ( "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", num_estimators={num_estimators}, stride={stride}" diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index 95060560..d9fc4df4 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -178,7 +178,7 @@ def sample(self) -> tuple[Tensor, Tensor | None]: bias = self.bias_sampler.sample() if self.bias_mu is not None else None return weight, bias - def extra_repr(self): # coverage: ignore + def extra_repr(self) -> str: # coverage: ignore s = ( "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", stride={stride}" @@ -197,7 +197,7 @@ def extra_repr(self): # coverage: ignore s += ", padding_mode={padding_mode}" return s.format(**self.__dict__) - def __setstate__(self, state): + def __setstate__(self, state) -> None: super().__setstate__(state) if not hasattr(self, "padding_mode"): # coverage: ignore self.padding_mode = "zeros" diff --git a/torch_uncertainty/layers/bayesian/bayes_linear.py b/torch_uncertainty/layers/bayesian/bayes_linear.py index f722f842..074f8554 100644 --- a/torch_uncertainty/layers/bayesian/bayes_linear.py +++ b/torch_uncertainty/layers/bayesian/bayes_linear.py @@ -113,7 +113,7 @@ def forward(self, inputs: Tensor) -> Tensor: return self._frozen_forward(inputs) return self._forward(inputs) - def _frozen_forward(self, inputs): + def _frozen_forward(self, inputs) -> Tensor: return F.linear(inputs, self.weight_mu, self.bias_mu) def _forward(self, inputs: Tensor) -> Tensor: diff --git a/torch_uncertainty/layers/mc_batch_norm.py b/torch_uncertainty/layers/mc_batch_norm.py index 1dd5907a..9a68e633 100644 --- a/torch_uncertainty/layers/mc_batch_norm.py +++ b/torch_uncertainty/layers/mc_batch_norm.py @@ -101,7 +101,7 @@ class MCBatchNorm1d(_MCBatchNorm): Check MCBatchNorm in torch_uncertainty/post_processing/. """ - def _check_input_dim(self, inputs): + def _check_input_dim(self, inputs) -> None: if inputs.dim() != 2 and inputs.dim() != 3: raise ValueError( f"expected 2D or 3D input (got {inputs.dim()}D input)" @@ -127,7 +127,7 @@ class MCBatchNorm2d(_MCBatchNorm): Check MCBatchNorm in torch_uncertainty/post_processing/. """ - def _check_input_dim(self, inputs): + def _check_input_dim(self, inputs) -> None: if inputs.dim() != 3 and inputs.dim() != 4: raise ValueError( f"expected 3D or 4D input (got {inputs.dim()}D input)" @@ -153,7 +153,7 @@ class MCBatchNorm3d(_MCBatchNorm): Check MCBatchNorm in torch_uncertainty/post_processing/. """ - def _check_input_dim(self, inputs): + def _check_input_dim(self, inputs) -> None: if inputs.dim() != 4 and inputs.dim() != 5: raise ValueError( f"expected 4D or 5D input (got {inputs.dim()}D input)" diff --git a/torch_uncertainty/layers/modules.py b/torch_uncertainty/layers/modules.py index c2e9a6e3..a5b9b18e 100644 --- a/torch_uncertainty/layers/modules.py +++ b/torch_uncertainty/layers/modules.py @@ -4,6 +4,7 @@ class Identity(nn.Module): + # ruff: noqa: ARG002 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__() diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index ed552924..55aeb91a 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -120,7 +120,7 @@ def set_model(self, model: nn.Module) -> None: def _elbo_loss_checks( inner_loss: nn.Module, kl_weight: float, num_samples: int -): +) -> None: if isinstance(inner_loss, type): raise TypeError( "The inner_loss should be an instance of a class." @@ -375,7 +375,7 @@ def forward( raise NotImplementedError( "DECLoss does not yet support mixup/cutmix." ) - # else: # TODO: handle binary + # TODO: handle binary targets = F.one_hot(targets, num_classes=evidence.size()[-1]) if self.loss_type == "mse": diff --git a/torch_uncertainty/metrics/fpr95.py b/torch_uncertainty/metrics/fpr95.py index 1f351aed..2108406c 100644 --- a/torch_uncertainty/metrics/fpr95.py +++ b/torch_uncertainty/metrics/fpr95.py @@ -82,7 +82,6 @@ def compute(self) -> Tensor: in_scores = conf[np.logical_not(out_labels)] out_scores = conf[out_labels] - # pos = OOD neg = np.array(in_scores[:]).reshape((-1, 1)) pos = np.array(out_scores[:]).reshape((-1, 1)) examples = np.squeeze(np.vstack((pos, neg))) diff --git a/torch_uncertainty/metrics/grouping_loss.py b/torch_uncertainty/metrics/grouping_loss.py index 062bff45..da9eab41 100644 --- a/torch_uncertainty/metrics/grouping_loss.py +++ b/torch_uncertainty/metrics/grouping_loss.py @@ -7,7 +7,9 @@ class GLEstimator(GLEstimatorBase): - def fit(self, probs: Tensor, targets: Tensor, features: Tensor): + def fit( + self, probs: Tensor, targets: Tensor, features: Tensor + ) -> "GLEstimator": probs = probs.detach().cpu().numpy() features = features.detach().cpu().numpy() targets = (targets * 1).detach().cpu().numpy() diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/sparsification.py index a647f182..822cc0cf 100644 --- a/torch_uncertainty/metrics/sparsification.py +++ b/torch_uncertainty/metrics/sparsification.py @@ -19,7 +19,7 @@ class AUSE(Metric): scores: list[Tensor] errors: list[Tensor] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: """The Area Under the Sparsification Error curve (AUSE) metric to estimate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors. diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index a76beb99..637e43d6 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -62,7 +62,7 @@ def deep_ensembles( task: Literal[ "classification", "regression", "segmentation" ] = "classification", - probabilistic=None, + probabilistic: bool | None = None, reset_model_parameters: bool = False, ) -> _DeepEnsembles: """Build a Deep Ensembles out of the original models. @@ -91,10 +91,8 @@ def deep_ensembles( Simple and scalable predictive uncertainty estimation using deep ensembles. In NeurIPS, 2017. """ - if ( - isinstance(models, list) - and len(models) == 1 - or isinstance(models, nn.Module) + if (isinstance(models, list) and len(models) == 1) or isinstance( + models, nn.Module ): if num_estimators is None: raise ValueError( diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 61832ce7..fcb9663e 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -10,7 +10,7 @@ from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear from torch_uncertainty.models.utils import stochastic_model -__all__ = ["lenet", "packed_lenet", "bayesian_lenet"] +__all__ = ["bayesian_lenet", "lenet", "packed_lenet"] class _LeNet(nn.Module): diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index deac7f23..a822343d 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -7,7 +7,7 @@ from torch_uncertainty.layers.packed import PackedLinear from torch_uncertainty.models.utils import stochastic_model -__all__ = ["mlp", "packed_mlp", "bayesian_mlp"] +__all__ = ["bayesian_mlp", "mlp", "packed_mlp"] class _MLP(nn.Module): diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index fcdc1a5b..52b3fc1f 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -300,7 +300,7 @@ def _make_layer( self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: out = x.repeat(self.num_estimators, 1, 1, 1) out = F.relu(self.bn1(self.conv1(out))) out = self.optional_pool(out) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 8c8ce3c3..ea61606d 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -105,7 +105,7 @@ def __init__( num_estimators=num_estimators, scale=scale, groups=groups, - bias=False, + bias=conv_bias, ) self.bn1 = normalization_layer(planes) self.conv2 = MaskedConv2d( @@ -117,7 +117,7 @@ def __init__( stride=stride, padding=1, groups=groups, - bias=False, + bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) self.bn2 = normalization_layer(planes) @@ -128,7 +128,7 @@ def __init__( num_estimators=num_estimators, scale=scale, groups=groups, - bias=False, + bias=conv_bias, ) self.bn3 = normalization_layer(self.expansion * planes) @@ -143,7 +143,7 @@ def __init__( scale=scale, stride=stride, groups=groups, - bias=False, + bias=conv_bias, ), normalization_layer(self.expansion * planes), ) diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index 148740d1..0eeea7ba 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -143,6 +143,7 @@ def forward(self, x: Tensor) -> Tensor: return self.activation_fn(out) +# ruff: noqa: ERA001 # class Robust_Bottleneck(nn.Module): # """Robust _Bottleneck from "Can CNNs be more robust than transformers?" # This corresponds to ResNet-Up-Inverted-DW in the paper. diff --git a/torch_uncertainty/models/segmentation/segformer/std.py b/torch_uncertainty/models/segmentation/segformer/std.py index 83c86e02..3881e055 100644 --- a/torch_uncertainty/models/segmentation/segformer/std.py +++ b/torch_uncertainty/models/segmentation/segformer/std.py @@ -471,11 +471,10 @@ def forward_features(self, x): def forward(self, x): return self.forward_features(x) - # x = self.head(x) class MitB0(MixVisionTransformer): - def __init__(self, **kwargs): + def __init__(self): super().__init__( patch_size=4, embed_dims=[32, 64, 160, 256], @@ -491,7 +490,7 @@ def __init__(self, **kwargs): class MitB1(MixVisionTransformer): - def __init__(self, **kwargs): + def __init__(self): super().__init__( patch_size=4, embed_dims=[64, 128, 320, 512], @@ -507,7 +506,7 @@ def __init__(self, **kwargs): class MitB2(MixVisionTransformer): - def __init__(self, **kwargs): + def __init__(self): super().__init__( patch_size=4, embed_dims=[64, 128, 320, 512], @@ -523,7 +522,7 @@ def __init__(self, **kwargs): class MitB3(MixVisionTransformer): - def __init__(self, **kwargs): + def __init__(self): super().__init__( patch_size=4, embed_dims=[64, 128, 320, 512], @@ -539,7 +538,7 @@ def __init__(self, **kwargs): class MitB4(MixVisionTransformer): - def __init__(self, **kwargs): + def __init__(self): super().__init__( patch_size=4, embed_dims=[64, 128, 320, 512], @@ -555,7 +554,7 @@ def __init__(self, **kwargs): class MitB5(MixVisionTransformer): - def __init__(self, **kwargs): + def __init__(self): super().__init__( patch_size=4, embed_dims=[64, 128, 320, 512], @@ -629,7 +628,7 @@ def __init__( assert min(feature_strides) == feature_strides[0] self.feature_strides = feature_strides self.num_classes = num_classes - # self.in_index = [0, 1, 2, 3], + # --- self in_index [0, 1, 2, 3] ( c1_in_channels, @@ -671,10 +670,10 @@ def __init__( self.dropout = None def forward(self, inputs): - # x = [inputs[i] for i in self.in_index] # len=4, 1/4,1/8,1/16,1/32 + # x [inputs[i] for i in self.in_index] # len=4, 1/4,1/8,1/16,1/32 c1, c2, c3, c4 = inputs[0], inputs[1], inputs[2], inputs[3] - n, _, h, w = c4.shape + n, _, _, _ = c4.shape _c4 = ( self.linear_c4(c4) diff --git a/torch_uncertainty/models/wideresnet/std.py b/torch_uncertainty/models/wideresnet/std.py index 3c943eaa..3e14b2c8 100644 --- a/torch_uncertainty/models/wideresnet/std.py +++ b/torch_uncertainty/models/wideresnet/std.py @@ -163,7 +163,7 @@ def _wide_layer( num_blocks: int, dropout_rate: float, stride: int, - groups, + groups: int, ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index 7ee39586..a413b02c 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -10,8 +10,8 @@ "optim_cifar10_resnet18", "optim_cifar10_resnet34", "optim_cifar10_resnet50", - "optim_cifar10_wideresnet", "optim_cifar10_vgg16", + "optim_cifar10_wideresnet", "optim_cifar100_resnet18", "optim_cifar100_resnet34", "optim_cifar100_resnet50", diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index c2d0c368..f4141214 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -96,7 +96,7 @@ def calib_eval() -> float: def forward(self, inputs: Tensor) -> Tensor: if not self.trained: print( - "TemperatureScaler has not been trained yet. Returning a " + "TemperatureScaler has not been trained yet. Returning " "manually tempered inputs." ) return self._scale(self.model(inputs)) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index ec2276dd..d99fdd7c 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -62,7 +62,7 @@ def __init__( "model does not contain any MCBatchNorm2d after conversion." ) - def fit(self, dataset: Dataset): + def fit(self, dataset: Dataset) -> None: """Fit the model on the dataset. Args: diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 61e76f8a..23877f06 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -265,12 +265,12 @@ def init_mixup( ) return Identity() - def configure_optimizers(self): + def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: - init_metrics = {k: 0 for k in self.val_cls_metrics} - init_metrics.update({k: 0 for k in self.test_cls_metrics}) + init_metrics = dict.fromkeys(self.val_cls_metrics, 0) + init_metrics.update(dict.fromkeys(self.test_cls_metrics, 0)) if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( @@ -579,7 +579,7 @@ def _classification_routine_checks( num_estimators: int, ood_criterion: str, eval_grouping_loss: bool, -): +) -> None: if not isinstance(num_estimators, int) or num_estimators < 1: raise ValueError( "The number of estimators must be a positive integer >= 1." diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 1cefa4f7..59af9e92 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -79,12 +79,12 @@ def __init__( self.one_dim_regression = output_dim == 1 - def configure_optimizers(self): + def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: - init_metrics = {k: 0 for k in self.val_metrics} - init_metrics.update({k: 0 for k in self.test_metrics}) + init_metrics = dict.fromkeys(self.val_metrics, 0) + init_metrics.update(dict.fromkeys(self.test_metrics, 0)) if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( @@ -204,7 +204,7 @@ def on_test_epoch_end(self) -> None: self.test_metrics.reset() -def _regression_routine_checks(num_estimators, output_dim): +def _regression_routine_checks(num_estimators: int, output_dim: int) -> None: if num_estimators < 1: raise ValueError( f"num_estimators must be positive, got {num_estimators}." diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index d9e28065..75b86770 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -62,15 +62,15 @@ def __init__( self.val_seg_metrics = seg_metrics.clone(prefix="val/") self.test_seg_metrics = seg_metrics.clone(prefix="test/") - def configure_optimizers(self): + def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def forward(self, img: Tensor) -> Tensor: return self.model(img) def on_train_start(self) -> None: - init_metrics = {k: 0 for k in self.val_seg_metrics} - init_metrics.update({k: 0 for k in self.test_seg_metrics}) + init_metrics = dict.fromkeys(self.val_seg_metrics, 0) + init_metrics.update(dict.fromkeys(self.test_seg_metrics, 0)) self.logger.log_hyperparams(self.hparams, init_metrics) @@ -131,7 +131,7 @@ def on_test_epoch_end(self) -> None: self.test_seg_metrics.reset() -def _segmentation_routine_checks(num_estimators, num_classes): +def _segmentation_routine_checks(num_estimators: int, num_classes: int) -> None: if num_estimators < 1: raise ValueError( f"num_estimators must be positive, got {num_estimators}." diff --git a/torch_uncertainty/transforms/corruptions.py b/torch_uncertainty/transforms/corruptions.py index 193e5898..5235c6fe 100644 --- a/torch_uncertainty/transforms/corruptions.py +++ b/torch_uncertainty/transforms/corruptions.py @@ -24,16 +24,16 @@ from torch_uncertainty.datasets import FrostImages __all__ = [ - "GaussianNoise", - "ShotNoise", - "ImpulseNoise", - "SpeckleNoise", + "DefocusBlur", + "Frost", "GaussianBlur", + "GaussianNoise", "GlassBlur", - "DefocusBlur", + "ImpulseNoise", "JPEGCompression", "Pixelate", - "Frost", + "ShotNoise", + "SpeckleNoise", ] diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 4447b2bc..64c57cac 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -16,6 +16,7 @@ def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: return 1 / (dist_rate + 1e-12) +# ruff: noqa: ERA001 # def tensor_linspace(start: Tensor, stop: Tensor, num: int): # """ # Creates a tensor of shape [num, *start.shape] whose values are evenly diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 62a66bdd..5cbdb5a5 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -116,7 +116,14 @@ class NormalInverseGamma(Distribution): support = constraints.real has_rsample = False - def __init__(self, loc, lmbda, alpha, beta, validate_args=None): + def __init__( + self, + loc: Number | Tensor, + lmbda: Number | Tensor, + alpha: Number | Tensor, + beta: Number | Tensor, + validate_args: bool | None = None, + ) -> None: self.loc, self.lmbda, self.alpha, self.beta = broadcast_all( loc, lmbda, alpha, beta ) @@ -132,7 +139,7 @@ def __init__(self, loc, lmbda, alpha, beta, validate_args=None): super().__init__(batch_shape, validate_args=validate_args) @property - def mean(self): + def mean(self) -> Tensor: """Impromper mean of the NormalInverseGamma distribution. This value is necessary to perform point-wise predictions in the @@ -140,17 +147,17 @@ def mean(self): """ return self.loc - def mode(self): + def mode(self) -> None: raise NotImplementedError( "Mode is not meaningful for the NormalInverseGamma distribution" ) - def stddev(self): + def stddev(self) -> None: raise NotImplementedError( "Standard deviation is not meaningful for the NormalInverseGamma distribution" ) - def variance(self): + def variance(self) -> None: raise NotImplementedError( "Variance is not meaningful for the NormalInverseGamma distribution" ) diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index 9a134ccf..ab5d697d 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -75,7 +75,7 @@ def create_train_val_split( dataset: Dataset, val_split_rate: float, val_transforms: Callable | None = None, -): +) -> tuple[Dataset, Dataset]: train, val = random_split(dataset, [1 - val_split_rate, val_split_rate]) val = copy.deepcopy(val) val.dataset.transform = val_transforms From ff2974407d9ec6ae2c2dc949f2f76cc04cadcd26 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 24 Mar 2024 11:19:50 +0100 Subject: [PATCH 117/148] :bug: Fix tests --- tests/_dummies/dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 6e78b423..a227ecdd 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -38,6 +38,7 @@ def __init__( image_size: int = 4, num_classes: int = 10, num_images: int = 2, + **args, ) -> None: self.root = root self.train = train # training set or test set @@ -111,6 +112,7 @@ def __init__( in_features: int = 3, out_features: int = 10, num_samples: int = 2, + **args, ) -> None: self.root = root self.train = train # training set or test set @@ -167,6 +169,7 @@ def __init__( image_size: int = 4, num_classes: int = 10, num_images: int = 2, + **args, ) -> None: super().__init__() From fe9e1d03061e49cbc8c121ddcbf22dd7b35c4848 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 24 Mar 2024 12:25:16 +0100 Subject: [PATCH 118/148] :bug: Remove randomness from GL tests --- tests/metrics/test_grouping_loss.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/metrics/test_grouping_loss.py b/tests/metrics/test_grouping_loss.py index 7d2fec31..ffce34b1 100644 --- a/tests/metrics/test_grouping_loss.py +++ b/tests/metrics/test_grouping_loss.py @@ -10,25 +10,27 @@ class TestGroupingLoss: def test_compute(self): metric = GroupingLoss() metric.update( - torch.rand(100), - (torch.rand(100) > 0.3).long(), - torch.rand((100, 10)), + torch.cat([torch.tensor([0, 1, 0, 1]), torch.ones(200) / 10]), + torch.cat( + [torch.tensor([0, 0, 1, 1]), torch.zeros(100), torch.ones(100)] + ).long(), + torch.cat([torch.zeros((104, 10)), torch.ones((100, 10))]), ) metric.compute() metric = GroupingLoss() metric.update( - torch.ones((100, 4, 10)) / 10, - torch.arange(100), - torch.rand((100, 4, 10)), + torch.ones((200, 4, 10)), + torch.cat([torch.arange(100), torch.arange(100)]), + torch.cat([torch.zeros((100, 4, 10)), torch.ones((100, 4, 10))]), ) metric.compute() metric.reset() metric.update( - torch.ones((100, 10)) / 10, - torch.nn.functional.one_hot(torch.arange(100)), - torch.rand((100, 10)), + torch.ones((200, 10)) / 10, + torch.nn.functional.one_hot(torch.arange(200)), + torch.cat([torch.zeros((100, 10)), torch.ones((1004, 10))]), ) def test_errors(self): From 2b3704bec0e6d741b07f989aa0be4b772a446f07 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 24 Mar 2024 12:57:07 +0100 Subject: [PATCH 119/148] :books: Slightly improve documentation --- .github/workflows/build-docs.yml | 2 +- .github/workflows/run-tests.yml | 2 +- README.md | 16 +++--- docs/source/index.rst | 4 +- docs/source/quickstart.rst | 89 ++++++++++++++++++++++++++++++-- 5 files changed, 97 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index d8db9c81..f28dcb77 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -30,7 +30,7 @@ jobs: run: | echo "PYTHON_VERSION=$(python -c "import platform; print(platform.python_version())")" - - name: Cache folder for Torch Uncertainty + - name: Cache folder for TorchUncertainty uses: actions/cache@v3 id: cache-folder with: diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 308a3f84..c404a106 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -52,7 +52,7 @@ jobs: LICENSE .gitignore - - name: Cache folder for Torch Uncertainty + - name: Cache folder for TorchUncertainty if: steps.changed-files-specific.outputs.only_changed != 'true' uses: actions/cache@v4 id: cache-folder diff --git a/README.md b/README.md index a49c775c..f84e4d3e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@
-![Torch Uncertainty Logo](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/blob/main/docs/source/_static/images/torch_uncertainty.png) +![TorchUncertaintyLogo](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/blob/main/docs/source/_static/images/torch_uncertainty.png) [![pypi](https://img.shields.io/pypi/v/torch_uncertainty.svg)](https://pypi.python.org/pypi/torch_uncertainty) [![tests](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/actions/workflows/run-tests.yml/badge.svg?branch=main&event=push)](https://github.com/ENSTA-U2IS-AI/torch-uncertainty/actions/workflows/run-tests.yml) @@ -15,6 +15,8 @@ _TorchUncertainty_ is a package designed to help you leverage [uncertainty quant :construction: _TorchUncertainty_ is in early development :construction: - expect changes, but reach out and contribute if you are interested in the project! **Please raise an issue if you have any bugs or difficulties and join the [discord server](https://discord.gg/HMCawt5MJu).** +Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). + --- This package provides a multi-level API, including: @@ -38,13 +40,11 @@ pip install torch-uncertainty The installation procedure for contributors is different: have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). -## Getting Started and Documentation - -Please find the documentation at [torch-uncertainty.github.io](https://torch-uncertainty.github.io). +## :racehorse: Quckstart -A quickstart is available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). +We make a quickstart available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). -## Implemented methods +## :books: Implemented methods TorchUncertainty currently supports **Classification**, **probabilistic** and pointwise **Regression** and **Segmentation**. @@ -57,7 +57,7 @@ To date, the following deep learning baselines have been implemented: - BatchEnsemble - Masksembles - MIMO -- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) +- Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) - Regression with Beta Gaussian NLL Loss - Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) @@ -77,7 +77,7 @@ To date, the following post-processing methods have been implemented: ## Tutorials -The following tutorials willWe provide the following tutorials in our documentation: +Our documentation contains the following tutorials: - [From a Standard Classifier to a Packed-Ensemble](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - [Training a Bayesian Neural Network in 3 minutes](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) diff --git a/docs/source/index.rst b/docs/source/index.rst index aebd2502..ff18d0a6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,9 +1,9 @@ -.. Torch Uncertainty documentation master file, created by +.. TorchUncertainty documentation master file, created by sphinx-quickstart on Wed Feb 1 18:07:01 2023. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Torch Uncertainty +TorchUncertainty ================= .. role:: bash(code) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 2d3182ac..8997e546 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -4,10 +4,91 @@ Quickstart .. role:: bash(code) :language: bash -Torch Uncertainty comes with different usage levels ranging from specific -PyTorch layers to ready to train Lightning-based models. The following -presents a short introduction to each one of them. Let's start with the -highest-level usage. +TorchUncertainty is centered around **uncertainty-aware** training and evaluation routines. +These routines make it very easy to: + +- train ensembles-like methods (Deep Ensembles, Packed-Ensembles, MIMO, Masksembles, etc) +- compute and monitor uncertainty metrics: calibration, out-of-distribution detection, proper scores, grouping loss, etc. +- leverage calibration methods automatically during evaluation + +Yet, we take account that their will be as many different uses of TorchUncertainty as there are of users. +This page provides ideas on how to benefit from TorchUncertainty at all levels: from ready-to-train lightning-based models to using only specific +PyTorch layers. + +Training with TorchUncertainty's Uncertainty-aware Routines +----------------------------------------------------------- + +Let's have a look at the `Classification routine `_. + +.. code:: python + from lightning.pytorch import LightningModule + + class ClassificationRoutine(LightningModule): + def __init__( + self, + model: nn.Module, + num_classes: int, + loss: nn.Module, + num_estimators: int = 1, + format_batch_fn: nn.Module | None = None, + optim_recipe: dict | Optimizer | None = None, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, + eval_ood: bool = False, + eval_grouping_loss: bool = False, + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + ) -> None: + + +Building your First Routine +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +This routine is a wrapper of any custom or TorchUncertainty classification model. To use it, +just build your model and pass it to the routine as argument along with the optimization criterion (the loss) +as well as the number of classes that we use for torch metrics. + +.. code:: python + model = MyModel(num_classes=10) + routine = ClassificationRoutine(model, num_classes=10, loss=nn.CrossEntropyLoss()) + + +Training with the Routine +^^^^^^^^^^^^^^^^^^^^^^^^^ + +To train with this routine, you will first need to create a lightning Trainer and have either a lightning datamodule +or PyTorch dataloaders. When benchmarking models, we advise to use lightning datamodules that will automatically handle +train/val/test splits, out-of-distribution detection and dataset shift. For this example, let us use TorchUncertainty's +CIFAR10 datamodule. Please keep in mind that you could use your own datamodule or dataloaders. + +.. code:: python + from torch_uncertainty.datamodules import CIFAR10DataModule + from pytorch_lightning import Trainer + + dm = CIFAR10DataModule(root="data", batch_size=32) + trainer = Trainer(gpus=1, max_epochs=100) + trainer.fit(routine, dm) + trainer.eval(routine, dm) + +Here it is, you have trained your first model with TorchUncertainty! As a result, you will get access to various metrics +measuring the ability of your model to handle uncertainty. + +More metrics +^^^^^^^^^^^^ + +With TorchUncertainty datamodules, you can easily test models on out-of-distribution datasets, by +setting the `eval_ood` parameter to True. You can also evaluate the grouping loss by setting `eval_grouping_loss` to True. +Finally, you can calibrate your model using the `calibration_set` parameter. In this case, you will get +metrics for but the uncalibrated and calibrated models: the metrics corresponding to the temperature scaled +model will begin with `ts_`. Using the Lightning-based CLI tool ---------------------------------- From f7299e8a2edd16196c796900e2d0618748541ab9 Mon Sep 17 00:00:00 2001 From: alafage Date: Sun, 24 Mar 2024 19:08:53 +0100 Subject: [PATCH 120/148] :bulb: Add docstrings for Segmentation datamodules --- docs/source/api.rst | 2 +- torch_uncertainty/datamodules/__init__.py | 2 +- .../datamodules/segmentation/camvid.py | 48 ++++++++++-- .../datamodules/segmentation/cityscapes.py | 74 +++++++++++++++++++ .../datamodules/segmentation/muad.py | 74 +++++++++++++++++++ .../datasets/segmentation/camvid.py | 29 ++++---- 6 files changed, 207 insertions(+), 22 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 39099970..5a01cad4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -244,4 +244,4 @@ Segmentation CamVidDataModule CityscapesDataModule - + MUADDataModule diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index 24859701..d1ab23b4 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -4,5 +4,5 @@ from .classification.imagenet import ImageNetDataModule from .classification.mnist import MNISTDataModule from .classification.tiny_imagenet import TinyImageNetDataModule -from .segmentation import CamVidDataModule, CityscapesDataModule +from .segmentation import CamVidDataModule, CityscapesDataModule, MUADDataModule from .uci_regression import UCIDataModule diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index cf248669..cf722978 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -9,6 +9,46 @@ class CamVidDataModule(AbstractDataModule): + r"""DataModule for the CamVid dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose( + [ + v2.Resize((360, 480)), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + ] + ) + + + + """ + def __init__( self, root: str | Path, @@ -30,9 +70,7 @@ def __init__( self.train_transform = v2.Compose( [ - v2.Resize( - (360, 480), interpolation=v2.InterpolationMode.NEAREST - ), + v2.Resize((360, 480)), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, @@ -45,9 +83,7 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.Resize( - (360, 480), interpolation=v2.InterpolationMode.NEAREST - ), + v2.Resize((360, 480)), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 57e0db0d..79693cbc 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -13,6 +13,80 @@ class CityscapesDataModule(AbstractDataModule): + r"""DataModule for the Cityscapes dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + Training transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + Validation/Test transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + v2.Resize(size=inference_size, antialias=True), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ + def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index acf0535c..910d5da2 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -13,6 +13,80 @@ class MUADDataModule(AbstractDataModule): + r"""DataModule for the MUAD dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + Training transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + Validation/Test transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + v2.Resize(size=inference_size, antialias=True), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ + def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 5a25c821..6cc068ea 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -23,6 +23,21 @@ class CamVidClass(NamedTuple): class CamVid(VisionDataset): + """`CamVid `_ Dataset. + + Args: + root (str): Root directory of dataset where ``camvid/`` exists or + will be saved to if download is set to ``True``. + split (str, optional): The dataset split, supports ``train``, + ``val`` and ``test``. Default: ``None``. + transforms (callable, optional): A function/transform that takes + input sample and its target as entry and returns a transformed + version. Default: ``None``. + download (bool, optional): If true, downloads the dataset from the + internet and puts it in root directory. If dataset is already + downloaded, it is not downloaded again. + """ + # Notes: some classes are not used here classes = [ CamVidClass("sky", 0, (128, 128, 128)), @@ -67,20 +82,6 @@ def __init__( transforms: Callable | None = None, download: bool = False, ) -> None: - """`CamVid `_ Dataset. - - Args: - root (str): Root directory of dataset where ``camvid/`` exists or - will be saved to if download is set to ``True``. - split (str, optional): The dataset split, supports ``train``, - ``val`` and ``test``. Default: ``None``. - transforms (callable, optional): A function/transform that takes - input sample and its target as entry and returns a transformed - version. Default: ``None``. - download (bool, optional): If true, downloads the dataset from the - internet and puts it in root directory. If dataset is already - downloaded, it is not downloaded again. - """ if split not in ["train", "val", "test", None]: raise ValueError( f"Unknown split '{split}'. " From 32cd8adedef20927b0713d9b850c569e3b71a586 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 01:04:02 +0100 Subject: [PATCH 121/148] :sparkles: Add CLI guide in the documentation - Small docstring and documentation fixes --- docs/source/cli_guide.rst | 251 ++++++++++++++++++ docs/source/index.rst | 1 + docs/source/quickstart.rst | 166 +++++------- .../cifar10/configs/resnet.yaml | 3 +- .../datamodules/segmentation/camvid.py | 79 +++--- .../datamodules/segmentation/cityscapes.py | 147 +++++----- .../datamodules/segmentation/muad.py | 147 +++++----- .../datasets/segmentation/camvid.py | 29 +- torch_uncertainty/layers/batch_ensemble.py | 18 +- torch_uncertainty/routines/classification.py | 50 ++-- torch_uncertainty/routines/regression.py | 12 +- 11 files changed, 561 insertions(+), 342 deletions(-) create mode 100644 docs/source/cli_guide.rst diff --git a/docs/source/cli_guide.rst b/docs/source/cli_guide.rst new file mode 100644 index 00000000..ec111cf1 --- /dev/null +++ b/docs/source/cli_guide.rst @@ -0,0 +1,251 @@ +CLI Guide +========= + +Introduction to the Lightning CLI +--------------------------------- + +The Lightning CLI tool eases the implementation of a CLI to instanciate models to train and evaluate them on +some data. The CLI tool is a wrapper around the ``Trainer`` class and provides a set of subcommands to train +and test a ``LightningModule`` on a ``LightningDataModule``. To better match our needs, we created an inherited +class from the ``LightningCLI`` class, namely ``TULightningCLI``. + +.. note:: + ``TULightningCLI`` adds a new argument to the ``LightningCLI`` class: :attr:`eval_after_fit` to know whether + an evaluation on the test set should be performed after the training phase. + +Let's see how to implement the CLI, by checking out the ``experiments/classification/cifar10/resnet.py``. + +.. code:: python + + import torch + from lightning.pytorch.cli import LightningArgumentParser + + from torch_uncertainty.baselines.classification import ResNetBaseline + from torch_uncertainty.datamodules import CIFAR10DataModule + from torch_uncertainty.utils import TULightningCLI + + + class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) + + + def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, CIFAR10DataModule) + + + if __name__ == "__main__": + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") + +This file enables both training and testing ResNet architectures on the CIFAR-10 dataset. +The ``ResNetCLI`` class inherits from the ``TULightningCLI`` class and implements the +``add_arguments_to_parser`` method to add the optimizer and learning rate scheduler arguments +into the parser. In this case, we use the ``torch.optim.SGD`` optimizer and the +``torch.optim.lr_scheduler.MultiStepLR`` learning rate scheduler. + +.. code:: python + + class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) + +The ``LightningCLI`` takes a ``LightningModule`` and a ``LightningDataModule`` as arguments. +Here the ``cli_main`` function creates an instance of the ``ResNetCLI`` class by taking the ``ResNetBaseline`` +model and the ``CIFAR10DataModule`` as arguments. + +.. code:: python + + def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, CIFAR10DataModule) + +.. note:: + + The ``ResNetBaseline`` is a subclass of the ``ClassificationRoutine`` seemlessly instanciating a + ResNet model based on a :attr:`version` and an :attr:`arch` to be passed to the routine. + +Depending on the CLI subcommand calling ``cli_main()`` will either train or test the model on the using +the CIFAR-10 dataset. But what are these subcommands? + +.. code:: bash + + python resnet.py --help + +This command will display the available subcommands of the CLI tool. + +.. code:: bash + + subcommands: + For more details of each subcommand, add it as an argument followed by --help. + + Available subcommands: + fit Runs the full optimization routine. + validate Perform one evaluation epoch over the validation set. + test Perform one evaluation epoch over the test set. + predict Run inference on your data. + +You can execute whichever subcommand you like and set up all your hyperparameters directly using the command line + +.. code:: bash + + python resnet.py fit --trainer.max_epochs 75 --trainer.accelerators gpu --trainer.devices 1 --model.version std --model.arch 18 --model.in_channels 3 --model.num_classes 10 --model.loss CrossEntropyLoss --model.style cifar --data.root ./data --data.batch_size 128 --optimizer.lr 0.05 --lr_scheduler.milestones [25,50] + +All arguments in the ``__init__()`` methods of the ``Trainer``, ``LightningModule`` (here ``ResNetBaseline``), +``LightningDataModule`` (here ``CIFAR10DataModule``), ``torch.optim.SGD``, and ``torch.optim.lr_scheduler.MultiStepLR`` +classes are configurable using the CLI tool using the ``--trainer``, ``--model``, ``--data``, ``--optimizer``, and +``--lr_scheduler`` prefixes, respectively. + +However for a large number of hyperparameters, it is not practical to pass them all in the command line. +It is more convenient to use configuration files to store these hyperparameters and ease the burden of +repeating them each time you want to train or test a model. Let's see how to do that. + +.. note:: + + Note that ``Pytorch`` classes are supported by the CLI tool, so you can use them directly: ``--model.loss CrossEntropyLoss`` + and they would be automatically instanciated by the CLI tool with their default arguments (i.e., ``CrossEntropyLoss()``). + +.. tip:: + + Add the following after calling ``cli=cli_main()`` to eventually evaluate the model on the test set + after training, if the ``eval_after_fit`` argument is set to ``True`` and ``trainer.fast_dev_run`` + is set to ``False``. + + .. code:: python + + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") + +Configuration files +------------------- + +By default the ``LightningCLI`` support configuration files in the YAML format (learn more about this format +`here `_). +Taking the previous example, we can create a configuration file named ``config.yaml`` with the following content: + +.. code:: yaml + + # config.yaml + eval_after_fit: true + trainer: + max_epochs: 75 + accelerators: gpu + devices: 1 + model: + version: std + arch: 18 + in_channels: 3 + num_classes: 10 + loss: CrossEntropyLoss + style: cifar + data: + root: ./data + batch_size: 128 + optimizer: + lr: 0.05 + lr_scheduler: + milestones: + - 25 + - 50 + +Then, we can run the following command to train the model: + +.. code:: bash + + python resnet.py fit --config config.yaml + +By default, executing the command above will store the experiment results in a directory named ``lightning_logs``, +and the last state of the model will be saved in a directory named ``lightning_logs/version_{int}/checkpoints``. +In addition, all arguments passed to instanciate the ``Trainer``, ``ResNetBaseline``, ``CIFAR10DataModule``, +``torch.optim.SGD``, and ``torch.optim.lr_scheduler.MultiStepLR`` classes will be saved in a file named +``lightning_logs/version_{int}/config.yaml``. When testing the model, we advise to use this configuration file +to ensure that the same hyperparameters are used for training and testing. + +.. code:: bash + + python resnet.py test --config lightning_logs/version_{int}/config.yaml --ckpt_path lightning_logs/version_{int}/checkpoints/{filename}.ckpt + +Experiment folder usage +----------------------- + +Now that we have seen how to implement the CLI tool and how to use configuration files, let explore the +configurations available in the ``experiments`` directory. The ``experiments`` directory is +mainly organized as follows: + +.. code:: bash + + experiments + ├── classification + │ ├── cifar10 + │ │ ├── configs + │ │ ├── resnet.py + │ │ ├── vgg.py + │ │ └── wideresnet.py + │ └── cifar100 + │ ├── configs + │ ├── resnet.py + │ ├── vgg.py + │ └── wideresnet.py + ├── regression + │ └── uci_datasets + │ ├── configs + │ └── mlp.py + └── segmentation + ├── cityscapes + │ ├── configs + │ └── segformer.py + └── muad + ├── configs + └── segformer.py + +For each task (**classification**, **regression**, and **segmentation**), we have a directory containing the datasets +(e.g., CIFAR10, CIFAR100, UCI datasets, Cityscapes, and Muad) and for each dataset, we have a directory containing +the configuration files and the CLI files for different backbones. + +You can directly use the CLI files with the command line or use the predefined configuration files to train and test +the models. The configuration files are stored in the ``configs``. For example, the configuration file for the classic +ResNet-18 model on the CIFAR-10 dataset is stored in the ``experiments/classification/cifar10/configs/resnet18/standard.yaml`` +file. For the Packed ResNet-18 model on the CIFAR-10 dataset, the configuration file is stored in the +``experiments/classification/cifar10/configs/resnet18/packed.yaml`` file. + +If you are interested in using a ResNet model but want to choose some of the hyperparameters using the command line, +you can use the configuration file and override the hyperparameters using the command line. For example, to train +a ResNet-18 model on the CIFAR-10 dataset with a batch size of :math:`256`, you can use the following command: + +.. code:: bash + + python resnet.py fit --config configs/resnet18/standard.yaml --data.batch_size 256 + +To use the weights argument of the ``torch.nn.CrossEntropyLoss`` class, you can use the following command: + +.. code:: bash + + python resnet.py fit --config configs/resnet18/standard.yaml --model.loss CrossEntropyLoss --model.loss.weight Tensor --model.loss.weight.dict_kwargs.data [1,2,3,4,5,6,7,8,9,10] + + +In addition, we provide a default configuration file for some backbones in the ``configs`` directory. For example, +``experiments/classification/cifar10/configs/resnet.yaml`` contains the default hyperparameters to train a ResNet model +on the CIFAR-10 dataset. Yet, some hyperparameters are purposely missing to be set by the user using the command line. + +For instance, to train a Packed ResNet-34 model on the CIFAR-10 dataset with :math:`4` estimators and a :math:`\alpha` value of :math:`2`, +you can use the following command: + +.. code:: bash + + python resnet.py fit --config configs/resnet.yaml --trainer.max_epochs 75 --model.version packed --model.arch 34 --model.num_estimators 4 --model.alpha 2 --optimizer.lr 0.05 --lr_scheduler.milestones [25,50] + + +.. tip:: + + Explore the `Lightning CLI docs `_ to learn more about the CLI tool, + the available arguments, and how to use them with configuration files. diff --git a/docs/source/index.rst b/docs/source/index.rst index ff18d0a6..09a9d53e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -45,6 +45,7 @@ the following paper: quickstart introduction_uncertainty auto_tutorials/index + cli_guide api contributing references diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8997e546..18960abb 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -21,44 +21,54 @@ Training with TorchUncertainty's Uncertainty-aware Routines Let's have a look at the `Classification routine `_. .. code:: python - from lightning.pytorch import LightningModule - - class ClassificationRoutine(LightningModule): - def __init__( - self, - model: nn.Module, - num_classes: int, - loss: nn.Module, - num_estimators: int = 1, - format_batch_fn: nn.Module | None = None, - optim_recipe: dict | Optimizer | None = None, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, - eval_ood: bool = False, - eval_grouping_loss: bool = False, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", - log_plots: bool = False, - save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, - ) -> None: + + from lightning.pytorch import LightningModule + + class ClassificationRoutine(LightningModule): + def __init__( + self, + model: nn.Module, + num_classes: int, + loss: nn.Module, + num_estimators: int = 1, + format_batch_fn: nn.Module | None = None, + optim_recipe: dict | Optimizer | None = None, + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, + mixup_alpha: float = 0, + cutmix_alpha: float = 0, + eval_ood: bool = False, + eval_grouping_loss: bool = False, + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + ) -> None: + ... Building your First Routine ^^^^^^^^^^^^^^^^^^^^^^^^^^^ This routine is a wrapper of any custom or TorchUncertainty classification model. To use it, -just build your model and pass it to the routine as argument along with the optimization criterion (the loss) -as well as the number of classes that we use for torch metrics. +just build your model and pass it to the routine as argument along with an optimization recipe +and the loss as well as the number of classes that we use for torch metrics. .. code:: python + + from torch import nn, optim + model = MyModel(num_classes=10) - routine = ClassificationRoutine(model, num_classes=10, loss=nn.CrossEntropyLoss()) + routine = ClassificationRoutine( + model, + num_classes=10, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim.Adam(model.parameters(), lr=1e-3), + ) Training with the Routine @@ -70,13 +80,14 @@ train/val/test splits, out-of-distribution detection and dataset shift. For this CIFAR10 datamodule. Please keep in mind that you could use your own datamodule or dataloaders. .. code:: python + from torch_uncertainty.datamodules import CIFAR10DataModule from pytorch_lightning import Trainer dm = CIFAR10DataModule(root="data", batch_size=32) trainer = Trainer(gpus=1, max_epochs=100) trainer.fit(routine, dm) - trainer.eval(routine, dm) + trainer.test(routine, dm) Here it is, you have trained your first model with TorchUncertainty! As a result, you will get access to various metrics measuring the ability of your model to handle uncertainty. @@ -85,89 +96,44 @@ More metrics ^^^^^^^^^^^^ With TorchUncertainty datamodules, you can easily test models on out-of-distribution datasets, by -setting the `eval_ood` parameter to True. You can also evaluate the grouping loss by setting `eval_grouping_loss` to True. -Finally, you can calibrate your model using the `calibration_set` parameter. In this case, you will get +setting the ``eval_ood`` parameter to ``True``. You can also evaluate the grouping loss by setting ``eval_grouping_loss`` to ``True``. +Finally, you can calibrate your model using the ``calibration_set`` parameter. In this case, you will get metrics for but the uncalibrated and calibrated models: the metrics corresponding to the temperature scaled -model will begin with `ts_`. +model will begin with ``ts_``. + +---- -Using the Lightning-based CLI tool +Using the Lightning CLI tool ---------------------------------- Procedure ^^^^^^^^^ -The library provides a full-fledged trainer which can be used directly, via -CLI. To do so, create a file in the experiments folder and use the `cli_main` -routine, which takes as arguments: - -* a Lightning Module corresponding to the model, its own arguments, and - forward/validation/test logic. For instance, you might use already available - modules, such as the Packed-Ensembles-style ResNet available at - `torch_uncertainty/baselines/packed/resnet.py `_ -* a Lightning DataModule corresponding to the training, validation, and test - sets with again its arguments and logic. CIFAR-10/100, ImageNet, and - ImageNet-200 are available, for instance. -* a PyTorch loss such as the torch.nn.CrossEntropyLoss -* a dictionary containing the optimization recipe, namely a scheduler and - an optimizer. Many procedures are available at - `torch_uncertainty/optim_recipes.py `_ - -* the path to the data and logs folder, in the example below, the root of the library -* and finally, the name of your model (used for logs) - -Move to the directory containing your file and execute the code with :bash:`python3 experiment.py`. -Add lightning arguments such as :bash:`--accelerator gpu --devices "0, 1" --benchmark True` -for multi-gpu training and cuDNN benchmark, etc. - -Example -^^^^^^^ - -The following code - `available in the experiments folder `_ - -trains any ResNet architecture on CIFAR10: - -.. code:: python +The library leverages the `Lightning CLI tool `_ +to provide a simple way to train models and evaluate them, while insuring reproducibility via configuration files. +Under the ``experiment`` folder, you will find scripts for the three application tasks covered by the library: +classification, regression and segmentation. Take the most out of the CLI by checking our `CLI Guide `_. - from pathlib import Path +.. note:: - from torch import nn + In particular, the ``experiments/classification`` folder contains scripts to reproduce the experiments covered + in the paper: *Packed-Ensembles for Efficient Uncertainty Estimation*, O. Laurent & A. Lafage, et al., in ICLR 2023. - from torch_uncertainty import cli_main, init_args - from torch_uncertainty.baselines import ResNet - from torch_uncertainty.datamodules import CIFAR10DataModule - from torch_uncertainty.optim_recipes import get_procedure - root = Path(__file__).parent.absolute().parents[1] - args = init_args(ResNet, CIFAR10DataModule) - - net_name = f"{args.version}-resnet{args.arch}-cifar10" - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss(), - optim_recipe=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - style="cifar", - **vars(args), - ) - - cli_main(model, dm, args.exp_dir, args.exp_name, args) +Example +^^^^^^^ -Run this model with, for instance: +Training a model with the Lightning CLI tool is as simple as running the following command: .. code:: bash - python3 resnet.py --version std --arch 18 --accelerator gpu --device 1 --benchmark True --max_epochs 75 --precision 16 + # in pyjam/experiments/classification/cifar10 + python resnet.py fit --config configs/resnet18/standard.yaml + +Which trains a classic ResNet18 model on CIFAR10 with the settings used in the Packed-Ensembles paper. -You may replace the architecture (which should be a Lightning Module), the -Datamodule (a Lightning Datamodule), the loss or the optimization recipe to your likings. +---- Using the PyTorch-based models ------------------------------ @@ -199,6 +165,8 @@ backbone with the following code: num_classes = 10, ) +---- + Using the PyTorch-based layers ------------------------------ @@ -216,7 +184,7 @@ issue on the GitHub repository! .. tip:: - Do not hesitate to go to the API reference to get better explanations on the + Do not hesitate to go to the `API Reference `_ to get better explanations on the layer usage. Example @@ -259,6 +227,8 @@ code: packed_net = PackedNet() +---- + Other usage ----------- diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index 2ba51027..dbbe41e9 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -27,8 +27,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss style: cifar data: root: ./data diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index cf722978..4a4aee65 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -9,46 +9,6 @@ class CamVidDataModule(AbstractDataModule): - r"""DataModule for the CamVid dataset. - - Args: - root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. - val_split (float or None, optional): Share of training samples to use - for validation. Defaults to ``None``. - num_workers (int, optional): Number of dataloaders to use. Defaults to - ``1``. - pin_memory (bool, optional): Whether to pin memory. Defaults to - ``True``. - persistent_workers (bool, optional): Whether to use persistent workers. - Defaults to ``True``. - - Note: - This datamodule injects the following transforms into the training and - validation/test datasets: - - .. code-block:: python - - from torchvision.transforms import v2 - - v2.Compose( - [ - v2.Resize((360, 480)), - v2.ToDtype( - dtype={ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None, - }, - scale=True, - ), - ] - ) - - - - """ - def __init__( self, root: str | Path, @@ -58,6 +18,45 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, ) -> None: + r"""DataModule for the CamVid dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose( + [ + v2.Resize((360, 480)), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None, + }, + scale=True, + ), + ] + ) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ super().__init__( root=root, batch_size=batch_size, diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 79693cbc..08a036c1 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -13,80 +13,6 @@ class CityscapesDataModule(AbstractDataModule): - r"""DataModule for the Cityscapes dataset. - - Args: - root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. - crop_size (sequence or int, optional): Desired input image and - segmentation mask sizes during training. If :attr:`crop_size` is an - int instead of sequence like :math:`(H, W)`, a square crop - :math:`(\text{size},\text{size})` is made. If provided a sequence - of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and - segmentation mask sizes during inference. If size is an int, - smaller edge of the images will be matched to this number, i.e., - :math:`\text{height}>\text{width}`, then image will be rescaled to - :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(1024,2048)``. - val_split (float or None, optional): Share of training samples to use - for validation. Defaults to ``None``. - num_workers (int, optional): Number of dataloaders to use. Defaults to - ``1``. - pin_memory (bool, optional): Whether to pin memory. Defaults to - ``True``. - persistent_workers (bool, optional): Whether to use persistent workers. - Defaults to ``True``. - - - Note: - This datamodule injects the following transforms into the training and - validation/test datasets: - - Training transforms: - - .. code-block:: python - - from torchvision.transforms import v2 - - v2.Compose([ - v2.ToImage(), - RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop(size=crop_size, pad_if_needed=True), - v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), - v2.RandomHorizontalFlip(), - v2.ToDtype({ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None - }, scale=True), - v2.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - ]) - - Validation/Test transforms: - - .. code-block:: python - - from torchvision.transforms import v2 - - v2.Compose([ - v2.ToImage(), - v2.Resize(size=inference_size, antialias=True), - v2.ToDtype({ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None - }, scale=True), - v2.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - ]) - - This behavior can be modified by overriding ``self.train_transform`` - and ``self.test_transform`` after initialization. - """ - def __init__( self, root: str | Path, @@ -98,6 +24,79 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, ) -> None: + r"""DataModule for the Cityscapes dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + Training transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + Validation/Test transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + v2.Resize(size=inference_size, antialias=True), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ super().__init__( root=root, batch_size=batch_size, diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 910d5da2..dae5d20e 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -13,80 +13,6 @@ class MUADDataModule(AbstractDataModule): - r"""DataModule for the MUAD dataset. - - Args: - root (str or Path): Root directory of the datasets. - batch_size (int): Number of samples per batch. - crop_size (sequence or int, optional): Desired input image and - segmentation mask sizes during training. If :attr:`crop_size` is an - int instead of sequence like :math:`(H, W)`, a square crop - :math:`(\text{size},\text{size})` is made. If provided a sequence - of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and - segmentation mask sizes during inference. If size is an int, - smaller edge of the images will be matched to this number, i.e., - :math:`\text{height}>\text{width}`, then image will be rescaled to - :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(1024,2048)``. - val_split (float or None, optional): Share of training samples to use - for validation. Defaults to ``None``. - num_workers (int, optional): Number of dataloaders to use. Defaults to - ``1``. - pin_memory (bool, optional): Whether to pin memory. Defaults to - ``True``. - persistent_workers (bool, optional): Whether to use persistent workers. - Defaults to ``True``. - - - Note: - This datamodule injects the following transforms into the training and - validation/test datasets: - - Training transforms: - - .. code-block:: python - - from torchvision.transforms import v2 - - v2.Compose([ - v2.ToImage(), - RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop(size=crop_size, pad_if_needed=True), - v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), - v2.RandomHorizontalFlip(), - v2.ToDtype({ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None - }, scale=True), - v2.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - ]) - - Validation/Test transforms: - - .. code-block:: python - - from torchvision.transforms import v2 - - v2.Compose([ - v2.ToImage(), - v2.Resize(size=inference_size, antialias=True), - v2.ToDtype({ - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.int64, - "others": None - }, scale=True), - v2.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - ]) - - This behavior can be modified by overriding ``self.train_transform`` - and ``self.test_transform`` after initialization. - """ - def __init__( self, root: str | Path, @@ -98,6 +24,79 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, ) -> None: + r"""DataModule for the MUAD dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + + + Note: + This datamodule injects the following transforms into the training and + validation/test datasets: + + Training transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop(size=crop_size, pad_if_needed=True), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), + v2.RandomHorizontalFlip(), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + Validation/Test transforms: + + .. code-block:: python + + from torchvision.transforms import v2 + + v2.Compose([ + v2.ToImage(), + v2.Resize(size=inference_size, antialias=True), + v2.ToDtype({ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.int64, + "others": None + }, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + This behavior can be modified by overriding ``self.train_transform`` + and ``self.test_transform`` after initialization. + """ super().__init__( root=root, batch_size=batch_size, diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 6cc068ea..5a25c821 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -23,21 +23,6 @@ class CamVidClass(NamedTuple): class CamVid(VisionDataset): - """`CamVid `_ Dataset. - - Args: - root (str): Root directory of dataset where ``camvid/`` exists or - will be saved to if download is set to ``True``. - split (str, optional): The dataset split, supports ``train``, - ``val`` and ``test``. Default: ``None``. - transforms (callable, optional): A function/transform that takes - input sample and its target as entry and returns a transformed - version. Default: ``None``. - download (bool, optional): If true, downloads the dataset from the - internet and puts it in root directory. If dataset is already - downloaded, it is not downloaded again. - """ - # Notes: some classes are not used here classes = [ CamVidClass("sky", 0, (128, 128, 128)), @@ -82,6 +67,20 @@ def __init__( transforms: Callable | None = None, download: bool = False, ) -> None: + """`CamVid `_ Dataset. + + Args: + root (str): Root directory of dataset where ``camvid/`` exists or + will be saved to if download is set to ``True``. + split (str, optional): The dataset split, supports ``train``, + ``val`` and ``test``. Default: ``None``. + transforms (callable, optional): A function/transform that takes + input sample and its target as entry and returns a transformed + version. Default: ``None``. + download (bool, optional): If true, downloads the dataset from the + internet and puts it in root directory. If dataset is already + downloaded, it is not downloaded again. + """ if split not in ["train", "val", "test", None]: raise ValueError( f"Unknown split '{split}'. " diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index b4169ae5..f616aff3 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -26,7 +26,7 @@ def __init__( dtype=None, ) -> None: r"""Applies a linear transformation using BatchEnsemble method to the - incoming data: :math:`y=(x\circ \hat{r_{group}})W^{T}\circ \hat{s_{group}} + \hat{b}`. + incoming data: :math:`y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b}`. Args: in_features (int): size of each input sample. @@ -70,9 +70,9 @@ def __init__( Shape: - Input: :math:`(N, H_{in})` where :math:`N` is the batch size and - :math:`H_{in} = \text{in_features}`. + :math:`H_{in} = \text{in_features}`. - Output: :math:`(N, H_{out})` where - :math:`H_{out} = \text{out_features}`. + :math:`H_{out} = \text{out_features}`. Warning: Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`. @@ -208,12 +208,12 @@ def __init__( :math:`(N, C_{out}, H_{out}, W_{out})` can be precisely described as: .. math:: - \text{out}(N_i, C_{\text{out}_j})=\ - &\hat{b}(N_i,C_{\text{out}_j}) - +\hat{s_group}(N_{i},C_{\text{out}_j}) \\ - &\times \sum_{k = 0}^{C_{\text{in}} - 1} - \text{weight}(C_{\text{out}_j}, k)\star (\text{input}(N_i, k) - \times \hat{r_group}(N_i, k)) + \text{out}(N_i, C_{\text{out}_j})=\ + &\widehat{b}(N_i,C_{\text{out}_j}) + +\widehat{s_{group}}(N_{i},C_{\text{out}_j}) \\ + &\times \sum_{k = 0}^{C_{\text{in}} - 1} + \text{weight}(C_{\text{out}_j}, k)\star (\text{input}(N_i, k) + \times \widehat{r_{group}}(N_i, k)) Reference: Introduced by the paper `BatchEnsemble: An Alternative Approach to diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 23877f06..09cb0c0d 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -60,51 +60,53 @@ def __init__( save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, ) -> None: - """Classification routine for Lightning. + r"""Routine for efficient training and testing on **classification tasks** + using LightningModule. Args: - model (nn.Module): Model to train. + model (torch.nn.Module): Model to train. num_classes (int): Number of classes. - loss (type[nn.Module]): Loss function to optimize the :attr:`model`. + loss (torch.nn.Module): Loss function to optimize the :attr:`model`. num_estimators (int, optional): Number of estimators for the - ensemble. Defaults to 1 (single model). - format_batch_fn (nn.Module, optional): Function to format the batch. + ensemble. Defaults to ``1`` (single model). + format_batch_fn (torch.nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. - optim_recipe (dict | Optimizer, optional): The optimizer and + optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. mixtype (str, optional): Mixup type. Defaults to ``"erm"``. mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to 1.0. + Defaults to ``1.0``. kernel_tau_std (float, optional): Standard deviation for the kernel tau. - Defaults to 0.5. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to 0. + Defaults to ``0.5``. + mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to ``0``. cutmix_alpha (float, optional): Alpha parameter for Cutmix. - Defaults to 0. + Defaults to ``0``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection performance or not. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. - ood_criterion (str, optional): OOD criterion. Defaults to ``"msp"``. - MSP is the maximum softmax probability, logit is the maximum - logit, energy the logsumexp of the mean logits, entropy the - entropy of the mean prediction, mi is the mutual information of - the ensemble and vr is the variation ratio of the ensemble. + ood_criterion (str, optional): OOD criterion. Available options are + + - ``"msp"`` (default): Maximum softmax probability. + - ``"logit"``: Maximum logit. + - ``"energy"``: Logsumexp of the mean logits. + - ``"entropy"``: Entropy of the mean prediction. + - ``"mi"``: Mutual information of the ensemble. + - ``"vr"``: Variation ratio of the ensemble. + log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. save_in_csv(bool, optional): Save the results in csv. Defaults to ``False``. - calibration_set (Callable, optional): Function to get the calibration - set. Defaults to ``None``. - - Warning: - You must define :attr:`optim_recipe` if you do not use - the CLI. + calibration_set (str, optional): The calibration dataset to use for + scaling. If not ``None``, it uses either the validation set when + set to ``"val"`` or the test set when set to ``"test"``. + Defaults to ``None``. Warning: - You must provide a datamodule to the trainer or use the CLI for if - :attr:`calibration_set` is not ``None``. + You must define :attr:`optim_recipe` if you do not use the CLI. """ super().__init__() _classification_routine_checks( @@ -308,7 +310,7 @@ def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: not. Defaults to ``False``. Note: - The features are stored in the :attr:`features` attribute. + The features are stored in the :attr:`self.features` attribute. """ if save_feats: self.features = self.model.feats_forward(inputs) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 59af9e92..8e7bc21b 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -29,21 +29,21 @@ def __init__( """Regression routine for Lightning. Args: - model (nn.Module): Model to train. + model (torch.nn.Module): Model to train. probabilistic (bool): Whether the model is probabilistic, i.e., outputs a PyTorch distribution. output_dim (int): Number of outputs of the model. - loss (type[nn.Module]): Loss function to optimize the :attr:`model`. + loss (torch.nn.Module): Loss function to optimize the :attr:`model`. num_estimators (int, optional): The number of estimators for the ensemble. Defaults to 1 (single model). - optim_recipe (dict | Optimizer, optional): The optimizer and + optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. - format_batch_fn (nn.Module, optional): The function to format the - batch. Defaults to None. + format_batch_fn (torch.nn.Module, optional): The function to format the + batch. Defaults to ``None``. Warning: If :attr:`probabilistic` is True, the model must output a `PyTorch - distribution _`. + distribution `_. Warning: You must define :attr:`optim_recipe` if you do not use From 41df36205ccdf97a837a6ff80be56789e9a35832 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 25 Mar 2024 09:36:20 +0100 Subject: [PATCH 122/148] :bug: Fix typos --- auto_tutorials_source/tutorial_der_cubic.py | 2 +- auto_tutorials_source/tutorial_mc_batch_norm.py | 2 +- torch_uncertainty/metrics/entropy.py | 2 +- torch_uncertainty/metrics/variation_ratio.py | 4 ++-- torch_uncertainty/transforms/mixup.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index b6bc91e1..b77a0a4d 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -5,7 +5,7 @@ This tutorial provides an introduction to probabilistic regression in TorchUncertainty. More specifically, we present Deep Evidential Regression (DER) using a practical example. We demonstrate an application of DER by tackling the toy-problem of fitting :math:`y=x^3` using a Multi-Layer Perceptron (MLP) neural network model. -The output layer of the MLP provides a NormalInverseGamma distribution which is used to optimize the model, trhough its negative log-likelihood. +The output layer of the MLP provides a NormalInverseGamma distribution which is used to optimize the model, through its negative log-likelihood. DER represents an evidential approach to quantifying epistemic and aleatoric uncertainty in neural network regression models. This method involves introducing prior distributions over the parameters of the Gaussian likelihood function. diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 0da3b4b2..217e69ed 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -20,7 +20,7 @@ - the classification training routine in the torch_uncertainty.training.classification module - an optimization recipe in the torch_uncertainty.optim_recipes module. -We also need import the neural network utils withing `torch.nn`. +We also need import the neural network utils within `torch.nn`. """ # %% from pathlib import Path diff --git a/torch_uncertainty/metrics/entropy.py b/torch_uncertainty/metrics/entropy.py index dabb4cb4..a7eae7f6 100644 --- a/torch_uncertainty/metrics/entropy.py +++ b/torch_uncertainty/metrics/entropy.py @@ -16,7 +16,7 @@ def __init__( **kwargs: Any, ) -> None: """The Shannon Entropy Metric to estimate the confidence of a single model - or the mean confidence accross estimators. + or the mean confidence across estimators. Args: reduction (str, optional): Determines how to reduce over the diff --git a/torch_uncertainty/metrics/variation_ratio.py b/torch_uncertainty/metrics/variation_ratio.py index faaddce3..cdf05c89 100644 --- a/torch_uncertainty/metrics/variation_ratio.py +++ b/torch_uncertainty/metrics/variation_ratio.py @@ -52,7 +52,7 @@ def compute(self) -> Tensor: n_estimators = probs_per_est.shape[1] probs = probs_per_est.mean(dim=1) - # best class for exemple + # best class for example max_classes = probs.argmax(dim=-1) if self.probabilistic: @@ -61,7 +61,7 @@ def compute(self) -> Tensor: torch.arange(probs_per_est.size(0)), max_classes ].mean(dim=1) else: - # best class for (exemple, estimator) + # best class for (example, estimator) max_classes_per_est = probs_per_est.argmax(dim=-1) variation_ratio = ( 1 diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 64c57cac..ec51ae34 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -21,7 +21,7 @@ def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: # """ # Creates a tensor of shape [num, *start.shape] whose values are evenly # spaced from start to end, inclusive. -# Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. +# Replicates but the multi-dimensional behaviour of numpy.linspace in PyTorch. # """ # # create a tensor of 'num' steps from 0 to 1 # steps = torch.arange(num, dtype=torch.float32, device=start.device) / ( @@ -32,7 +32,7 @@ def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: # # to allow for broadcastings # # using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here # # but torchscript -# # "cannot statically infer the expected size of a list in this contex", +# # "cannot statically infer the expected size of a list in this context", # # hence the code below # for i in range(start.ndim): # steps = steps.unsqueeze(-1) From 0a112052553a2502a5be3883ed241aba2abc8b5e Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 12:20:52 +0100 Subject: [PATCH 123/148] :bulb: Improve docstrings --- torch_uncertainty/layers/batch_ensemble.py | 13 +++++++++--- torch_uncertainty/metrics/sparsification.py | 2 +- torch_uncertainty/routines/classification.py | 5 +++++ torch_uncertainty/routines/regression.py | 14 +++++++++---- torch_uncertainty/routines/segmentation.py | 22 +++++++++++++------- 5 files changed, 40 insertions(+), 16 deletions(-) diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index f616aff3..6022f40b 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -25,8 +25,13 @@ def __init__( device=None, dtype=None, ) -> None: - r"""Applies a linear transformation using BatchEnsemble method to the - incoming data: :math:`y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b}`. + r"""BatchEnsemble-style Linear layer. + + Applies a linear transformation using BatchEnsemble method to the incoming + data. + + .. math:: + y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b} Args: in_features (int): size of each input sample. @@ -200,7 +205,9 @@ def __init__( device=None, dtype=None, ) -> None: - r"""Applies a 2d convolution over an input signal composed of several input + r"""BatchEnsemble-style Conv2d layer. + + Applies a 2d convolution over an input signal composed of several input planes using BatchEnsemble method to the incoming data. In the simplest case, the output value of the layer with input size diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/sparsification.py index 822cc0cf..c843f442 100644 --- a/torch_uncertainty/metrics/sparsification.py +++ b/torch_uncertainty/metrics/sparsification.py @@ -35,7 +35,7 @@ def __init__(self, **kwargs) -> None: Inputs: - :attr:`scores`: Uncertainty scores of shape :math:`(B,)`. A higher - score means a higher uncertainty. + score means a higher uncertainty. - :attr:`errors`: Binary errors of shape :math:`(B,)`, where :math:`B` is the batch size. diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 09cb0c0d..b8aa7bf1 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -107,6 +107,11 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the CLI. + + Note: + :attr:`optim_recipe` can be anything that can be returned by + :meth:`LightningModule.configure_optimizers()`. Find more details + `here `_. """ super().__init__() _classification_routine_checks( diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 8e7bc21b..8d6718cc 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -21,21 +21,22 @@ def __init__( model: nn.Module, output_dim: int, probabilistic: bool, - loss: type[nn.Module], + loss: nn.Module, num_estimators: int = 1, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, ) -> None: - """Regression routine for Lightning. + r"""Routine for efficient training and testing on **regression tasks** + using LightningModule. Args: model (torch.nn.Module): Model to train. + output_dim (int): Number of outputs of the model. probabilistic (bool): Whether the model is probabilistic, i.e., outputs a PyTorch distribution. - output_dim (int): Number of outputs of the model. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to 1 (single model). + ensemble. Defaults to ``1`` (single model). optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. format_batch_fn (torch.nn.Module, optional): The function to format the @@ -48,6 +49,11 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the CLI. + + Note: + :attr:`optim_recipe` can be anything that can be returned by + :meth:`LightningModule.configure_optimizers()`. Find more details + `here `_. """ super().__init__() _regression_routine_checks(num_estimators, output_dim) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 75b86770..b7b7149e 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -14,27 +14,33 @@ def __init__( self, model: nn.Module, num_classes: int, - loss: type[nn.Module], + loss: nn.Module, num_estimators: int = 1, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, ) -> None: - """Segmentation routine for Lightning. + """Routine for efficient training and testing on **segmentation tasks** + using LightningModule. Args: - model (nn.Module): Model to train. + model (torch.nn.Module): Model to train. num_classes (int): Number of classes in the segmentation task. - loss (type[nn.Module]): Loss function to optimize the :attr:`model`. + loss (torch.nn.Module): Loss function to optimize the :attr:`model`. num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to 1 (single model). - optim_recipe (dict | Optimizer, optional): The optimizer and + ensemble. Defaults to ̀`1̀` (single model). + optim_recipe (dict or Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. - format_batch_fn (nn.Module, optional): The function to format the - batch. Defaults to None. + format_batch_fn (torch.nn.Module, optional): The function to format the + batch. Defaults to ``None``. Warning: You must define :attr:`optim_recipe` if you do not use the CLI. + + Note: + :attr:`optim_recipe` can be anything that can be returned by + :meth:`LightningModule.configure_optimizers()`. Find more details + `here `_. """ super().__init__() _segmentation_routine_checks(num_estimators, num_classes) From 24bb1c540d8f1519837757f19041745b4c144d5c Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 12:22:22 +0100 Subject: [PATCH 124/148] :sparkles: TU has now its own Trainer for a more tailored result display --- torch_uncertainty/utils/__init__.py | 1 + torch_uncertainty/utils/cli.py | 4 +- torch_uncertainty/utils/evaluation_loop.py | 126 +++++++++++++++++++++ torch_uncertainty/utils/trainer.py | 19 ++++ 4 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 torch_uncertainty/utils/evaluation_loop.py create mode 100644 torch_uncertainty/utils/trainer.py diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index ad7a12db..de0547c7 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -3,3 +3,4 @@ from .cli import TULightningCLI from .hub import load_hf from .misc import create_train_val_split, csv_writer, plot_hist +from .trainer import TUTrainer diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 899bcee9..8b8659c4 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -12,6 +12,8 @@ ) from typing_extensions import override +from torch_uncertainty.utils.trainer import TUTrainer + class TUSaveConfigCallback(SaveConfigCallback): @override @@ -72,7 +74,7 @@ def __init__( save_config_callback: type[SaveConfigCallback] | None = TUSaveConfigCallback, save_config_kwargs: dict[str, Any] | None = None, - trainer_class: type[Trainer] | Callable[..., Trainer] = Trainer, + trainer_class: type[Trainer] | Callable[..., Trainer] = TUTrainer, trainer_defaults: dict[str, Any] | None = None, seed_everything_default: bool | int = True, parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py new file mode 100644 index 00000000..faf10e85 --- /dev/null +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -0,0 +1,126 @@ +import os +import shutil +import sys +from typing import Any + +from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE +from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop +from lightning.pytorch.trainer.connectors.logger_connector.result import ( + _OUT_DICT, +) +from lightning_utilities.core.apply_func import apply_to_collection +from torch import Tensor + + +class TUEvaluationLoop(_EvaluationLoop): + @staticmethod + def _print_results(results: list[_OUT_DICT], stage: str) -> None: + # remove the dl idx suffix + results = [ + {k.split("/dataloader_idx_")[0]: v for k, v in result.items()} + for result in results + ] + metrics_paths = { + k + for keys in apply_to_collection( + results, dict, _EvaluationLoop._get_keys + ) + for k in keys + } + if not metrics_paths: + return + + metrics_strs = [":".join(metric) for metric in metrics_paths] + # sort both lists based on metrics_strs + metrics_strs, metrics_paths = zip( + *sorted(zip(metrics_strs, metrics_paths, strict=False)), + strict=False, + ) + + if len(results) == 2: + headers = ["In-Distribution", "Out-of-Distribution"] + else: + headers = [f"DataLoader {i}" for i in range(len(results))] + + # fallback is useful for testing of printed output + term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 + max_length = int( + min( + max( + len(max(metrics_strs, key=len)), + len(max(headers, key=len)), + 25, + ), + term_size / 2, + ) + ) + + rows: list[list[Any]] = [[] for _ in metrics_paths] + + for result in results: + for metric, row in zip(metrics_paths, rows, strict=False): + val = _EvaluationLoop._find_value(result, metric) + if val is not None: + if isinstance(val, Tensor): + val = val.item() if val.numel() == 1 else val.tolist() + row.append(f"{val:.3f}") + else: + row.append(" ") + + # keep one column with max length for metrics + num_cols = int((term_size - max_length) / max_length) + + for i in range(0, len(headers), num_cols): + table_headers = headers[i : (i + num_cols)] + table_rows = [row[i : (i + num_cols)] for row in rows] + + table_headers.insert(0, f"{stage} Metric".capitalize()) + + if _RICH_AVAILABLE: + from rich import get_console + from rich.table import Column, Table + + columns = [ + Column( + h, justify="center", style="magenta", width=max_length + ) + for h in table_headers + ] + columns[0].style = "cyan" + + table = Table(*columns) + for metric, row in zip(metrics_strs, table_rows, strict=False): + row.insert(0, metric) + table.add_row(*row) + + console = get_console() + console.print(table) + else: + row_format = f"{{:^{max_length}}}" * len(table_headers) + half_term_size = int(term_size / 2) + + try: + # some terminals do not support this character + if sys.stdout.encoding is not None: + "─".encode(sys.stdout.encoding) + except UnicodeEncodeError: + bar_character = "-" + else: + bar_character = "─" + bar = bar_character * term_size + + lines = [bar, row_format.format(*table_headers).rstrip(), bar] + for metric, row in zip(metrics_strs, table_rows, strict=False): + # deal with column overflow + if len(metric) > half_term_size: + while len(metric) > half_term_size: + row_metric = metric[:half_term_size] + metric = metric[half_term_size:] + lines.append( + row_format.format(row_metric, *row).rstrip() + ) + lines.append(row_format.format(metric, " ").rstrip()) + else: + lines.append(row_format.format(metric, *row).rstrip()) + lines.append(bar) + print(os.linesep.join(lines)) diff --git a/torch_uncertainty/utils/trainer.py b/torch_uncertainty/utils/trainer.py new file mode 100644 index 00000000..e1b09f7d --- /dev/null +++ b/torch_uncertainty/utils/trainer.py @@ -0,0 +1,19 @@ +from lightning.pytorch import Trainer +from lightning.pytorch.trainer.states import ( + RunningStage, + TrainerFn, +) + +from torch_uncertainty.utils.evaluation_loop import TUEvaluationLoop + + +class TUTrainer(Trainer): + def __init__(self, inference_mode: bool = True, **kwargs): + super().__init__(inference_mode=inference_mode, **kwargs) + + self.test_loop = TUEvaluationLoop( + self, + TrainerFn.TESTING, + RunningStage.TESTING, + inference_mode=inference_mode, + ) From db12508e9fbd2e0593cc17abcea50023aca05a96 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 12:29:22 +0100 Subject: [PATCH 125/148] :white_check_mark: Routine tests now use the TUTrainer class --- tests/routines/test_classification.py | 26 +++++++++++++------------- tests/routines/test_regression.py | 10 +++++----- tests/routines/test_segmentation.py | 6 +++--- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index dea1ee61..9b22d898 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -1,7 +1,6 @@ from pathlib import Path import pytest -from lightning import Trainer from torch import nn from tests._dummies import ( @@ -12,13 +11,14 @@ from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TUTrainer class TestClassification: """Testing the classification routine.""" def test_one_estimator_binary(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -41,7 +41,7 @@ def test_one_estimator_binary(self): model(dm.get_test_set()[0][0]) def test_two_estimators_binary(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -64,7 +64,7 @@ def test_two_estimators_binary(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -89,7 +89,7 @@ def test_one_estimator_two_classes(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes_timm(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -117,7 +117,7 @@ def test_one_estimator_two_classes_timm(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes_mixup(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -144,7 +144,7 @@ def test_one_estimator_two_classes_mixup(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes_mixup_io(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -171,7 +171,7 @@ def test_one_estimator_two_classes_mixup_io(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes_regmixup(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -198,7 +198,7 @@ def test_one_estimator_two_classes_regmixup(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes_kernel_warping_emb(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -225,7 +225,7 @@ def test_one_estimator_two_classes_kernel_warping_emb(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes_kernel_warping_inp(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -253,7 +253,7 @@ def test_one_estimator_two_classes_kernel_warping_inp(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_classes_calibrated_with_ood(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True, logger=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=True) dm = DummyClassificationDataModule( root=Path(), @@ -280,7 +280,7 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): model(dm.get_test_set()[0][0]) def test_two_estimators_two_classes_mi(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) dm = DummyClassificationDataModule( root=Path(), @@ -305,7 +305,7 @@ def test_two_estimators_two_classes_mi(self): model(dm.get_test_set()[0][0]) def test_two_estimator_two_classes_elbo_vr_logs(self): - trainer = Trainer( + trainer = TUTrainer( accelerator="cpu", max_epochs=1, limit_train_batches=1, diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index c22c799d..2c7eb469 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -1,20 +1,20 @@ from pathlib import Path import pytest -from lightning.pytorch import Trainer from torch import nn from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule from torch_uncertainty.losses import DistributionNLLLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import RegressionRoutine +from torch_uncertainty.utils import TUTrainer class TestRegression: """Testing the Regression routine.""" def test_one_estimator_one_output(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) @@ -48,7 +48,7 @@ def test_one_estimator_one_output(self): model(dm.get_test_set()[0][0]) def test_one_estimator_two_outputs(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) @@ -79,7 +79,7 @@ def test_one_estimator_two_outputs(self): model(dm.get_test_set()[0][0]) def test_two_estimators_one_output(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummyRegressionDataModule(out_features=1, root=root, batch_size=4) @@ -110,7 +110,7 @@ def test_two_estimators_one_output(self): model(dm.get_test_set()[0][0]) def test_two_estimators_two_outputs(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummyRegressionDataModule(out_features=2, root=root, batch_size=4) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 1801a7c1..7e03b673 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -1,7 +1,6 @@ from pathlib import Path import pytest -from lightning.pytorch import Trainer from torch import nn from tests._dummies import ( @@ -10,11 +9,12 @@ ) from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import SegmentationRoutine +from torch_uncertainty.utils import TUTrainer class TestSegmentation: def test_one_estimator_two_classes(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) @@ -34,7 +34,7 @@ def test_one_estimator_two_classes(self): model(dm.get_test_set()[0][0]) def test_two_estimators_two_classes(self): - trainer = Trainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) From 883aabc23f7694ef8b3b44449c33d74e4dbdeaf7 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 12:36:34 +0100 Subject: [PATCH 126/148] :rotating_light: Add coverage ignore in TUEvaluationLoop --- torch_uncertainty/utils/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index faf10e85..99d8f6af 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -95,7 +95,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: console = get_console() console.print(table) - else: + else: # coverage: ignore row_format = f"{{:^{max_length}}}" * len(table_headers) half_term_size = int(term_size / 2) From b82524d0b8f05c7ec1814195106fa52973b80dfa Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 13:27:38 +0100 Subject: [PATCH 127/148] :sparkles: Add ECE, Brier and NLL to Segmentation Routine --- torch_uncertainty/routines/segmentation.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index b7b7149e..eb99e1e8 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -6,7 +6,12 @@ from torchmetrics import Accuracy, MetricCollection from torchvision.transforms.v2 import functional as F -from torch_uncertainty.metrics import MeanIntersectionOverUnion +from torch_uncertainty.metrics import ( + CE, + BrierScore, + CategoricalNLL, + MeanIntersectionOverUnion, +) class SegmentationRoutine(LightningModule): @@ -60,13 +65,16 @@ def __init__( seg_metrics = MetricCollection( { "acc": Accuracy(task="multiclass", num_classes=num_classes), + "ece": CE(task="multiclass", num_classes=num_classes), "mean_iou": MeanIntersectionOverUnion(num_classes=num_classes), + "brier": BrierScore(num_classes=num_classes), + "nll": CategoricalNLL(), }, - compute_groups=[["acc", "mean_iou"]], + compute_groups=[["acc", "mean_iou"], ["ece"], ["brier"], ["nll"]], ) - self.val_seg_metrics = seg_metrics.clone(prefix="val/") - self.test_seg_metrics = seg_metrics.clone(prefix="test/") + self.val_seg_metrics = seg_metrics.clone(prefix="seg_val/") + self.test_seg_metrics = seg_metrics.clone(prefix="seg_test/") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe From d568e6b79e463aedfbb99b59483aa8e432302947 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 25 Mar 2024 13:35:24 +0100 Subject: [PATCH 128/148] :sparkles: Add FPRx metric --- tests/metrics/test_fpr95.py | 61 +++++++++--------------------- torch_uncertainty/metrics/fpr95.py | 43 +++++++++++++++++---- 2 files changed, 54 insertions(+), 50 deletions(-) diff --git a/tests/metrics/test_fpr95.py b/tests/metrics/test_fpr95.py index 6015d484..46b40dc8 100644 --- a/tests/metrics/test_fpr95.py +++ b/tests/metrics/test_fpr95.py @@ -1,62 +1,37 @@ import pytest import torch -from torch_uncertainty.metrics import FPR95 - - -@pytest.fixture() -def confs_zero() -> torch.Tensor: - return torch.as_tensor([1] * 99 + [0.99]) - - -@pytest.fixture() -def target_zero() -> torch.Tensor: - return torch.as_tensor([1] * 99 + [0]) - - -@pytest.fixture() -def confs_half() -> torch.Tensor: - return torch.as_tensor([0.9] * 100 + [0.95] * 50 + [0.85] * 50) - - -@pytest.fixture() -def target_half() -> torch.Tensor: - return torch.as_tensor([1] * 100 + [0] * 100) - - -@pytest.fixture() -def confs_one() -> torch.Tensor: - return torch.as_tensor([0.99] * 99 + [1]) - - -@pytest.fixture() -def target_one() -> torch.Tensor: - return torch.as_tensor([1] * 99 + [0]) +from torch_uncertainty.metrics.fpr95 import FPR95, FPRx class TestFPR95: """Testing the Entropy metric class.""" - def test_compute_zero( - self, confs_zero: torch.Tensor, target_zero: torch.Tensor - ): + def test_compute_zero(self): metric = FPR95(pos_label=1) - metric.update(confs_zero, target_zero) + metric.update( + torch.as_tensor([1] * 99 + [0.99]), torch.as_tensor([1] * 99 + [0]) + ) res = metric.compute() assert res == 0 - def test_compute_half( - self, confs_half: torch.Tensor, target_half: torch.Tensor - ): + def test_compute_half(self): metric = FPR95(pos_label=1) - metric.update(confs_half, target_half) + metric.update( + torch.as_tensor([0.9] * 100 + [0.95] * 50 + [0.85] * 50), + torch.as_tensor([1] * 100 + [0] * 100), + ) res = metric.compute() assert res == 0.5 - def test_compute_one( - self, confs_one: torch.Tensor, target_one: torch.Tensor - ): + def test_compute_one(self): metric = FPR95(pos_label=1) - metric.update(confs_one, target_one) + metric.update( + torch.as_tensor([0.99] * 99 + [1]), torch.as_tensor([1] * 99 + [0]) + ) res = metric.compute() assert res == 1 + + def test_error(self): + with pytest.raises(ValueError): + FPRx(recall_level=1.2, pos_label=1) diff --git a/torch_uncertainty/metrics/fpr95.py b/torch_uncertainty/metrics/fpr95.py index 2108406c..87a1b93a 100644 --- a/torch_uncertainty/metrics/fpr95.py +++ b/torch_uncertainty/metrics/fpr95.py @@ -38,7 +38,7 @@ def stable_cumsum(arr: ArrayLike, rtol: float = 1e-05, atol: float = 1e-08): return out -class FPR95(Metric): +class FPRx(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False @@ -46,29 +46,46 @@ class FPR95(Metric): conf: list[Tensor] targets: list[Tensor] - def __init__(self, pos_label: int, **kwargs) -> None: - """The False Positive Rate at 95% Recall metric.""" + def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: + """The False Positive Rate at x% Recall metric. + + Args: + recall_level (float): The recall level at which to compute the FPR. + pos_label (int): The positive label. + kwargs: Additional arguments to pass to the metric class. + """ super().__init__(**kwargs) + if recall_level < 0 or recall_level > 1: + raise ValueError( + f"Recall level must be between 0 and 1. Got {recall_level}." + ) + self.recall_level = recall_level self.pos_label = pos_label self.add_state("conf", [], dist_reduce_fx="cat") self.add_state("targets", [], dist_reduce_fx="cat") rank_zero_warn( - "Metric `FPR95` will save all targets and predictions" + f"Metric `FPR{int(recall_level*100)}` will save all targets and predictions" " in buffer. For large datasets this may lead to large memory" " footprint." ) def update(self, conf: Tensor, target: Tensor) -> None: + """Update the metric state. + + Args: + conf (Tensor): The confidence scores. + target (Tensor): The target labels. + """ self.conf.append(conf) self.targets.append(target) def compute(self) -> Tensor: - r"""Compute the actual False Positive Rate at 95% Recall. + r"""Compute the actual False Positive Rate at x% Recall. Returns: - Tensor: The value of the FPR95. + Tensor: The value of the FPRx. Reference: Inpired by https://github.com/hendrycks/anomaly-seg. @@ -119,8 +136,20 @@ def compute(self) -> Tensor: thresholds[sl], ) - cutoff = np.argmin(np.abs(recall - 0.95)) + cutoff = np.argmin(np.abs(recall - self.recall_level)) return torch.tensor( fps[cutoff] / (np.sum(np.logical_not(labels))), dtype=torch.float32 ) + + +class FPR95(FPRx): + def __init__(self, pos_label: int, **kwargs) -> None: + """The False Positive Rate at 95% Recall metric. + + Args: + recall_level (float): The recall level at which to compute the FPR. + pos_label (int): The positive label. + kwargs: Additional arguments to pass to the metric class. + """ + super().__init__(recall_level=0.95, pos_label=pos_label, **kwargs) From e977d59b9b195e8cf8f56e35ed57111e5fdcfe0c Mon Sep 17 00:00:00 2001 From: Adrien Lafage Date: Mon, 25 Mar 2024 18:57:24 +0100 Subject: [PATCH 129/148] :bug: Fix typo in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f84e4d3e..b90edf6f 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ pip install torch-uncertainty The installation procedure for contributors is different: have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). -## :racehorse: Quckstart +## :racehorse: Quickstart We make a quickstart available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). From 07aadb85815969360b951a84cbded05482c41f4f Mon Sep 17 00:00:00 2001 From: Adrien Lafage Date: Mon, 25 Mar 2024 19:44:13 +0100 Subject: [PATCH 130/148] :bug: Fix typo in tutorial_mc_dropout.py --- auto_tutorials_source/tutorial_mc_dropout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 7600dec4..d8902bfe 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -73,7 +73,7 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # This is a classification problem, and we use CrossEntropyLoss as the likelihood. # We define the training routine using the classification training routine from -# torch_uncertainty.training.classification. We provide the number of classes +# torch_uncertainty.routines.classification. We provide the number of classes # and channels, the optimizer wrapper, the dropout rate, and the number of # forward passes to perform through the network, as well as all the default # arguments. From 9e62a1bfe507f8ce54d1be0052191f2cd989733a Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 23:27:58 +0100 Subject: [PATCH 131/148] :bug: Fix fill value for tv_tensors.Mask when cropping in Segmentation DataModule --- docs/source/api.rst | 2 ++ experiments/segmentation/camvid/segformer.py | 2 +- experiments/segmentation/cityscapes/segformer.py | 2 +- torch_uncertainty/datamodules/__init__.py | 2 +- .../datamodules/segmentation/cityscapes.py | 6 +++++- torch_uncertainty/datamodules/segmentation/muad.py | 10 ++++++---- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 5a01cad4..9c815767 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -234,6 +234,8 @@ Regression UCIDataModule +.. currentmodule:: torch_uncertainty.datamodules.segmentation + Segmentation ^^^^^^^^^^^^ diff --git a/experiments/segmentation/camvid/segformer.py b/experiments/segmentation/camvid/segformer.py index a42756c3..8eecfb50 100644 --- a/experiments/segmentation/camvid/segformer.py +++ b/experiments/segmentation/camvid/segformer.py @@ -2,7 +2,7 @@ from lightning.pytorch.cli import LightningArgumentParser from torch_uncertainty.baselines.segmentation import SegFormerBaseline -from torch_uncertainty.datamodules import CamVidDataModule +from torch_uncertainty.datamodules.segmentation import CamVidDataModule from torch_uncertainty.utils import TULightningCLI diff --git a/experiments/segmentation/cityscapes/segformer.py b/experiments/segmentation/cityscapes/segformer.py index 7dab5755..2b7fe992 100644 --- a/experiments/segmentation/cityscapes/segformer.py +++ b/experiments/segmentation/cityscapes/segformer.py @@ -2,7 +2,7 @@ from lightning.pytorch.cli import LightningArgumentParser from torch_uncertainty.baselines.segmentation import SegFormerBaseline -from torch_uncertainty.datamodules import CityscapesDataModule +from torch_uncertainty.datamodules.segmentation import CityscapesDataModule from torch_uncertainty.utils import TULightningCLI diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index d1ab23b4..24859701 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -4,5 +4,5 @@ from .classification.imagenet import ImageNetDataModule from .classification.mnist import MNISTDataModule from .classification.tiny_imagenet import TinyImageNetDataModule -from .segmentation import CamVidDataModule, CityscapesDataModule, MUADDataModule +from .segmentation import CamVidDataModule, CityscapesDataModule from .uci_regression import UCIDataModule diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 08a036c1..f35bd65d 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -115,7 +115,11 @@ def __init__( [ v2.ToImage(), RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop(size=self.crop_size, pad_if_needed=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), v2.RandomHorizontalFlip(), v2.ToDtype( diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index dae5d20e..c126b05e 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -24,7 +24,7 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, ) -> None: - r"""DataModule for the MUAD dataset. + r"""Segmentation DataModule for the MUAD dataset. Args: root (str or Path): Root directory of the datasets. @@ -112,9 +112,12 @@ def __init__( self.train_transform = v2.Compose( [ - v2.ToImage(), RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop(size=self.crop_size, pad_if_needed=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), v2.RandomHorizontalFlip(), v2.ToDtype( @@ -132,7 +135,6 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.ToImage(), v2.Resize(size=self.inference_size, antialias=True), v2.ToDtype( dtype={ From e993cfe1e5ba5d51ac8ac62c9947e0f5bae25feb Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 23:32:11 +0100 Subject: [PATCH 132/148] :sparkles: Add MUAD monocular depth DataModule --- .../datamodules/depth_regression/__init__.py | 2 + .../datamodules/depth_regression/muad.py | 144 ++++++++++++++++++ torch_uncertainty/datasets/muad.py | 96 +++++++++--- 3 files changed, 221 insertions(+), 21 deletions(-) create mode 100644 torch_uncertainty/datamodules/depth_regression/__init__.py create mode 100644 torch_uncertainty/datamodules/depth_regression/muad.py diff --git a/torch_uncertainty/datamodules/depth_regression/__init__.py b/torch_uncertainty/datamodules/depth_regression/__init__.py new file mode 100644 index 00000000..dc94a8cb --- /dev/null +++ b/torch_uncertainty/datamodules/depth_regression/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .muad import MUADDataModule diff --git a/torch_uncertainty/datamodules/depth_regression/muad.py b/torch_uncertainty/datamodules/depth_regression/muad.py new file mode 100644 index 00000000..a751926e --- /dev/null +++ b/torch_uncertainty/datamodules/depth_regression/muad.py @@ -0,0 +1,144 @@ +from pathlib import Path + +import torch +from torch.nn.common_types import _size_2_t +from torch.nn.modules.utils import _pair +from torchvision import tv_tensors +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datasets import MUAD +from torch_uncertainty.transforms import RandomRescale +from torch_uncertainty.utils.misc import create_train_val_split + + +class MUADDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + crop_size: _size_2_t = 1024, + inference_size: _size_2_t = (1024, 2048), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""Segmentation DataModule for the MUAD dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + segmentation mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + """ + super().__init__( + root=root, + batch_size=batch_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.dataset = MUAD + self.crop_size = _pair(crop_size) + self.inference_size = _pair(inference_size) + + self.train_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: -float("inf")}, + ), + v2.RandomHorizontalFlip(), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.test_transform = v2.Compose( + [ + v2.Resize(size=self.inference_size, antialias=True), + v2.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + "others": None, + }, + scale=True, + ), + v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset( + root=self.root, split="train", target_type="depth", download=True + ) + self.dataset( + root=self.root, split="val", target_type="depth", download=True + ) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + full = self.dataset( + root=self.root, + split="train", + target_type="depth", + transforms=self.train_transform, + ) + + if self.val_split is not None: + self.train, self.val = create_train_val_split( + full, + self.val_split, + self.test_transform, + ) + else: + self.train = full + self.val = self.dataset( + root=self.root, + split="val", + target_type="depth", + transforms=self.test_transform, + ) + + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="val", + target_type="depth", + transforms=self.test_transform, + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 477d9116..9c51a9f6 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -1,8 +1,11 @@ import json +import os +import shutil from collections.abc import Callable from pathlib import Path from typing import Any, Literal +import cv2 import numpy as np import torch from einops import rearrange @@ -28,14 +31,18 @@ class MUAD(VisionDataset): "val": "957af9c1c36f0a85c33279e06b6cf8d8", "val_depth": "0282030d281aeffee3335f713ba12373", } - samples: list[Path] = [] + _num_samples = { + "train": 3420, + "val": 492, + "test": ..., + } targets: list[Path] = [] # TODO: Add depth regression mode def __init__( self, root: str | Path, - split: Literal["train", "val", "train_depth", "val_depth"], + split: Literal["train", "val"], target_type: Literal["semantic", "depth"] = "semantic", transforms: Callable | None = None, download: bool = False, @@ -44,9 +51,8 @@ def __init__( Args: root (str): Root directory of dataset where directory 'leftImg8bit' - and 'leftLabel' are located. - split (str, optional): The image split to use, 'train', 'val', - 'train_depth' or 'val_depth'. + and 'leftLabel' or 'leftDepth' are located. + split (str, optional): The image split to use, 'train' or 'val'. target_type (str, optional): The type of target to use, 'semantic' or 'depth'. transforms (callable, optional): A function/transform that takes in @@ -78,16 +84,42 @@ def __init__( self.split = split self.target_type = target_type - split_path = self.root / (split + ".zip") - if (not check_integrity(split_path, self.zip_md5[split])) and download: - self._download(split=self.split) + if not self.check_split_integrity("leftImg8bit"): + if download: + self._download(split=split) + else: + raise FileNotFoundError( + f"MUAD {split} split not found or incomplete. Set download=True to download it." + ) + + if ( + not self.check_split_integrity("leftLabel") + and target_type == "semantic" + ): + if download: + self._download(split=split) + else: + raise FileNotFoundError( + f"MUAD {split} split not found or incomplete. Set download=True to download it." + ) if ( - self.target_type == "depth" - and not check_integrity(split_path, self.zip_md5[split + "_depth"]) - and download + not self.check_split_integrity("leftDepth") + and target_type == "depth" ): - self._download(split=f"{split}_depth") + if download: + self._download(split=f"{split}_depth") + # FIXME: Depth target for train are in a different folder + # thus we move them to the correct folder + if split == "train": + shutil.move( + self.root / f"{split}_depth", + self.root / split / "leftDepth", + ) + else: + raise FileNotFoundError( + f"MUAD {split} split not found or incomplete. Set download=True to download it." + ) # Load classes metadata cls_path = self.root / "classes.json" @@ -149,18 +181,39 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: or a depth map. """ image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) - target = tv_tensors.Mask( - self.encode_target(Image.open(self.targets[index])) - ) + if self.target_type == "semantic": + target = tv_tensors.Mask( + self.encode_target(Image.open(self.targets[index])) + ) + else: + os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + target = Image.fromarray( + cv2.imread( + str(self.targets[index]), + cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, + ) + ) + # TODO: in the long tun it would be better to use a custom + # tv_tensor for depth maps (e.g. tv_tensors.DepthMap) + target = np.asarray(target, np.float32) + target = tv_tensors.Mask(400 * (1 - target)) # convert to meters if self.transforms is not None: image, target = self.transforms(image, target) return image, target + def check_split_integrity(self, folder: str) -> bool: + split_path = self.root / self.split + return ( + split_path.is_dir() + and len(list((split_path / folder).glob("**/*"))) + == self._num_samples[self.split] + ) + def __len__(self) -> int: """The number of samples in the dataset.""" - return len(self.samples) + return self._num_samples[self.split] def _make_dataset(self, path: Path) -> None: """Create a list of samples and targets. @@ -173,13 +226,14 @@ def _make_dataset(self, path: Path) -> None: "Depth regression mode is not implemented yet. Raise an issue " "if you need it." ) - self.samples = list((path / "leftImg8bit/").glob("**/*")) + self.samples = sorted((path / "leftImg8bit/").glob("**/*")) if self.target_type == "semantic": - self.targets = list((path / "leftLabel/").glob("**/*")) + self.targets = sorted((path / "leftLabel/").glob("**/*")) + elif self.target_type == "depth": + self.targets = sorted((path / "leftDepth/").glob("**/*")) else: - raise NotImplementedError( - "Depth regression mode is not implemented yet. Raise an issue " - "if you need it." + raise ValueError( + f"target_type must be one of ['semantic', 'depth']. Got {self.target_type}." ) def _download(self, split: str) -> None: From 64dd48a1eba3974d8c85976a1c94974946c3c2bc Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 25 Mar 2024 23:40:01 +0100 Subject: [PATCH 133/148] :white_check_mark: Add MUADDataModule tests for depth regression --- tests/_dummies/dataset.py | 56 +++++++++++++++++++ .../datamodules/depth_regression/__init__.py | 0 .../datamodules/depth_regression/test_muad.py | 37 ++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 tests/datamodules/depth_regression/__init__.py create mode 100644 tests/datamodules/depth_regression/test_muad.py diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index a227ecdd..f14970a4 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -212,3 +212,59 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: def __len__(self) -> int: return len(self.data) + + +class DummyDepthRegressionDataset(Dataset): + def __init__( + self, + root: Path, + split: str = "train", + transforms: Callable[..., Any] | None = None, + num_channels: int = 3, + image_size: int = 4, + num_images: int = 2, + **args, + ) -> None: + super().__init__() + + self.root = root + self.split = split + self.transforms = transforms + + self.data: Any = [] + self.targets = [] + + if num_channels == 1: + img_shape = (num_images, image_size, image_size) + else: + img_shape = (num_images, num_channels, image_size, image_size) + + smnt_shape = (num_images, 1, image_size, image_size) + + self.data = np.random.randint( + low=0, + high=255, + size=img_shape, + dtype=np.uint8, + ) + + self.targets = ( + np.random.uniform( + low=0, + high=1, + size=smnt_shape, + ) + * 100 + ) + + def __getitem__(self, index: int) -> tuple[Any, Any]: + img = tv_tensors.Image(self.data[index]) + target = tv_tensors.Mask(self.targets[index]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self) -> int: + return len(self.data) diff --git a/tests/datamodules/depth_regression/__init__.py b/tests/datamodules/depth_regression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/depth_regression/test_muad.py b/tests/datamodules/depth_regression/test_muad.py new file mode 100644 index 00000000..db24dd65 --- /dev/null +++ b/tests/datamodules/depth_regression/test_muad.py @@ -0,0 +1,37 @@ +import pytest + +from tests._dummies.dataset import DummyDepthRegressionDataset +from torch_uncertainty.datamodules.depth_regression import MUADDataModule +from torch_uncertainty.datasets import MUAD + + +class TestMUADDataModule: + """Testing the MUADDataModule datamodule.""" + + def test_camvid_main(self): + dm = MUADDataModule(root="./data/", batch_size=128) + + assert dm.dataset == MUAD + + dm.dataset = DummyDepthRegressionDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() From c6d97938d7ed4ccb476f7b2e73a468e40b11e7be Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 19:59:40 +0100 Subject: [PATCH 134/148] :shirt: Switch to 5 significant digits --- torch_uncertainty/utils/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index 99d8f6af..ac11c02a 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -63,7 +63,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if val is not None: if isinstance(val, Tensor): val = val.item() if val.numel() == 1 else val.tolist() - row.append(f"{val:.3f}") + row.append(f"{val:.5f}") else: row.append(" ") From a805db43316d82573583a69f523652399178b96f Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 20:00:12 +0100 Subject: [PATCH 135/148] :hammer: Simplify the regression routine --- torch_uncertainty/routines/regression.py | 37 ++++++++++++++---------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 8d6718cc..3560173a 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -11,7 +11,7 @@ from torch.optim import Optimizer from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection -from torch_uncertainty.metrics.nll import DistributionNLL +from torch_uncertainty.metrics.regression.nll import DistributionNLL from torch_uncertainty.utils.distributions import squeeze_dist, to_ensemble_dist @@ -73,16 +73,22 @@ def __init__( reg_metrics = MetricCollection( { "mae": MeanAbsoluteError(), - "mse": MeanSquaredError(squared=False), + "mse": MeanSquaredError(squared=True), + "rmse": MeanSquaredError(squared=False), }, - compute_groups=False, + compute_groups=True, ) - if self.probabilistic: - reg_metrics["nll"] = DistributionNLL(reduction="mean") self.val_metrics = reg_metrics.clone(prefix="reg_val/") self.test_metrics = reg_metrics.clone(prefix="reg_test/") + if self.probabilistic: + reg_prob_metrics = MetricCollection( + DistributionNLL(reduction="mean") + ) + self.val_prob_metrics = reg_prob_metrics.clone(prefix="reg_val/") + self.test_prob_metrics = reg_prob_metrics.clone(prefix="reg_test/") + self.one_dim_regression = output_dim == 1 def configure_optimizers(self) -> Optimizer | dict: @@ -91,6 +97,9 @@ def configure_optimizers(self) -> Optimizer | dict: def on_train_start(self) -> None: init_metrics = dict.fromkeys(self.val_metrics, 0) init_metrics.update(dict.fromkeys(self.test_metrics, 0)) + if self.probabilistic: + init_metrics.update(dict.fromkeys(self.val_prob_metrics, 0)) + init_metrics.update(dict.fromkeys(self.test_prob_metrics, 0)) if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( @@ -133,7 +142,6 @@ def training_step( targets = targets.unsqueeze(-1) loss = self.loss(dists, targets) - self.log("train_loss", loss) return loss @@ -153,15 +161,15 @@ def validation_step( torch.ones(self.num_estimators, device=self.device) ) mixture = MixtureSameFamily(mix, ens_dist) - self.val_metrics.nll.update(mixture, targets) - preds = mixture.mean else: preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) preds = preds.mean(dim=1) - self.val_metrics.mse.update(preds, targets) - self.val_metrics.mae.update(preds, targets) + self.val_metrics.update(preds, targets) + + if self.probabilistic: + self.val_prob_metrics.update(mixture, targets) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute()) @@ -192,16 +200,15 @@ def test_step( torch.ones(self.num_estimators, device=self.device) ) mixture = MixtureSameFamily(mix, ens_dist) - self.test_metrics.nll.update(mixture, targets) - preds = mixture.mean - else: preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) preds = preds.mean(dim=1) - self.test_metrics.mse.update(preds, targets) - self.test_metrics.mae.update(preds, targets) + self.test_metrics.update(preds, targets) + + if self.probabilistic: + self.val_prob_metrics.update(mixture, targets) def on_test_epoch_end(self) -> None: self.log_dict( From 5770bdc96bf4014cb94bcf4559b12667a3e53134 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 20:08:53 +0100 Subject: [PATCH 136/148] :hammer: Refactor metrics folder --- torch_uncertainty/metrics/__init__.py | 32 ++++++++++++------- .../metrics/classification/__init__.py | 12 +++++++ .../{ => classification}/brier_score.py | 2 +- .../{ => classification}/calibration.py | 0 .../{ => classification}/disagreement.py | 0 .../metrics/{ => classification}/entropy.py | 0 .../metrics/{ => classification}/fpr95.py | 0 .../{ => classification}/grouping_loss.py | 0 .../metrics/{ => classification}/mean_iou.py | 0 .../mutual_information.py | 0 .../metrics/{ => classification}/nll.py | 28 +--------------- .../{ => classification}/sparsification.py | 0 .../{ => classification}/variation_ratio.py | 0 torch_uncertainty/metrics/regression/nll.py | 30 +++++++++++++++++ 14 files changed, 65 insertions(+), 39 deletions(-) create mode 100644 torch_uncertainty/metrics/classification/__init__.py rename torch_uncertainty/metrics/{ => classification}/brier_score.py (99%) rename torch_uncertainty/metrics/{ => classification}/calibration.py (100%) rename torch_uncertainty/metrics/{ => classification}/disagreement.py (100%) rename torch_uncertainty/metrics/{ => classification}/entropy.py (100%) rename torch_uncertainty/metrics/{ => classification}/fpr95.py (100%) rename torch_uncertainty/metrics/{ => classification}/grouping_loss.py (100%) rename torch_uncertainty/metrics/{ => classification}/mean_iou.py (100%) rename torch_uncertainty/metrics/{ => classification}/mutual_information.py (100%) rename torch_uncertainty/metrics/{ => classification}/nll.py (75%) rename torch_uncertainty/metrics/{ => classification}/sparsification.py (100%) rename torch_uncertainty/metrics/{ => classification}/variation_ratio.py (100%) create mode 100644 torch_uncertainty/metrics/regression/nll.py diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 0f89fb96..5387b56e 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -1,12 +1,22 @@ # ruff: noqa: F401 -from .brier_score import BrierScore -from .calibration import CE -from .disagreement import Disagreement -from .entropy import Entropy -from .fpr95 import FPR95 -from .grouping_loss import GroupingLoss -from .mean_iou import MeanIntersectionOverUnion -from .mutual_information import MutualInformation -from .nll import CategoricalNLL, DistributionNLL -from .sparsification import AUSE -from .variation_ratio import VariationRatio +from .classification import ( + AUSE, + CE, + FPR95, + BrierScore, + CategoricalNLL, + Disagreement, + Entropy, + GroupingLoss, + MeanIntersectionOverUnion, + MutualInformation, + VariationRatio, +) +from .regression import ( + DistributionNLL, + Log10, + MeanGTRelativeAbsoluteError, + MeanGTRelativeSquaredError, + SILog, + ThresholdAccuracy, +) diff --git a/torch_uncertainty/metrics/classification/__init__.py b/torch_uncertainty/metrics/classification/__init__.py new file mode 100644 index 00000000..df6078c9 --- /dev/null +++ b/torch_uncertainty/metrics/classification/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa: F401 +from .brier_score import BrierScore +from .calibration import CE +from .disagreement import Disagreement +from .entropy import Entropy +from .fpr95 import FPR95 +from .grouping_loss import GroupingLoss +from .mean_iou import MeanIntersectionOverUnion +from .mutual_information import MutualInformation +from .nll import CategoricalNLL +from .sparsification import AUSE +from .variation_ratio import VariationRatio diff --git a/torch_uncertainty/metrics/brier_score.py b/torch_uncertainty/metrics/classification/brier_score.py similarity index 99% rename from torch_uncertainty/metrics/brier_score.py rename to torch_uncertainty/metrics/classification/brier_score.py index 4d67450a..43b12f2c 100644 --- a/torch_uncertainty/metrics/brier_score.py +++ b/torch_uncertainty/metrics/classification/brier_score.py @@ -8,7 +8,7 @@ class BrierScore(Metric): - is_differentiable: bool = False + is_differentiable: bool = True higher_is_better: bool | None = False full_state_update: bool = False diff --git a/torch_uncertainty/metrics/calibration.py b/torch_uncertainty/metrics/classification/calibration.py similarity index 100% rename from torch_uncertainty/metrics/calibration.py rename to torch_uncertainty/metrics/classification/calibration.py diff --git a/torch_uncertainty/metrics/disagreement.py b/torch_uncertainty/metrics/classification/disagreement.py similarity index 100% rename from torch_uncertainty/metrics/disagreement.py rename to torch_uncertainty/metrics/classification/disagreement.py diff --git a/torch_uncertainty/metrics/entropy.py b/torch_uncertainty/metrics/classification/entropy.py similarity index 100% rename from torch_uncertainty/metrics/entropy.py rename to torch_uncertainty/metrics/classification/entropy.py diff --git a/torch_uncertainty/metrics/fpr95.py b/torch_uncertainty/metrics/classification/fpr95.py similarity index 100% rename from torch_uncertainty/metrics/fpr95.py rename to torch_uncertainty/metrics/classification/fpr95.py diff --git a/torch_uncertainty/metrics/grouping_loss.py b/torch_uncertainty/metrics/classification/grouping_loss.py similarity index 100% rename from torch_uncertainty/metrics/grouping_loss.py rename to torch_uncertainty/metrics/classification/grouping_loss.py diff --git a/torch_uncertainty/metrics/mean_iou.py b/torch_uncertainty/metrics/classification/mean_iou.py similarity index 100% rename from torch_uncertainty/metrics/mean_iou.py rename to torch_uncertainty/metrics/classification/mean_iou.py diff --git a/torch_uncertainty/metrics/mutual_information.py b/torch_uncertainty/metrics/classification/mutual_information.py similarity index 100% rename from torch_uncertainty/metrics/mutual_information.py rename to torch_uncertainty/metrics/classification/mutual_information.py diff --git a/torch_uncertainty/metrics/nll.py b/torch_uncertainty/metrics/classification/nll.py similarity index 75% rename from torch_uncertainty/metrics/nll.py rename to torch_uncertainty/metrics/classification/nll.py index 98df27e8..6a08f6d2 100644 --- a/torch_uncertainty/metrics/nll.py +++ b/torch_uncertainty/metrics/classification/nll.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F -from torch import Tensor, distributions +from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities.data import dim_zero_cat @@ -92,29 +92,3 @@ def compute(self) -> Tensor: return values.sum(dim=-1) / self.total # reduction is None or "none" return values - - -class DistributionNLL(CategoricalNLL): - def update(self, dists: distributions.Distribution, target: Tensor) -> None: - """Update state with the predicted distributions and the targets. - - Args: - dists (torch.distributions.Distribution): Predicted distributions. - target (Tensor): Ground truth labels. - """ - if self.reduction is None or self.reduction == "none": - self.values.append(-dists.log_prob(target)) - else: - self.values += -dists.log_prob(target).sum() - self.total += target.size(0) - - def compute(self) -> Tensor: - """Computes NLL based on inputs passed in to ``update`` previously.""" - values = dim_zero_cat(self.values) - - if self.reduction == "sum": - return values.sum(dim=-1) - if self.reduction == "mean": - return values.sum(dim=-1) / self.total - # reduction is None or "none" - return values diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/classification/sparsification.py similarity index 100% rename from torch_uncertainty/metrics/sparsification.py rename to torch_uncertainty/metrics/classification/sparsification.py diff --git a/torch_uncertainty/metrics/variation_ratio.py b/torch_uncertainty/metrics/classification/variation_ratio.py similarity index 100% rename from torch_uncertainty/metrics/variation_ratio.py rename to torch_uncertainty/metrics/classification/variation_ratio.py diff --git a/torch_uncertainty/metrics/regression/nll.py b/torch_uncertainty/metrics/regression/nll.py new file mode 100644 index 00000000..9b2f9c3a --- /dev/null +++ b/torch_uncertainty/metrics/regression/nll.py @@ -0,0 +1,30 @@ +from torch import Tensor, distributions +from torchmetrics.utilities.data import dim_zero_cat + +from torch_uncertainty.metrics import CategoricalNLL + + +class DistributionNLL(CategoricalNLL): + def update(self, dist: distributions.Distribution, target: Tensor) -> None: + """Update state with the predicted distributions and the targets. + + Args: + dist (torch.distributions.Distribution): Predicted distributions. + target (Tensor): Ground truth labels. + """ + if self.reduction is None or self.reduction == "none": + self.values.append(-dist.log_prob(target)) + else: + self.values += -dist.log_prob(target).sum() + self.total += target.size(0) + + def compute(self) -> Tensor: + """Computes NLL based on inputs passed in to ``update`` previously.""" + values = dim_zero_cat(self.values) + + if self.reduction == "sum": + return values.sum(dim=-1) + if self.reduction == "mean": + return values.sum(dim=-1) / self.total + # reduction is None or "none" + return values From 99d1d87967f5e7acd46160929a5aa41a38ebaddf Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 20:09:29 +0100 Subject: [PATCH 137/148] :sparkles: Add monocular depth estimation metrics --- .../metrics/regression/__init__.py | 9 +++ torch_uncertainty/metrics/regression/log10.py | 38 +++++++++++ .../metrics/regression/relative_error.py | 68 +++++++++++++++++++ torch_uncertainty/metrics/regression/silog.py | 38 +++++++++++ .../metrics/regression/threshold_accuracy.py | 43 ++++++++++++ 5 files changed, 196 insertions(+) create mode 100644 torch_uncertainty/metrics/regression/__init__.py create mode 100644 torch_uncertainty/metrics/regression/log10.py create mode 100644 torch_uncertainty/metrics/regression/relative_error.py create mode 100644 torch_uncertainty/metrics/regression/silog.py create mode 100644 torch_uncertainty/metrics/regression/threshold_accuracy.py diff --git a/torch_uncertainty/metrics/regression/__init__.py b/torch_uncertainty/metrics/regression/__init__.py new file mode 100644 index 00000000..7e917aa7 --- /dev/null +++ b/torch_uncertainty/metrics/regression/__init__.py @@ -0,0 +1,9 @@ +# ruff: noqa: F401 +from .log10 import Log10 +from .nll import DistributionNLL +from .relative_error import ( + MeanGTRelativeAbsoluteError, + MeanGTRelativeSquaredError, +) +from .silog import SILog +from .threshold_accuracy import ThresholdAccuracy diff --git a/torch_uncertainty/metrics/regression/log10.py b/torch_uncertainty/metrics/regression/log10.py new file mode 100644 index 00000000..1c83cd39 --- /dev/null +++ b/torch_uncertainty/metrics/regression/log10.py @@ -0,0 +1,38 @@ +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class Log10(Metric): + def __init__(self, **kwargs) -> None: + r"""The Log10 metric. + + .. math:: \text{Log10} = \frac{1}{N} \sum_{i=1}^{N} \log_{10}(y_i) - \log_{10}(\hat{y_i}) + + where :math:`N` is the batch size, :math:`y_i` is a tensor of target values and :math:`\hat{y_i}` is a tensor of prediction. + + Inputs: + - :attr:`preds`: :math:`(N)` + - :attr:`target`: :math:`(N)` + + where :math:`N` is the batch size. + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + """ + super().__init__(**kwargs) + self.add_state( + "values", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.values += torch.sum(pred.log10() - target.log10()) + self.total += target.size(0) + + def compute(self) -> Tensor: + """Compute the Log10 metric.""" + values = dim_zero_cat(self.values) + return values / self.total diff --git a/torch_uncertainty/metrics/regression/relative_error.py b/torch_uncertainty/metrics/regression/relative_error.py new file mode 100644 index 00000000..27ac1eb4 --- /dev/null +++ b/torch_uncertainty/metrics/regression/relative_error.py @@ -0,0 +1,68 @@ +import torch +from torch import Tensor +from torchmetrics import MeanAbsoluteError, MeanSquaredError + + +class MeanGTRelativeAbsoluteError(MeanAbsoluteError): + def __init__(self, **kwargs) -> None: + r"""`Compute Mean Absolute Error relative to the Ground Truth`_ (MAErel or ARE). + + .. math:: \text{MAErel} = \frac{1}{N}\sum_i^N \frac{| y_i - \hat{y_i} |}{y_i} + + where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rel_mean_absolute_error`` (:class:`~torch.Tensor`): A tensor with the + relative mean absolute error over the state + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Reference: + As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + """ + super().__init__(**kwargs) + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + return super().update(pred / target, torch.ones_like(target)) + + +class MeanGTRelativeSquaredError(MeanSquaredError): + def __init__( + self, squared: bool = True, num_outputs: int = 1, **kwargs + ) -> None: + r"""Compute `mean squared error relative to the Ground Truth`_ (MSErel or SRE). + + .. math:: \text{MSErel} = \frac{1}{N}\sum_i^N \frac{(y_i - \hat{y_i})^2}{y_i} + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rel_mean_squared_error`` (:class:`~torch.Tensor`): A tensor with the relative mean squared error + + Args: + squared: If True returns MSErel value, if False returns RMSErel value. + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Reference: + As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + """ + super().__init__(squared, num_outputs, **kwargs) + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + return super().update(pred / torch.sqrt(target), torch.sqrt(target)) diff --git a/torch_uncertainty/metrics/regression/silog.py b/torch_uncertainty/metrics/regression/silog.py new file mode 100644 index 00000000..c6bea3b7 --- /dev/null +++ b/torch_uncertainty/metrics/regression/silog.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class SILog(Metric): + def __init__(self, **kwargs: Any) -> None: + r"""The Scale-Invariant Logarithmic Loss metric. + + .. math:: \text{SILog} = \frac{1}{N} \sum_{i=1}^{N} \left(\log(y_i) - \log(\hat{y_i})\right)^2 - \left(\frac{1}{N} \sum_{i=1}^{N} \log(y_i) \right)^2 + + where :math:`N` is the batch size, :math:`y_i` is a tensor of target values and :math:`\hat{y_i}` is a tensor of prediction. + + Inputs: + - :attr:`pred`: :math:`(N)` + - :attr:`target`: :math:`(N)` + + where :math:`N` is the batch size. + + Reference: + Depth Map Prediction from a Single Image using a Multi-Scale Deep Network. + """ + super().__init__(**kwargs) + self.add_state("log_dists", default=[], dist_reduce_fx="cat") + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.log_dists.append(torch.flatten(pred.log() - target.log())) + + def compute(self) -> Tensor: + """Compute the Scale-Invariant Logarithmic Loss.""" + log_dists = dim_zero_cat(self.log_dists) + return torch.mean(log_dists**2) - torch.sum(log_dists) ** 2 / ( + log_dists.size(0) * log_dists.size(0) + ) diff --git a/torch_uncertainty/metrics/regression/threshold_accuracy.py b/torch_uncertainty/metrics/regression/threshold_accuracy.py new file mode 100644 index 00000000..68068ad8 --- /dev/null +++ b/torch_uncertainty/metrics/regression/threshold_accuracy.py @@ -0,0 +1,43 @@ +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class ThresholdAccuracy(Metric): + def __init__(self, power: int, lmbda: float = 1.25, **kwargs) -> None: + r"""The Threshold Accuracy metric, a.k.a. d1, d2, d3. + + Args: + power: The power to raise the threshold to. Often in [1, 2, 3]. + lmbda: The threshold to compare the max of ratio of predictions + to targets and its inverse to. Defaults to 1.25. + kwargs: Additional arguments to pass to the metric class. + """ + super().__init__(**kwargs) + if power < 0: + raise ValueError( + f"Power must be greater than or equal to 0. Got {power}." + ) + self.power = power + if lmbda < 1: + raise ValueError( + f"Lambda must be greater than or equal to 1. Got {lmbda}." + ) + self.lmbda = lmbda + self.add_state( + "values", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.values += torch.sum( + torch.max(preds / target, target / preds) < self.lmbda**self.power + ) + self.total += target.size(0) + + def compute(self) -> Tensor: + """Compute the Threshold Accuracy.""" + values = dim_zero_cat(self.values) + return values / self.total From 6d7ec606f89c9a984d4fd47231ee0a7ae25c879a Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 20:09:49 +0100 Subject: [PATCH 138/148] :books: Add metrics to doc --- docs/source/api.rst | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 9c815767..24abb1ef 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -159,13 +159,18 @@ Metrics AUSE BrierScore + CategoricalNLL CE Disagreement - Entropy - MutualInformation - CategoricalNLL DistributionNLL + Entropy FPR95 + Log10 + MeanGTRelativeAbsoluteError + MeanGTRelativeSquaredError + MutualInformation + SILog + ThresholdAccuracy Losses ------ From fddf8c0fe7f3386dc83779a219c3d94dab77b231 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 20:10:42 +0100 Subject: [PATCH 139/148] :white_check_mark: Refactor metrics' tests and add depth metrics' tests --- tests/_dummies/dataset.py | 2 +- .../datamodules/depth_regression/test_muad.py | 4 +- tests/metrics/classification/__init__.py | 0 .../{ => classification}/test_brier_score.py | 0 .../{ => classification}/test_calibration.py | 53 ++++------ .../{ => classification}/test_disagreement.py | 0 .../{ => classification}/test_entropy.py | 0 .../{ => classification}/test_fpr95.py | 2 +- .../test_grouping_loss.py | 0 .../test_mutual_information.py | 0 .../test_sparsification.py | 0 .../test_variation_ratio.py | 0 tests/metrics/regression/__init__.py | 0 .../test_depth_estimation_metrics.py | 97 +++++++++++++++++++ tests/metrics/{ => regression}/test_nll.py | 0 15 files changed, 119 insertions(+), 39 deletions(-) create mode 100644 tests/metrics/classification/__init__.py rename tests/metrics/{ => classification}/test_brier_score.py (100%) rename tests/metrics/{ => classification}/test_calibration.py (51%) rename tests/metrics/{ => classification}/test_disagreement.py (100%) rename tests/metrics/{ => classification}/test_entropy.py (100%) rename tests/metrics/{ => classification}/test_fpr95.py (93%) rename tests/metrics/{ => classification}/test_grouping_loss.py (100%) rename tests/metrics/{ => classification}/test_mutual_information.py (100%) rename tests/metrics/{ => classification}/test_sparsification.py (100%) rename tests/metrics/{ => classification}/test_variation_ratio.py (100%) create mode 100644 tests/metrics/regression/__init__.py create mode 100644 tests/metrics/regression/test_depth_estimation_metrics.py rename tests/metrics/{ => regression}/test_nll.py (100%) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index f14970a4..3e5e4024 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -214,7 +214,7 @@ def __len__(self) -> int: return len(self.data) -class DummyDepthRegressionDataset(Dataset): +class DummyDepthDataset(Dataset): def __init__( self, root: Path, diff --git a/tests/datamodules/depth_regression/test_muad.py b/tests/datamodules/depth_regression/test_muad.py index db24dd65..58023861 100644 --- a/tests/datamodules/depth_regression/test_muad.py +++ b/tests/datamodules/depth_regression/test_muad.py @@ -1,6 +1,6 @@ import pytest -from tests._dummies.dataset import DummyDepthRegressionDataset +from tests._dummies.dataset import DummyDepthDataset from torch_uncertainty.datamodules.depth_regression import MUADDataModule from torch_uncertainty.datasets import MUAD @@ -13,7 +13,7 @@ def test_camvid_main(self): assert dm.dataset == MUAD - dm.dataset = DummyDepthRegressionDataset + dm.dataset = DummyDepthDataset dm.prepare_data() dm.setup() diff --git a/tests/metrics/classification/__init__.py b/tests/metrics/classification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/metrics/test_brier_score.py b/tests/metrics/classification/test_brier_score.py similarity index 100% rename from tests/metrics/test_brier_score.py rename to tests/metrics/classification/test_brier_score.py diff --git a/tests/metrics/test_calibration.py b/tests/metrics/classification/test_calibration.py similarity index 51% rename from tests/metrics/test_calibration.py rename to tests/metrics/classification/test_calibration.py index 8f807f28..fb3c2035 100644 --- a/tests/metrics/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -5,41 +5,15 @@ from torch_uncertainty.metrics import CE -@pytest.fixture -def preds_binary() -> torch.Tensor: - return torch.as_tensor([0.25, 0.25, 0.55, 0.75, 0.75]) - - -@pytest.fixture -def targets_binary() -> torch.Tensor: - return torch.as_tensor([0, 0, 1, 1, 1]) - - -@pytest.fixture -def preds_multiclass() -> torch.Tensor: - return torch.as_tensor( - [ - [0.25, 0.20, 0.55], - [0.55, 0.05, 0.40], - [0.10, 0.30, 0.60], - [0.90, 0.05, 0.05], - ] - ) - - -@pytest.fixture -def targets_multiclass() -> torch.Tensor: - return torch.as_tensor([0, 1, 2, 0]) - - class TestCE: """Testing the CE metric class.""" - def test_plot_binary( - self, preds_binary: torch.Tensor, targets_binary: torch.Tensor - ) -> None: + def test_plot_binary(self) -> None: metric = CE(task="binary", n_bins=2, norm="l1") - metric.update(preds_binary, targets_binary) + metric.update( + torch.as_tensor([0.25, 0.25, 0.55, 0.75, 0.75]), + torch.as_tensor([0, 0, 1, 1, 1]), + ) fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) @@ -48,10 +22,20 @@ def test_plot_binary( plt.close(fig) def test_plot_multiclass( - self, preds_multiclass: torch.Tensor, targets_multiclass: torch.Tensor + self, ) -> None: metric = CE(task="multiclass", n_bins=3, norm="l1", num_classes=3) - metric.update(preds_multiclass, targets_multiclass) + metric.update( + torch.as_tensor( + [ + [0.25, 0.20, 0.55], + [0.55, 0.05, 0.40], + [0.10, 0.30, 0.60], + [0.90, 0.05, 0.05], + ] + ), + torch.as_tensor([0, 1, 2, 0]), + ) fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) @@ -59,10 +43,9 @@ def test_plot_multiclass( assert ax.get_ylabel() == "Success Rate (%)" plt.close(fig) - def test_bad_task_argument(self) -> None: + def test_errors(self) -> None: with pytest.raises(ValueError): _ = CE(task="geometric_mean") - def test_bad_num_classes_argument(self) -> None: with pytest.raises(ValueError): _ = CE(task="multiclass", num_classes=1.5) diff --git a/tests/metrics/test_disagreement.py b/tests/metrics/classification/test_disagreement.py similarity index 100% rename from tests/metrics/test_disagreement.py rename to tests/metrics/classification/test_disagreement.py diff --git a/tests/metrics/test_entropy.py b/tests/metrics/classification/test_entropy.py similarity index 100% rename from tests/metrics/test_entropy.py rename to tests/metrics/classification/test_entropy.py diff --git a/tests/metrics/test_fpr95.py b/tests/metrics/classification/test_fpr95.py similarity index 93% rename from tests/metrics/test_fpr95.py rename to tests/metrics/classification/test_fpr95.py index 46b40dc8..e94e785c 100644 --- a/tests/metrics/test_fpr95.py +++ b/tests/metrics/classification/test_fpr95.py @@ -1,7 +1,7 @@ import pytest import torch -from torch_uncertainty.metrics.fpr95 import FPR95, FPRx +from torch_uncertainty.metrics.classification.fpr95 import FPR95, FPRx class TestFPR95: diff --git a/tests/metrics/test_grouping_loss.py b/tests/metrics/classification/test_grouping_loss.py similarity index 100% rename from tests/metrics/test_grouping_loss.py rename to tests/metrics/classification/test_grouping_loss.py diff --git a/tests/metrics/test_mutual_information.py b/tests/metrics/classification/test_mutual_information.py similarity index 100% rename from tests/metrics/test_mutual_information.py rename to tests/metrics/classification/test_mutual_information.py diff --git a/tests/metrics/test_sparsification.py b/tests/metrics/classification/test_sparsification.py similarity index 100% rename from tests/metrics/test_sparsification.py rename to tests/metrics/classification/test_sparsification.py diff --git a/tests/metrics/test_variation_ratio.py b/tests/metrics/classification/test_variation_ratio.py similarity index 100% rename from tests/metrics/test_variation_ratio.py rename to tests/metrics/classification/test_variation_ratio.py diff --git a/tests/metrics/regression/__init__.py b/tests/metrics/regression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/metrics/regression/test_depth_estimation_metrics.py b/tests/metrics/regression/test_depth_estimation_metrics.py new file mode 100644 index 00000000..fd0d3ccf --- /dev/null +++ b/tests/metrics/regression/test_depth_estimation_metrics.py @@ -0,0 +1,97 @@ +import pytest +import torch + +from torch_uncertainty.metrics import ( + Log10, + MeanGTRelativeAbsoluteError, + MeanGTRelativeSquaredError, + SILog, + ThresholdAccuracy, +) + + +class TestLog10: + """Testing the Log10 metric.""" + + def test_main(self): + metric = Log10() + preds = torch.rand((10, 2)).double() + targets = torch.rand((10, 2)).double() + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert torch.mean( + preds.log10().flatten() - targets.log10().flatten() + ) == pytest.approx(metric.compute()) + + +class TestMeanGTRelativeAbsoluteError: + """Testing the MeanGTRelativeAbsoluteError metric.""" + + def test_main(self): + metric = MeanGTRelativeAbsoluteError() + preds = torch.rand((10, 2)) + targets = torch.rand((10, 2)) + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert (torch.abs(preds - targets) / targets).mean() == pytest.approx( + metric.compute() + ) + + +class TestMeanGTRelativeSquaredError: + """Testing the MeanGTRelativeSquaredError metric.""" + + def test_main(self): + metric = MeanGTRelativeSquaredError() + preds = torch.rand((10, 2)) + targets = torch.rand((10, 2)) + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert torch.flatten( + (preds - targets) ** 2 / targets + ).mean() == pytest.approx(metric.compute()) + + +class TestSILog: + """Testing the SILog metric.""" + + def test_main(self): + metric = SILog() + preds = torch.rand((10, 2)).double() + targets = torch.rand((10, 2)).double() + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + mean_log_dists = torch.mean( + targets.flatten().log() - preds.flatten().log() + ) + assert torch.mean( + (preds.flatten().log() - targets.flatten().log() + mean_log_dists) + ** 2 + ) == pytest.approx(metric.compute()) + + +class TestThresholdAccuracy: + """Testing the ThresholdAccuracy metric.""" + + def test_main(self): + metric = ThresholdAccuracy(power=1, lmbda=1.25) + preds = torch.ones((10, 2)) + targets = torch.ones((10, 2)) * 1.3 + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert metric.compute() == 0.0 + + metric = ThresholdAccuracy(power=1, lmbda=1.25) + preds = torch.cat( + [torch.ones((10, 2)) * 1.2, torch.ones((10, 2))], dim=0 + ) + targets = torch.ones((20, 2)) * 1.3 + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert metric.compute() == 0.5 + + def test_error(self): + with pytest.raises(ValueError, match="Power must be"): + ThresholdAccuracy(power=-1) + with pytest.raises(ValueError, match="Lambda must be"): + ThresholdAccuracy(power=1, lmbda=0.5) diff --git a/tests/metrics/test_nll.py b/tests/metrics/regression/test_nll.py similarity index 100% rename from tests/metrics/test_nll.py rename to tests/metrics/regression/test_nll.py From c20e7635d464d8f4efb1ffa354a369c378c5b893 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 20:31:04 +0100 Subject: [PATCH 140/148] :sparkles: Add MSELog --- .../test_depth_estimation_metrics.py | 15 ++++++++ torch_uncertainty/metrics/__init__.py | 1 + .../metrics/regression/__init__.py | 1 + .../metrics/regression/mse_log.py | 34 +++++++++++++++++++ 4 files changed, 51 insertions(+) create mode 100644 torch_uncertainty/metrics/regression/mse_log.py diff --git a/tests/metrics/regression/test_depth_estimation_metrics.py b/tests/metrics/regression/test_depth_estimation_metrics.py index fd0d3ccf..0c1fbea0 100644 --- a/tests/metrics/regression/test_depth_estimation_metrics.py +++ b/tests/metrics/regression/test_depth_estimation_metrics.py @@ -5,6 +5,7 @@ Log10, MeanGTRelativeAbsoluteError, MeanGTRelativeSquaredError, + MeanSquaredLogError, SILog, ThresholdAccuracy, ) @@ -95,3 +96,17 @@ def test_error(self): ThresholdAccuracy(power=-1) with pytest.raises(ValueError, match="Lambda must be"): ThresholdAccuracy(power=1, lmbda=0.5) + + +class TestMeanSquaredLogError: + """Testing the MeanSquaredLogError metric.""" + + def test_main(self): + metric = MeanSquaredLogError() + preds = torch.rand((10, 2)).double() + targets = torch.rand((10, 2)).double() + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + assert torch.mean( + (preds.log() - targets.log()).flatten() ** 2 + ) == pytest.approx(metric.compute()) diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 5387b56e..207d0c9b 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -17,6 +17,7 @@ Log10, MeanGTRelativeAbsoluteError, MeanGTRelativeSquaredError, + MeanSquaredLogError, SILog, ThresholdAccuracy, ) diff --git a/torch_uncertainty/metrics/regression/__init__.py b/torch_uncertainty/metrics/regression/__init__.py index 7e917aa7..50f26c74 100644 --- a/torch_uncertainty/metrics/regression/__init__.py +++ b/torch_uncertainty/metrics/regression/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .log10 import Log10 +from .mse_log import MeanSquaredLogError from .nll import DistributionNLL from .relative_error import ( MeanGTRelativeAbsoluteError, diff --git a/torch_uncertainty/metrics/regression/mse_log.py b/torch_uncertainty/metrics/regression/mse_log.py new file mode 100644 index 00000000..caae3186 --- /dev/null +++ b/torch_uncertainty/metrics/regression/mse_log.py @@ -0,0 +1,34 @@ +from torch import Tensor +from torchmetrics import MeanSquaredError + + +class MeanSquaredLogError(MeanSquaredError): + def __init__(self, squared: bool = True, **kwargs) -> None: + r"""`Compute MeanSquaredLogError`_ (MSELog). + + .. math:: \text{MSELog} = \frac{1}{N}\sum_i^N (\log \hat{y_i} - \log y_i)^2 + + where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``mse_log`` (:class:`~torch.Tensor`): A tensor with the + relative mean absolute error over the state + + Args: + squared: If True returns MSELog value, if False returns EMSELog value. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Reference: + As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + """ + super().__init__(squared, **kwargs) + + def update(self, pred: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + return super().update(pred.log(), target.log()) From 36c836414ed0205b7c8319c0b37854bd2eeb5458 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 20:35:35 +0100 Subject: [PATCH 141/148] :bug: Also log prob metrics in reg routine --- torch_uncertainty/routines/regression.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 3560173a..24cbb0a4 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -167,13 +167,17 @@ def validation_step( preds = preds.mean(dim=1) self.val_metrics.update(preds, targets) - if self.probabilistic: self.val_prob_metrics.update(mixture, targets) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute()) self.val_metrics.reset() + if self.probabilistic: + self.log_dict( + self.val_prob_metrics.compute(), + ) + self.val_prob_metrics.reset() def test_step( self, @@ -206,9 +210,8 @@ def test_step( preds = preds.mean(dim=1) self.test_metrics.update(preds, targets) - if self.probabilistic: - self.val_prob_metrics.update(mixture, targets) + self.test_prob_metrics.update(mixture, targets) def on_test_epoch_end(self) -> None: self.log_dict( @@ -216,6 +219,12 @@ def on_test_epoch_end(self) -> None: ) self.test_metrics.reset() + if self.probabilistic: + self.log_dict( + self.test_prob_metrics.compute(), + ) + self.test_prob_metrics.reset() + def _regression_routine_checks(num_estimators: int, output_dim: int) -> None: if num_estimators < 1: From 0fd00ce363e1ddc73d43d1e98f2a3d25b29d2dc4 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 26 Mar 2024 21:35:54 +0100 Subject: [PATCH 142/148] :shirt: Add capital letters to metrics' logs & misc --- experiments/readme.md | 10 +++++++++- experiments/regression/uci_datasets/readme.md | 17 +++++++++++++++++ torch_uncertainty/datasets/muad.py | 1 - torch_uncertainty/routines/classification.py | 18 +++++++++--------- 4 files changed, 35 insertions(+), 11 deletions(-) create mode 100644 experiments/regression/uci_datasets/readme.md diff --git a/experiments/readme.md b/experiments/readme.md index 19d6dd0a..0035c5a7 100644 --- a/experiments/readme.md +++ b/experiments/readme.md @@ -1,11 +1,19 @@ # Experiments -Torch-Uncertainty proposes various benchmarks to evaluate the uncertainty estimation methods. +Torch-Uncertainty proposes various benchmarks to evaluate uncertainty quantification methods. ## Classification *Work in progress* +## Segmentation + +*Work in progress* + ## Regression *Work in progress* + +## Monocular Depth Estimation + +*Work in progress* diff --git a/experiments/regression/uci_datasets/readme.md b/experiments/regression/uci_datasets/readme.md new file mode 100644 index 00000000..3e0ec7b0 --- /dev/null +++ b/experiments/regression/uci_datasets/readme.md @@ -0,0 +1,17 @@ +# UCI Regression - Benchmark + +This folder contains the code to train models on the UCI regression datasets. The task is to predict (a) continuous target variable(s). + +Three experiments are provided: + +```bash +python mlp.py fit --config configs/pw_mlp_kin8nm.yaml +``` + +```bash +python mlp.py fit --config configs/gaussian_mlp_kin8nm.yaml +``` + +```bash +python mlp.py fit --config configs/laplace_mlp_kin8nm.yaml +``` diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 9c51a9f6..ffe842e8 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -38,7 +38,6 @@ class MUAD(VisionDataset): } targets: list[Path] = [] - # TODO: Add depth regression mode def __init__( self, root: str | Path, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index b8aa7bf1..5c15c2de 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -145,7 +145,7 @@ def __init__( cls_metrics = MetricCollection( { "acc": Accuracy(task="binary"), - "ece": CE(task="binary"), + "ECE": CE(task="binary"), "brier": BrierScore(num_classes=1), }, compute_groups=False, @@ -153,11 +153,11 @@ def __init__( else: cls_metrics = MetricCollection( { - "nll": CategoricalNLL(), + "NLL": CategoricalNLL(), "acc": Accuracy( task="multiclass", num_classes=self.num_classes ), - "ece": CE(task="multiclass", num_classes=self.num_classes), + "ECE": CE(task="multiclass", num_classes=self.num_classes), "brier": BrierScore(num_classes=self.num_classes), }, compute_groups=False, @@ -174,11 +174,11 @@ def __init__( if self.eval_ood: ood_metrics = MetricCollection( { - "fpr95": FPR95(pos_label=1), - "auroc": BinaryAUROC(), - "aupr": BinaryAveragePrecision(), + "FPR95": FPR95(pos_label=1), + "AUROC": BinaryAUROC(), + "AUPR": BinaryAveragePrecision(), }, - compute_groups=[["auroc", "aupr"], ["fpr95"]], + compute_groups=[["AUROC", "AUPR"], ["FPR95"]], ) self.test_ood_metrics = ood_metrics.clone(prefix="ood/") self.test_entropy_ood = Entropy() @@ -216,7 +216,7 @@ def __init__( ens_metrics = MetricCollection( { "disagreement": Disagreement(), - "mi": MutualInformation(), + "MI": MutualInformation(), "entropy": Entropy(), } ) @@ -537,7 +537,7 @@ def on_test_epoch_end(self) -> None: if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( - "Calibration Plot", self.test_cls_metrics["ece"].plot()[0] + "Calibration Plot", self.test_cls_metrics["ECE"].plot()[0] ) # plot histograms of logits and likelihoods From 5fe7361d293c29c0731243f3f3f38e04d283017b Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 27 Mar 2024 11:07:56 +0100 Subject: [PATCH 143/148] :zap: Update loss definition in config files --- experiments/classification/cifar10/configs/resnet.yaml | 2 +- .../classification/cifar10/configs/resnet18/batched.yaml | 3 +-- .../classification/cifar10/configs/resnet18/masked.yaml | 3 +-- experiments/classification/cifar10/configs/resnet18/mimo.yaml | 3 +-- .../classification/cifar10/configs/resnet18/packed.yaml | 3 +-- .../classification/cifar10/configs/resnet18/standard.yaml | 3 +-- .../classification/cifar10/configs/resnet50/batched.yaml | 3 +-- .../classification/cifar10/configs/resnet50/masked.yaml | 3 +-- experiments/classification/cifar10/configs/resnet50/mimo.yaml | 3 +-- .../classification/cifar10/configs/resnet50/packed.yaml | 3 +-- .../classification/cifar10/configs/resnet50/standard.yaml | 3 +-- .../classification/cifar10/configs/wideresnet28x10.yaml | 4 +--- .../cifar10/configs/wideresnet28x10/batched.yaml | 3 +-- .../cifar10/configs/wideresnet28x10/masked.yaml | 3 +-- .../classification/cifar10/configs/wideresnet28x10/mimo.yaml | 3 +-- .../cifar10/configs/wideresnet28x10/packed.yaml | 3 +-- .../cifar10/configs/wideresnet28x10/standard.yaml | 3 +-- experiments/classification/cifar100/configs/resnet.yaml | 3 +-- .../classification/cifar100/configs/resnet18/batched.yaml | 3 +-- .../classification/cifar100/configs/resnet18/masked.yaml | 3 +-- .../classification/cifar100/configs/resnet18/mimo.yaml | 3 +-- .../classification/cifar100/configs/resnet18/packed.yaml | 3 +-- .../classification/cifar100/configs/resnet18/standard.yaml | 3 +-- .../classification/cifar100/configs/resnet50/batched.yaml | 3 +-- .../classification/cifar100/configs/resnet50/masked.yaml | 3 +-- .../classification/cifar100/configs/resnet50/mimo.yaml | 3 +-- .../classification/cifar100/configs/resnet50/packed.yaml | 3 +-- .../classification/cifar100/configs/resnet50/standard.yaml | 3 +-- experiments/segmentation/camvid/configs/segformer.yaml | 3 +-- experiments/segmentation/cityscapes/configs/segformer.yaml | 3 +-- experiments/segmentation/muad/configs/segformer.yaml | 3 +-- .../{depth_regression => depth_estimation}/__init__.py | 0 .../{depth_regression => depth_estimation}/test_muad.py | 2 +- .../{depth_regression => depth_estimation}/__init__.py | 0 .../{depth_regression => depth_estimation}/muad.py | 0 35 files changed, 32 insertions(+), 63 deletions(-) rename tests/datamodules/{depth_regression => depth_estimation}/__init__.py (100%) rename tests/datamodules/{depth_regression => depth_estimation}/test_muad.py (92%) rename torch_uncertainty/datamodules/{depth_regression => depth_estimation}/__init__.py (100%) rename torch_uncertainty/datamodules/{depth_regression => depth_estimation}/muad.py (100%) diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index dbbe41e9..4585471b 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -8,7 +8,7 @@ trainer: logger: class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: - save_dir: logs/ + save_dir: logs/resnet default_hp_metric: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index 59369531..e5534a98 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: batched arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index 79aa2fe7..93e32f64 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: masked arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index d73cb421..11ba94c6 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: mimo arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index e920b354..f0450fda 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: packed arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index 0184abf1..06c57899 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: std arch: 18 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index 1352eb88..dcc80071 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: batched arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index dea33597..85d5beec 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: masked arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index d9575e9f..c5afffc6 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: mimo arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index aa4d4e76..da305e2a 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: packed arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index f24e039e..37115d8b 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: std arch: 50 style: cifar diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index 82e91c72..0f2952ab 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -10,7 +10,6 @@ trainer: class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: save_dir: logs/wideresnet28x10 - name: standard default_hp_metric: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint @@ -29,8 +28,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss style: cifar data: root: ./data diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index 74e806db..dddfdd90 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: batched style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 437b9243..204209c2 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: masked style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index 45bc95cd..5746e940 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: mimo style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index 2fec727e..834ed7b9 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: packed style: cifar num_estimators: 4 diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index 15c3a848..e555353e 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: std style: cifar data: diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index 2ba51027..dbbe41e9 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -27,8 +27,7 @@ trainer: model: num_classes: 10 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss style: cifar data: root: ./data diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index 8c8c0d77..847c16ff 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: batched arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index e184c07d..d0c7a4a7 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: masked arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index 983dec22..868e4f3e 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: mimo arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 099c93b7..8e3e8cc1 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: packed arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index 8de85cc4..91fa7f08 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: standard arch: 18 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index 752158e2..ce44c862 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: batched arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index 306dc50c..656de967 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: masked arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index a24e9693..80a3152b 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: mimo arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index 4c5384fc..5fd414dc 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: packed arch: 50 style: cifar diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index 5ae26e1f..b78b2476 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -29,8 +29,7 @@ trainer: model: num_classes: 100 in_channels: 3 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: standard arch: 50 style: cifar diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index 9ec8fca7..7cbb001b 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -6,8 +6,7 @@ trainer: devices: 1 model: num_classes: 12 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: std arch: 0 num_estimators: 1 diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index 366450a8..b2abf11e 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -7,8 +7,7 @@ trainer: max_steps: 160000 model: num_classes: 19 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: std arch: 0 num_estimators: 1 diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml index 366450a8..b2abf11e 100644 --- a/experiments/segmentation/muad/configs/segformer.yaml +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -7,8 +7,7 @@ trainer: max_steps: 160000 model: num_classes: 19 - loss: - class_path: torch.nn.CrossEntropyLoss + loss: CrossEntropyLoss version: std arch: 0 num_estimators: 1 diff --git a/tests/datamodules/depth_regression/__init__.py b/tests/datamodules/depth_estimation/__init__.py similarity index 100% rename from tests/datamodules/depth_regression/__init__.py rename to tests/datamodules/depth_estimation/__init__.py diff --git a/tests/datamodules/depth_regression/test_muad.py b/tests/datamodules/depth_estimation/test_muad.py similarity index 92% rename from tests/datamodules/depth_regression/test_muad.py rename to tests/datamodules/depth_estimation/test_muad.py index 58023861..cce2088a 100644 --- a/tests/datamodules/depth_regression/test_muad.py +++ b/tests/datamodules/depth_estimation/test_muad.py @@ -1,7 +1,7 @@ import pytest from tests._dummies.dataset import DummyDepthDataset -from torch_uncertainty.datamodules.depth_regression import MUADDataModule +from torch_uncertainty.datamodules.depth_estimation import MUADDataModule from torch_uncertainty.datasets import MUAD diff --git a/torch_uncertainty/datamodules/depth_regression/__init__.py b/torch_uncertainty/datamodules/depth_estimation/__init__.py similarity index 100% rename from torch_uncertainty/datamodules/depth_regression/__init__.py rename to torch_uncertainty/datamodules/depth_estimation/__init__.py diff --git a/torch_uncertainty/datamodules/depth_regression/muad.py b/torch_uncertainty/datamodules/depth_estimation/muad.py similarity index 100% rename from torch_uncertainty/datamodules/depth_regression/muad.py rename to torch_uncertainty/datamodules/depth_estimation/muad.py From c7a81da07e3e6b158656b567b34a5f789a00ec1f Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 27 Mar 2024 11:45:38 +0100 Subject: [PATCH 144/148] :zap: Update metric logging names for better legibility --- .../classification/cifar10/configs/resnet.yaml | 4 ++-- .../cifar10/configs/resnet18/batched.yaml | 4 ++-- .../cifar10/configs/resnet18/masked.yaml | 4 ++-- .../cifar10/configs/resnet18/mimo.yaml | 4 ++-- .../cifar10/configs/resnet18/packed.yaml | 4 ++-- .../cifar10/configs/resnet18/standard.yaml | 4 ++-- .../cifar10/configs/resnet50/batched.yaml | 4 ++-- .../cifar10/configs/resnet50/masked.yaml | 4 ++-- .../cifar10/configs/resnet50/mimo.yaml | 4 ++-- .../cifar10/configs/resnet50/packed.yaml | 4 ++-- .../cifar10/configs/resnet50/standard.yaml | 4 ++-- .../cifar10/configs/wideresnet28x10.yaml | 4 ++-- .../cifar10/configs/wideresnet28x10/batched.yaml | 4 ++-- .../cifar10/configs/wideresnet28x10/masked.yaml | 4 ++-- .../cifar10/configs/wideresnet28x10/mimo.yaml | 4 ++-- .../cifar10/configs/wideresnet28x10/packed.yaml | 4 ++-- .../cifar10/configs/wideresnet28x10/standard.yaml | 4 ++-- .../classification/cifar100/configs/resnet.yaml | 4 ++-- .../cifar100/configs/resnet18/batched.yaml | 4 ++-- .../cifar100/configs/resnet18/masked.yaml | 4 ++-- .../cifar100/configs/resnet18/mimo.yaml | 4 ++-- .../cifar100/configs/resnet18/packed.yaml | 4 ++-- .../cifar100/configs/resnet18/standard.yaml | 4 ++-- .../cifar100/configs/resnet50/batched.yaml | 4 ++-- .../cifar100/configs/resnet50/masked.yaml | 4 ++-- .../cifar100/configs/resnet50/mimo.yaml | 4 ++-- .../cifar100/configs/resnet50/packed.yaml | 4 ++-- .../cifar100/configs/resnet50/standard.yaml | 4 ++-- .../uci_datasets/configs/gaussian_mlp_kin8nm.yaml | 7 +++---- .../uci_datasets/configs/laplace_mlp_kin8nm.yaml | 7 +++---- .../uci_datasets/configs/pw_mlp_kin8nm.yaml | 7 +++---- torch_uncertainty/routines/classification.py | 12 ++++++------ torch_uncertainty/routines/regression.py | 8 ++++---- torch_uncertainty/routines/segmentation.py | 12 ++++++------ 34 files changed, 81 insertions(+), 84 deletions(-) diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index 4585471b..aa053391 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index e5534a98..e71130f9 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index 93e32f64..202ba0c4 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index 11ba94c6..e45988db 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index f0450fda..79bd47f3 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index 06c57899..b5406a28 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index dcc80071..7133cc5f 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index 85d5beec..00eaf9c3 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index c5afffc6..d7d23ccd 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index da305e2a..2ecc4e6a 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index 37115d8b..1797df73 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index 0f2952ab..fb1bea00 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -14,7 +14,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -22,7 +22,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index dddfdd90..f4010902 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index 204209c2..ae31197b 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index 5746e940..31a09775 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index 834ed7b9..a46c6fac 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index e555353e..c5cd566f 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index dbbe41e9..d72a2c2b 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index 847c16ff..61393563 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index d0c7a4a7..31f6e2a8 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index 868e4f3e..7a3aec17 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 8e3e8cc1..4e14cce9 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index 91fa7f08..f8e9b821 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index ce44c862..69259b96 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index 656de967..a1707666 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index 80a3152b..987a632d 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index 5fd414dc..954caf11 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index b78b2476..575b6e6f 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/acc + monitor: cls_val/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/acc + monitor: cls_val/Acc patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml index 225eb6ee..2e9b056d 100644 --- a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: reg_val/nll + monitor: reg_val/NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: reg_val/nll + monitor: reg_val/NLL patience: 1000 check_finite: true model: @@ -31,8 +31,7 @@ model: in_features: 8 hidden_dims: - 100 - loss: - class_path: torch_uncertainty.losses.DistributionNLLLoss + loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal data: diff --git a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml index 4f4d5345..d95e09a1 100644 --- a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: reg_val/nll + monitor: reg_val/NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: reg_val/nll + monitor: reg_val/NLL patience: 1000 check_finite: true model: @@ -31,8 +31,7 @@ model: in_features: 8 hidden_dims: - 100 - loss: - class_path: torch_uncertainty.losses.DistributionNLLLoss + loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace data: diff --git a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml index 1cb32c36..b6ce9fad 100644 --- a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: reg_val/mse + monitor: reg_val/MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: reg_val/mse + monitor: reg_val/MSE patience: 1000 check_finite: true model: @@ -31,8 +31,7 @@ model: in_features: 8 hidden_dims: - 100 - loss: - class_path: torch.nn.MSELoss + loss: MSELoss version: std data: root: ./data diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 5c15c2de..29019873 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -144,9 +144,9 @@ def __init__( if self.binary_cls: cls_metrics = MetricCollection( { - "acc": Accuracy(task="binary"), + "Acc": Accuracy(task="binary"), "ECE": CE(task="binary"), - "brier": BrierScore(num_classes=1), + "Brier": BrierScore(num_classes=1), }, compute_groups=False, ) @@ -154,11 +154,11 @@ def __init__( cls_metrics = MetricCollection( { "NLL": CategoricalNLL(), - "acc": Accuracy( + "Acc": Accuracy( task="multiclass", num_classes=self.num_classes ), "ECE": CE(task="multiclass", num_classes=self.num_classes), - "brier": BrierScore(num_classes=self.num_classes), + "Brier": BrierScore(num_classes=self.num_classes), }, compute_groups=False, ) @@ -215,9 +215,9 @@ def __init__( if self.num_estimators > 1: ens_metrics = MetricCollection( { - "disagreement": Disagreement(), + "Disagreement": Disagreement(), "MI": MutualInformation(), - "entropy": Entropy(), + "Entropy": Entropy(), } ) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 24cbb0a4..66367ab4 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -72,9 +72,9 @@ def __init__( reg_metrics = MetricCollection( { - "mae": MeanAbsoluteError(), - "mse": MeanSquaredError(squared=True), - "rmse": MeanSquaredError(squared=False), + "MAE": MeanAbsoluteError(), + "MSE": MeanSquaredError(squared=True), + "RMSE": MeanSquaredError(squared=False), }, compute_groups=True, ) @@ -84,7 +84,7 @@ def __init__( if self.probabilistic: reg_prob_metrics = MetricCollection( - DistributionNLL(reduction="mean") + {"NLL": DistributionNLL(reduction="mean")} ) self.val_prob_metrics = reg_prob_metrics.clone(prefix="reg_val/") self.test_prob_metrics = reg_prob_metrics.clone(prefix="reg_test/") diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index eb99e1e8..a3227dcf 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -64,13 +64,13 @@ def __init__( # metrics seg_metrics = MetricCollection( { - "acc": Accuracy(task="multiclass", num_classes=num_classes), - "ece": CE(task="multiclass", num_classes=num_classes), - "mean_iou": MeanIntersectionOverUnion(num_classes=num_classes), - "brier": BrierScore(num_classes=num_classes), - "nll": CategoricalNLL(), + "Acc": Accuracy(task="multiclass", num_classes=num_classes), + "ECE": CE(task="multiclass", num_classes=num_classes), + "mIoU": MeanIntersectionOverUnion(num_classes=num_classes), + "Brier": BrierScore(num_classes=num_classes), + "NLL": CategoricalNLL(), }, - compute_groups=[["acc", "mean_iou"], ["ece"], ["brier"], ["nll"]], + compute_groups=[["Acc", "mIoU"], ["ECE"], ["Brier"], ["NLL"]], ) self.val_seg_metrics = seg_metrics.clone(prefix="seg_val/") From a1a93190e961df3809e89d7eda4bffa9e7feaba0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 27 Mar 2024 11:05:15 +0100 Subject: [PATCH 145/148] :sparkles: Add lmbda parameter to SILog --- torch_uncertainty/metrics/regression/log10.py | 2 -- torch_uncertainty/metrics/regression/silog.py | 17 ++++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/torch_uncertainty/metrics/regression/log10.py b/torch_uncertainty/metrics/regression/log10.py index 1c83cd39..acd9a0e1 100644 --- a/torch_uncertainty/metrics/regression/log10.py +++ b/torch_uncertainty/metrics/regression/log10.py @@ -16,8 +16,6 @@ def __init__(self, **kwargs) -> None: - :attr:`preds`: :math:`(N)` - :attr:`target`: :math:`(N)` - where :math:`N` is the batch size. - Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. """ diff --git a/torch_uncertainty/metrics/regression/silog.py b/torch_uncertainty/metrics/regression/silog.py index c6bea3b7..370b7036 100644 --- a/torch_uncertainty/metrics/regression/silog.py +++ b/torch_uncertainty/metrics/regression/silog.py @@ -7,7 +7,7 @@ class SILog(Metric): - def __init__(self, **kwargs: Any) -> None: + def __init__(self, lmbda: float = 1, **kwargs: Any) -> None: r"""The Scale-Invariant Logarithmic Loss metric. .. math:: \text{SILog} = \frac{1}{N} \sum_{i=1}^{N} \left(\log(y_i) - \log(\hat{y_i})\right)^2 - \left(\frac{1}{N} \sum_{i=1}^{N} \log(y_i) \right)^2 @@ -18,12 +18,18 @@ def __init__(self, **kwargs: Any) -> None: - :attr:`pred`: :math:`(N)` - :attr:`target`: :math:`(N)` - where :math:`N` is the batch size. + Args: + lmbda: The regularization parameter on the variance of error (default 1). + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Reference: Depth Map Prediction from a Single Image using a Multi-Scale Deep Network. + David Eigen, Christian Puhrsch, Rob Fergus. NeurIPS 2014. + From Big to Small: Multi-Scale Local Planar Guidance for Monocular Depth Estimation. + Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh. For the lambda parameter. """ super().__init__(**kwargs) + self.lmbda = lmbda self.add_state("log_dists", default=[], dist_reduce_fx="cat") def update(self, pred: Tensor, target: Tensor) -> None: @@ -33,6 +39,7 @@ def update(self, pred: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute the Scale-Invariant Logarithmic Loss.""" log_dists = dim_zero_cat(self.log_dists) - return torch.mean(log_dists**2) - torch.sum(log_dists) ** 2 / ( - log_dists.size(0) * log_dists.size(0) - ) + num_samples = log_dists.size(0) + return torch.mean(log_dists**2) - self.lmbda * torch.sum( + log_dists + ) ** 2 / (num_samples * num_samples) From 23c3bdd43305a5bf594fdd5c9d837abf02c4c39c Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 27 Mar 2024 11:15:18 +0100 Subject: [PATCH 146/148] :sparkles: Add dist_rearrange --- torch_uncertainty/routines/regression.py | 12 +++++++++--- torch_uncertainty/utils/distributions.py | 24 ++++++++---------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 66367ab4..3124856d 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -12,7 +12,7 @@ from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torch_uncertainty.metrics.regression.nll import DistributionNLL -from torch_uncertainty.utils.distributions import squeeze_dist, to_ensemble_dist +from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist class RegressionRoutine(LightningModule): @@ -155,7 +155,10 @@ def validation_step( if self.probabilistic: ens_dist = Independent( - to_ensemble_dist(preds, num_estimators=self.num_estimators), 1 + dist_rearrange( + preds, "(m b) c -> b m c", m=self.num_estimators + ), + 1, ) mix = Categorical( torch.ones(self.num_estimators, device=self.device) @@ -198,7 +201,10 @@ def test_step( if self.probabilistic: ens_dist = Independent( - to_ensemble_dist(preds, num_estimators=self.num_estimators), 1 + dist_rearrange( + preds, "(m b) c -> b m c", m=self.num_estimators + ), + 1, ) mix = Categorical( torch.ones(self.num_estimators, device=self.device) diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 5cbdb5a5..ab2336bc 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -78,27 +78,19 @@ def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: ) -def to_ensemble_dist( - distribution: Distribution, num_estimators: int = 1 +def dist_rearrange( + distribution: Distribution, pattern: str, **axes_lengths: int ) -> Distribution: dist_type = type(distribution) if isinstance(distribution, Normal | Laplace): - loc = rearrange(distribution.loc, "(n b) c -> b n c", n=num_estimators) - scale = rearrange( - distribution.scale, "(n b) c -> b n c", n=num_estimators - ) + loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) + scale = rearrange(distribution.scale, pattern=pattern, **axes_lengths) return dist_type(loc=loc, scale=scale) if isinstance(distribution, NormalInverseGamma): - loc = rearrange(distribution.loc, "(n b) c -> b n c", n=num_estimators) - lmbda = rearrange( - distribution.lmbda, "(n b) c -> b n c", n=num_estimators - ) - alpha = rearrange( - distribution.alpha, "(n b) c -> b n c", n=num_estimators - ) - beta = rearrange( - distribution.beta, "(n b) c -> b n c", n=num_estimators - ) + loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) + lmbda = rearrange(distribution.lmbda, pattern=pattern, **axes_lengths) + alpha = rearrange(distribution.alpha, pattern=pattern, **axes_lengths) + beta = rearrange(distribution.beta, pattern=pattern, **axes_lengths) return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( f"Ensemble distribution of {dist_type} is not supported." From 6c769de1d63cd53c110db21d60345f42c285d48d Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 27 Mar 2024 12:44:04 +0100 Subject: [PATCH 147/148] :fire: Remove useless datamodule.setup("test") in test files --- tests/datamodules/classification/test_cifar100_datamodule.py | 1 - tests/datamodules/classification/test_imagenet_datamodule.py | 1 - .../datamodules/classification/test_tiny_imagenet_datamodule.py | 1 - .../datamodules/classification/test_uci_regression_datamodule.py | 1 - 4 files changed, 4 deletions(-) diff --git a/tests/datamodules/classification/test_cifar100_datamodule.py b/tests/datamodules/classification/test_cifar100_datamodule.py index 487e7214..e24af243 100644 --- a/tests/datamodules/classification/test_cifar100_datamodule.py +++ b/tests/datamodules/classification/test_cifar100_datamodule.py @@ -20,7 +20,6 @@ def test_cifar100(self): dm.prepare_data() dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() diff --git a/tests/datamodules/classification/test_imagenet_datamodule.py b/tests/datamodules/classification/test_imagenet_datamodule.py index 80b73aff..9088a701 100644 --- a/tests/datamodules/classification/test_imagenet_datamodule.py +++ b/tests/datamodules/classification/test_imagenet_datamodule.py @@ -17,7 +17,6 @@ def test_imagenet(self): dm.ood_dataset = DummyClassificationDataset dm.prepare_data() dm.setup() - dm.setup("test") path = ( Path(__file__).parent.resolve() / "../../assets/dummy_indices.yaml" diff --git a/tests/datamodules/classification/test_tiny_imagenet_datamodule.py b/tests/datamodules/classification/test_tiny_imagenet_datamodule.py index 5885fdb3..007b5f4d 100644 --- a/tests/datamodules/classification/test_tiny_imagenet_datamodule.py +++ b/tests/datamodules/classification/test_tiny_imagenet_datamodule.py @@ -34,7 +34,6 @@ def test_tiny_imagenet(self): dm.prepare_data() dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() diff --git a/tests/datamodules/classification/test_uci_regression_datamodule.py b/tests/datamodules/classification/test_uci_regression_datamodule.py index 9c8155fa..1297666c 100644 --- a/tests/datamodules/classification/test_uci_regression_datamodule.py +++ b/tests/datamodules/classification/test_uci_regression_datamodule.py @@ -16,7 +16,6 @@ def test_uci_regression(self): dm.prepare_data() dm.val_split = 0.5 dm.setup() - dm.setup("test") dm.train_dataloader() dm.val_dataloader() From 06b37c399b136fac03aa9f072037ae5d668b5571 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 27 Mar 2024 13:11:54 +0100 Subject: [PATCH 148/148] :bug: Fix typo in test file --- tests/test_losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_losses.py b/tests/test_losses.py index d88ba2a1..f368e6cc 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -58,7 +58,7 @@ def test_failures(self): ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1.5) -class TestNIGLoss: +class TestDERLoss: """Testing the DERLoss class.""" def test_main(self):