diff --git a/pyproject.toml b/pyproject.toml index 38708ec5..bb52c317 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,18 +25,16 @@ keywords = [ ] classifiers = [ "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3 :: Only", ] dependencies = [ "timm", "lightning[pytorch-extra]>=2.0", "torchvision>=0.16", "einops", - "matplotlib", - "rich>=10.2.2", "seaborn", ] @@ -82,7 +80,7 @@ repository = "https://github.com/ENSTA-U2IS-AI/torch-uncertainty.git" name = "torch_uncertainty" [tool.ruff] -line-length = 80 +line-length = 100 target-version = "py310" lint.extend-select = [ "A", diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 9f39ce05..0040bf50 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -65,9 +65,7 @@ def __init__( else: self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC - self.targets = torch.randint( - low=0, high=num_classes, size=(num_images,) - ) + self.targets = torch.randint(low=0, high=num_classes, size=(num_images,)) self.targets = torch.arange(start=0, end=num_classes).repeat( num_images // (num_classes) + 1 )[:num_images] @@ -123,10 +121,8 @@ def __init__( self.targets = [] input_shape = (num_samples, in_features) - if out_features != 1: - output_shape = (num_samples, out_features) - else: - output_shape = (num_samples,) + + output_shape = (num_samples, out_features) if out_features != 1 else (num_samples,) self.data = torch.rand( size=input_shape, diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index c6c2d64d..7ec2581b 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -60,9 +60,7 @@ def __init__( self.in_channels = in_channels self.num_classes = num_classes self.image_size = image_size - self.conv = nn.Conv2d( - in_channels, num_classes, kernel_size=3, padding=1 - ) + self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=3, padding=1) self.dropout = nn.Dropout(p=dropout_rate) self.last_layer = last_layer diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index cd8642cf..8d6126d6 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -9,9 +9,7 @@ class TestDeepEnsembles: """Testing the Deep Ensembles baseline class.""" def test_failure(self): - with pytest.raises( - ValueError, match="Models must not be an empty list." - ): + with pytest.raises(ValueError, match="Models must not be an empty list."): DeepEnsemblesBaseline( log_path=".", checkpoint_ids=[], diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index 45c672c0..51e6a0d6 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -72,9 +72,7 @@ def test_cifar10_main(self): auto_augment="rand-m9-n2-mstd0.5", ) - with pytest.raises( - ValueError, match="CIFAR-H can only be used in testing." - ): + with pytest.raises(ValueError, match="CIFAR-H can only be used in testing."): dm = CIFAR10DataModule( root="./data/", batch_size=128, @@ -100,25 +98,21 @@ def test_cifar10_main(self): def test_cifar10_cv(self): dm = CIFAR10DataModule(root="./data/", batch_size=128) - dm.dataset = ( - lambda root, train, download, transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, ) dm.make_cross_val_splits(2, 1) dm = CIFAR10DataModule(root="./data/", batch_size=128, val_split=0.1) - dm.dataset = ( - lambda root, train, download, transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, ) dm.make_cross_val_splits(2, 1) diff --git a/tests/datamodules/classification/test_cifar100.py b/tests/datamodules/classification/test_cifar100.py index bdb32a56..47394cfd 100644 --- a/tests/datamodules/classification/test_cifar100.py +++ b/tests/datamodules/classification/test_cifar100.py @@ -58,35 +58,27 @@ def test_cifar100(self): randaugment=True, ) - dm = CIFAR100DataModule( - root="./data/", batch_size=128, randaugment=True - ) + dm = CIFAR100DataModule(root="./data/", batch_size=128, randaugment=True) - dm = CIFAR100DataModule( - root="./data/", batch_size=128, auto_augment="rand-m9-n2-mstd0.5" - ) + dm = CIFAR100DataModule(root="./data/", batch_size=128, auto_augment="rand-m9-n2-mstd0.5") def test_cifar100_cv(self): dm = CIFAR100DataModule(root="./data/", batch_size=128) - dm.dataset = ( - lambda root, train, download, transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, ) dm.make_cross_val_splits(2, 1) dm = CIFAR100DataModule(root="./data/", batch_size=128, val_split=0.1) - dm.dataset = ( - lambda root, train, download, transform: DummyClassificationDataset( - root, - train=train, - download=download, - transform=transform, - num_images=20, - ) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, ) dm.make_cross_val_splits(2, 1) diff --git a/tests/datamodules/classification/test_imagenet.py b/tests/datamodules/classification/test_imagenet.py index 0d514d82..dd949cce 100644 --- a/tests/datamodules/classification/test_imagenet.py +++ b/tests/datamodules/classification/test_imagenet.py @@ -18,9 +18,7 @@ def test_imagenet(self): dm.prepare_data() dm.setup() - path = ( - Path(__file__).parent.resolve() / "../../assets/dummy_indices.yaml" - ) + path = Path(__file__).parent.resolve() / "../../assets/dummy_indices.yaml" dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=path) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset @@ -55,22 +53,16 @@ def test_imagenet(self): dm.setup("other") for test_alt in ["r", "o", "a"]: - dm = ImageNetDataModule( - root="./data/", batch_size=128, test_alt=test_alt - ) + dm = ImageNetDataModule(root="./data/", batch_size=128, test_alt=test_alt) with pytest.raises(ValueError): dm.setup() with pytest.raises(ValueError): - dm = ImageNetDataModule( - root="./data/", batch_size=128, test_alt="x" - ) + dm = ImageNetDataModule(root="./data/", batch_size=128, test_alt="x") for ood_ds in ["inaturalist", "imagenet-o", "textures", "openimage-o"]: - dm = ImageNetDataModule( - root="./data/", batch_size=128, ood_ds=ood_ds - ) + dm = ImageNetDataModule(root="./data/", batch_size=128, ood_ds=ood_ds) if ood_ds == "inaturalist": dm.eval_ood = True dm.dataset = DummyClassificationDataset @@ -80,9 +72,7 @@ def test_imagenet(self): dm.test_dataloader() with pytest.raises(ValueError): - dm = ImageNetDataModule( - root="./data/", batch_size=128, ood_ds="other" - ) + dm = ImageNetDataModule(root="./data/", batch_size=128, ood_ds="other") for procedure in ["ViT", "A3"]: dm = ImageNetDataModule( @@ -93,9 +83,7 @@ def test_imagenet(self): ) with pytest.raises(ValueError): - dm = ImageNetDataModule( - root="./data/", batch_size=128, procedure="A2" - ) + dm = ImageNetDataModule(root="./data/", batch_size=128, procedure="A2") with pytest.raises(FileNotFoundError): dm._verify_splits(split="test") diff --git a/tests/datamodules/classification/test_tiny_imagenet.py b/tests/datamodules/classification/test_tiny_imagenet.py index 6f70313f..8826d849 100644 --- a/tests/datamodules/classification/test_tiny_imagenet.py +++ b/tests/datamodules/classification/test_tiny_imagenet.py @@ -28,9 +28,7 @@ def test_tiny_imagenet(self): ) with pytest.raises(ValueError): - TinyImageNetDataModule( - root="./data/", batch_size=128, ood_ds="other" - ) + TinyImageNetDataModule(root="./data/", batch_size=128, ood_ds="other") dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset @@ -52,9 +50,7 @@ def test_tiny_imagenet(self): dm.setup("test") dm.test_dataloader() - dm = TinyImageNetDataModule( - root="./data/", batch_size=128, ood_ds="svhn" - ) + dm = TinyImageNetDataModule(root="./data/", batch_size=128, ood_ds="svhn") dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset dm.shift_dataset = DummyClassificationDataset @@ -70,9 +66,7 @@ def test_tiny_imagenet_cv(self): ) dm.make_cross_val_splits(2, 1) - dm = TinyImageNetDataModule( - root="./data/", batch_size=128, val_split=0.1 - ) + dm = TinyImageNetDataModule(root="./data/", batch_size=128, val_split=0.1) dm.dataset = lambda root, split, transform: DummyClassificationDataset( root, split=split, transform=transform, num_images=20 ) diff --git a/tests/datamodules/segmentation/test_camvid.py b/tests/datamodules/segmentation/test_camvid.py index eaccb088..f9017228 100644 --- a/tests/datamodules/segmentation/test_camvid.py +++ b/tests/datamodules/segmentation/test_camvid.py @@ -9,12 +9,8 @@ class TestCamVidDataModule: """Testing the CamVidDataModule datamodule.""" def test_camvid_main(self): - dm = CamVidDataModule( - root="./data/", batch_size=128, group_classes=False - ) - dm = CamVidDataModule( - root="./data/", batch_size=128, basic_augment=False - ) + dm = CamVidDataModule(root="./data/", batch_size=128, group_classes=False) + dm = CamVidDataModule(root="./data/", batch_size=128, basic_augment=False) assert dm.dataset == CamVid diff --git a/tests/datamodules/segmentation/test_cityscapes.py b/tests/datamodules/segmentation/test_cityscapes.py index 46781c8c..4cb46709 100644 --- a/tests/datamodules/segmentation/test_cityscapes.py +++ b/tests/datamodules/segmentation/test_cityscapes.py @@ -10,9 +10,7 @@ class TestCityscapesDataModule: def test_camvid_main(self): dm = CityscapesDataModule(root="./data/", batch_size=128) - dm = CityscapesDataModule( - root="./data/", batch_size=128, basic_augment=False - ) + dm = CityscapesDataModule(root="./data/", batch_size=128, basic_augment=False) assert dm.dataset == Cityscapes diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 0f0bd64f..84a87293 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -31,9 +31,7 @@ def test_cv_main(self): dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule( - "root", [0], [1], dm, 128, 0.0, 4, True, True - ) + cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 0.0, 4, True, True) cv_dm.setup() cv_dm.setup("test") @@ -54,9 +52,7 @@ def test_errors(self): dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule( - "root", [0], [1], dm, 128, 0.0, 4, True, True - ) + cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): cv_dm.setup() cv_dm._get_train_data() diff --git a/tests/datamodules/test_depth.py b/tests/datamodules/test_depth.py index 4305be05..733e2adf 100644 --- a/tests/datamodules/test_depth.py +++ b/tests/datamodules/test_depth.py @@ -13,9 +13,7 @@ class TestMUADDataModule: """Testing the MUADDataModule datamodule.""" def test_muad_main(self): - dm = MUADDataModule( - root="./data/", min_depth=0, max_depth=100, batch_size=128 - ) + dm = MUADDataModule(root="./data/", min_depth=0, max_depth=100, batch_size=128) assert dm.dataset == MUAD diff --git a/tests/datamodules/test_uci_regression.py b/tests/datamodules/test_uci_regression.py index aeda20ea..f6fe36b5 100644 --- a/tests/datamodules/test_uci_regression.py +++ b/tests/datamodules/test_uci_regression.py @@ -8,9 +8,7 @@ class TestUCIRegressionDataModule: """Testing the UCIRegressionDataModule datamodule class.""" def test_uci_regression(self): - dm = UCIRegressionDataModule( - dataset_name="kin8nm", root="./data/", batch_size=128 - ) + dm = UCIRegressionDataModule(dataset_name="kin8nm", root="./data/", batch_size=128) dm.dataset = partial(DummyRegressionDataset, num_samples=64) dm.prepare_data() diff --git a/tests/layers/test_bayesian.py b/tests/layers/test_bayesian.py index 420da049..66b09523 100644 --- a/tests/layers/test_bayesian.py +++ b/tests/layers/test_bayesian.py @@ -80,9 +80,7 @@ def test_conv1(self, feat_input_odd: torch.Tensor) -> None: assert out.shape == torch.Size([2, 10]) def test_conv1_even(self, feat_input_even: torch.Tensor) -> None: - layer = BayesConv1d( - 8, 2, kernel_size=1, sigma_init=0, padding_mode="reflect" - ) + layer = BayesConv1d(8, 2, kernel_size=1, sigma_init=0, padding_mode="reflect") print(layer) out = layer(feat_input_even) assert out.shape == torch.Size([2, 10]) @@ -94,9 +92,7 @@ def test_conv1_even(self, feat_input_even: torch.Tensor) -> None: def test_error(self): with pytest.raises(ValueError): - BayesConv1d( - 8, 2, kernel_size=1, sigma_init=0, padding_mode="random" - ) + BayesConv1d(8, 2, kernel_size=1, sigma_init=0, padding_mode="random") class TestBayesConv2d: @@ -115,9 +111,7 @@ def test_conv2(self, img_input_odd: torch.Tensor) -> None: layer.sample() def test_conv2_even(self, img_input_even: torch.Tensor) -> None: - layer = BayesConv2d( - 10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect" - ) + layer = BayesConv2d(10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect") print(layer) out = layer(img_input_even) assert out.shape == torch.Size([8, 2, 3, 3]) @@ -141,9 +135,7 @@ def test_conv3(self, cube_input_odd: torch.Tensor) -> None: assert out.shape == torch.Size([1, 2, 3, 3, 3]) def test_conv3_even(self, cube_input_even: torch.Tensor) -> None: - layer = BayesConv3d( - 10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect" - ) + layer = BayesConv3d(10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect") print(layer) out = layer(cube_input_even) assert out.shape == torch.Size([2, 2, 3, 3, 3]) @@ -192,17 +184,13 @@ def test_conv2(self, img_input_odd: torch.Tensor) -> None: out = layer(img_input_odd.repeat(4, 1, 1, 1)) assert out.shape == torch.Size([5 * 4, 2, 3, 3]) - layer = LPBNNConv2d( - 10, 2, kernel_size=1, num_estimators=4, bias=False, gamma=False - ) + layer = LPBNNConv2d(10, 2, kernel_size=1, num_estimators=4, bias=False, gamma=False) layer = layer.eval() out = layer(img_input_odd.repeat(4, 1, 1, 1)) assert out.shape == torch.Size([5 * 4, 2, 3, 3]) def test_conv2_even(self, img_input_even: torch.Tensor) -> None: - layer = LPBNNConv2d( - 10, 2, kernel_size=1, num_estimators=4, padding_mode="reflect" - ) + layer = LPBNNConv2d(10, 2, kernel_size=1, num_estimators=4, padding_mode="reflect") print(layer) out = layer(img_input_even.repeat(4, 1, 1, 1)) assert out.shape == torch.Size([8 * 4, 2, 3, 3]) diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index 11989804..7cc7fd1d 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -44,9 +44,7 @@ class TestPackedLinear: # Legacy tests def test_linear_one_estimator_no_rearrange(self, feat_input: torch.Tensor): - layer = PackedLinear( - 6, 2, alpha=1, num_estimators=1, rearrange=False, bias=False - ) + layer = PackedLinear(6, 2, alpha=1, num_estimators=1, rearrange=False, bias=False) out = layer(feat_input) assert out.shape == torch.Size([2, 1]) @@ -55,9 +53,7 @@ def test_linear_two_estimators_no_rearrange(self, feat_input: torch.Tensor): out = layer(feat_input) assert out.shape == torch.Size([2, 1]) - def test_linear_one_estimator_rearrange( - self, feat_input_one_rearrange: torch.Tensor - ): + def test_linear_one_estimator_rearrange(self, feat_input_one_rearrange: torch.Tensor): layer = PackedLinear(5, 2, alpha=1, num_estimators=1, rearrange=True) out = layer(feat_input_one_rearrange) assert out.shape == torch.Size([3, 2]) @@ -68,52 +64,32 @@ def test_linear_two_estimator_rearrange_not_divisible(self): out = layer(feat) assert out.shape == torch.Size([6, 1]) - def test_linear_full_implementation( - self, feat_input_16_features: torch.Tensor - ): - layer = PackedLinear( - 16, 4, alpha=1, num_estimators=1, implementation="full" - ) + def test_linear_full_implementation(self, feat_input_16_features: torch.Tensor): + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full") out = layer(feat_input_16_features) assert out.shape == torch.Size([2, 4]) - layer = PackedLinear( - 16, 4, alpha=1, num_estimators=2, implementation="full" - ) + layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="full") out = layer(feat_input_16_features) assert out.shape == torch.Size([2, 4]) - def test_linear_sparse_implementation( - self, feat_input_16_features: torch.Tensor - ): - layer = PackedLinear( - 16, 4, alpha=1, num_estimators=1, implementation="sparse" - ) + def test_linear_sparse_implementation(self, feat_input_16_features: torch.Tensor): + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="sparse") out = layer(feat_input_16_features) assert out.shape == torch.Size([2, 4]) - layer = PackedLinear( - 16, 4, alpha=1, num_estimators=2, implementation="sparse" - ) + layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="sparse") out = layer(feat_input_16_features) assert out.shape == torch.Size([2, 4]) - def test_linear_einsum_implementation( - self, feat_input_16_features: torch.Tensor - ): - layer = PackedLinear( - 16, 4, alpha=1, num_estimators=1, implementation="einsum" - ) + def test_linear_einsum_implementation(self, feat_input_16_features: torch.Tensor): + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="einsum") out = layer(feat_input_16_features) assert out.shape == torch.Size([2, 4]) - layer = PackedLinear( - 16, 4, alpha=1, num_estimators=2, implementation="einsum" - ) + layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="einsum") out = layer(feat_input_16_features) assert out.shape == torch.Size([2, 4]) def test_linear_extend(self): - _ = PackedConv2d( - 5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1 - ) + _ = PackedConv2d(5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1) def test_linear_failures(self): with pytest.raises(ValueError): @@ -132,14 +108,10 @@ def test_linear_failures(self): _ = PackedLinear(5, 2, alpha=1, num_estimators=-1, rearrange=True) with pytest.raises(TypeError): - _ = PackedLinear( - 5, 2, alpha=1, num_estimators=1, gamma=0.5, rearrange=True - ) + _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=0.5, rearrange=True) with pytest.raises(ValueError): - _ = PackedLinear( - 5, 2, alpha=1, num_estimators=1, gamma=-1, rearrange=True - ) + _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=-1, rearrange=True) with pytest.raises(AssertionError): _ = PackedLinear( @@ -152,9 +124,7 @@ def test_linear_failures(self): ) with pytest.raises(ValueError): - layer = PackedLinear( - 16, 4, alpha=1, num_estimators=1, implementation="full" - ) + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full") layer.implementation = "invalid" _ = layer(torch.rand((2, 16))) @@ -175,39 +145,29 @@ def test_conv_two_estimators(self, seq_input: torch.Tensor): assert out.shape == torch.Size([5, 2, 3]) def test_conv_one_estimator_gamma2(self, seq_input: torch.Tensor): - layer = PackedConv1d( - 6, 2, alpha=1, num_estimators=1, kernel_size=1, gamma=2 - ) + layer = PackedConv1d(6, 2, alpha=1, num_estimators=1, kernel_size=1, gamma=2) out = layer(seq_input) assert out.shape == torch.Size([5, 2, 3]) assert layer.conv.groups == 1 # and not 2 def test_conv_two_estimators_gamma2(self, seq_input: torch.Tensor): - layer = PackedConv1d( - 6, 2, alpha=1, num_estimators=2, kernel_size=1, gamma=2 - ) + layer = PackedConv1d(6, 2, alpha=1, num_estimators=2, kernel_size=1, gamma=2) out = layer(seq_input) assert out.shape == torch.Size([5, 2, 3]) assert layer.conv.groups == 2 # and not 4 def test_conv_extend(self): - _ = PackedConv1d( - 5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1 - ) + _ = PackedConv1d(5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1) def test_conv1_failures(self): with pytest.raises(ValueError): _ = PackedConv1d(5, 2, kernel_size=1, alpha=-1, num_estimators=1) with pytest.raises(TypeError): - _ = PackedConv1d( - 5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=0.5 - ) + _ = PackedConv1d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=0.5) with pytest.raises(ValueError): - _ = PackedConv1d( - 5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1 - ) + _ = PackedConv1d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1) class TestPackedConv2d: @@ -224,39 +184,29 @@ def test_conv_two_estimators(self, img_input: torch.Tensor): assert out.shape == torch.Size([5, 2, 3, 3]) def test_conv_one_estimator_gamma2(self, img_input: torch.Tensor): - layer = PackedConv2d( - 6, 2, alpha=1, num_estimators=1, kernel_size=1, gamma=2 - ) + layer = PackedConv2d(6, 2, alpha=1, num_estimators=1, kernel_size=1, gamma=2) out = layer(img_input) assert out.shape == torch.Size([5, 2, 3, 3]) assert layer.conv.groups == 1 # and not 2 def test_conv_two_estimators_gamma2(self, img_input: torch.Tensor): - layer = PackedConv2d( - 6, 2, alpha=1, num_estimators=2, kernel_size=1, gamma=2 - ) + layer = PackedConv2d(6, 2, alpha=1, num_estimators=2, kernel_size=1, gamma=2) out = layer(img_input) assert out.shape == torch.Size([5, 2, 3, 3]) assert layer.conv.groups == 2 # and not 4 def test_conv_extend(self): - _ = PackedConv2d( - 5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1 - ) + _ = PackedConv2d(5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1) def test_conv2_failures(self): with pytest.raises(ValueError): _ = PackedConv2d(5, 2, kernel_size=1, alpha=-1, num_estimators=1) with pytest.raises(TypeError): - _ = PackedConv2d( - 5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=0.5 - ) + _ = PackedConv2d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=0.5) with pytest.raises(ValueError): - _ = PackedConv2d( - 5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1 - ) + _ = PackedConv2d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1) class TestPackedConv3d: @@ -275,36 +225,26 @@ def test_conv_two_estimators(self, voxels_input: torch.Tensor): assert out.shape == torch.Size([5, 2, 3, 3, 3]) def test_conv_one_estimator_gamma2(self, voxels_input: torch.Tensor): - layer = PackedConv3d( - 6, 2, alpha=1, num_estimators=1, kernel_size=1, gamma=2 - ) + layer = PackedConv3d(6, 2, alpha=1, num_estimators=1, kernel_size=1, gamma=2) out = layer(voxels_input) assert out.shape == torch.Size([5, 2, 3, 3, 3]) assert layer.conv.groups == 1 # and not 2 def test_conv_two_estimators_gamma2(self, voxels_input: torch.Tensor): - layer = PackedConv3d( - 6, 2, alpha=1, num_estimators=2, kernel_size=1, gamma=2 - ) + layer = PackedConv3d(6, 2, alpha=1, num_estimators=2, kernel_size=1, gamma=2) out = layer(voxels_input) assert out.shape == torch.Size([5, 2, 3, 3, 3]) assert layer.conv.groups == 2 # and not 4 def test_conv_extend(self): - _ = PackedConv3d( - 5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1 - ) + _ = PackedConv3d(5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1) def test_conv3_failures(self): with pytest.raises(ValueError): _ = PackedConv3d(5, 2, kernel_size=1, alpha=-1, num_estimators=1) with pytest.raises(TypeError): - _ = PackedConv3d( - 5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=0.5 - ) + _ = PackedConv3d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=0.5) with pytest.raises(ValueError): - _ = PackedConv3d( - 5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1 - ) + _ = PackedConv3d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1) diff --git a/tests/losses/test_bayesian.py b/tests/losses/test_bayesian.py index 9e43c1f3..6d68e16d 100644 --- a/tests/losses/test_bayesian.py +++ b/tests/losses/test_bayesian.py @@ -48,14 +48,10 @@ def test_failures(self): model = BayesLinear(1, 1) criterion = nn.BCEWithLogitsLoss() - with pytest.raises( - TypeError, match="The inner_loss should be an instance of a class." - ): + with pytest.raises(TypeError, match="The inner_loss should be an instance of a class."): ELBOLoss(model, nn.BCEWithLogitsLoss, kl_weight=1, num_samples=1) - with pytest.raises( - ValueError, match="The KL weight should be non-negative. Got " - ): + with pytest.raises(ValueError, match="The KL weight should be non-negative. Got "): ELBOLoss(model, criterion, kl_weight=-1, num_samples=1) with pytest.raises( @@ -64,7 +60,5 @@ def test_failures(self): ): ELBOLoss(model, criterion, kl_weight=1, num_samples=-1) - with pytest.raises( - TypeError, match="The number of samples should be an integer. " - ): + with pytest.raises(TypeError, match="The number of samples should be an integer. "): ELBOLoss(model, criterion, kl_weight=1e-5, num_samples=1.5) diff --git a/tests/losses/test_classification.py b/tests/losses/test_classification.py index d4c4fc28..922414c0 100644 --- a/tests/losses/test_classification.py +++ b/tests/losses/test_classification.py @@ -14,9 +14,7 @@ class TestDECLoss: """Testing the DECLoss class.""" def test_main(self): - loss = DECLoss( - loss_type="mse", reg_weight=1e-2, annealing_step=1, reduction="sum" - ) + loss = DECLoss(loss_type="mse", reg_weight=1e-2, annealing_step=1, reduction="sum") loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0]), current_epoch=1) loss = DECLoss(loss_type="mse", reg_weight=1e-2, annealing_step=1) loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0]), current_epoch=0) @@ -32,9 +30,7 @@ def test_failures(self): ): DECLoss(reg_weight=-1) - with pytest.raises( - ValueError, match="The annealing step should be positive, but got " - ): + with pytest.raises(ValueError, match="The annealing step should be positive, but got "): DECLoss(annealing_step=0) loss = DECLoss(annealing_step=10) @@ -45,14 +41,10 @@ def test_failures(self): current_epoch=None, ) - with pytest.raises( - ValueError, match=" is not a valid value for reduction." - ): + with pytest.raises(ValueError, match=" is not a valid value for reduction."): DECLoss(reduction="median") - with pytest.raises( - ValueError, match="is not a valid value for mse/log/digamma loss." - ): + with pytest.raises(ValueError, match="is not a valid value for mse/log/digamma loss."): DECLoss(loss_type="regression") @@ -74,9 +66,7 @@ def test_failures(self): ): ConfidencePenaltyLoss(reg_weight=-1) - with pytest.raises( - ValueError, match="is not a valid value for reduction." - ): + with pytest.raises(ValueError, match="is not a valid value for reduction."): ConfidencePenaltyLoss(reduction="median") with pytest.raises( @@ -104,9 +94,7 @@ def test_failures(self): ): ConflictualLoss(reg_weight=-1) - with pytest.raises( - ValueError, match="is not a valid value for reduction." - ): + with pytest.raises(ValueError, match="is not a valid value for reduction."): ConflictualLoss(reduction="median") @@ -128,9 +116,7 @@ def test_failures(self): ): FocalLoss(gamma=-1) - with pytest.raises( - ValueError, match="is not a valid value for reduction." - ): + with pytest.raises(ValueError, match="is not a valid value for reduction."): FocalLoss(gamma=1, reduction="median") @@ -138,9 +124,7 @@ class TestBCEWithLogitsLSLoss: """Testing the BCEWithLogitsLSLoss class.""" def test_main(self): - loss = BCEWithLogitsLSLoss( - reduction="sum", label_smoothing=0.1, weight=torch.Tensor([1]) - ) + loss = BCEWithLogitsLSLoss(reduction="sum", label_smoothing=0.1, weight=torch.Tensor([1])) loss(torch.tensor([0.0]), torch.tensor([0])) loss = BCEWithLogitsLSLoss(reduction="mean", label_smoothing=0.6) loss(torch.tensor([0.0]), torch.tensor([0])) diff --git a/tests/losses/test_regression.py b/tests/losses/test_regression.py index 41f413a1..893845f9 100644 --- a/tests/losses/test_regression.py +++ b/tests/losses/test_regression.py @@ -27,9 +27,7 @@ class TestDERLoss: def test_main(self): loss = DERLoss(reg_weight=1e-2) layer = NormalInverseGamma - inputs = layer( - torch.ones(1), torch.ones(1), torch.ones(1), torch.ones(1) - ) + inputs = layer(torch.ones(1), torch.ones(1), torch.ones(1), torch.ones(1)) targets = torch.tensor([[1.0]], dtype=torch.float32) assert loss(inputs, targets) == pytest.approx(2 * math.log(2)) @@ -67,9 +65,7 @@ def test_failures(self): ): DERLoss(reg_weight=-1) - with pytest.raises( - ValueError, match="is not a valid value for reduction." - ): + with pytest.raises(ValueError, match="is not a valid value for reduction."): DERLoss(reg_weight=1.0, reduction="median") @@ -108,12 +104,8 @@ def test_main(self): ) == pytest.approx([0.0, 0.0]) def test_failures(self): - with pytest.raises( - ValueError, match="The beta parameter should be in range " - ): + with pytest.raises(ValueError, match="The beta parameter should be in range "): BetaNLL(beta=-1) - with pytest.raises( - ValueError, match="is not a valid value for reduction." - ): + with pytest.raises(ValueError, match="is not a valid value for reduction."): BetaNLL(beta=1.0, reduction="median") diff --git a/tests/metrics/classification/test_brier_score.py b/tests/metrics/classification/test_brier_score.py index b0d0b03f..518a6d5e 100644 --- a/tests/metrics/classification/test_brier_score.py +++ b/tests/metrics/classification/test_brier_score.py @@ -35,9 +35,7 @@ def vec2d_min_target() -> torch.Tensor: @pytest.fixture() def vec2d_5classes() -> torch.Tensor: - return torch.as_tensor( - [[0.2, 0.6, 0.1, 0.05, 0.05], [0.05, 0.25, 0.1, 0.3, 0.3]] - ) + return torch.as_tensor([[0.2, 0.6, 0.1, 0.05, 0.05], [0.05, 0.25, 0.1, 0.3, 0.3]]) @pytest.fixture() @@ -74,9 +72,7 @@ def vec3d_target1d() -> torch.Tensor: class TestBrierScore: """Testing the BrierScore metric class.""" - def test_compute( - self, vec2d_min: torch.Tensor, vec2d_min_target: torch.Tensor - ): + def test_compute(self, vec2d_min: torch.Tensor, vec2d_min_target: torch.Tensor): metric = BrierScore(num_classes=2) metric.update(vec2d_min, vec2d_min_target) assert metric.compute() == 0 @@ -85,16 +81,12 @@ def test_compute( metric.update(vec2d_min, vec2d_min_target) assert metric.compute() == 0 - def test_compute_max( - self, vec2d_max: torch.Tensor, vec2d_max_target: torch.Tensor - ): + def test_compute_max(self, vec2d_max: torch.Tensor, vec2d_max_target: torch.Tensor): metric = BrierScore(num_classes=2, reduction="sum") metric.update(vec2d_max, vec2d_max_target) assert metric.compute() == 0.5 - def test_compute_max_target1d( - self, vec2d_max: torch.Tensor, vec2d_max_target1d: torch.Tensor - ): + def test_compute_max_target1d(self, vec2d_max: torch.Tensor, vec2d_max_target1d: torch.Tensor): metric = BrierScore(num_classes=2, reduction="sum") metric.update(vec2d_max, vec2d_max_target1d) assert metric.compute() == 0.5 @@ -110,14 +102,7 @@ def test_compute_5classes( metric.update(vec2d_5classes, vec2d_5classes_target1d) assert ( metric.compute() / 2 - == 0.2**2 - + 0.6**2 - + 0.1**2 * 2 - + 0.95**2 - + 0.05**2 * 2 - + 0.25**2 - + 0.3**2 - + 0.7**2 + == 0.2**2 + 0.6**2 + 0.1**2 * 2 + 0.95**2 + 0.05**2 * 2 + 0.25**2 + 0.3**2 + 0.7**2 ) metric = BrierScore(num_classes=5, top_class=True, reduction="sum") @@ -160,9 +145,7 @@ def test_multiple_compute_none( metric.update(vec2d_max, vec2d_max_target) assert all(metric.compute() == torch.as_tensor([0, 0.5])) - def test_compute_3d_mean( - self, vec3d: torch.Tensor, vec3d_target: torch.Tensor - ): + def test_compute_3d_mean(self, vec3d: torch.Tensor, vec3d_target: torch.Tensor): """Test that the metric returns the mean of the BrierScore over the estimators. """ @@ -170,23 +153,17 @@ def test_compute_3d_mean( metric.update(vec3d, vec3d_target) assert metric.compute() == 1 - def test_compute_3d_sum( - self, vec3d: torch.Tensor, vec3d_target: torch.Tensor - ): + def test_compute_3d_sum(self, vec3d: torch.Tensor, vec3d_target: torch.Tensor): metric = BrierScore(num_classes=2, reduction="sum") metric.update(vec3d, vec3d_target) assert metric.compute() == 1 - def test_compute_3d_sum_target1d( - self, vec3d: torch.Tensor, vec3d_target1d: torch.Tensor - ): + def test_compute_3d_sum_target1d(self, vec3d: torch.Tensor, vec3d_target1d: torch.Tensor): metric = BrierScore(num_classes=2, reduction="sum") metric.update(vec3d, vec3d_target1d) assert metric.compute() == 1 - def test_compute_3d_to_2d( - self, vec3d: torch.Tensor, vec3d_target: torch.Tensor - ): + def test_compute_3d_to_2d(self, vec3d: torch.Tensor, vec3d_target: torch.Tensor): metric = BrierScore(num_classes=2, reduction="mean") vec3d = vec3d.mean(1) metric.update(vec3d, vec3d_target) @@ -198,7 +175,5 @@ def test_bad_input(self) -> None: metric.update(torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2)) def test_bad_argument(self): - with pytest.raises( - ValueError, match="Expected argument `reduction` to be one of" - ): + with pytest.raises(ValueError, match="Expected argument `reduction` to be one of"): _ = BrierScore(num_classes=2, reduction="geometric_mean") diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index 0d6ad4d9..615a3129 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -26,9 +26,7 @@ def test_plot_binary(self) -> None: def test_plot_multiclass( self, ) -> None: - metric = CalibrationError( - task="multiclass", num_bins=3, norm="l1", num_classes=3 - ) + metric = CalibrationError(task="multiclass", num_bins=3, norm="l1", num_classes=3) metric.update( torch.as_tensor( [ @@ -51,9 +49,7 @@ def test_plot_multiclass( def test_errors(self) -> None: with pytest.raises(TypeError, match="is expected to be `int`"): CalibrationError(task="multiclass", num_classes=None) - with pytest.raises( - ValueError, match="`n_bins` does not exist, use `num_bins`." - ): + with pytest.raises(ValueError, match="`n_bins` does not exist, use `num_bins`."): CalibrationError(task="multiclass", num_classes=2, n_bins=1) @@ -61,13 +57,9 @@ class TestAdaptiveCalibrationError: """Testing the AdaptiveCalibrationError metric class.""" def test_main(self) -> None: - ace = AdaptiveCalibrationError( - task="binary", num_bins=2, norm="l1", validate_args=True - ) + ace = AdaptiveCalibrationError(task="binary", num_bins=2, norm="l1", validate_args=True) - ace = AdaptiveCalibrationError( - task="binary", num_bins=2, norm="l1", validate_args=False - ) + ace = AdaptiveCalibrationError(task="binary", num_bins=2, norm="l1", validate_args=False) ece = CalibrationError(task="binary", num_bins=2, norm="l1") ace.update( torch.as_tensor([0.35, 0.35, 0.75, 0.75]), @@ -96,14 +88,11 @@ def test_main(self) -> None: validate_args=True, ) ace.update( - torch.as_tensor( - [[0.7, 0.3], [0.76, 0.24], [0.75, 0.25], [0.2, 0.8], [0.8, 0.2]] - ), + torch.as_tensor([[0.7, 0.3], [0.76, 0.24], [0.75, 0.25], [0.2, 0.8], [0.8, 0.2]]), torch.as_tensor([0, 0, 0, 1, 1]), ) assert ace.compute().item() ** 2 == pytest.approx( - 3 / 5 * (1 - 1 / 3 * (0.7 + 0.76 + 0.75)) ** 2 - + 2 / 5 * (0.8 - 0.5) ** 2 + 3 / 5 * (1 - 1 / 3 * (0.7 + 0.76 + 0.75)) ** 2 + 2 / 5 * (0.8 - 0.5) ** 2 ) ace = AdaptiveCalibrationError( @@ -114,9 +103,7 @@ def test_main(self) -> None: validate_args=False, ) ace.update( - torch.as_tensor( - [[0.7, 0.3], [0.76, 0.24], [0.75, 0.25], [0.2, 0.8], [0.8, 0.2]] - ), + torch.as_tensor([[0.7, 0.3], [0.76, 0.24], [0.75, 0.25], [0.2, 0.8], [0.8, 0.2]]), torch.as_tensor([0, 0, 0, 1, 1]), ) assert ace.compute().item() == pytest.approx(0.8 - 0.5) diff --git a/tests/metrics/classification/test_disagreement.py b/tests/metrics/classification/test_disagreement.py index f49c7cc3..1fa7d45c 100644 --- a/tests/metrics/classification/test_disagreement.py +++ b/tests/metrics/classification/test_disagreement.py @@ -34,18 +34,14 @@ def test_compute_agreement(self, agreement_probas: torch.Tensor): res = metric.compute() assert res == 0.0 - def test_compute_mixed( - self, disagreement_probas: torch.Tensor, agreement_probas: torch.Tensor - ): + def test_compute_mixed(self, disagreement_probas: torch.Tensor, agreement_probas: torch.Tensor): metric = Disagreement() metric.update(agreement_probas) metric.update(disagreement_probas) res = metric.compute() assert res == 0.5 - def test_compute_mixed_3_estimators( - self, disagreement_probas_3: torch.Tensor - ): + def test_compute_mixed_3_estimators(self, disagreement_probas_3: torch.Tensor): metric = Disagreement() metric.update(disagreement_probas_3) res = metric.compute() diff --git a/tests/metrics/classification/test_entropy.py b/tests/metrics/classification/test_entropy.py index 558184c4..0a119c30 100644 --- a/tests/metrics/classification/test_entropy.py +++ b/tests/metrics/classification/test_entropy.py @@ -42,27 +42,21 @@ def test_compute_max(self, vec2d_max: torch.Tensor): res = metric.compute() assert res == math.log(2) - def test_multiple_compute_sum( - self, vec2d_min: torch.Tensor, vec2d_max: torch.Tensor - ): + def test_multiple_compute_sum(self, vec2d_min: torch.Tensor, vec2d_max: torch.Tensor): metric = Entropy(reduction="sum") metric.update(vec2d_min) metric.update(vec2d_max) res = metric.compute() assert res == math.log(2) - def test_multiple_compute_mean( - self, vec2d_min: torch.Tensor, vec2d_max: torch.Tensor - ): + def test_multiple_compute_mean(self, vec2d_min: torch.Tensor, vec2d_max: torch.Tensor): metric = Entropy(reduction="mean") metric.update(vec2d_min) metric.update(vec2d_max) res = metric.compute() assert res == math.log(2) / 2 - def test_multiple_compute_none( - self, vec2d_min: torch.Tensor, vec2d_max: torch.Tensor - ): + def test_multiple_compute_none(self, vec2d_min: torch.Tensor, vec2d_max: torch.Tensor): metric = Entropy(reduction=None) metric.update(vec2d_min) metric.update(vec2d_max) diff --git a/tests/metrics/classification/test_fpr95.py b/tests/metrics/classification/test_fpr95.py index 3c10fd01..17609574 100644 --- a/tests/metrics/classification/test_fpr95.py +++ b/tests/metrics/classification/test_fpr95.py @@ -9,9 +9,7 @@ class TestFPR95: def test_compute_zero(self): metric = FPR95(pos_label=1) - metric.update( - torch.as_tensor([1] * 99 + [0.99]), torch.as_tensor([1] * 99 + [0]) - ) + metric.update(torch.as_tensor([1] * 99 + [0.99]), torch.as_tensor([1] * 99 + [0])) res = metric.compute() assert res == 0 @@ -26,17 +24,13 @@ def test_compute_half(self): def test_compute_one(self): metric = FPR95(pos_label=1) - metric.update( - torch.as_tensor([0.99] * 99 + [1]), torch.as_tensor([1] * 99 + [0]) - ) + metric.update(torch.as_tensor([0.99] * 99 + [1]), torch.as_tensor([1] * 99 + [0])) res = metric.compute() assert res == 1 def test_compute_nan(self): metric = FPR95(pos_label=1) - metric.update( - torch.as_tensor([0.1] * 50 + [0.4] * 50), torch.as_tensor([0] * 100) - ) + metric.update(torch.as_tensor([0.1] * 50 + [0.4] * 50), torch.as_tensor([0] * 100)) res = metric.compute() assert torch.isnan(res).all() diff --git a/tests/metrics/classification/test_grouping_loss.py b/tests/metrics/classification/test_grouping_loss.py index ffce34b1..9f40a480 100644 --- a/tests/metrics/classification/test_grouping_loss.py +++ b/tests/metrics/classification/test_grouping_loss.py @@ -11,9 +11,7 @@ def test_compute(self): metric = GroupingLoss() metric.update( torch.cat([torch.tensor([0, 1, 0, 1]), torch.ones(200) / 10]), - torch.cat( - [torch.tensor([0, 0, 1, 1]), torch.zeros(100), torch.ones(100)] - ).long(), + torch.cat([torch.tensor([0, 0, 1, 1]), torch.zeros(100), torch.ones(100)]).long(), torch.cat([torch.zeros((104, 10)), torch.ones((100, 10))]), ) metric.compute() diff --git a/tests/metrics/classification/test_mutual_information.py b/tests/metrics/classification/test_mutual_information.py index 99dde31d..22597f5e 100644 --- a/tests/metrics/classification/test_mutual_information.py +++ b/tests/metrics/classification/test_mutual_information.py @@ -32,9 +32,7 @@ def test_compute_agreement(self, agreement_probas: torch.Tensor): res = metric.compute() assert res == 0.0 - def test_compute_mixed( - self, disagreement_probas: torch.Tensor, agreement_probas: torch.Tensor - ): + def test_compute_mixed(self, disagreement_probas: torch.Tensor, agreement_probas: torch.Tensor): metric = MutualInformation(reduction="mean") metric.update(agreement_probas) metric.update(disagreement_probas) diff --git a/tests/metrics/classification/test_risk_coverage.py b/tests/metrics/classification/test_risk_coverage.py index 85981d5d..3db90223 100644 --- a/tests/metrics/classification/test_risk_coverage.py +++ b/tests/metrics/classification/test_risk_coverage.py @@ -29,9 +29,7 @@ def test_compute_binary(self) -> None: assert torch.isnan(metric(torch.Tensor([0.0]), torch.Tensor([1]))) def test_compute_multiclass(self) -> None: - probs = torch.Tensor( - [[0.1, 0.9], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], [0.2, 0.8]] - ) + probs = torch.Tensor([[0.1, 0.9], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], [0.2, 0.8]]) targets = torch.Tensor([1, 1, 1, 1, 1]).long() metric = AURC() assert metric(probs, targets).item() == pytest.approx(0) @@ -78,9 +76,7 @@ class TestCovAtxRisk: """Testing the CovAtxRisk metric class.""" def test_compute_zero(self) -> None: - probs = torch.Tensor( - [[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.8, 0.2]] - ) + probs = torch.Tensor([[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.8, 0.2]]) targets = torch.Tensor([1, 1, 1, 1, 1]) metric = CovAtxRisk(risk_threshold=0.5) # no cov for given risk @@ -105,13 +101,9 @@ def test_compute_zero(self) -> None: assert metric(torch.zeros(0), torch.zeros(0)).isnan() def test_errors(self): - with pytest.raises( - TypeError, match="Expected threshold to be of type float" - ): + with pytest.raises(TypeError, match="Expected threshold to be of type float"): CovAtxRisk(risk_threshold="0.5") - with pytest.raises( - ValueError, match="Threshold should be in the range" - ): + with pytest.raises(ValueError, match="Threshold should be in the range"): CovAtxRisk(risk_threshold=-0.5) @@ -119,9 +111,7 @@ class TestRiskAtxCov: """Testing the RiskAtxCov metric class.""" def test_compute_zero(self) -> None: - probs = torch.Tensor( - [[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.8, 0.2]] - ) + probs = torch.Tensor([[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.8, 0.2]]) targets = torch.Tensor([1, 1, 1, 1, 1]) metric = RiskAtxCov(cov_threshold=0.5) assert metric(probs, targets) == 1 diff --git a/tests/metrics/regression/test_depth.py b/tests/metrics/regression/test_depth.py index 4281ec5f..3d6a1f5a 100644 --- a/tests/metrics/regression/test_depth.py +++ b/tests/metrics/regression/test_depth.py @@ -34,9 +34,7 @@ def test_main(self): targets = torch.rand((10, 2)) metric.update(preds[:, 0], targets[:, 0]) metric.update(preds[:, 1], targets[:, 1]) - assert (torch.abs(preds - targets) / targets).mean() == pytest.approx( - metric.compute() - ) + assert (torch.abs(preds - targets) / targets).mean() == pytest.approx(metric.compute()) class TestMeanGTRelativeSquaredError: @@ -48,9 +46,9 @@ def test_main(self): targets = torch.rand((10, 2)) metric.update(preds[:, 0], targets[:, 0]) metric.update(preds[:, 1], targets[:, 1]) - assert torch.flatten( - (preds - targets) ** 2 / targets - ).mean() == pytest.approx(metric.compute()) + assert torch.flatten((preds - targets) ** 2 / targets).mean() == pytest.approx( + metric.compute() + ) class TestSILog: @@ -62,12 +60,9 @@ def test_main(self): targets = torch.rand((10, 2)).double() metric.update(preds[:, 0], targets[:, 0]) metric.update(preds[:, 1], targets[:, 1]) - mean_log_dists = torch.mean( - targets.flatten().log() - preds.flatten().log() - ) + mean_log_dists = torch.mean(targets.flatten().log() - preds.flatten().log()) assert torch.mean( - (preds.flatten().log() - targets.flatten().log() + mean_log_dists) - ** 2 + (preds.flatten().log() - targets.flatten().log() + mean_log_dists) ** 2 ) == pytest.approx(metric.compute()) metric = SILog(sqrt=True) @@ -75,12 +70,9 @@ def test_main(self): targets = torch.rand((10, 2)).double() metric.update(preds[:, 0], targets[:, 0]) metric.update(preds[:, 1], targets[:, 1]) - mean_log_dists = torch.mean( - targets.flatten().log() - preds.flatten().log() - ) + mean_log_dists = torch.mean(targets.flatten().log() - preds.flatten().log()) assert torch.mean( - (preds.flatten().log() - targets.flatten().log() + mean_log_dists) - ** 2 + (preds.flatten().log() - targets.flatten().log() + mean_log_dists) ** 2 ) ** 0.5 == pytest.approx(metric.compute()) @@ -96,9 +88,7 @@ def test_main(self): assert metric.compute() == 0.0 metric = ThresholdAccuracy(power=1, lmbda=1.25) - preds = torch.cat( - [torch.ones((10, 2)) * 1.2, torch.ones((10, 2))], dim=0 - ) + preds = torch.cat([torch.ones((10, 2)) * 1.2, torch.ones((10, 2))], dim=0) targets = torch.ones((20, 2)) * 1.3 metric.update(preds[:, 0], targets[:, 0]) metric.update(preds[:, 1], targets[:, 1]) @@ -120,6 +110,6 @@ def test_main(self): targets = torch.rand((10, 2)).double() metric.update(preds[:, 0], targets[:, 0]) metric.update(preds[:, 1], targets[:, 1]) - assert torch.mean( - (preds.log() - targets.log()).flatten() ** 2 - ) == pytest.approx(metric.compute()) + assert torch.mean((preds.log() - targets.log()).flatten() ** 2) == pytest.approx( + metric.compute() + ) diff --git a/tests/models/test_resnets.py b/tests/models/test_resnets.py index 31677f51..fe757591 100644 --- a/tests/models/test_resnets.py +++ b/tests/models/test_resnets.py @@ -44,12 +44,8 @@ class TestPackedResnet: def test_main(self): model = packed_resnet(1, 10, 20, 2, 2, 1) model = packed_resnet(1, 10, 152, 2, 2, 1) - assert model.check_config( - {"alpha": 2, "gamma": 1, "groups": 1, "num_estimators": 2} - ) - assert not model.check_config( - {"alpha": 1, "gamma": 1, "groups": 1, "num_estimators": 2} - ) + assert model.check_config({"alpha": 2, "gamma": 1, "groups": 1, "num_estimators": 2}) + assert not model.check_config({"alpha": 1, "gamma": 1, "groups": 1, "num_estimators": 2}) def test_error(self): with pytest.raises(ValueError): @@ -96,9 +92,7 @@ def test_main(self): def test_error(self): with pytest.raises(ValueError): lpbnn_resnet(1, 10, 20, 2, style="test") - with pytest.raises( - ValueError, match="Unknown ResNet architecture. Got" - ): + with pytest.raises(ValueError, match="Unknown ResNet architecture. Got"): lpbnn_resnet(1, 10, 42, 2, style="test") diff --git a/tests/models/test_wideresnets.py b/tests/models/test_wideresnets.py index 62045d63..1d50bb91 100644 --- a/tests/models/test_wideresnets.py +++ b/tests/models/test_wideresnets.py @@ -128,18 +128,14 @@ def test_main(self): ) with pytest.raises(ValueError): - batched_wideresnet28x10( - in_channels=1, num_classes=10, num_estimators=2, style="test" - ) + batched_wideresnet28x10(in_channels=1, num_classes=10, num_estimators=2, style="test") class TestMIMOWide: """Testing the WideResNet mimo class.""" def test_main(self): - model = mimo_wideresnet28x10( - in_channels=1, num_classes=10, num_estimators=2, style="cifar" - ) + model = mimo_wideresnet28x10(in_channels=1, num_classes=10, num_estimators=2, style="cifar") model(torch.rand((2, 1, 28, 28))) with pytest.raises(ValueError): @@ -153,6 +149,4 @@ def test_main(self): conv_bias=False, ) with pytest.raises(ValueError): - mimo_wideresnet28x10( - in_channels=1, num_classes=10, num_estimators=2, style="test" - ) + mimo_wideresnet28x10(in_channels=1, num_classes=10, num_estimators=2, style="test") diff --git a/tests/models/wrappers/test_mc_dropout.py b/tests/models/wrappers/test_mc_dropout.py index bf63d5b1..80ecd42c 100644 --- a/tests/models/wrappers/test_mc_dropout.py +++ b/tests/models/wrappers/test_mc_dropout.py @@ -35,22 +35,14 @@ def test_mc_dropout_eval(self): def test_mc_dropout_errors(self): model = dummy_model(10, 5, 0.1) - with pytest.raises( - ValueError, match="`num_estimators` must be strictly positive" - ): - MCDropout( - model=model, num_estimators=-1, last_layer=True, on_batch=True - ) + with pytest.raises(ValueError, match="`num_estimators` must be strictly positive"): + MCDropout(model=model, num_estimators=-1, last_layer=True, on_batch=True) dropout_model = mc_dropout(model, 5) - with pytest.raises( - TypeError, match="Training mode is expected to be boolean" - ): + with pytest.raises(TypeError, match="Training mode is expected to be boolean"): dropout_model.train(mode=1) - with pytest.raises( - TypeError, match="Training mode is expected to be boolean" - ): + with pytest.raises(TypeError, match="Training mode is expected to be boolean"): dropout_model.train(mode=None) model = dummy_model(10, 5, 0.0) diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py index f590473b..1bc0da0e 100644 --- a/tests/models/wrappers/test_swa.py +++ b/tests/models/wrappers/test_swa.py @@ -28,13 +28,9 @@ def test_training(self): swa(torch.randn(1, 1)) def test_failures(self): - with pytest.raises( - ValueError, match="`cycle_start` must be non-negative." - ): + with pytest.raises(ValueError, match="`cycle_start` must be non-negative."): SWA(nn.Module(), cycle_start=-1, cycle_length=1) - with pytest.raises( - ValueError, match="`cycle_length` must be strictly positive." - ): + with pytest.raises(ValueError, match="`cycle_length` must be strictly positive."): SWA(nn.Module(), cycle_start=1, cycle_length=0) @@ -57,39 +53,27 @@ def test_training(self): swag.train() swag(torch.randn(1, 1)) swag.update_wrapper(0) - assert swag.swag_stats[ - "model.swag_stats.linear.weight_covariance_sqrt" - ].shape == (0, 10) + assert swag.swag_stats["model.swag_stats.linear.weight_covariance_sqrt"].shape == (0, 10) swag.bn_update(dl, "cpu") swag(torch.randn(1, 1)) swag.update_wrapper(1) - assert swag.swag_stats[ - "model.swag_stats.linear.weight_covariance_sqrt" - ].shape == (0, 10) + assert swag.swag_stats["model.swag_stats.linear.weight_covariance_sqrt"].shape == (0, 10) assert swag.num_avgd_models == 0 swag.bn_update(dl, "cpu") swag.update_wrapper(2) - assert swag.swag_stats[ - "model.swag_stats.linear.weight_covariance_sqrt" - ].shape == (1, 10) + assert swag.swag_stats["model.swag_stats.linear.weight_covariance_sqrt"].shape == (1, 10) swag.bn_update(dl, "cpu") swag(torch.randn(1, 1)) swag.update_wrapper(3) - assert swag.swag_stats[ - "model.swag_stats.linear.weight_covariance_sqrt" - ].shape == (2, 10) + assert swag.swag_stats["model.swag_stats.linear.weight_covariance_sqrt"].shape == (2, 10) swag.update_wrapper(4) assert swag.num_avgd_models == 3 - assert swag.swag_stats[ - "model.swag_stats.linear.weight_covariance_sqrt" - ].shape == (3, 10) + assert swag.swag_stats["model.swag_stats.linear.weight_covariance_sqrt"].shape == (3, 10) swag.update_wrapper(5) assert swag.num_avgd_models == 4 - assert swag.swag_stats[ - "model.swag_stats.linear.weight_covariance_sqrt" - ].shape == (3, 10) + assert swag.swag_stats["model.swag_stats.linear.weight_covariance_sqrt"].shape == (3, 10) swag.eval() swag(torch.randn(1, 1)) @@ -110,24 +94,16 @@ def test_state_dict(self): swag.load_state_dict(swag.state_dict()) def test_failures(self): - with pytest.raises( - NotImplementedError, match="Raise an issue if you need this feature" - ): + with pytest.raises(NotImplementedError, match="Raise an issue if you need this feature"): swag = SWAG(nn.Module(), scale=1, cycle_start=1, cycle_length=1) swag.sample(scale=1, block=True) with pytest.raises(ValueError, match="`scale` must be non-negative."): SWAG(nn.Module(), scale=-1, cycle_start=1, cycle_length=1) - with pytest.raises( - ValueError, match="`max_num_models` must be non-negative." - ): + with pytest.raises(ValueError, match="`max_num_models` must be non-negative."): SWAG(nn.Module(), max_num_models=-1, cycle_start=1, cycle_length=1) - with pytest.raises( - ValueError, match="`var_clamp` must be non-negative. " - ): + with pytest.raises(ValueError, match="`var_clamp` must be non-negative. "): SWAG(nn.Module(), var_clamp=-1, cycle_start=1, cycle_length=1) - swag = SWAG( - nn.Module(), cycle_start=1, cycle_length=1, diag_covariance=True - ) + swag = SWAG(nn.Module(), cycle_start=1, cycle_length=1, diag_covariance=True) with pytest.raises( ValueError, match="Cannot sample full rank from diagonal covariance matrix.", diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index 0d1acdca..bbe987ca 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -17,9 +17,7 @@ class TestMCBatchNorm: def test_main(self): """Test initialization.""" mc_model = lenet(1, 1, norm=partial(MCBatchNorm2d, num_estimators=2)) - stoch_model = MCBatchNorm( - mc_model, num_estimators=2, convert=False, mc_batch_size=1 - ) + stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False, mc_batch_size=1) model = lenet(1, 1, norm=nn.BatchNorm2d) stoch_model = MCBatchNorm( @@ -42,9 +40,7 @@ def test_main(self): stoch_model.eval() stoch_model(torch.randn(1, 1, 20, 20)) - stoch_model = MCBatchNorm( - num_estimators=2, convert=False, mc_batch_size=1 - ) + stoch_model = MCBatchNorm(num_estimators=2, convert=False, mc_batch_size=1) stoch_model.set_model(mc_model) def test_errors(self): @@ -52,18 +48,14 @@ def test_errors(self): model = nn.Identity() with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=0, convert=True) - with pytest.raises( - ValueError, match="mc_batch_size must be a positive integer" - ): + with pytest.raises(ValueError, match="mc_batch_size must be a positive integer"): MCBatchNorm(model, num_estimators=1, convert=True, mc_batch_size=-1) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=False) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=True) model = lenet(1, 1, norm=nn.BatchNorm2d) - stoch_model = MCBatchNorm( - model, num_estimators=4, convert=True, mc_batch_size=1 - ) + stoch_model = MCBatchNorm(model, num_estimators=4, convert=True, mc_batch_size=1) dataset = DummyClassificationDataset( "./", num_channels=1, diff --git a/tests/post_processing/test_scalers.py b/tests/post_processing/test_scalers.py index 98fc36ce..eda1356e 100644 --- a/tests/post_processing/test_scalers.py +++ b/tests/post_processing/test_scalers.py @@ -27,9 +27,7 @@ def test_fit_biased(self): calibration_set = list(zip(inputs, labels, strict=True)) - scaler = TemperatureScaler( - model=nn.Identity(), init_val=2, lr=1, max_iter=10 - ) + scaler = TemperatureScaler(model=nn.Identity(), init_val=2, lr=1, max_iter=10) assert scaler.temperature[0] == 2.0 scaler.fit(calibration_set) assert scaler.temperature[0] > 10 # best is +inf diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 14fa99d9..5164206e 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -318,9 +318,7 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): model = DummyClassificationBaseline( num_classes=dm.num_classes, in_channels=dm.num_channels, - loss=ELBOLoss( - None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4 - ), + loss=ELBOLoss(None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4), baseline_type="ensemble", ood_criterion="vr", eval_ood=True, @@ -362,9 +360,7 @@ def test_classification_failures(self): mixup_params=mixup_params, ) - with pytest.raises( - ValueError, match="num_calibration_bins must be at least 2, got" - ): + with pytest.raises(ValueError, match="num_calibration_bins must be at least 2, got"): ClassificationRoutine( model=nn.Identity(), num_classes=2, @@ -391,9 +387,7 @@ def test_classification_failures(self): model = dummy_model(1, 1, 0, with_feats=False) with pytest.raises(ValueError): - ClassificationRoutine( - num_classes=10, model=model, loss=None, eval_grouping_loss=True - ) + ClassificationRoutine(num_classes=10, model=model, loss=None, eval_grouping_loss=True) with pytest.raises( ValueError, diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py index c60b8d7a..6b2bfdcf 100644 --- a/tests/routines/test_pixel_regression.py +++ b/tests/routines/test_pixel_regression.py @@ -27,9 +27,7 @@ def test_one_estimator_two_classes(self): ) root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummyPixelRegressionDataModule( - root=root, batch_size=5, output_dim=3 - ) + dm = DummyPixelRegressionDataModule(root=root, batch_size=5, output_dim=3) model = DummyPixelRegressionBaseline( probabilistic=False, @@ -73,9 +71,7 @@ def test_two_estimators_one_class(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummyPixelRegressionDataModule( - root=root, batch_size=4, output_dim=1 - ) + dm = DummyPixelRegressionDataModule(root=root, batch_size=4, output_dim=1) model = DummyPixelRegressionBaseline( probabilistic=False, diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index beed8a1d..fab7a27c 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -88,16 +88,10 @@ def test_two_estimators_two_classes(self): model(dm.get_test_set()[0][0]) def test_segmentation_errors(self): - with pytest.raises( - ValueError, match="num_classes must be at least 2, got" - ): - SegmentationRoutine( - model=nn.Identity(), num_classes=1, loss=nn.CrossEntropyLoss() - ) + with pytest.raises(ValueError, match="num_classes must be at least 2, got"): + SegmentationRoutine(model=nn.Identity(), num_classes=1, loss=nn.CrossEntropyLoss()) - with pytest.raises( - ValueError, match="metric_subsampling_rate must be in" - ): + with pytest.raises(ValueError, match="metric_subsampling_rate must be in"): SegmentationRoutine( model=nn.Identity(), num_classes=2, @@ -105,9 +99,7 @@ def test_segmentation_errors(self): metric_subsampling_rate=-1, ) - with pytest.raises( - ValueError, match="num_calibration_bins must be at least 2, got" - ): + with pytest.raises(ValueError, match="num_calibration_bins must be at least 2, got"): SegmentationRoutine( model=nn.Identity(), num_classes=2, diff --git a/tests/test_utils.py b/tests/test_utils.py index e6aceb06..60d9e9e9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -53,9 +53,7 @@ class TestMisc: def test_csv_writer(self): root = Path(__file__).parent.resolve() csv_writer(root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0}) - csv_writer( - root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0, "c": 3.0} - ) + csv_writer(root / "testlog" / "results.csv", {"a": 1.0, "b": 2.0, "c": 3.0}) def test_plot_hist(self): conf = [torch.rand(20), torch.rand(20)] diff --git a/tests/transforms/test_mixup.py b/tests/transforms/test_mixup.py index 24d72c9a..01900fb4 100644 --- a/tests/transforms/test_mixup.py +++ b/tests/transforms/test_mixup.py @@ -59,19 +59,13 @@ class TestWarpingMixup: """Testing WarpingMixup augmentation.""" def test_batch_kernel_warpingmixup(self, batch_input): - mixup = WarpingMixup( - alpha=1.0, mode="batch", num_classes=2, apply_kernel=True - ) + mixup = WarpingMixup(alpha=1.0, mode="batch", num_classes=2, apply_kernel=True) _ = mixup(*batch_input, batch_input[0]) def test_elem_kernel_warpingmixup(self, batch_input): - mixup = WarpingMixup( - alpha=1.0, mode="elem", num_classes=2, apply_kernel=True - ) + mixup = WarpingMixup(alpha=1.0, mode="elem", num_classes=2, apply_kernel=True) _ = mixup(*batch_input, batch_input[0]) def test_elem_warpingmixup(self, batch_input): - mixup = WarpingMixup( - alpha=1.0, mode="elem", num_classes=2, apply_kernel=False - ) + mixup = WarpingMixup(alpha=1.0, mode="elem", num_classes=2, apply_kernel=False) _ = mixup(*batch_input, batch_input[0]) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index eef4abe3..bb2cb96f 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -24,9 +24,7 @@ def __init__( eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", + ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, calibration_set: Literal["val", "test"] = "val", ) -> None: @@ -36,9 +34,7 @@ def __init__( models = [] for version in checkpoint_ids: # coverage: ignore - ckpt_file, hparams_file = get_version( - root=log_path, version=version - ) + ckpt_file, hparams_file = get_version(root=log_path, version=version) trained_model = backbone_cls.load_from_checkpoint( checkpoint_path=ckpt_file, hparams_file=hparams_file, diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index f1c5c486..2bed36d2 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -66,9 +66,7 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", + ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] = "val", diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 7dc65a59..4ba0bc6b 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -38,9 +38,7 @@ def __init__( groups: int = 1, alpha: int | None = None, gamma: int = 1, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", + ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] = "val", diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index a51d08cf..f3d57fee 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -34,9 +34,7 @@ def __init__( num_classes: int, in_channels: int, loss: nn.Module, - version: Literal[ - "std", "mc-dropout", "packed", "batched", "masked", "mimo" - ], + version: Literal["std", "mc-dropout", "packed", "batched", "masked", "mimo"], style: str = "imagenet", num_estimators: int = 1, dropout_rate: float = 0.0, @@ -49,9 +47,7 @@ def __init__( gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", + ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] = "val", diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 75d303e4..0dc88033 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -89,9 +89,7 @@ def test_dataloader(self) -> list[DataLoader]: """ return [self._data_loader(self.test)] - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: """Create a dataloader for a given dataset. Args: @@ -121,9 +119,7 @@ def _get_train_data(self) -> ArrayLike: def _get_train_targets(self) -> ArrayLike: raise NotImplementedError - def make_cross_val_splits( - self, n_splits: int = 10, train_over: int = 4 - ) -> list: + def make_cross_val_splits(self, n_splits: int = 10, train_over: int = 4) -> list: self.setup("fit") skf = StratifiedKFold(n_splits) cv_dm = [] diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 5b6a5487..2717cc12 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -87,9 +87,7 @@ def __init__( self.shift_severity = shift_severity - if (cutout is not None) + randaugment + int( - auto_augment is not None - ) > 1: + if (cutout is not None) + randaugment + int(auto_augment is not None) > 1: raise ValueError( "Only one data augmentation can be chosen at a time. Raise a " "GitHub issue if needed." diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index f56bc34b..d436fd16 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -139,9 +139,7 @@ def __init__( if basic_augment: basic_transform = T.Compose( [ - T.RandomResizedCrop( - train_size, interpolation=self.interpolation - ), + T.RandomResizedCrop(train_size, interpolation=self.interpolation), T.RandomHorizontalFlip(), ] ) @@ -232,9 +230,7 @@ def prepare_data(self) -> None: # coverage: ignore def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: if self.test_alt is not None: - raise ValueError( - "The test_alt argument is not supported for training." - ) + raise ValueError("The test_alt argument is not supported for training.") full = self.dataset( self.root, split="train", diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 604ac8c2..a49fc168 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -78,16 +78,11 @@ def __init__( elif ood_ds == "notMNIST": self.ood_dataset = NotMNIST else: - raise ValueError( - f"`ood_ds` should be in {self.ood_datasets}. Got {ood_ds}." - ) + raise ValueError(f"`ood_ds` should be in {self.ood_datasets}. Got {ood_ds}.") self.shift_dataset = MNISTC self.shift_severity = 1 - if basic_augment: - basic_transform = T.RandomCrop(28, padding=4) - else: - basic_transform = nn.Identity() + basic_transform = T.RandomCrop(28, padding=4) if basic_augment else nn.Identity() main_transform = Cutout(cutout) if cutout else nn.Identity() diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 542a10fd..3c3c7ec4 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -68,9 +68,7 @@ def __init__( elif ood_ds == "textures": self.ood_dataset = DTD else: - raise ValueError( - f"OOD dataset {ood_ds} not supported for TinyImageNet." - ) + raise ValueError(f"OOD dataset {ood_ds} not supported for TinyImageNet.") self.shift_dataset = TinyImageNetC if basic_augment: basic_transform = T.Compose( diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 2e3f7865..59804c0a 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -84,9 +84,7 @@ def __init__( }, scale=True, ), - v2.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) self.test_transform = v2.Compose( @@ -99,9 +97,7 @@ def __init__( }, scale=True, ), - v2.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) @@ -112,9 +108,7 @@ def prepare_data(self) -> None: # coverage: ignore max_depth=self.max_depth, download=True, ) - self.dataset( - root=self.root, split="val", max_depth=self.max_depth, download=True - ) + self.dataset(root=self.root, split="val", max_depth=self.max_depth, download=True) def setup(self, stage: str | None = None) -> None: if stage == "fit" or stage is None: diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index bd4b06b7..08fd432a 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -116,9 +116,7 @@ def __init__( pad_if_needed=True, fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, ), - v2.ColorJitter( - brightness=0.5, contrast=0.5, saturation=0.5 - ), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), v2.RandomHorizontalFlip(), ] ) diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 75e89515..a6005893 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -129,9 +129,7 @@ def __init__( pad_if_needed=True, fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, ), - v2.ColorJitter( - brightness=0.5, contrast=0.5, saturation=0.5 - ), + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), v2.RandomHorizontalFlip(), ] ) diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 5e967000..00f39251 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -151,12 +151,8 @@ def __init__( ) def prepare_data(self) -> None: # coverage: ignore - self.dataset( - root=self.root, split="train", target_type="semantic", download=True - ) - self.dataset( - root=self.root, split="val", target_type="semantic", download=True - ) + self.dataset(root=self.root, split="train", target_type="semantic", download=True) + self.dataset(root=self.root, split="val", target_type="semantic", download=True) def setup(self, stage: str | None = None) -> None: if stage == "fit" or stage is None: diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 4d1b304e..6dae8899 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -55,9 +55,7 @@ def __init__( persistent_workers=persistent_workers, ) - self.dataset = partial( - UCIRegression, dataset_name=dataset_name, seed=split_seed - ) + self.dataset = partial(UCIRegression, dataset_name=dataset_name, seed=split_seed) self.input_shape = input_shape self.gen = Generator().manual_seed(split_seed) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index 225410f8..ebafe7f8 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -97,9 +97,7 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found. You can use download=True to download it." - ) + raise RuntimeError("Dataset not found. You can use download=True to download it.") super().__init__( root=self.root / self.base_folder, @@ -107,20 +105,15 @@ def __init__( target_transform=target_transform, ) if subset not in ["all", *self.cifarc_subsets]: - raise ValueError( - f"The subset '{subset}' does not exist in CIFAR-C." - ) + raise ValueError(f"The subset '{subset}' does not exist in CIFAR-C.") self.subset = subset self.shift_severity = shift_severity if shift_severity not in list(range(1, 6)): raise ValueError( - "Corruptions shift_severity should be chosen between 1 and 5 " - "included." + "Corruptions shift_severity should be chosen between 1 and 5 " "included." ) - samples, labels = self.make_dataset( - self.root, self.subset, self.shift_severity - ) + samples, labels = self.make_dataset(self.root, self.subset, self.shift_severity) self.samples = samples self.labels = labels.astype(np.int64) @@ -200,9 +193,7 @@ def download(self) -> None: if self._check_integrity(): logging.info("Files already downloaded and verified") return - download_and_extract_archive( - self.url, self.root, filename=self.filename, md5=self.tgz_md5 - ) + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) class CIFAR100C(CIFAR10C): diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index 168f8571..f8354a0b 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -27,10 +27,7 @@ class CIFAR10H(CIFAR10): """ h_test_list = ["cifar-10h-probs.npy", "7b41f73eee90fdefc73bfc820ab29ba8"] - h_url = ( - "https://github.com/jcpeterson/cifar-10h/raw/master/data/" - "cifar10h-probs.npy" - ) + h_url = "https://github.com/jcpeterson/cifar-10h/raw/master/data/" "cifar10h-probs.npy" def __init__( self, @@ -42,10 +39,7 @@ def __init__( ) -> None: if train: raise ValueError("CIFAR10H does not support training data.") - print( - "WARNING: CIFAR10H cannot be used with Classification routines " - "for now." - ) + print("WARNING: CIFAR10H cannot be used with Classification routines " "for now.") super().__init__( Path(root), train=False, @@ -59,13 +53,10 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " - "download it." + "Dataset not found or corrupted. You can use download=True to " "download it." ) - self.targets = list( - torch.as_tensor(np.load(self.root / self.h_test_list[0])) - ) + self.targets = list(torch.as_tensor(np.load(self.root / self.h_test_list[0]))) def _check_specific_integrity(self) -> bool: filename, md5 = self.h_test_list diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_n.py b/torch_uncertainty/datasets/classification/cifar/cifar_n.py index 069a081a..6f6f8c95 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_n.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_n.py @@ -61,8 +61,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " - "download it." + "Dataset not found or corrupted. You can use download=True to " "download it." ) self.targets = list(torch.load(self.root / self.filename)[file_arg]) @@ -113,8 +112,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " - "download it." + "Dataset not found or corrupted. You can use download=True to " "download it." ) self.targets = list(torch.load(self.root / self.filename)[file_arg]) diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index c5229df7..722955cf 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -34,9 +34,7 @@ class ImageNetVariation(ImageFolder): root_appendix: str wnid_to_idx_url = "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/classification/imagenet/classes.json" - wnid_to_idx_md5 = ( - "1bcf467b49f735dbeb745249eae6b133" # avoid replacement attack - ) + wnid_to_idx_md5 = "1bcf467b49f735dbeb745249eae6b133" # avoid replacement attack def __init__( self, @@ -54,8 +52,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. " - "You can use download=True to download it." + "Dataset not found or corrupted. " "You can use download=True to download it." ) super().__init__( @@ -97,13 +94,9 @@ def download(self) -> None: md5=self.tgz_md5, ) elif isinstance(self.filename, list): # ImageNet-C - for url, filename, md5 in zip( - self.url, self.filename, self.tgz_md5, strict=True - ): + for url, filename, md5 in zip(self.url, self.filename, self.tgz_md5, strict=True): # Check that this particular file is not already downloaded - if not check_integrity( - self.root / self.root_appendix / Path(filename), md5 - ): + if not check_integrity(self.root / self.root_appendix / Path(filename), md5): download_and_extract_archive( url, self.root, diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py index 74be5f68..82886a2d 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py @@ -77,17 +77,13 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found. You can use download=True to download it." - ) + raise RuntimeError("Dataset not found. You can use download=True to download it.") super().__init__( root=self.root / self.base_folder / "brightness/1/", transform=transform, ) if subset not in ["all", *self.subsets]: - raise ValueError( - f"The subset '{subset}' does not exist in TinyImageNet-C." - ) + raise ValueError(f"The subset '{subset}' does not exist in TinyImageNet-C.") self.subset = subset self.shift_severity = shift_severity @@ -142,9 +138,7 @@ def _make_c_dataset(self, subset: str, shift_severity: int) -> None: def _check_integrity(self) -> bool: """Check the integrity of the dataset.""" - for filename, md5 in list( - zip(self.filename, self.tgz_md5, strict=True) - ): + for filename, md5 in list(zip(self.filename, self.tgz_md5, strict=True)): if "extra" in filename: fpath = self.root / "Tiny-ImageNet-C" / filename else: @@ -158,9 +152,7 @@ def download(self) -> None: if self._check_integrity(): logging.info("Files already downloaded and verified") return - for filename, md5 in list( - zip(self.filename, self.tgz_md5, strict=True) - ): + for filename, md5 in list(zip(self.filename, self.tgz_md5, strict=True)): if "extra" in filename: download_and_extract_archive( self.url + filename, diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index ae1bf563..6d8086cb 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -79,9 +79,7 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError( - "Dataset not found. You can use download=True to download it." - ) + raise RuntimeError("Dataset not found. You can use download=True to download it.") super().__init__( root=self.root / self.base_folder, @@ -89,15 +87,11 @@ def __init__( target_transform=target_transform, ) if subset not in ["all", *self.mnistc_subsets]: - raise ValueError( - f"The subset '{subset}' does not exist in MNIST-C." - ) + raise ValueError(f"The subset '{subset}' does not exist in MNIST-C.") self.subset = subset if split not in ["train", "test"]: - raise ValueError( - f"The split '{split}' should be either 'train' or 'test'." - ) + raise ValueError(f"The split '{split}' should be either 'train' or 'test'.") self.split = split samples, labels = self.make_dataset(self.root, self.subset, self.split) @@ -171,6 +165,4 @@ def download(self) -> None: if self._check_integrity(): logging.info("Files already downloaded and verified") return - download_and_extract_archive( - self.url, self.root, filename=self.filename, md5=self.zip_md5 - ) + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.zip_md5) diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index 8fa77b4c..1590d6be 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -49,9 +49,7 @@ def __init__( self.root = Path(root) if subset not in self.subsets: - raise ValueError( - f"The subset '{subset}' does not exist for notMNIST." - ) + raise ValueError(f"The subset '{subset}' does not exist for notMNIST.") ind = self.subsets.index(subset) self.url = self.url_base + "/" + self.filenames[ind] self.filename = self.filenames[ind] @@ -62,8 +60,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " - "download it." + "Dataset not found or corrupted. You can use download=True to " "download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/classification/uci/bank_marketing.py b/torch_uncertainty/datasets/classification/uci/bank_marketing.py index 9f693d4f..7542ebeb 100644 --- a/torch_uncertainty/datasets/classification/uci/bank_marketing.py +++ b/torch_uncertainty/datasets/classification/uci/bank_marketing.py @@ -75,9 +75,7 @@ def download(self) -> None: filename="bank+marketing.zip", md5=self.md5_zip, ) - extract_archive( - self.root / "bank-additional.zip", self.root / "bank-marketing" - ) + extract_archive(self.root / "bank-additional.zip", self.root / "bank-marketing") def _make_dataset(self) -> None: """Create dataset from extracted files.""" diff --git a/torch_uncertainty/datasets/classification/uci/dota2_games.py b/torch_uncertainty/datasets/classification/uci/dota2_games.py index 1236995e..4f023620 100644 --- a/torch_uncertainty/datasets/classification/uci/dota2_games.py +++ b/torch_uncertainty/datasets/classification/uci/dota2_games.py @@ -34,9 +34,7 @@ class DOTA2Games(UCIClassificationDataset): """ md5_zip = "896623c082b062f56b9c49c6c1fc0bf7" - url = ( - "https://archive.ics.uci.edu/static/public/367/dota2+games+results.zip" - ) + url = "https://archive.ics.uci.edu/static/public/367/dota2+games+results.zip" dataset_name = "dota2+games+results" filename = "dota2Train.csv" num_features = 116 @@ -79,11 +77,7 @@ def download(self) -> None: def _make_dataset(self) -> None: """Create dataset from extracted files.""" - path = ( - self.root - / "dota2_games" - / ("dota2Train.csv" if self.train else "dota2Test.csv") - ) + path = self.root / "dota2_games" / ("dota2Train.csv" if self.train else "dota2Test.csv") data = pd.read_csv(path, sep=",", header=None) data[0] = np.where(data[0] == 1, 1, 0) diff --git a/torch_uncertainty/datasets/classification/uci/htru2.py b/torch_uncertainty/datasets/classification/uci/htru2.py index b161c720..5284f17e 100644 --- a/torch_uncertainty/datasets/classification/uci/htru2.py +++ b/torch_uncertainty/datasets/classification/uci/htru2.py @@ -60,11 +60,7 @@ def __init__( def _make_dataset(self) -> None: """Create dataset from extracted files.""" - data = pd.read_csv( - self.root / self.dataset_name / self.filename, sep=",", header=None - ) + data = pd.read_csv(self.root / self.dataset_name / self.filename, sep=",", header=None) self.targets = torch.as_tensor(data[8].values, dtype=torch.long) - self.data = torch.as_tensor( - data.drop(columns=[8]).values, dtype=torch.float32 - ) + self.data = torch.as_tensor(data.drop(columns=[8]).values, dtype=torch.float32) self.num_features = self.data.shape[1] diff --git a/torch_uncertainty/datasets/classification/uci/spam_base.py b/torch_uncertainty/datasets/classification/uci/spam_base.py index 159b5416..0d3fd175 100644 --- a/torch_uncertainty/datasets/classification/uci/spam_base.py +++ b/torch_uncertainty/datasets/classification/uci/spam_base.py @@ -60,11 +60,7 @@ def __init__( def _make_dataset(self) -> None: """Create dataset from extracted files.""" - data = pd.read_csv( - self.root / self.dataset_name / self.filename, sep=",", header=None - ) + data = pd.read_csv(self.root / self.dataset_name / self.filename, sep=",", header=None) self.targets = torch.as_tensor(data[57].values, dtype=torch.long) - self.data = torch.as_tensor( - data.drop(columns=[57]).values, dtype=torch.float32 - ) + self.data = torch.as_tensor(data.drop(columns=[57]).values, dtype=torch.float32) self.num_features = self.data.shape[1] diff --git a/torch_uncertainty/datasets/classification/uci/uci_classification.py b/torch_uncertainty/datasets/classification/uci/uci_classification.py index c976ad9f..9f6f60e0 100644 --- a/torch_uncertainty/datasets/classification/uci/uci_classification.py +++ b/torch_uncertainty/datasets/classification/uci/uci_classification.py @@ -82,9 +82,7 @@ def __init__( self.data = self.data[self.split_idx] self.targets = self.targets[self.split_idx] if not binary: - self.targets = torch.nn.functional.one_hot( - self.targets, num_classes=2 - ) + self.targets = torch.nn.functional.one_hot(self.targets, num_classes=2) def __len__(self) -> int: """Get the length of the dataset.""" diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py index 8f021942..dbdff04a 100644 --- a/torch_uncertainty/datasets/corrupted.py +++ b/torch_uncertainty/datasets/corrupted.py @@ -21,9 +21,7 @@ def __init__( super().__init__() self.core_dataset = core_dataset if shift_severity <= 0: - raise ValueError( - f"Severity must be greater than 0. Got {shift_severity}." - ) + raise ValueError(f"Severity must be greater than 0. Got {shift_severity}.") self.shift_severity = shift_severity self.core_length = len(core_dataset) self.on_the_fly = on_the_fly @@ -50,9 +48,7 @@ def prepare_data(self): for corruption in tqdm(corruption_transforms): corruption_name = corruption.__name__.lower() (self.root / corruption_name).mkdir(parents=True) - self.save_corruption( - self.root / corruption_name, corruption(self.shift_severity) - ) + self.save_corruption(self.root / corruption_name, corruption(self.shift_severity)) def save_corruption(self, root: Path, corruption: nn.Module) -> None: for i in range(self.core_length): diff --git a/torch_uncertainty/datasets/fractals.py b/torch_uncertainty/datasets/fractals.py index d46358b5..329a2086 100644 --- a/torch_uncertainty/datasets/fractals.py +++ b/torch_uncertainty/datasets/fractals.py @@ -40,13 +40,10 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " - "download it." + "Dataset not found or corrupted. You can use download=True to " "download it." ) - super().__init__( - self.root, transform=transform, target_transform=target_transform - ) + super().__init__(self.root, transform=transform, target_transform=target_transform) def _check_integrity(self) -> bool: fpath = self.root / self.filename diff --git a/torch_uncertainty/datasets/frost.py b/torch_uncertainty/datasets/frost.py index 6e391b93..dbde6d89 100644 --- a/torch_uncertainty/datasets/frost.py +++ b/torch_uncertainty/datasets/frost.py @@ -44,8 +44,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " - "download it." + "Dataset not found or corrupted. You can use download=True to " "download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/kitti.py b/torch_uncertainty/datasets/kitti.py index 5256f385..69dfe1ef 100644 --- a/torch_uncertainty/datasets/kitti.py +++ b/torch_uncertainty/datasets/kitti.py @@ -58,9 +58,7 @@ def __init__( self.max_depth = max_depth if split not in ["train", "val"]: - raise ValueError( - f"split must be one of ['train', 'val']. Got {split}." - ) + raise ValueError(f"split must be one of ['train', 'val']. Got {split}.") self.split = split @@ -86,13 +84,10 @@ def check_split_integrity(self, folder: str) -> bool: split_path = self.root / self.split return ( split_path.is_dir() - and len(list((split_path / folder).glob("*.png"))) - == self._num_samples[self.split] + and len(list((split_path / folder).glob("*.png"))) == self._num_samples[self.split] ) - def __getitem__( - self, index: int - ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: + def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the sample at the given index. Args: @@ -105,9 +100,7 @@ def __getitem__( target = tv_tensors.Mask( F.pil_to_tensor(Image.open(self.targets[index])).squeeze(0) / 256.0 ) - target[(target <= self.min_depth) | (target > self.max_depth)] = float( - "nan" - ) + target[(target <= self.min_depth) | (target > self.max_depth)] = float("nan") if self.transforms is not None: image, target = self.transforms(image, target) @@ -119,12 +112,8 @@ def __len__(self) -> int: return self._num_samples[self.split] def _make_dataset(self) -> None: - self.samples = sorted( - (self.root / self.split / "leftImg8bit").glob("*.png") - ) - self.targets = sorted( - (self.root / self.split / "leftDepth").glob("*.png") - ) + self.samples = sorted((self.root / self.split / "leftImg8bit").glob("*.png")) + self.targets = sorted((self.root / self.split / "leftDepth").glob("*.png")) def _download_depth(self) -> None: """Download and extract the depth annotation dataset.""" @@ -147,9 +136,7 @@ def _download_depth(self) -> None: logging.info("Train files...") for file in tqdm(depth_files): exp_code = file.parents[3].name.split("_") - filecode = "_".join( - [exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name] - ) + filecode = "_".join([exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name]) shutil.copy(file, self.root / "train" / "leftDepth" / filecode) if (self.root / "val" / "leftDepth").exists(): @@ -161,9 +148,7 @@ def _download_depth(self) -> None: logging.info("Validation files...") for file in tqdm(depth_files): exp_code = file.parents[3].name.split("_") - filecode = "_".join( - [exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name] - ) + filecode = "_".join([exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name]) shutil.copy(file, self.root / "val" / "leftDepth" / filecode) shutil.rmtree(self.root / "tmp") @@ -190,16 +175,12 @@ def _download_raw(self, remove_unused: bool) -> None: logging.info("Re-structuring the raw data...") - samples_to_keep = list( - (self.root / "train" / "leftDepth").glob("*.png") - ) + samples_to_keep = list((self.root / "train" / "leftDepth").glob("*.png")) if (self.root / "train" / "leftImg8bit").exists(): shutil.rmtree(self.root / "train" / "leftImg8bit") - (self.root / "train" / "leftImg8bit").mkdir( - parents=True, exist_ok=False - ) + (self.root / "train" / "leftImg8bit").mkdir(parents=True, exist_ok=False) logging.info("Train files...") for sample in tqdm(samples_to_keep): @@ -216,17 +197,9 @@ def _download_raw(self, remove_unused: bool) -> None: ] ) raw_path = ( - self.root - / "raw" - / first_level - / second_level - / "image_02" - / "data" - / filecode[4] - ) - shutil.copy( - raw_path, self.root / "train" / "leftImg8bit" / sample.name + self.root / "raw" / first_level / second_level / "image_02" / "data" / filecode[4] ) + shutil.copy(raw_path, self.root / "train" / "leftImg8bit" / sample.name) samples_to_keep = list((self.root / "val" / "leftDepth").glob("*.png")) @@ -250,17 +223,9 @@ def _download_raw(self, remove_unused: bool) -> None: ] ) raw_path = ( - self.root - / "raw" - / first_level - / second_level - / "image_02" - / "data" - / filecode[4] - ) - shutil.copy( - raw_path, self.root / "val" / "leftImg8bit" / sample.name + self.root / "raw" / first_level / second_level / "image_02" / "data" / filecode[4] ) + shutil.copy(raw_path, self.root / "val" / "leftImg8bit" / sample.name) if remove_unused: shutil.rmtree(self.root / "raw") diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 5d82085b..c21f9f67 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -98,9 +98,7 @@ def __init__( self.max_depth = max_depth if split not in ["train", "val"]: - raise ValueError( - f"split must be one of ['train', 'val']. Got {split}." - ) + raise ValueError(f"split must be one of ['train', 'val']. Got {split}.") self.split = split self.target_type = target_type @@ -112,10 +110,7 @@ def __init__( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) - if ( - not self.check_split_integrity("leftLabel") - and target_type == "semantic" - ): + if not self.check_split_integrity("leftLabel") and target_type == "semantic": if download: self._download(split=split) else: @@ -123,10 +118,7 @@ def __init__( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) - if ( - not self.check_split_integrity("leftDepth") - and target_type == "depth" - ): + if not self.check_split_integrity("leftDepth") and target_type == "depth": if download: self._download(split=f"{split}_depth") # Depth target for train are in a different folder @@ -154,11 +146,7 @@ def __init__( with (self.root / "classes.json").open() as file: self.classes = json.load(file) - train_id_to_color = [ - c["object_id"] - for c in self.classes - if c["train_id"] not in [-1, 255] - ] + train_id_to_color = [c["object_id"] for c in self.classes if c["train_id"] not in [-1, 255]] train_id_to_color.append([0, 0, 0]) self.train_id_to_color = np.array(train_id_to_color) @@ -178,11 +166,9 @@ def encode_target(self, target: Image.Image) -> Image.Image: out = torch.zeros_like(target[..., :1]) # convert target color to index for muad_class in self.classes: - out[ - ( - target == torch.tensor(muad_class["id"], dtype=target.dtype) - ).all(dim=-1) - ] = muad_class["train_id"] + out[(target == torch.tensor(muad_class["id"], dtype=target.dtype)).all(dim=-1)] = ( + muad_class["train_id"] + ) return F.to_pil_image(rearrange(out, "h w c -> c h w")) @@ -190,9 +176,7 @@ def decode_target(self, target: Image.Image) -> np.ndarray: target[target == 255] = 19 return self.train_id_to_color[target] - def __getitem__( - self, index: int - ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: + def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the sample at the given index. Args: @@ -204,9 +188,7 @@ def __getitem__( """ image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) if self.target_type == "semantic": - target = tv_tensors.Mask( - self.encode_target(Image.open(self.targets[index])) - ) + target = tv_tensors.Mask(self.encode_target(Image.open(self.targets[index]))) else: os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" target = Image.fromarray( @@ -219,9 +201,7 @@ def __getitem__( # tv_tensor for depth maps (e.g. tv_tensors.DepthMap) target = np.asarray(target, np.float32) target = tv_tensors.Mask(400 * (1 - target)) # convert to meters - target[(target <= self.min_depth) | (target > self.max_depth)] = ( - float("nan") - ) + target[(target <= self.min_depth) | (target > self.max_depth)] = float("nan") if self.transforms is not None: image, target = self.transforms(image, target) @@ -232,8 +212,7 @@ def check_split_integrity(self, folder: str) -> bool: split_path = self.root / self.split return ( split_path.is_dir() - and len(list((split_path / folder).glob("**/*"))) - == self._num_samples[self.split] + and len(list((split_path / folder).glob("**/*"))) == self._num_samples[self.split] ) def __len__(self) -> int: @@ -248,8 +227,7 @@ def _make_dataset(self, path: Path) -> None: """ if "depth" in path.name: raise NotImplementedError( - "Depth mode is not implemented yet. Raise an issue " - "if you need it." + "Depth mode is not implemented yet. Raise an issue " "if you need it." ) self.samples = sorted((path / "leftImg8bit/").glob("**/*")) if self.target_type == "semantic": @@ -264,9 +242,7 @@ def _make_dataset(self, path: Path) -> None: def _download(self, split: str) -> None: """Download and extract the chosen split of the dataset.""" split_url = self.base_url + split + ".zip" - download_and_extract_archive( - split_url, self.root, md5=self.zip_md5[split] - ) + download_and_extract_archive(split_url, self.root, md5=self.zip_md5[split]) @property def color_palette(self) -> np.ndarray: diff --git a/torch_uncertainty/datasets/nyu.py b/torch_uncertainty/datasets/nyu.py index 15556bfa..1d02ecbb 100644 --- a/torch_uncertainty/datasets/nyu.py +++ b/torch_uncertainty/datasets/nyu.py @@ -79,9 +79,7 @@ def __init__( self.max_depth = max_depth if split not in ["train", "val"]: - raise ValueError( - f"split must be one of ['train', 'val']. Got {split}." - ) + raise ValueError(f"split must be one of ['train', 'val']. Got {split}.") self.split = split if not self._check_integrity(): @@ -112,9 +110,7 @@ def __getitem__(self, index: int): ) target = np.asarray(target, np.uint16) target = tv_tensors.Mask(target / 1e4) # convert to meters - target[(target <= self.min_depth) | (target > self.max_depth)] = float( - "nan" - ) + target[(target <= self.min_depth) | (target > self.max_depth)] = float("nan") if self.transforms is not None: image, target = self.transforms(image, target) return image, target @@ -145,9 +141,7 @@ def _download(self): md5=self.rgb_md5[self.split], ) if not check_integrity(self.root / "depth.mat", self.depth_md5): - download_url( - NYUv2.depth_url, self.root, "depth.mat", self.depth_md5 - ) + download_url(NYUv2.depth_url, self.root, "depth.mat", self.depth_md5) self._create_depth_files() def _create_depth_files(self): @@ -162,6 +156,4 @@ def _create_depth_files(self): img_id = i + 1 if img_id in ids: img = (depths[i] * 1e4).astype(np.uint16).T - Image.fromarray(img).save( - path / "depth" / f"nyu_depth_{str(img_id).zfill(4)}.png" - ) + Image.fromarray(img).save(path / "depth" / f"nyu_depth_{str(img_id).zfill(4)}.png") diff --git a/torch_uncertainty/datasets/regression/toy.py b/torch_uncertainty/datasets/regression/toy.py index d5c3cdfb..dca8bc84 100644 --- a/torch_uncertainty/datasets/regression/toy.py +++ b/torch_uncertainty/datasets/regression/toy.py @@ -26,8 +26,6 @@ def __init__( ) -> None: noise = (noise_mean, noise_std) - samples = torch.linspace( - lower_bound, upper_bound, num_samples - ).unsqueeze(-1) + samples = torch.linspace(lower_bound, upper_bound, num_samples).unsqueeze(-1) targets = samples**3 + torch.normal(*noise, size=samples.size()) super().__init__(samples, targets.squeeze(-1)) diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index 3a23ae8d..560d2136 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -117,23 +117,17 @@ class UCIRegression(Dataset): "4e6727f462779e2d396e8f7d2ddb79a3", ] urls = [ - "https://archive.ics.uci.edu/ml/machine-learning-databases/housing/" - "housing.data", - "https://archive.ics.uci.edu/static/public/165/concrete+compressive+" - "strength.zip", + "https://archive.ics.uci.edu/ml/machine-learning-databases/housing/" "housing.data", + "https://archive.ics.uci.edu/static/public/165/concrete+compressive+" "strength.zip", "https://archive.ics.uci.edu/static/public/242/energy+efficiency.zip", - "https://archive.ics.uci.edu/static/public/374/appliances+energy+" - "prediction.zip", + "https://archive.ics.uci.edu/static/public/374/appliances+energy+" "prediction.zip", "https://www.openml.org/data/get_csv/3626/dataset_2175_kin8nm.arff", - "https://raw.githubusercontent.com/luishpinto/cm-naval-propulsion-" - "plant/master/data.csv", - "https://archive.ics.uci.edu/static/public/294/combined+cycle+power+" - "plant.zip", + "https://raw.githubusercontent.com/luishpinto/cm-naval-propulsion-" "plant/master/data.csv", + "https://archive.ics.uci.edu/static/public/294/combined+cycle+power+" "plant.zip", "https://archive.ics.uci.edu/static/public/265/physicochemical+" "properties+of+protein+tertiary+structure.zip", "https://archive.ics.uci.edu/static/public/186/wine+quality.zip", - "https://archive.ics.uci.edu/static/public/243/yacht+" - "hydrodynamics.zip", + "https://archive.ics.uci.edu/static/public/243/yacht+" "hydrodynamics.zip", ] def __init__( @@ -197,10 +191,7 @@ def download(self) -> None: logging.info("Files already downloaded and verified") return if self.url is None: - raise ValueError( - f"The dataset {self.dataset_name} is not available for " - "download." - ) + raise ValueError(f"The dataset {self.dataset_name} is not available for " "download.") download_root = self.root / self.root_appendix / self.dataset_name if self.dataset_name == "boston": download_url( @@ -255,9 +246,7 @@ def _make_dataset(self) -> None: elif self.dataset_name == "kin8nm": array = pd.read_csv(path / "kin8nm.csv").to_numpy() elif self.dataset_name == "naval-propulsion-plant": - df = pd.read_csv( - path / "data.csv", header=None, sep=";", decimal="," - ) + df = pd.read_csv(path / "data.csv", header=None, sep=";", decimal=",") # convert Ex to 10^x and remove second target array = df.apply(pd.to_numeric, errors="coerce").to_numpy()[:, :-1] elif self.dataset_name == "protein": diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 7957f688..937b83f7 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -133,8 +133,7 @@ def __init__( """ if split not in ["train", "val", "test", None]: raise ValueError( - f"Unknown split '{split}'. " - "Supported splits are ['train', 'val', 'test', None]" + f"Unknown split '{split}'. " "Supported splits are ['train', 'val', 'test', None]" ) super().__init__(root, transforms, None, None) @@ -154,18 +153,13 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. " - "You can use download=True to download it" + "Dataset not found or corrupted. " "You can use download=True to download it" ) # get filenames for split if split is None: - self.images = sorted( - (Path(self.root) / "camvid" / "raw").glob("*.png") - ) - self.targets = sorted( - (Path(self.root) / "camvid" / "label").glob("*.png") - ) + self.images = sorted((Path(self.root) / "camvid" / "raw").glob("*.png")) + self.targets = sorted((Path(self.root) / "camvid" / "label").glob("*.png")) else: with (Path(self.root) / "camvid" / "splits.json").open() as f: filenames = json.load(f)[split] @@ -173,18 +167,14 @@ def __init__( self.images = sorted( [ path - for path in (Path(self.root) / "camvid" / "raw").glob( - "*.png" - ) + for path in (Path(self.root) / "camvid" / "raw").glob("*.png") if path.stem in filenames ] ) self.targets = sorted( [ path - for path in (Path(self.root) / "camvid" / "label").glob( - "*.png" - ) + for path in (Path(self.root) / "camvid" / "label").glob("*.png") if path.stem[:-2] in filenames ] ) @@ -210,10 +200,7 @@ def encode_target(self, target: Image.Image) -> torch.Tensor: if self.group_classes and index != 255: index = self.class_to_superclass[index] target[ - ( - colored_target - == torch.tensor(camvid_class.color, dtype=target.dtype) - ).all(dim=-1) + (colored_target == torch.tensor(camvid_class.color, dtype=target.dtype)).all(dim=-1) ] = index return rearrange(target, "h w c -> c h w") @@ -231,18 +218,12 @@ def decode_target(self, target: torch.Tensor) -> Image.Image: if not self.group_classes: for camvid_class in self.classes: colored_target[ - ( - target - == torch.tensor(camvid_class.index, dtype=target.dtype) - ).all(dim=0) + (target == torch.tensor(camvid_class.index, dtype=target.dtype)).all(dim=0) ] = torch.tensor(camvid_class.color, dtype=target.dtype) else: for camvid_class in self.superclasses: colored_target[ - ( - target - == torch.tensor(camvid_class.index, dtype=target.dtype) - ).all(dim=0) + (target == torch.tensor(camvid_class.index, dtype=target.dtype)).all(dim=0) ] = torch.tensor(camvid_class.color, dtype=target.dtype) return F.to_pil_image(rearrange(colored_target, "h w c -> c h w")) @@ -253,9 +234,7 @@ def color_palette(self) -> list[tuple[int, int, int]]: return [camvid_class.color for camvid_class in self.superclasses] return [camvid_class.color for camvid_class in self.classes] - def __getitem__( - self, index: int - ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: + def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the image and target at the given index. Args: @@ -265,9 +244,7 @@ def __getitem__( tuple[tv_tensors.Image, tv_tensors.Mask]: Image and target. """ image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) - target = tv_tensors.Mask( - self.encode_target(Image.open(self.targets[index])) - ) + target = tv_tensors.Mask(self.encode_target(Image.open(self.targets[index]))) if self.transforms is not None: image, target = self.transforms(image, target) @@ -280,10 +257,7 @@ def __len__(self) -> int: def _check_integrity(self) -> bool: """Check if the dataset exists.""" - if ( - len(list((Path(self.root) / "camvid" / "raw").glob("*.png"))) - != self.num_samples["all"] - ): + if len(list((Path(self.root) / "camvid" / "raw").glob("*.png"))) != self.num_samples["all"]: return False if ( len(list((Path(self.root) / "camvid" / "label").glob("*.png"))) diff --git a/torch_uncertainty/datasets/segmentation/cityscapes.py b/torch_uncertainty/datasets/segmentation/cityscapes.py index 769174df..37505c1c 100644 --- a/torch_uncertainty/datasets/segmentation/cityscapes.py +++ b/torch_uncertainty/datasets/segmentation/cityscapes.py @@ -31,9 +31,7 @@ def __init__( transforms, ) train_id_to_color = [ - c.color - for c in self.classes - if (c.train_id != -1 and c.train_id != 255) + c.color for c in self.classes if (c.train_id != -1 and c.train_id != 255) ] train_id_to_color.append([0, 0, 0]) self.train_id_to_color = torch.tensor(train_id_to_color) @@ -54,10 +52,9 @@ def encode_target(cls, target: Image.Image) -> Image.Image: # convert target color to index for cityscapes_class in cls.classes: target[ - ( - colored_target - == torch.tensor(cityscapes_class.id, dtype=target.dtype) - ).all(dim=-1) + (colored_target == torch.tensor(cityscapes_class.id, dtype=target.dtype)).all( + dim=-1 + ) ] = cityscapes_class.train_id return F.to_pil_image(rearrange(target, "h w c -> c h w")) @@ -92,9 +89,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: if t == "polygon": target = self._load_json(self.targets[index][i]) elif t == "semantic": - target = tv_tensors.Mask( - self.encode_target(Image.open(self.targets[index][i])) - ) + target = tv_tensors.Mask(self.encode_target(Image.open(self.targets[index][i]))) else: target = Image.open(self.targets[index][i]) @@ -107,9 +102,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: return image, target - def plot_sample( - self, index: int, ax: _AX_TYPE | None = None - ) -> _PLOT_OUT_TYPE: + def plot_sample(self, index: int, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: """Plot a sample from the dataset. Args: diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index ac641413..dde72139 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -102,16 +102,10 @@ def __init__( **factory_kwargs, ) - self.r_group = nn.Parameter( - torch.empty((num_estimators, in_features), **factory_kwargs) - ) - self.s_group = nn.Parameter( - torch.empty((num_estimators, out_features), **factory_kwargs) - ) + self.r_group = nn.Parameter(torch.empty((num_estimators, in_features), **factory_kwargs)) + self.s_group = nn.Parameter(torch.empty((num_estimators, out_features), **factory_kwargs)) if bias: - self.bias = nn.Parameter( - torch.empty((num_estimators, out_features), **factory_kwargs) - ) + self.bias = nn.Parameter(torch.empty((num_estimators, out_features), **factory_kwargs)) else: self.register_parameter("bias", None) self.reset_parameters() @@ -120,9 +114,7 @@ def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out( - self.linear.weight - ) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.linear.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) @@ -133,18 +125,10 @@ def forward(self, inputs: Tensor) -> Tensor: ) extra = batch_size % self.num_estimators - r_group = torch.repeat_interleave( - self.r_group, examples_per_estimator, dim=0 - ) - r_group = torch.cat( - [r_group, r_group[:extra]], dim=0 - ) # .unsqueeze(-1).unsqueeze(-1) - s_group = torch.repeat_interleave( - self.s_group, examples_per_estimator, dim=0 - ) - s_group = torch.cat( - [s_group, s_group[:extra]], dim=0 - ) # .unsqueeze(-1).unsqueeze(-1) + r_group = torch.repeat_interleave(self.r_group, examples_per_estimator, dim=0) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) + s_group = torch.repeat_interleave(self.s_group, examples_per_estimator, dim=0) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) if self.bias is not None: bias = torch.repeat_interleave( self.bias, @@ -155,9 +139,7 @@ def forward(self, inputs: Tensor) -> Tensor: else: bias = None - return self.linear(inputs * r_group) * s_group + ( - bias if bias is not None else 0 - ) + return self.linear(inputs * r_group) * s_group + (bias if bias is not None else 0) def extra_repr(self) -> str: return ( @@ -324,16 +306,10 @@ def __init__( bias=False, **factory_kwargs, ) - self.r_group = nn.Parameter( - torch.empty((num_estimators, in_channels), **factory_kwargs) - ) - self.s_group = nn.Parameter( - torch.empty((num_estimators, out_channels), **factory_kwargs) - ) + self.r_group = nn.Parameter(torch.empty((num_estimators, in_channels), **factory_kwargs)) + self.s_group = nn.Parameter(torch.empty((num_estimators, out_channels), **factory_kwargs)) if bias: - self.bias = nn.Parameter( - torch.empty((num_estimators, out_channels), **factory_kwargs) - ) + self.bias = nn.Parameter(torch.empty((num_estimators, out_channels), **factory_kwargs)) else: self.register_parameter("bias", None) @@ -365,9 +341,7 @@ def forward(self, inputs: Tensor) -> Tensor: .unsqueeze(-1) .unsqueeze(-1) ) - r_group = torch.cat( - [r_group, r_group[:extra]], dim=0 - ) # .unsqueeze(-1).unsqueeze(-1) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) s_group = ( torch.repeat_interleave( self.s_group, @@ -402,9 +376,7 @@ def forward(self, inputs: Tensor) -> Tensor: else: bias = None - return self.conv(inputs * r_group) * s_group + ( - bias if bias is not None else 0 - ) + return self.conv(inputs * r_group) * s_group + (bias if bias is not None else 0) def extra_repr(self) -> str: return ( diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index d6122bb5..e328aa39 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -89,8 +89,7 @@ def __init__( if transposed: raise NotImplementedError( - "Bayesian transposed convolution not implemented yet. Raise an" - " issue if needed." + "Bayesian transposed convolution not implemented yet. Raise an" " issue if needed." ) self.in_channels = in_channels @@ -109,9 +108,7 @@ def __init__( self.groups = groups self.padding_mode = padding_mode - self._reversed_padding_repeated_twice = _reverse_repeat_tuple( - self.padding, 2 - ) + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) self.weight_mu = Parameter( torch.empty( @@ -127,33 +124,21 @@ def __init__( ) if bias: - self.bias_mu = Parameter( - torch.empty(out_channels, **factory_kwargs) - ) - self.bias_sigma = Parameter( - torch.empty(out_channels, **factory_kwargs) - ) + self.bias_mu = Parameter(torch.empty(out_channels, **factory_kwargs)) + self.bias_sigma = Parameter(torch.empty(out_channels, **factory_kwargs)) else: self.register_parameter("bias_mu", None) self.register_parameter("bias_sigma", None) - self.weight_prior_dist = CenteredGaussianMixture( - prior_sigma_1, prior_sigma_2, prior_pi - ) + self.weight_prior_dist = CenteredGaussianMixture(prior_sigma_1, prior_sigma_2, prior_pi) if bias: - self.bias_prior_dist = CenteredGaussianMixture( - prior_sigma_1, prior_sigma_2, prior_pi - ) + self.bias_prior_dist = CenteredGaussianMixture(prior_sigma_1, prior_sigma_2, prior_pi) self.reset_parameters() - self.weight_sampler = TrainableDistribution( - self.weight_mu, self.weight_sigma - ) + self.weight_sampler = TrainableDistribution(self.weight_mu, self.weight_sigma) if bias: - self.bias_sampler = TrainableDistribution( - self.bias_mu, self.bias_sigma - ) + self.bias_sampler = TrainableDistribution(self.bias_mu, self.bias_sigma) def reset_parameters(self) -> None: # TODO: change init @@ -179,10 +164,7 @@ def sample(self) -> tuple[Tensor, Tensor | None]: return weight, bias def extra_repr(self) -> str: # coverage: ignore - s = ( - "{in_channels}, {out_channels}, kernel_size={kernel_size}" - ", stride={stride}" - ) + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", stride={stride}" if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): @@ -253,9 +235,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward( - self, inputs: Tensor, weight: Tensor, bias: Tensor | None - ) -> Tensor: + def _conv_forward(self, inputs: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: if self.padding_mode != "zeros": return F.conv1d( F.pad( @@ -294,9 +274,7 @@ def forward(self, inputs: Tensor) -> Tensor: else: bias, bias_lposterior, bias_lprior = None, 0, 0 - self.lvposterior = ( - self.weight_sampler.log_posterior() + bias_lposterior - ) + self.lvposterior = self.weight_sampler.log_posterior() + bias_lposterior self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) @@ -350,9 +328,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward( - self, inputs: Tensor, weight: Tensor, bias: Tensor | None - ) -> Tensor: + def _conv_forward(self, inputs: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: if self.padding_mode != "zeros": return F.conv2d( F.pad( @@ -391,9 +367,7 @@ def forward(self, inputs: Tensor) -> Tensor: else: bias, bias_lposterior, bias_lprior = None, 0, 0 - self.lvposterior = ( - self.weight_sampler.log_posterior() + bias_lposterior - ) + self.lvposterior = self.weight_sampler.log_posterior() + bias_lposterior self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) @@ -447,9 +421,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward( - self, inputs: Tensor, weight: Tensor, bias: Tensor | None - ) -> Tensor: + def _conv_forward(self, inputs: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: if self.padding_mode != "zeros": return F.conv3d( F.pad( @@ -488,9 +460,7 @@ def forward(self, inputs: Tensor) -> Tensor: else: bias, bias_lposterior, bias_lprior = None, 0, 0 - self.lvposterior = ( - self.weight_sampler.log_posterior() + bias_lposterior - ) + self.lvposterior = self.weight_sampler.log_posterior() + bias_lposterior self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) diff --git a/torch_uncertainty/layers/bayesian/bayes_linear.py b/torch_uncertainty/layers/bayesian/bayes_linear.py index ff2247d2..6db022b9 100644 --- a/torch_uncertainty/layers/bayesian/bayes_linear.py +++ b/torch_uncertainty/layers/bayesian/bayes_linear.py @@ -64,40 +64,24 @@ def __init__( self.sigma_init = sigma_init self.frozen = frozen - self.weight_mu = nn.Parameter( - torch.empty((out_features, in_features), **factory_kwargs) - ) - self.weight_sigma = nn.Parameter( - torch.empty((out_features, in_features), **factory_kwargs) - ) + self.weight_mu = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + self.weight_sigma = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) if bias: - self.bias_mu = nn.Parameter( - torch.empty(out_features, **factory_kwargs) - ) - self.bias_sigma = nn.Parameter( - torch.empty(out_features, **factory_kwargs) - ) + self.bias_mu = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + self.bias_sigma = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: self.register_parameter("bias_mu", None) self.register_parameter("bias_log_sigma", None) self.reset_parameters() - self.weight_sampler = TrainableDistribution( - self.weight_mu, self.weight_sigma - ) + self.weight_sampler = TrainableDistribution(self.weight_mu, self.weight_sigma) if bias: - self.bias_sampler = TrainableDistribution( - self.bias_mu, self.bias_sigma - ) + self.bias_sampler = TrainableDistribution(self.bias_mu, self.bias_sigma) - self.weight_prior_dist = CenteredGaussianMixture( - prior_sigma_1, prior_sigma_2, prior_pi - ) + self.weight_prior_dist = CenteredGaussianMixture(prior_sigma_1, prior_sigma_2, prior_pi) if bias: - self.bias_prior_dist = CenteredGaussianMixture( - prior_sigma_1, prior_sigma_2, prior_pi - ) + self.bias_prior_dist = CenteredGaussianMixture(prior_sigma_1, prior_sigma_2, prior_pi) def reset_parameters(self) -> None: # TODO: change init diff --git a/torch_uncertainty/layers/bayesian/lpbnn.py b/torch_uncertainty/layers/bayesian/lpbnn.py index 2b53305a..996119f3 100644 --- a/torch_uncertainty/layers/bayesian/lpbnn.py +++ b/torch_uncertainty/layers/bayesian/lpbnn.py @@ -12,17 +12,11 @@ def check_lpbnn_parameters_consistency( hidden_size: int, std_factor: float, num_estimators: int ) -> None: if hidden_size < 1: - raise ValueError( - f"hidden_size must be greater than 0. Got {hidden_size}." - ) + raise ValueError(f"hidden_size must be greater than 0. Got {hidden_size}.") if std_factor < 0: - raise ValueError( - f"std_factor must be greater than 0. Got {std_factor}." - ) + raise ValueError(f"std_factor must be greater than 0. Got {std_factor}.") if num_estimators < 1: - raise ValueError( - f"num_estimators must be greater than 0. Got {num_estimators}." - ) + raise ValueError(f"num_estimators must be greater than 0. Got {num_estimators}.") def _sample(mu: Tensor, logvar: Tensor, std_factor: float) -> Tensor: @@ -72,9 +66,7 @@ def __init__( `Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification `_. """ - check_lpbnn_parameters_consistency( - hidden_size, std_factor, num_estimators - ) + check_lpbnn_parameters_consistency(hidden_size, std_factor, num_estimators) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -87,33 +79,19 @@ def __init__( # for the KL Loss self.lprior = 0 - self.linear = nn.Linear( - in_features, out_features, bias=False, **factory_kwargs - ) + self.linear = nn.Linear(in_features, out_features, bias=False, **factory_kwargs) self.alpha = nn.Parameter( torch.empty((num_estimators, in_features), **factory_kwargs), requires_grad=False, ) - self.gamma = nn.Parameter( - torch.empty((num_estimators, out_features), **factory_kwargs) - ) - self.encoder = nn.Linear( - in_features, self.hidden_size, **factory_kwargs - ) - self.latent_mean = nn.Linear( - self.hidden_size, self.hidden_size, **factory_kwargs - ) - self.latent_logvar = nn.Linear( - self.hidden_size, self.hidden_size, **factory_kwargs - ) - self.decoder = nn.Linear( - self.hidden_size, in_features, **factory_kwargs - ) + self.gamma = nn.Parameter(torch.empty((num_estimators, out_features), **factory_kwargs)) + self.encoder = nn.Linear(in_features, self.hidden_size, **factory_kwargs) + self.latent_mean = nn.Linear(self.hidden_size, self.hidden_size, **factory_kwargs) + self.latent_logvar = nn.Linear(self.hidden_size, self.hidden_size, **factory_kwargs) + self.decoder = nn.Linear(self.hidden_size, in_features, **factory_kwargs) self.latent_loss = torch.zeros(1, **factory_kwargs) if bias: - self.bias = nn.Parameter( - torch.empty((num_estimators, out_features), **factory_kwargs) - ) + self.bias = nn.Parameter(torch.empty((num_estimators, out_features), **factory_kwargs)) else: self.register_parameter("bias", None) self.reset_parameters() @@ -127,9 +105,7 @@ def reset_parameters(self): self.latent_mean.reset_parameters() self.latent_logvar.reset_parameters() if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out( - self.linear.weight - ) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.linear.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) @@ -148,9 +124,7 @@ def forward(self, x: Tensor) -> Tensor: # Compute the latent loss if self.training: mse = F.mse_loss(alpha_sample, self.alpha) - kld = -0.5 * torch.sum( - 1 + latent_logvar - latent_mean**2 - torch.exp(latent_logvar) - ) + kld = -0.5 * torch.sum(1 + latent_logvar - latent_mean**2 - torch.exp(latent_logvar)) # For the KL Loss self.lvposterior = mse + kld @@ -218,9 +192,7 @@ def __init__( `Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification `_. """ - check_lpbnn_parameters_consistency( - hidden_size, std_factor, num_estimators - ) + check_lpbnn_parameters_consistency(hidden_size, std_factor, num_estimators) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -248,31 +220,19 @@ def __init__( requires_grad=False, ) - self.encoder = nn.Linear( - in_channels, self.hidden_size, **factory_kwargs - ) - self.decoder = nn.Linear( - self.hidden_size, in_channels, **factory_kwargs - ) - self.latent_mean = nn.Linear( - self.hidden_size, self.hidden_size, **factory_kwargs - ) - self.latent_logvar = nn.Linear( - self.hidden_size, self.hidden_size, **factory_kwargs - ) + self.encoder = nn.Linear(in_channels, self.hidden_size, **factory_kwargs) + self.decoder = nn.Linear(self.hidden_size, in_channels, **factory_kwargs) + self.latent_mean = nn.Linear(self.hidden_size, self.hidden_size, **factory_kwargs) + self.latent_logvar = nn.Linear(self.hidden_size, self.hidden_size, **factory_kwargs) self.latent_loss = torch.zeros(1, **factory_kwargs) if gamma: - self.gamma = nn.Parameter( - torch.empty((num_estimators, out_channels), **factory_kwargs) - ) + self.gamma = nn.Parameter(torch.empty((num_estimators, out_channels), **factory_kwargs)) else: self.register_parameter("gamma", None) if bias: - self.bias = nn.Parameter( - torch.empty((num_estimators, out_channels), **factory_kwargs) - ) + self.bias = nn.Parameter(torch.empty((num_estimators, out_channels), **factory_kwargs)) else: self.register_parameter("bias", None) self.reset_parameters() @@ -306,36 +266,22 @@ def forward(self, x: Tensor) -> Tensor: # Compute the latent loss if self.training: mse = F.mse_loss(alpha_sample, self.alpha) - kld = -0.5 * torch.sum( - 1 + latent_logvar - latent_mean.pow(2) - latent_logvar.exp() - ) + kld = -0.5 * torch.sum(1 + latent_logvar - latent_mean.pow(2) - latent_logvar.exp()) # for the KL Loss self.lvposterior = mse + kld num_examples_per_model = int(x.size(0) / self.num_estimators) # Compute the output - alpha = ( - alpha_sample.repeat((num_examples_per_model, 1)) - .unsqueeze(-1) - .unsqueeze(-1) - ) + alpha = alpha_sample.repeat((num_examples_per_model, 1)).unsqueeze(-1).unsqueeze(-1) if self.gamma is not None: - gamma = ( - self.gamma.repeat((num_examples_per_model, 1)) - .unsqueeze(-1) - .unsqueeze(-1) - ) + gamma = self.gamma.repeat((num_examples_per_model, 1)).unsqueeze(-1).unsqueeze(-1) out = self.conv(x * alpha) * gamma else: out = self.conv(x * alpha) if self.bias is not None: - bias = ( - self.bias.repeat((num_examples_per_model, 1)) - .unsqueeze(-1) - .unsqueeze(-1) - ) + bias = self.bias.repeat((num_examples_per_model, 1)).unsqueeze(-1).unsqueeze(-1) out += bias return out diff --git a/torch_uncertainty/layers/bayesian/sampler.py b/torch_uncertainty/layers/bayesian/sampler.py index dd5e710a..46fed63e 100644 --- a/torch_uncertainty/layers/bayesian/sampler.py +++ b/torch_uncertainty/layers/bayesian/sampler.py @@ -25,9 +25,7 @@ def sample(self) -> Tensor: def log_posterior(self, weight: Tensor | None = None) -> Tensor: if self.weight is None or self.sigma is None: - raise ValueError( - "Sample the weights before querying the log posterior." - ) + raise ValueError("Sample the weights before querying the log posterior.") if weight is None: # coverage: ignore weight = self.weight diff --git a/torch_uncertainty/layers/channel_layer_norm.py b/torch_uncertainty/layers/channel_layer_norm.py index 69999324..dce247e2 100644 --- a/torch_uncertainty/layers/channel_layer_norm.py +++ b/torch_uncertainty/layers/channel_layer_norm.py @@ -49,9 +49,7 @@ def __init__( - Output: :math:`(N, *)` (same shape as input) """ - super().__init__( - normalized_shape, eps, elementwise_affine, bias, device, dtype - ) + super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) self.cback = ChannelBack() self.cfront = ChannelFront() diff --git a/torch_uncertainty/layers/filter_response_norm.py b/torch_uncertainty/layers/filter_response_norm.py index 8c9f3aee..ef306a8f 100644 --- a/torch_uncertainty/layers/filter_response_norm.py +++ b/torch_uncertainty/layers/filter_response_norm.py @@ -23,15 +23,13 @@ def __init__( super().__init__() if dimension < 1 or not isinstance(dimension, int): raise ValueError( - "dimension should be an integer greater or equal than 1. " - f"got {dimension}." + "dimension should be an integer greater or equal than 1. " f"got {dimension}." ) self.dimension = dimension if num_channels < 1 or not isinstance(num_channels, int): raise ValueError( - "num_channels should be an integer greater or equal than 1. " - f"got {num_channels}." + "num_channels should be an integer greater or equal than 1. " f"got {num_channels}." ) shape = (1, num_channels) + (1,) * dimension self.eps = eps @@ -41,18 +39,14 @@ def __init__( self.gamma = nn.Parameter(torch.ones(shape, device=device, dtype=dtype)) def forward(self, x: Tensor) -> Tensor: - nu2 = torch.mean( - x**2, dim=list(range(-self.dimension, 0)), keepdim=True - ) + nu2 = torch.mean(x**2, dim=list(range(-self.dimension, 0)), keepdim=True) x = x * torch.rsqrt(nu2 + self.eps) y = self.gamma * x + self.beta return torch.max(y, self.tau) class FilterResponseNorm1d(_FilterResponseNormNd): - def __init__( - self, num_channels: int, eps: float = 1e-6, device=None, dtype=None - ) -> None: + def __init__(self, num_channels: int, eps: float = 1e-6, device=None, dtype=None) -> None: """1-dimensional Filter Response Normalization layer. Args: @@ -71,9 +65,7 @@ def __init__( class FilterResponseNorm2d(_FilterResponseNormNd): - def __init__( - self, num_channels: int, eps: float = 1e-6, device=None, dtype=None - ) -> None: + def __init__(self, num_channels: int, eps: float = 1e-6, device=None, dtype=None) -> None: """2-dimensional Filter Response Normalization layer. Args: @@ -92,9 +84,7 @@ def __init__( class FilterResponseNorm3d(_FilterResponseNormNd): - def __init__( - self, num_channels: int, eps: float = 1e-6, device=None, dtype=None - ) -> None: + def __init__(self, num_channels: int, eps: float = 1e-6, device=None, dtype=None) -> None: """3-dimensional Filter Response Normalization layer. Args: diff --git a/torch_uncertainty/layers/masksembles.py b/torch_uncertainty/layers/masksembles.py index d8493c9e..d94d9539 100644 --- a/torch_uncertainty/layers/masksembles.py +++ b/torch_uncertainty/layers/masksembles.py @@ -117,9 +117,7 @@ def generation_wrapper(c: int, n: int, scale: float) -> np.ndarray: class Mask1d(nn.Module): - def __init__( - self, channels: int, num_masks: int, scale: float, **factory_kwargs - ) -> None: + def __init__(self, channels: int, num_masks: int, scale: float, **factory_kwargs) -> None: super().__init__() self.num_masks = num_masks @@ -139,9 +137,7 @@ def forward(self, inputs: Tensor) -> Tensor: class Mask2d(nn.Module): - def __init__( - self, channels: int, num_masks: int, scale: float, **factory_kwargs - ) -> None: + def __init__(self, channels: int, num_masks: int, scale: float, **factory_kwargs) -> None: super().__init__() self.num_masks = num_masks @@ -206,9 +202,7 @@ def __init__( if scale < 1: raise ValueError(f"Attribute `scale` should be >= 1, not {scale}.") - self.mask = Mask1d( - in_features, num_masks=num_estimators, scale=scale, **factory_kwargs - ) + self.mask = Mask1d(in_features, num_masks=num_estimators, scale=scale, **factory_kwargs) self.linear = nn.Linear( in_features=in_features, out_features=out_features, @@ -275,9 +269,7 @@ def __init__( if scale < 1: raise ValueError(f"Attribute `scale` should be >= 1, not {scale}.") - self.mask = Mask2d( - in_channels, num_masks=num_estimators, scale=scale, **factory_kwargs - ) + self.mask = Mask2d(in_channels, num_masks=num_estimators, scale=scale, **factory_kwargs) self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, diff --git a/torch_uncertainty/layers/mc_batch_norm.py b/torch_uncertainty/layers/mc_batch_norm.py index 916dc6f8..4c6517a5 100644 --- a/torch_uncertainty/layers/mc_batch_norm.py +++ b/torch_uncertainty/layers/mc_batch_norm.py @@ -35,15 +35,11 @@ def __init__( self.register_buffer( "means", - torch.zeros( - num_estimators, num_features, device=device, dtype=dtype - ), + torch.zeros(num_estimators, num_features, device=device, dtype=dtype), ) self.register_buffer( "vars", - torch.zeros( - num_estimators, num_features, device=device, dtype=dtype - ), + torch.zeros(num_estimators, num_features, device=device, dtype=dtype), ) self.accumulate = True @@ -103,9 +99,7 @@ class MCBatchNorm1d(_MCBatchNorm): def _check_input_dim(self, inputs) -> None: if inputs.dim() != 2 and inputs.dim() != 3: - raise ValueError( - f"expected 2D or 3D input (got {inputs.dim()}D input)" - ) + raise ValueError(f"expected 2D or 3D input (got {inputs.dim()}D input)") class MCBatchNorm2d(_MCBatchNorm): @@ -129,9 +123,7 @@ class MCBatchNorm2d(_MCBatchNorm): def _check_input_dim(self, inputs) -> None: if inputs.dim() != 3 and inputs.dim() != 4: - raise ValueError( - f"expected 3D or 4D input (got {inputs.dim()}D input)" - ) + raise ValueError(f"expected 3D or 4D input (got {inputs.dim()}D input)") class MCBatchNorm3d(_MCBatchNorm): @@ -155,6 +147,4 @@ class MCBatchNorm3d(_MCBatchNorm): def _check_input_dim(self, inputs) -> None: if inputs.dim() != 4 and inputs.dim() != 5: - raise ValueError( - f"expected 4D or 5D input (got {inputs.dim()}D input)" - ) + raise ValueError(f"expected 4D or 5D input (got {inputs.dim()}D input)") diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 3858300e..490f65ee 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -8,9 +8,7 @@ from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t -def check_packed_parameters_consistency( - alpha: float, gamma: int, num_estimators: int -) -> None: +def check_packed_parameters_consistency(alpha: float, gamma: int, num_estimators: int) -> None: """Check the consistency of the parameters of the Packed-Ensembles layers. Args: @@ -25,26 +23,18 @@ def check_packed_parameters_consistency( raise ValueError(f"Attribute `alpha` should be > 0, not {alpha}") if not isinstance(gamma, int): - raise TypeError( - f"Attribute `gamma` should be an int, not {type(gamma)}" - ) + raise TypeError(f"Attribute `gamma` should be an int, not {type(gamma)}") if gamma <= 0: raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}") if num_estimators is None: - raise ValueError( - "You must specify the value of the arg. `num_estimators`" - ) + raise ValueError("You must specify the value of the arg. `num_estimators`") if not isinstance(num_estimators, int): raise TypeError( - "Attribute `num_estimators` should be an int, not " - f"{type(num_estimators)}" + "Attribute `num_estimators` should be an int, not " f"{type(num_estimators)}" ) if num_estimators <= 0: - raise ValueError( - "Attribute `num_estimators` should be >= 1, not " - f"{num_estimators}" - ) + raise ValueError("Attribute `num_estimators` should be >= 1, not " f"{num_estimators}") class PackedLinear(nn.Module): @@ -120,22 +110,16 @@ def __init__( # Define the number of features of the underlying convolution extended_in_features = int(in_features * (1 if first else alpha)) - extended_out_features = int( - out_features * (num_estimators if last else alpha) - ) + extended_out_features = int(out_features * (num_estimators if last else alpha)) # Define the number of groups of the underlying convolution actual_groups = num_estimators * gamma if not first else 1 # fix if not divisible by groups if extended_in_features % actual_groups: - extended_in_features += num_estimators - extended_in_features % ( - actual_groups - ) + extended_in_features += num_estimators - extended_in_features % (actual_groups) if extended_out_features % actual_groups: - extended_out_features += num_estimators - extended_out_features % ( - actual_groups - ) + extended_out_features += num_estimators - extended_out_features % (actual_groups) # FIXME: This is a temporary check assert implementation in [ @@ -173,9 +157,7 @@ def __init__( self.groups = actual_groups if bias: - self.bias = nn.Parameter( - torch.empty(extended_out_features, **factory_kwargs) - ) + self.bias = nn.Parameter(torch.empty(extended_out_features, **factory_kwargs)) else: self.register_parameter("bias", None) @@ -193,9 +175,7 @@ def reset_parameters(self) -> None: nn.init.uniform_(self.bias, -bound, bound) if self.implementation == "sparse": - self.weight = nn.Parameter( - torch.block_diag(*self.weight).to_sparse() - ) + self.weight = nn.Parameter(torch.block_diag(*self.weight).to_sparse()) def _rearrange_forward(self, x: Tensor) -> Tensor: x = x.unsqueeze(-1) @@ -209,9 +189,7 @@ def forward(self, inputs: Tensor) -> Tensor: if self.implementation == "legacy": if self.rearrange: return self._rearrange_forward(inputs) - return F.conv1d( - inputs, self.weight, self.bias, 1, 0, 1, self.groups - ) + return F.conv1d(inputs, self.weight, self.bias, 1, 0, 1, self.groups) if self.implementation == "full": block_diag = torch.block_diag(*self.weight) return F.linear(inputs, block_diag, self.bias) @@ -302,30 +280,23 @@ def __init__( # Define the number of channels of the underlying convolution extended_in_channels = int(in_channels * (1 if first else alpha)) - extended_out_channels = int( - out_channels * (num_estimators if last else alpha) - ) + extended_out_channels = int(out_channels * (num_estimators if last else alpha)) # Define the number of groups of the underlying convolution actual_groups = 1 if first else gamma * groups * num_estimators while ( extended_in_channels % actual_groups != 0 - or extended_in_channels // actual_groups - < minimum_channels_per_group + or extended_in_channels // actual_groups < minimum_channels_per_group ) and actual_groups // (groups * num_estimators) > 1: gamma -= 1 actual_groups = gamma * groups * num_estimators # fix if not divisible by groups if extended_in_channels % actual_groups: - extended_in_channels += ( - num_estimators - extended_in_channels % actual_groups - ) + extended_in_channels += num_estimators - extended_in_channels % actual_groups if extended_out_channels % actual_groups: - extended_out_channels += ( - num_estimators - extended_out_channels % actual_groups - ) + extended_out_channels += num_estimators - extended_out_channels % actual_groups self.conv = nn.Conv1d( in_channels=extended_in_channels, @@ -430,30 +401,23 @@ def __init__( # Define the number of channels of the underlying convolution extended_in_channels = int(in_channels * (1 if first else alpha)) - extended_out_channels = int( - out_channels * (num_estimators if last else alpha) - ) + extended_out_channels = int(out_channels * (num_estimators if last else alpha)) # Define the number of groups of the underlying convolution actual_groups = 1 if first else gamma * groups * num_estimators while ( extended_in_channels % actual_groups != 0 - or extended_in_channels // actual_groups - < minimum_channels_per_group + or extended_in_channels // actual_groups < minimum_channels_per_group ) and actual_groups // (groups * num_estimators) > 1: gamma -= 1 actual_groups = gamma * groups * num_estimators # fix if not divisible by groups if extended_in_channels % actual_groups: - extended_in_channels += ( - num_estimators - extended_in_channels % actual_groups - ) + extended_in_channels += num_estimators - extended_in_channels % actual_groups if extended_out_channels % actual_groups: - extended_out_channels += ( - num_estimators - extended_out_channels % actual_groups - ) + extended_out_channels += num_estimators - extended_out_channels % actual_groups self.conv = nn.Conv2d( in_channels=extended_in_channels, @@ -558,30 +522,23 @@ def __init__( # Define the number of channels of the underlying convolution extended_in_channels = int(in_channels * (1 if first else alpha)) - extended_out_channels = int( - out_channels * (num_estimators if last else alpha) - ) + extended_out_channels = int(out_channels * (num_estimators if last else alpha)) # Define the number of groups of the underlying convolution actual_groups = 1 if first else gamma * groups * num_estimators while ( extended_in_channels % actual_groups != 0 - or extended_in_channels // actual_groups - < minimum_channels_per_group + or extended_in_channels // actual_groups < minimum_channels_per_group ) and actual_groups // (groups * num_estimators) > 1: gamma -= 1 actual_groups = gamma * groups * num_estimators # fix if not divisible by groups if extended_in_channels % actual_groups: - extended_in_channels += ( - num_estimators - extended_in_channels % actual_groups - ) + extended_in_channels += num_estimators - extended_in_channels % actual_groups if extended_out_channels % actual_groups: - extended_out_channels += ( - num_estimators - extended_out_channels % actual_groups - ) + extended_out_channels += num_estimators - extended_out_channels % actual_groups self.conv = nn.Conv3d( in_channels=extended_in_channels, diff --git a/torch_uncertainty/losses/bayesian.py b/torch_uncertainty/losses/bayesian.py index 3621a8f2..6c17913e 100644 --- a/torch_uncertainty/losses/bayesian.py +++ b/torch_uncertainty/losses/bayesian.py @@ -24,9 +24,7 @@ def _kl_div(self) -> Tensor: count = 0 for module in self.model.modules(): if isinstance(module, bayesian_modules): - kl_divergence = kl_divergence.to( - device=module.lvposterior.device - ) + kl_divergence = kl_divergence.to(device=module.lvposterior.device) kl_divergence += module.lvposterior - module.lprior count += 1 return kl_divergence / count @@ -88,27 +86,14 @@ def set_model(self, model: nn.Module | None) -> None: self._kl_div = KLDiv(model) -def _elbo_loss_checks( - inner_loss: nn.Module, kl_weight: float, num_samples: int -) -> None: +def _elbo_loss_checks(inner_loss: nn.Module, kl_weight: float, num_samples: int) -> None: if isinstance(inner_loss, type): - raise TypeError( - "The inner_loss should be an instance of a class." - f"Got {inner_loss}." - ) + raise TypeError("The inner_loss should be an instance of a class." f"Got {inner_loss}.") if kl_weight < 0: - raise ValueError( - f"The KL weight should be non-negative. Got {kl_weight}." - ) + raise ValueError(f"The KL weight should be non-negative. Got {kl_weight}.") if num_samples < 1: - raise ValueError( - "The number of samples should not be lower than 1." - f"Got {num_samples}." - ) + raise ValueError("The number of samples should not be lower than 1." f"Got {num_samples}.") if not isinstance(num_samples, int): - raise TypeError( - "The number of samples should be an integer. " - f"Got {type(num_samples)}." - ) + raise TypeError("The number of samples should be an integer. " f"Got {type(num_samples)}.") diff --git a/torch_uncertainty/losses/classification.py b/torch_uncertainty/losses/classification.py index 1ab56b8c..82c48280 100644 --- a/torch_uncertainty/losses/classification.py +++ b/torch_uncertainty/losses/classification.py @@ -31,16 +31,12 @@ def __init__( if reg_weight is not None and (reg_weight < 0): raise ValueError( - "The regularization weight should be non-negative, but got " - f"{reg_weight}." + "The regularization weight should be non-negative, but got " f"{reg_weight}." ) self.reg_weight = reg_weight if annealing_step is not None and (annealing_step <= 0): - raise ValueError( - "The annealing step should be positive, but got " - f"{annealing_step}." - ) + raise ValueError("The annealing step should be positive, but got " f"{annealing_step}.") self.annealing_step = annealing_step if reduction not in ("none", "mean", "sum") and reduction is not None: @@ -48,18 +44,14 @@ def __init__( self.reduction = reduction if loss_type not in ["mse", "log", "digamma"]: - raise ValueError( - f"{loss_type} is not a valid value for mse/log/digamma loss." - ) + raise ValueError(f"{loss_type} is not a valid value for mse/log/digamma loss.") self.loss_type = loss_type def _mse_loss(self, evidence: Tensor, targets: Tensor) -> Tensor: evidence = torch.relu(evidence) alpha = evidence + 1.0 strength = torch.sum(alpha, dim=1, keepdim=True) - loglikelihood_err = torch.sum( - (targets - (alpha / strength)) ** 2, dim=1, keepdim=True - ) + loglikelihood_err = torch.sum((targets - (alpha / strength)) ** 2, dim=1, keepdim=True) loglikelihood_var = torch.sum( alpha * (strength - alpha) / (strength * strength * (strength + 1)), dim=1, @@ -98,9 +90,7 @@ def _kldiv_reg( kl_alpha = (alpha - 1) * (1 - targets) + 1 - ones = torch.ones( - [1, num_classes], dtype=evidence.dtype, device=evidence.device - ) + ones = torch.ones([1, num_classes], dtype=evidence.dtype, device=evidence.device) sum_kl_alpha = torch.sum(kl_alpha, dim=1, keepdim=True) first_term = ( torch.lgamma(sum_kl_alpha) @@ -109,8 +99,7 @@ def _kldiv_reg( - torch.lgamma(ones.sum(dim=1, keepdim=True)) ) second_term = torch.sum( - (kl_alpha - ones) - * (torch.digamma(kl_alpha) - torch.digamma(sum_kl_alpha)), + (kl_alpha - ones) * (torch.digamma(kl_alpha) - torch.digamma(sum_kl_alpha)), dim=1, keepdim=True, ) @@ -122,11 +111,7 @@ def forward( targets: Tensor, current_epoch: int | None = None, ) -> Tensor: - if ( - self.annealing_step is not None - and self.annealing_step > 0 - and current_epoch is None - ): + if self.annealing_step is not None and self.annealing_step > 0 and current_epoch is None: raise ValueError( "The epoch num should be positive when \ annealing_step is settled, but got " @@ -134,9 +119,7 @@ def forward( ) if targets.ndim != 1: # if no mixup or cutmix - raise NotImplementedError( - "DECLoss does not yet support mixup/cutmix." - ) + raise NotImplementedError("DECLoss does not yet support mixup/cutmix.") # TODO: handle binary targets = F.one_hot(targets, num_classes=evidence.size()[-1]) @@ -154,9 +137,7 @@ def forward( else: annealing_coef = torch.min( input=torch.tensor(1.0, dtype=evidence.dtype), - other=torch.tensor( - current_epoch / self.annealing_step, dtype=evidence.dtype - ), + other=torch.tensor(current_epoch / self.annealing_step, dtype=evidence.dtype), ) loss_reg = self._kldiv_reg(evidence, targets) @@ -197,14 +178,11 @@ def __init__( self.reduction = reduction if eps < 0: - raise ValueError( - "The epsilon value should be non-negative, but got " f"{eps}." - ) + raise ValueError("The epsilon value should be non-negative, but got " f"{eps}.") self.eps = eps if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " - f"{reg_weight}." + "The regularization weight should be non-negative, but got " f"{reg_weight}." ) self.reg_weight = reg_weight @@ -220,9 +198,9 @@ def forward(self, logits: Tensor, targets: Tensor) -> Tensor: """ probs = F.softmax(logits, dim=1) ce_loss = F.cross_entropy(logits, targets, reduction=self.reduction) - reg_loss = torch.log( - torch.tensor(logits.shape[-1], device=probs.device) - ) + (probs * torch.log(probs + self.eps)).sum(dim=-1) + reg_loss = torch.log(torch.tensor(logits.shape[-1], device=probs.device)) + ( + probs * torch.log(probs + self.eps) + ).sum(dim=-1) if self.reduction == "sum": return ce_loss + self.reg_weight * reg_loss.sum() if self.reduction == "mean": @@ -255,8 +233,7 @@ def __init__( self.reduction = reduction if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " - f"{reg_weight}." + "The regularization weight should be non-negative, but got " f"{reg_weight}." ) self.reg_weight = reg_weight @@ -310,8 +287,7 @@ def __init__( if gamma < 0: raise ValueError( - "The gamma term of the focal loss should be non-negative, but got " - f"{gamma}." + "The gamma term of the focal loss should be non-negative, but got " f"{gamma}." ) self.gamma = gamma @@ -370,12 +346,8 @@ def forward(self, preds: Tensor, targets: Tensor) -> Tensor: if self.label_smoothing == 0.0: return super().forward(preds, targets.type_as(preds)) targets = targets.float() - targets = ( - targets * (1 - self.label_smoothing) + self.label_smoothing / 2 - ) - loss = targets * F.logsigmoid(preds) + (1 - targets) * F.logsigmoid( - -preds - ) + targets = targets * (1 - self.label_smoothing) + self.label_smoothing / 2 + loss = targets * F.logsigmoid(preds) + (1 - targets) * F.logsigmoid(-preds) if self.weight is not None: loss = loss * self.weight if self.reduction == "mean": diff --git a/torch_uncertainty/losses/regression.py b/torch_uncertainty/losses/regression.py index 99b9b9fd..32cbe02f 100644 --- a/torch_uncertainty/losses/regression.py +++ b/torch_uncertainty/losses/regression.py @@ -8,9 +8,7 @@ class DistributionNLLLoss(nn.Module): - def __init__( - self, reduction: Literal["mean", "sum"] | None = "mean" - ) -> None: + def __init__(self, reduction: Literal["mean", "sum"] | None = "mean") -> None: """Negative Log-Likelihood loss using given distributions as inputs. Args: @@ -46,9 +44,7 @@ def forward( class DERLoss(DistributionNLLLoss): - def __init__( - self, reg_weight: float, reduction: str | None = "mean" - ) -> None: + def __init__(self, reg_weight: float, reduction: str | None = "mean") -> None: """The Deep Evidential Regression loss. This loss combines the negative log-likelihood loss of the normal @@ -71,8 +67,7 @@ def __init__( if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " - f"{reg_weight}." + "The regularization weight should be non-negative, but got " f"{reg_weight}." ) self.reg_weight = reg_weight @@ -98,9 +93,7 @@ def forward( class BetaNLL(nn.Module): - def __init__( - self, beta: float = 0.5, reduction: str | None = "mean" - ) -> None: + def __init__(self, beta: float = 0.5, reduction: str | None = "mean") -> None: """The Beta Negative Log-likelihood loss. Args: @@ -118,22 +111,15 @@ def __init__( super().__init__() if beta < 0 or beta > 1: - raise ValueError( - "The beta parameter should be in range [0, 1], but got " - f"{beta}." - ) + raise ValueError("The beta parameter should be in range [0, 1], but got " f"{beta}.") self.beta = beta self.nll_loss = nn.GaussianNLLLoss(reduction="none") if reduction not in ("none", "mean", "sum"): raise ValueError(f"{reduction} is not a valid value for reduction.") self.reduction = reduction - def forward( - self, mean: Tensor, targets: Tensor, variance: Tensor - ) -> Tensor: - loss = self.nll_loss(mean, targets, variance) * ( - variance.detach() ** self.beta - ) + def forward(self, mean: Tensor, targets: Tensor, variance: Tensor) -> Tensor: + loss = self.nll_loss(mean, targets, variance) * (variance.detach() ** self.beta) if self.reduction == "mean": return loss.mean() diff --git a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py index c1e066d9..3bcc47fe 100644 --- a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py +++ b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py @@ -70,9 +70,7 @@ def _ace_compute( Tensor: Adaptive Calibration error scalar. """ with torch.no_grad(): - acc_bin, conf_bin, prop_bin = _equal_binning_bucketize( - confidences, accuracies, num_bins - ) + acc_bin, conf_bin, prop_bin = _equal_binning_bucketize(confidences, accuracies, num_bins) if norm == "l1": return torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) @@ -111,9 +109,7 @@ def __init__( ) -> None: super().__init__(**kwargs) if ignore_index is not None: # coverage: ignore - raise ValueError( - "ignore_index is not supported for multiclass tasks." - ) + raise ValueError("ignore_index is not supported for multiclass tasks.") if validate_args: _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) @@ -134,9 +130,7 @@ def compute(self) -> Tensor: """Compute metric.""" confidences = dim_zero_cat(self.confidences) accuracies = dim_zero_cat(self.accuracies) - return _ace_compute( - confidences, accuracies, self.n_bins, norm=self.norm - ) + return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm) class MulticlassAdaptiveCalibrationError(Metric): @@ -160,14 +154,10 @@ def __init__( ) -> None: super().__init__(**kwargs) if ignore_index is not None: # coverage: ignore - raise ValueError( - "ignore_index is not supported for multiclass tasks." - ) + raise ValueError("ignore_index is not supported for multiclass tasks.") if validate_args: - _multiclass_calibration_error_arg_validation( - num_classes, n_bins, norm, ignore_index - ) + _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) self.n_bins = n_bins self.norm = norm @@ -185,9 +175,7 @@ def compute(self) -> Tensor: """Compute metric.""" confidences = dim_zero_cat(self.confidences) accuracies = dim_zero_cat(self.accuracies) - return _ace_compute( - confidences, accuracies, self.n_bins, norm=self.norm - ) + return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm) class AdaptiveCalibrationError: diff --git a/torch_uncertainty/metrics/classification/brier_score.py b/torch_uncertainty/metrics/classification/brier_score.py index af3f490a..d474c6aa 100644 --- a/torch_uncertainty/metrics/classification/brier_score.py +++ b/torch_uncertainty/metrics/classification/brier_score.py @@ -71,9 +71,7 @@ def __init__( self.num_estimators = 1 if self.reduction in ["mean", "sum"]: - self.add_state( - "values", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) + self.add_state("values", default=torch.tensor(0.0), dist_reduce_fx="sum") else: self.add_state("values", default=[], dist_reduce_fx="cat") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") @@ -108,9 +106,7 @@ def update(self, probs: Tensor, target: Tensor) -> None: target = target.gather(-1, indices.unsqueeze(-1)).squeeze(-1) brier_score = F.mse_loss(probs, target, reduction="none") else: - brier_score = F.mse_loss(probs, target, reduction="none").sum( - dim=-1 - ) + brier_score = F.mse_loss(probs, target, reduction="none").sum(dim=-1) if self.reduction is None or self.reduction == "none": self.values.append(brier_score) diff --git a/torch_uncertainty/metrics/classification/calibration_error.py b/torch_uncertainty/metrics/classification/calibration_error.py index 13fd9448..790335e5 100644 --- a/torch_uncertainty/metrics/classification/calibration_error.py +++ b/torch_uncertainty/metrics/classification/calibration_error.py @@ -190,9 +190,7 @@ def custom_plot(self) -> _PLOT_OUT_TYPE: ) with torch.no_grad(): - acc_bin, conf_bin, prop_bin = _binning_bucketize( - confidences, accuracies, bin_boundaries - ) + acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) np_acc_bin = acc_bin.cpu().numpy() np_conf_bin = conf_bin.cpu().numpy() diff --git a/torch_uncertainty/metrics/classification/categorical_nll.py b/torch_uncertainty/metrics/classification/categorical_nll.py index 6a08f6d2..b1a78ef0 100644 --- a/torch_uncertainty/metrics/classification/categorical_nll.py +++ b/torch_uncertainty/metrics/classification/categorical_nll.py @@ -75,9 +75,7 @@ def update(self, probs: Tensor, target: Tensor) -> None: target (Tensor): Ground truth labels. """ if self.reduction is None or self.reduction == "none": - self.values.append( - F.nll_loss(torch.log(probs), target, reduction="none") - ) + self.values.append(F.nll_loss(torch.log(probs), target, reduction="none")) else: self.values += F.nll_loss(torch.log(probs), target, reduction="sum") self.total += target.size(0) diff --git a/torch_uncertainty/metrics/classification/entropy.py b/torch_uncertainty/metrics/classification/entropy.py index a7eae7f6..ef5df817 100644 --- a/torch_uncertainty/metrics/classification/entropy.py +++ b/torch_uncertainty/metrics/classification/entropy.py @@ -55,9 +55,7 @@ def __init__( self.reduction = reduction if self.reduction in ["mean", "sum"]: - self.add_state( - "values", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) + self.add_state("values", default=torch.tensor(0.0), dist_reduce_fx="sum") else: self.add_state("values", default=[], dist_reduce_fx="cat") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/torch_uncertainty/metrics/classification/fpr.py b/torch_uncertainty/metrics/classification/fpr.py index 214daded..53e3e779 100644 --- a/torch_uncertainty/metrics/classification/fpr.py +++ b/torch_uncertainty/metrics/classification/fpr.py @@ -28,9 +28,7 @@ def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: super().__init__(**kwargs) if recall_level < 0 or recall_level > 1: - raise ValueError( - f"Recall level must be between 0 and 1. Got {recall_level}." - ) + raise ValueError(f"Recall level must be between 0 and 1. Got {recall_level}.") self.recall_level = recall_level self.pos_label = pos_label self.add_state("conf", [], dist_reduce_fx="cat") @@ -77,17 +75,13 @@ def compute(self) -> Tensor: threshold_idxs = torch.cat( [ distinct_value_indices, - torch.tensor( - [labels.shape[0] - 1], dtype=torch.long, device=self.device - ), + torch.tensor([labels.shape[0] - 1], dtype=torch.long, device=self.device), ] ) # accumulate the true positives with decreasing threshold true_pos = torch.cumsum(labels, dim=0)[threshold_idxs] - false_pos = ( - 1 + threshold_idxs - true_pos - ) # add one because of zero-based indexing + false_pos = 1 + threshold_idxs - true_pos # add one because of zero-based indexing # check that there is at least one OOD example if true_pos[-1] == 0: diff --git a/torch_uncertainty/metrics/classification/grouping_loss.py b/torch_uncertainty/metrics/classification/grouping_loss.py index 0a53ac0d..137ee264 100644 --- a/torch_uncertainty/metrics/classification/grouping_loss.py +++ b/torch_uncertainty/metrics/classification/grouping_loss.py @@ -17,9 +17,7 @@ class GLEstimator(GLEstimatorBase): - def fit( - self, probs: Tensor, targets: Tensor, features: Tensor - ) -> "GLEstimator": + def fit(self, probs: Tensor, targets: Tensor, features: Tensor) -> "GLEstimator": probs = probs.detach().cpu().numpy() features = features.detach().cpu().numpy() targets = (targets * 1).detach().cpu().numpy() diff --git a/torch_uncertainty/metrics/classification/mutual_information.py b/torch_uncertainty/metrics/classification/mutual_information.py index 0c6c738d..c1224eec 100644 --- a/torch_uncertainty/metrics/classification/mutual_information.py +++ b/torch_uncertainty/metrics/classification/mutual_information.py @@ -65,9 +65,7 @@ def __init__( self.reduction = reduction if self.reduction in ["mean", "sum"]: - self.add_state( - "values", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) + self.add_state("values", default=torch.tensor(0.0), dist_reduce_fx="sum") else: self.add_state("values", default=[], dist_reduce_fx="cat") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py index 33298abe..6f76ef1f 100644 --- a/torch_uncertainty/metrics/classification/risk_coverage.py +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -400,10 +400,6 @@ def __init__(self, **kwargs) -> None: def _risk_coverage_checks(threshold: float) -> None: if not isinstance(threshold, float): - raise TypeError( - f"Expected threshold to be of type float, but got {type(threshold)}" - ) + raise TypeError(f"Expected threshold to be of type float, but got {type(threshold)}") if threshold < 0 or threshold > 1: - raise ValueError( - f"Threshold should be in the range [0, 1], but got {threshold}." - ) + raise ValueError(f"Threshold should be in the range [0, 1], but got {threshold}.") diff --git a/torch_uncertainty/metrics/classification/variation_ratio.py b/torch_uncertainty/metrics/classification/variation_ratio.py index a4e7609e..ff1ffd73 100644 --- a/torch_uncertainty/metrics/classification/variation_ratio.py +++ b/torch_uncertainty/metrics/classification/variation_ratio.py @@ -64,10 +64,7 @@ def compute(self) -> Tensor: max_classes_per_est = probs_per_est.argmax(dim=-1) variation_ratio = ( 1 - - torch.sum( - max_classes_per_est == max_classes.unsqueeze(1), dim=-1 - ) - / n_estimators + - torch.sum(max_classes_per_est == max_classes.unsqueeze(1), dim=-1) / n_estimators ) if self.reduction == "mean": diff --git a/torch_uncertainty/metrics/regression/inverse.py b/torch_uncertainty/metrics/regression/inverse.py index d661d9d2..b80212ff 100644 --- a/torch_uncertainty/metrics/regression/inverse.py +++ b/torch_uncertainty/metrics/regression/inverse.py @@ -62,9 +62,7 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - super().update( - 1 / (preds * self.unit_factor), 1 / (target * self.unit_factor) - ) + super().update(1 / (preds * self.unit_factor), 1 / (target * self.unit_factor)) class MeanAbsoluteErrorInverse(MeanAbsoluteError): @@ -101,6 +99,4 @@ def __init__(self, unit: str = "km", **kwargs) -> None: def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - super().update( - 1 / (preds * self.unit_factor), 1 / (target * self.unit_factor) - ) + super().update(1 / (preds * self.unit_factor), 1 / (target * self.unit_factor)) diff --git a/torch_uncertainty/metrics/regression/log10.py b/torch_uncertainty/metrics/regression/log10.py index 2885da79..526eb4f5 100644 --- a/torch_uncertainty/metrics/regression/log10.py +++ b/torch_uncertainty/metrics/regression/log10.py @@ -21,9 +21,7 @@ def __init__(self, **kwargs) -> None: `_. """ super().__init__(**kwargs) - self.add_state( - "values", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) + self.add_state("values", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, pred: Tensor, target: Tensor) -> None: diff --git a/torch_uncertainty/metrics/regression/relative_error.py b/torch_uncertainty/metrics/regression/relative_error.py index 9362013a..2988015f 100644 --- a/torch_uncertainty/metrics/regression/relative_error.py +++ b/torch_uncertainty/metrics/regression/relative_error.py @@ -41,9 +41,7 @@ def update(self, pred: Tensor, target: Tensor) -> None: class MeanGTRelativeSquaredError(MeanSquaredError): - def __init__( - self, squared: bool = True, num_outputs: int = 1, **kwargs - ) -> None: + def __init__(self, squared: bool = True, num_outputs: int = 1, **kwargs) -> None: r"""Compute mean squared error relative to the Ground Truth (MSErel or SRE). .. math:: \text{MSErel} = \frac{1}{N}\sum_i^N \frac{(y_i - \hat{y_i})^2}{y_i} diff --git a/torch_uncertainty/metrics/regression/silog.py b/torch_uncertainty/metrics/regression/silog.py index b7c4cf0d..9db2873a 100644 --- a/torch_uncertainty/metrics/regression/silog.py +++ b/torch_uncertainty/metrics/regression/silog.py @@ -7,9 +7,7 @@ class SILog(Metric): - def __init__( - self, sqrt: bool = False, lmbda: float = 1.0, **kwargs: Any - ) -> None: + def __init__(self, sqrt: bool = False, lmbda: float = 1.0, **kwargs: Any) -> None: r"""The Scale-Invariant Logarithmic Loss metric. .. math:: \text{SILog} = \frac{1}{N} \sum_{i=1}^{N} \left(\log(y_i) - \log(\hat{y_i})\right)^2 - \left(\frac{1}{N} \sum_{i=1}^{N} \log(y_i) \right)^2, @@ -60,9 +58,7 @@ def compute(self) -> Tensor: """Compute the Scale-Invariant Logarithmic Loss.""" log_dists = dim_zero_cat(self.log_dists) sq_log_dists = dim_zero_cat(self.sq_log_dists) - out = sq_log_dists / self.total - self.lmbda * log_dists**2 / ( - self.total * self.total - ) + out = sq_log_dists / self.total - self.lmbda * log_dists**2 / (self.total * self.total) if self.sqrt: return torch.sqrt(out) return out diff --git a/torch_uncertainty/metrics/regression/threshold_accuracy.py b/torch_uncertainty/metrics/regression/threshold_accuracy.py index 29d5ce5b..7a36e758 100644 --- a/torch_uncertainty/metrics/regression/threshold_accuracy.py +++ b/torch_uncertainty/metrics/regression/threshold_accuracy.py @@ -17,25 +17,17 @@ def __init__(self, power: int, lmbda: float = 1.25, **kwargs) -> None: """ super().__init__(**kwargs) if power < 0: - raise ValueError( - f"Power must be greater than or equal to 0. Got {power}." - ) + raise ValueError(f"Power must be greater than or equal to 0. Got {power}.") self.power = power if lmbda < 1: - raise ValueError( - f"Lambda must be greater than or equal to 1. Got {lmbda}." - ) + raise ValueError(f"Lambda must be greater than or equal to 1. Got {lmbda}.") self.lmbda = lmbda - self.add_state( - "values", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) + self.add_state("values", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - self.values += torch.sum( - torch.max(preds / target, target / preds) < self.lmbda**self.power - ) + self.values += torch.sum(torch.max(preds / target, target / preds) < self.lmbda**self.power) self.total += target.size(0) def compute(self) -> Tensor: diff --git a/torch_uncertainty/models/depth/bts.py b/torch_uncertainty/models/depth/bts.py index bead997f..64bae545 100644 --- a/torch_uncertainty/models/depth/bts.py +++ b/torch_uncertainty/models/depth/bts.py @@ -62,9 +62,7 @@ def __init__( self.norm_first = norm_first if norm_first: - self.first_norm = nn.BatchNorm2d( - in_channels, momentum=norm_momentum, **factory_kwargs - ) + self.first_norm = nn.BatchNorm2d(in_channels, momentum=norm_momentum, **factory_kwargs) self.conv1 = nn.Conv2d( in_channels=in_channels, @@ -75,9 +73,7 @@ def __init__( padding=0, **factory_kwargs, ) - self.norm = nn.BatchNorm2d( - out_channels * 2, momentum=norm_momentum, **factory_kwargs - ) + self.norm = nn.BatchNorm2d(out_channels * 2, momentum=norm_momentum, **factory_kwargs) self.conv2 = nn.Conv2d( in_channels=out_channels * 2, out_channels=out_channels, @@ -217,12 +213,8 @@ def forward(self, x: Tensor) -> Tensor: class LocalPlanarGuidance(nn.Module): def __init__(self, up_ratio: int) -> None: super().__init__() - self.register_buffer( - "u", torch.arange(up_ratio).reshape([1, 1, up_ratio]) - ) - self.register_buffer( - "v", torch.arange(up_ratio).reshape([1, up_ratio, 1]) - ) + self.register_buffer("u", torch.arange(up_ratio).reshape([1, 1, up_ratio])) + self.register_buffer("v", torch.arange(up_ratio).reshape([1, up_ratio, 1])) self.up_ratio = up_ratio def forward(self, x: Tensor) -> Tensor: @@ -245,9 +237,7 @@ def forward(self, x: Tensor) -> Tensor: v = (v - (self.up_ratio - 1) * 0.5) / self.up_ratio return x_expanded[:, 3, :, :] / ( - x_expanded[:, 0, :, :] * u - + x_expanded[:, 1, :, :] * v - + x_expanded[:, 2, :, :] + x_expanded[:, 0, :, :] * u + x_expanded[:, 1, :, :] * v + x_expanded[:, 2, :, :] ) @@ -281,15 +271,11 @@ def __init__(self, backbone_name: str, pretrained: bool) -> None: ) elif backbone_name == "resnext50": model = tv_models.resnext50_32x4d( - weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V2 - if pretrained - else None + weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V2 if pretrained else None ) else: # backbone_name == "resnext101": model = tv_models.resnext101_32x8d( - weights=ResNeXt101_32X8D_Weights.IMAGENET1K_V2 - if pretrained - else None + weights=ResNeXt101_32X8D_Weights.IMAGENET1K_V2 if pretrained else None ) if "res" in backbone_name: # remove classification heads from ResNets feat_names = resnet_feat_names @@ -310,9 +296,7 @@ def __init__( super().__init__() self.max_depth = max_depth - self.upconv5 = UpConv2d( - in_channels=feat_out_channels[4], out_channels=num_features - ) + self.upconv5 = UpConv2d(in_channels=feat_out_channels[4], out_channels=num_features) self.bn5 = nn.BatchNorm2d(num_features, momentum=0.01, affine=True) self.conv5 = nn.Conv2d( @@ -323,9 +307,7 @@ def __init__( padding=1, bias=False, ) - self.upconv4 = UpConv2d( - in_channels=num_features, out_channels=num_features // 2 - ) + self.upconv4 = UpConv2d(in_channels=num_features, out_channels=num_features // 2) self.bn4 = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True) self.conv4 = nn.Conv2d( num_features // 2 + feat_out_channels[2], @@ -335,9 +317,7 @@ def __init__( padding=1, bias=False, ) - self.bn4_2 = nn.BatchNorm2d( - num_features // 2, momentum=0.01, affine=True - ) + self.bn4_2 = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True) self.daspp_3 = AtrousBlock2d( num_features // 2, @@ -381,14 +361,10 @@ def __init__( ), nn.ELU(), ) - self.reduc8x8 = Reduction1x1( - num_features // 4, num_features // 4, self.max_depth - ) + self.reduc8x8 = Reduction1x1(num_features // 4, num_features // 4, self.max_depth) self.lpg8x8 = LocalPlanarGuidance(8) - self.upconv3 = UpConv2d( - in_channels=num_features // 4, out_channels=num_features // 4 - ) + self.upconv3 = UpConv2d(in_channels=num_features // 4, out_channels=num_features // 4) self.bn3 = nn.BatchNorm2d(num_features // 4, momentum=0.01, affine=True) self.conv3 = nn.Conv2d( num_features // 4 + feat_out_channels[1] + 1, @@ -398,14 +374,10 @@ def __init__( padding=1, bias=False, ) - self.reduc4x4 = Reduction1x1( - num_features // 4, num_features // 8, self.max_depth - ) + self.reduc4x4 = Reduction1x1(num_features // 4, num_features // 8, self.max_depth) self.lpg4x4 = LocalPlanarGuidance(4) - self.upconv2 = UpConv2d( - in_channels=num_features // 4, out_channels=num_features // 8 - ) + self.upconv2 = UpConv2d(in_channels=num_features // 4, out_channels=num_features // 8) self.bn2 = nn.BatchNorm2d(num_features // 8, momentum=0.01, affine=True) self.conv2 = nn.Conv2d( num_features // 8 + feat_out_channels[0] + 1, @@ -416,33 +388,23 @@ def __init__( bias=False, ) - self.reduc2x2 = Reduction1x1( - num_features // 8, num_features // 16, self.max_depth - ) + self.reduc2x2 = Reduction1x1(num_features // 8, num_features // 16, self.max_depth) self.lpg2x2 = LocalPlanarGuidance(2) - self.upconv1 = UpConv2d( - in_channels=num_features // 8, out_channels=num_features // 16 - ) + self.upconv1 = UpConv2d(in_channels=num_features // 8, out_channels=num_features // 16) self.reduc1x1 = Reduction1x1( num_features // 16, num_features // 32, self.max_depth, is_final=True, ) - self.conv1 = nn.Conv2d( - num_features // 16 + 4, num_features // 16, 3, 1, 1, bias=False - ) + self.conv1 = nn.Conv2d(num_features // 16 + 4, num_features // 16, 3, 1, 1, bias=False) self.output_channels = 1 if dist_layer in (NormalLayer, LaplaceLayer): self.output_channels = 2 elif dist_layer != nn.Identity: - raise ValueError( - f"Unsupported distribution layer. Got {dist_layer}." - ) - self.depth = nn.Conv2d( - num_features // 16, self.output_channels, 3, 1, 1, bias=False - ) + raise ValueError(f"Unsupported distribution layer. Got {dist_layer}.") + self.depth = nn.Conv2d(num_features // 16, self.output_channels, 3, 1, 1, bias=False) self.dist_layer = dist_layer(dim=1) def feat_forward(self, features: list[Tensor]) -> Tensor: @@ -462,23 +424,17 @@ def feat_forward(self, features: list[Tensor]) -> Tensor: concat4_4 = torch.cat([concat4_3, daspp_12], dim=1) daspp_18 = self.daspp_18(concat4_4) daspp_24 = self.daspp_24(torch.cat([concat4_4, daspp_18], dim=1)) - concat4_daspp = torch.cat( - [iconv4, daspp_3, daspp_6, daspp_12, daspp_18, daspp_24], dim=1 - ) + concat4_daspp = torch.cat([iconv4, daspp_3, daspp_6, daspp_12, daspp_18, daspp_24], dim=1) daspp_feat = self.daspp_conv(concat4_daspp) reduc8x8 = self.reduc8x8(daspp_feat) plane_normal_8x8 = reduc8x8[:, :3, :, :] plane_normal_8x8 = F.normalize(plane_normal_8x8, p=2, dim=1) plane_dist_8x8 = reduc8x8[:, 3, :, :] - plane_eq_8x8 = torch.cat( - [plane_normal_8x8, plane_dist_8x8.unsqueeze(1)], 1 - ) + plane_eq_8x8 = torch.cat([plane_normal_8x8, plane_dist_8x8.unsqueeze(1)], 1) depth_8x8 = self.lpg8x8(plane_eq_8x8) depth_8x8_scaled = depth_8x8.unsqueeze(1) / self.max_depth - depth_8x8_scaled_ds = F.interpolate( - depth_8x8_scaled, scale_factor=0.25, mode="nearest" - ) + depth_8x8_scaled_ds = F.interpolate(depth_8x8_scaled, scale_factor=0.25, mode="nearest") upconv3 = self.bn3(self.upconv3(daspp_feat)) # H/4 concat3 = torch.cat([upconv3, features[1], depth_8x8_scaled_ds], dim=1) @@ -488,29 +444,19 @@ def feat_forward(self, features: list[Tensor]) -> Tensor: plane_normal_4x4 = reduc4x4[:, :3, :, :] plane_normal_4x4 = F.normalize(plane_normal_4x4, p=2, dim=1) plane_dist_4x4 = reduc4x4[:, 3, :, :] - plane_eq_4x4 = torch.cat( - [plane_normal_4x4, plane_dist_4x4.unsqueeze(1)], 1 - ) + plane_eq_4x4 = torch.cat([plane_normal_4x4, plane_dist_4x4.unsqueeze(1)], 1) depth_4x4 = self.lpg4x4(plane_eq_4x4) depth_4x4_scaled = depth_4x4.unsqueeze(1) / self.max_depth - depth_4x4_scaled_ds = F.interpolate( - depth_4x4_scaled, scale_factor=0.5, mode="nearest" - ) + depth_4x4_scaled_ds = F.interpolate(depth_4x4_scaled, scale_factor=0.5, mode="nearest") upconv2 = self.bn2(self.upconv2(iconv3)) # H/2 - iconv2 = F.elu( - self.conv2( - torch.cat([upconv2, features[0], depth_4x4_scaled_ds], dim=1) - ) - ) + iconv2 = F.elu(self.conv2(torch.cat([upconv2, features[0], depth_4x4_scaled_ds], dim=1))) reduc2x2 = self.reduc2x2(iconv2) plane_normal_2x2 = reduc2x2[:, :3, :, :] plane_normal_2x2 = F.normalize(plane_normal_2x2, p=2, dim=1) plane_dist_2x2 = reduc2x2[:, 3, :, :] - plane_eq_2x2 = torch.cat( - [plane_normal_2x2, plane_dist_2x2.unsqueeze(1)], 1 - ) + plane_eq_2x2 = torch.cat([plane_normal_2x2, plane_dist_2x2.unsqueeze(1)], 1) depth_2x2 = self.lpg2x2(plane_eq_2x2) depth_2x2_scaled = depth_2x2.unsqueeze(1) / self.max_depth @@ -583,9 +529,7 @@ def __init__( self.max_depth = max_depth self.backbone = BTSBackbone(backbone_name, pretrained_backbone) - self.decoder = BTSDecoder( - max_depth, self.backbone.feat_out_channels, bts_size, dist_layer - ) + self.decoder = BTSDecoder(max_depth, self.backbone.feat_out_channels, bts_size, dist_layer) # TODO: Handle focal def forward(self, x: Tensor, focal: float | None = None) -> Tensor: @@ -607,9 +551,7 @@ def _bts( ) -> _BTS: if backbone_name not in bts_backbones: raise ValueError(f"Unsupported backbone. Got {backbone_name}.") - return _BTS( - backbone_name, max_depth, bts_size, dist_layer, pretrained_backbone - ) + return _BTS(backbone_name, max_depth, bts_size, dist_layer, pretrained_backbone) def bts_resnet50( diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 55a4c772..4b76c9e6 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -33,20 +33,14 @@ def __init__( if norm == nn.Identity: self.norm1 = norm() self.norm2 = norm() - elif norm == nn.BatchNorm2d or ( - isinstance(norm, partial) and norm.func == MCBatchNorm2d - ): + elif norm == nn.BatchNorm2d or (isinstance(norm, partial) and norm.func == MCBatchNorm2d): batchnorm = True else: - raise ValueError( - f"norm must be nn.Identity or nn.BatchNorm2d. Got {norm}." - ) + raise ValueError(f"norm must be nn.Identity or nn.BatchNorm2d. Got {norm}.") self.dropout_rate = dropout_rate - self.conv1 = conv2d_layer( - in_channels, 6, (5, 5), groups=groups, **layer_args - ) + self.conv1 = conv2d_layer(in_channels, 6, (5, 5), groups=groups, **layer_args) if batchnorm: self.norm1 = norm(6) self.conv_dropout = nn.Dropout2d(p=dropout_rate) diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 0dd17547..0f519a26 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -56,21 +56,15 @@ def __init__( layers.append(layer(in_features, num_outputs, **layer_args)) else: if layer == PackedLinear: - layers.append( - layer(in_features, hidden_dims[0], first=True, **layer_args) - ) + layers.append(layer(in_features, hidden_dims[0], first=True, **layer_args)) else: layers.append(layer(in_features, hidden_dims[0], **layer_args)) for i in range(1, len(hidden_dims)): - layers.append( - layer(hidden_dims[i - 1], hidden_dims[i], **layer_args) - ) + layers.append(layer(hidden_dims[i - 1], hidden_dims[i], **layer_args)) if layer == PackedLinear: - layers.append( - layer(hidden_dims[-1], num_outputs, last=True, **layer_args) - ) + layers.append(layer(hidden_dims[-1], num_outputs, last=True, **layer_args)) else: layers.append(layer(hidden_dims[-1], num_outputs, **layer_args)) self.layers = layers diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index 933b0749..d7fa9b68 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -192,9 +192,7 @@ def __init__( self.bn1 = normalization_layer(block_planes) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() @@ -332,9 +330,7 @@ def batched_resnet( Returns: _BatchedResNet: A BatchEnsemble-style ResNet. """ - block = ( - _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck - ) + block = _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _BatchedResNet( block=block, diff --git a/torch_uncertainty/models/resnet/lpbnn.py b/torch_uncertainty/models/resnet/lpbnn.py index 6d22720f..c1641bf1 100644 --- a/torch_uncertainty/models/resnet/lpbnn.py +++ b/torch_uncertainty/models/resnet/lpbnn.py @@ -198,9 +198,7 @@ def __init__( self.bn1 = normalization_layer(block_planes) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() @@ -326,9 +324,7 @@ def lpbnn_resnet( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", ) -> _LPBNNResNet: - block = ( - _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck - ) + block = _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _LPBNNResNet( block=block, diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 4af8ab2a..30d28d10 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -202,9 +202,7 @@ def __init__( self.bn1 = normalization_layer(block_planes) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() @@ -351,9 +349,7 @@ def masked_resnet( Returns: _MaskedResNet: A Masksembles-style ResNet. """ - block = ( - _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck - ) + block = _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _MaskedResNet( block=block, diff --git a/torch_uncertainty/models/resnet/mimo.py b/torch_uncertainty/models/resnet/mimo.py index 11ee2228..1ecb1d2a 100644 --- a/torch_uncertainty/models/resnet/mimo.py +++ b/torch_uncertainty/models/resnet/mimo.py @@ -62,9 +62,7 @@ def mimo_resnet( style: Literal["imagenet", "cifar"] = "imagenet", normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _MIMOResNet: - block = ( - _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck - ) + block = _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _MIMOResNet( block=block, diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index 217a07a3..cf16343f 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -251,9 +251,7 @@ def __init__( self.bn1 = normalization_layer(block_planes * alpha) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() @@ -369,9 +367,7 @@ def forward(self, x: Tensor) -> Tensor: out = self.layer3(out) out = self.layer4(out) - out = rearrange( - out, "e (m c) h w -> (m e) c h w", m=self.num_estimators - ) + out = rearrange(out, "e (m c) h w -> (m e) c h w", m=self.num_estimators) out = self.pool(out) out = self.final_dropout(self.flatten(out)) @@ -425,9 +421,7 @@ def packed_resnet( Returns: _PackedResNet: A Packed-Ensembles ResNet. """ - block = ( - _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck - ) + block = _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 net = _PackedResNet( block=block, @@ -450,8 +444,6 @@ def packed_resnet( raise ValueError("No pretrained weights for this configuration") state_dict, config = load_hf(weights) if not net.check_config(config): - raise ValueError( - "Pretrained weights do not match current configuration." - ) + raise ValueError("Pretrained weights do not match current configuration.") net.load_state_dict(state_dict) return net diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index 1b9ddccd..1cd096bc 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -237,9 +237,7 @@ def __init__( self.bn1 = normalization_layer(block_planes) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() @@ -379,9 +377,7 @@ def resnet( Returns: _ResNet: The ResNet model. """ - block = ( - _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck - ) + block = _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 return _ResNet( block=block, diff --git a/torch_uncertainty/models/segmentation/deeplab.py b/torch_uncertainty/models/segmentation/deeplab.py index 7029b4bf..13bbfe8b 100644 --- a/torch_uncertainty/models/segmentation/deeplab.py +++ b/torch_uncertainty/models/segmentation/deeplab.py @@ -108,9 +108,7 @@ def __init__(self, in_channels: int, out_channels: int) -> None: """ super().__init__() self.pool = nn.AdaptiveAvgPool2d(1) - self.conv = nn.Conv2d( - in_channels, out_channels, kernel_size=1, bias=False - ) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x: Tensor) -> Tensor: @@ -147,16 +145,13 @@ def __init__( ) ) modules += [ - InnerConv(in_channels, out_channels, dilation, separable) - for dilation in atrous_rates + InnerConv(in_channels, out_channels, dilation, separable) for dilation in atrous_rates ] modules.append(InnerPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules) self.projection = nn.Sequential( - nn.Conv2d( - 5 * out_channels, out_channels, kernel_size=1, bias=False - ), + nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), @@ -264,9 +259,7 @@ def __init__( nn.BatchNorm2d(48), nn.ReLU(inplace=True), ) - self.atrous_spatial_pyramid_pool = ASPP( - in_channels, aspp_dilate, separable, dropout_rate - ) + self.atrous_spatial_pyramid_pool = ASPP(in_channels, aspp_dilate, separable, dropout_rate) if separable: self.conv = SeparableConv2d(304, 256, 3, padding=1, bias=False) else: @@ -283,9 +276,7 @@ def forward(self, features: list[Tensor]) -> Tensor: mode="bilinear", align_corners=False, ) - output_features = torch.cat( - [low_level_features, output_features], dim=1 - ) + output_features = torch.cat([low_level_features, output_features], dim=1) out = F.relu(self.bn(self.conv(output_features))) return self.classifier(out) @@ -328,13 +319,9 @@ def __init__( elif output_stride == 8: dilations = [12, 24, 36] else: - raise ValueError( - f"output_stride: {output_stride} is not supported." - ) + raise ValueError(f"output_stride: {output_stride} is not supported.") - self.backbone = DeepLabV3Backbone( - backbone_name, style, pretrained_backbone, norm_momentum - ) + self.backbone = DeepLabV3Backbone(backbone_name, style, pretrained_backbone, norm_momentum) if style == "v3": self.decoder = DeepLabV3Decoder( in_channels=2048, diff --git a/torch_uncertainty/models/segmentation/segformer.py b/torch_uncertainty/models/segmentation/segformer.py index cf71fce8..cf8f35d6 100644 --- a/torch_uncertainty/models/segmentation/segformer.py +++ b/torch_uncertainty/models/segmentation/segformer.py @@ -70,9 +70,7 @@ def __init__( sr_ratio: int = 1, ): super().__init__() - assert ( - dim % num_heads == 0 - ), f"dim {dim} should be divided by num_heads {num_heads}." + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads @@ -107,11 +105,7 @@ def _init_weights(self, m): def forward(self, x: Tensor, h: int, w: int): b, n, c = x.shape - q = ( - self.q(x) - .reshape(b, n, self.num_heads, c // self.num_heads) - .permute(0, 2, 1, 3) - ) + q = self.q(x).reshape(b, n, self.num_heads, c // self.num_heads).permute(0, 2, 1, 3) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(b, c, h, w) @@ -167,9 +161,7 @@ def __init__( ) # NOTE: drop path for stochastic depth, we shall see if this is better # than dropout here - self.drop_path = ( - DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MLP( @@ -202,9 +194,7 @@ def forward(self, x, h, w): class OverlapPatchEmbed(nn.Module): """Image to Patch Embedding.""" - def __init__( - self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768 - ): + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -559,38 +549,16 @@ def forward(self, inputs: Tensor) -> Tensor: n, _, _, _ = c4.shape - _c4 = ( - self.linear_c4(c4) - .permute(0, 2, 1) - .reshape(n, -1, c4.shape[2], c4.shape[3]) - ) - _c4 = resize( - _c4, size=c1.size()[2:], mode="bilinear", align_corners=False - ) + _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) + _c4 = resize(_c4, size=c1.size()[2:], mode="bilinear", align_corners=False) - _c3 = ( - self.linear_c3(c3) - .permute(0, 2, 1) - .reshape(n, -1, c3.shape[2], c3.shape[3]) - ) - _c3 = resize( - _c3, size=c1.size()[2:], mode="bilinear", align_corners=False - ) + _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) + _c3 = resize(_c3, size=c1.size()[2:], mode="bilinear", align_corners=False) - _c2 = ( - self.linear_c2(c2) - .permute(0, 2, 1) - .reshape(n, -1, c2.shape[2], c2.shape[3]) - ) - _c2 = resize( - _c2, size=c1.size()[2:], mode="bilinear", align_corners=False - ) + _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) + _c2 = resize(_c2, size=c1.size()[2:], mode="bilinear", align_corners=False) - _c1 = ( - self.linear_c1(c1) - .permute(0, 2, 1) - .reshape(n, -1, c1.shape[2], c1.shape[3]) - ) + _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) _c = self.fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) diff --git a/torch_uncertainty/models/vgg/base.py b/torch_uncertainty/models/vgg/base.py index 7a633b5b..96b48f06 100644 --- a/torch_uncertainty/models/vgg/base.py +++ b/torch_uncertainty/models/vgg/base.py @@ -46,9 +46,7 @@ def __init__( kernel_surface = 1 if self.linear_layer == PackedLinear: - last_linear = linear_layer( - 4096, num_classes, last=True, **model_kwargs - ) + last_linear = linear_layer(4096, num_classes, last=True, **model_kwargs) else: last_linear = linear_layer(4096, num_classes, **model_kwargs) @@ -69,9 +67,7 @@ def __init__( def _init_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d | PackedConv2d): - nn.init.kaiming_normal_( - m.weight, mode="fan_out", nonlinearity="relu" - ) + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: # coverage: ignore nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # coverage: ignore diff --git a/torch_uncertainty/models/wideresnet/batched.py b/torch_uncertainty/models/wideresnet/batched.py index 120dc267..779bb6e3 100644 --- a/torch_uncertainty/models/wideresnet/batched.py +++ b/torch_uncertainty/models/wideresnet/batched.py @@ -128,9 +128,7 @@ def __init__( self.bn1 = normalization_layer(num_stages[0]) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() diff --git a/torch_uncertainty/models/wideresnet/masked.py b/torch_uncertainty/models/wideresnet/masked.py index 230441dd..4661d6d3 100644 --- a/torch_uncertainty/models/wideresnet/masked.py +++ b/torch_uncertainty/models/wideresnet/masked.py @@ -131,9 +131,7 @@ def __init__( self.bn1 = normalization_layer(num_stages[0]) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() @@ -181,9 +179,7 @@ def __init__( self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) - self.linear = MaskedLinear( - num_stages[3], num_classes, num_estimators, scale=scale - ) + self.linear = MaskedLinear(num_stages[3], num_classes, num_estimators, scale=scale) def _wide_layer( self, diff --git a/torch_uncertainty/models/wideresnet/mimo.py b/torch_uncertainty/models/wideresnet/mimo.py index 3e6d9991..974da680 100644 --- a/torch_uncertainty/models/wideresnet/mimo.py +++ b/torch_uncertainty/models/wideresnet/mimo.py @@ -46,9 +46,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.training: x = x.repeat(self.num_estimators, 1, 1, 1) out = rearrange(x, "(m b) c h w -> b (m c) h w", m=self.num_estimators) - return rearrange( - super().forward(out), "b (m d) -> (m b) d", m=self.num_estimators - ) + return rearrange(super().forward(out), "b (m d) -> (m b) d", m=self.num_estimators) def mimo_wideresnet28x10( diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/wideresnet/packed.py index 25195925..d1ffbd21 100644 --- a/torch_uncertainty/models/wideresnet/packed.py +++ b/torch_uncertainty/models/wideresnet/packed.py @@ -144,9 +144,7 @@ def __init__( self.bn1 = normalization_layer(num_stages[0] * alpha) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() @@ -248,9 +246,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) - out = rearrange( - out, "e (m c) h w -> (m e) c h w", m=self.num_estimators - ) + out = rearrange(out, "e (m c) h w -> (m e) c h w", m=self.num_estimators) out = self.pool(out) return self.final_dropout(self.flatten(out)) diff --git a/torch_uncertainty/models/wideresnet/std.py b/torch_uncertainty/models/wideresnet/std.py index be4ca9b6..fb5cc67e 100644 --- a/torch_uncertainty/models/wideresnet/std.py +++ b/torch_uncertainty/models/wideresnet/std.py @@ -119,9 +119,7 @@ def __init__( self.bn1 = normalization_layer(num_stages[0]) if style == "imagenet": - self.optional_pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1 - ) + self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.optional_pool = nn.Identity() diff --git a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py index f2d4d869..ced789a9 100644 --- a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py +++ b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py @@ -59,9 +59,7 @@ def eval_forward(self, x: torch.Tensor) -> torch.Tensor: """ if not len(self.saved_models): return self.core_model.forward(x) - preds = torch.cat( - [model.forward(x) for model in self.saved_models], dim=0 - ) + preds = torch.cat([model.forward(x) for model in self.saved_models], dim=0) if self.use_final_checkpoint: model_forward = self.core_model.forward(x) preds = torch.cat([model_forward, preds], dim=0) diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index a72ae7c4..2a333e07 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -29,9 +29,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: where :math:`B` is the batch size, :math:`N` is the number of estimators, and :math:`C` is the number of classes. """ - return torch.cat( - [model.forward(x) for model in self.core_models], dim=0 - ) + return torch.cat([model.forward(x) for model in self.core_models], dim=0) class _RegDeepEnsembles(_DeepEnsembles): @@ -54,9 +52,7 @@ def forward(self, x: torch.Tensor) -> Distribution: Distribution: """ if self.probabilistic: - return cat_dist( - [model.forward(x) for model in self.core_models], dim=0 - ) + return cat_dist([model.forward(x) for model in self.core_models], dim=0) return super().forward(x) @@ -98,17 +94,11 @@ def deep_ensembles( """ if isinstance(models, list) and len(models) == 0: raise ValueError("Models must not be an empty list.") - if (isinstance(models, list) and len(models) == 1) or isinstance( - models, nn.Module - ): + if (isinstance(models, list) and len(models) == 1) or isinstance(models, nn.Module): if num_estimators is None: - raise ValueError( - "if models is a module, num_estimators must be specified." - ) + raise ValueError("if models is a module, num_estimators must be specified.") if num_estimators < 2: - raise ValueError( - f"num_estimators must be at least 2. Got {num_estimators}." - ) + raise ValueError(f"num_estimators must be at least 2. Got {num_estimators}.") if isinstance(models, list): models = models[0] @@ -121,21 +111,13 @@ def deep_ensembles( if hasattr(layer, "reset_parameters"): layer.reset_parameters() - elif ( - isinstance(models, list) - and len(models) > 1 - and num_estimators is not None - ): - raise ValueError( - "num_estimators must be None if you provided a non-singleton list." - ) + elif isinstance(models, list) and len(models) > 1 and num_estimators is not None: + raise ValueError("num_estimators must be None if you provided a non-singleton list.") if task in ("classification", "segmentation"): return _DeepEnsembles(models=models) if task in ("regression", "pixel_regression"): if probabilistic is None: - raise ValueError( - "probabilistic must be specified for regression models." - ) + raise ValueError("probabilistic must be specified for regression models.") return _RegDeepEnsembles(probabilistic=probabilistic, models=models) raise ValueError(f"Unknown task: {task}.") diff --git a/torch_uncertainty/models/wrappers/ema.py b/torch_uncertainty/models/wrappers/ema.py index 386fcca7..1e4c276c 100644 --- a/torch_uncertainty/models/wrappers/ema.py +++ b/torch_uncertainty/models/wrappers/ema.py @@ -33,9 +33,7 @@ def update_wrapper(self, epoch: int | None = None) -> None: self.core_model.parameters(), strict=False, ): - ema_param.data = ( - ema_param.data * self.momentum + param.data * self.remainder - ) + ema_param.data = ema_param.data * self.momentum + param.data * self.remainder def eval_forward(self, x: Tensor) -> Tensor: return self.ema_model.forward(x) @@ -48,6 +46,4 @@ def forward(self, x: Tensor) -> Tensor: def _ema_checks(momentum: float) -> None: if momentum < 0.0 or momentum >= 1.0: - raise ValueError( - f"`momentum` must be in the range [0, 1). Got {momentum}." - ) + raise ValueError(f"`momentum` must be in the range [0, 1). Got {momentum}.") diff --git a/torch_uncertainty/models/wrappers/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py index 6bd92ac0..ff6df387 100644 --- a/torch_uncertainty/models/wrappers/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -88,9 +88,7 @@ def forward( x = x.repeat(self.num_estimators, 1, 1, 1) return self.core_model(x) # Else, for loop - return torch.cat( - [self.core_model(x) for _ in range(self.num_estimators)], dim=0 - ) + return torch.cat([self.core_model(x) for _ in range(self.num_estimators)], dim=0) def mc_dropout( @@ -118,9 +116,7 @@ def mc_dropout( ) -def _dropout_checks( - filtered_modules: list[nn.Module], num_estimators: int -) -> None: +def _dropout_checks(filtered_modules: list[nn.Module], num_estimators: int) -> None: if not filtered_modules: raise ValueError( "No dropout module found in the model. " @@ -128,10 +124,6 @@ def _dropout_checks( ) # Check that at least one module has > 0.0 dropout rate if not any(mod.p > 0.0 for mod in filtered_modules): - raise ValueError( - "At least one dropout module must have a dropout rate > 0.0." - ) + raise ValueError("At least one dropout module must have a dropout rate > 0.0.") if num_estimators <= 0: - raise ValueError( - "`num_estimators` must be strictly positive to use MC Dropout." - ) + raise ValueError("`num_estimators` must be strictly positive to use MC Dropout.") diff --git a/torch_uncertainty/models/wrappers/stochastic.py b/torch_uncertainty/models/wrappers/stochastic.py index 7f298a87..64ab8371 100644 --- a/torch_uncertainty/models/wrappers/stochastic.py +++ b/torch_uncertainty/models/wrappers/stochastic.py @@ -11,9 +11,7 @@ def __init__(self, model: nn.Module, num_samples: int) -> None: self.num_samples = num_samples def eval_forward(self, x: Tensor) -> Tensor: - return torch.cat( - [self.core_model.forward(x) for _ in range(self.num_samples)], dim=0 - ) + return torch.cat([self.core_model.forward(x) for _ in range(self.num_samples)], dim=0) def forward(self, x: Tensor) -> Tensor: if self.training: @@ -37,8 +35,7 @@ def sample(self, num_samples: int = 1) -> list[dict]: break # TODO: fix this model |= { - module_name + "." + key: val - for key, val in module.state_dict().items() + module_name + "." + key: val for key, val in module.state_dict().items() } return sampled_models diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py index 27fbb20e..1239bfa2 100644 --- a/torch_uncertainty/models/wrappers/swa.py +++ b/torch_uncertainty/models/wrappers/swa.py @@ -42,10 +42,7 @@ def __init__( @torch.no_grad() def update_wrapper(self, epoch: int) -> None: - if ( - epoch >= self.cycle_start - and (epoch - self.cycle_start) % self.cycle_length == 0 - ): + if epoch >= self.cycle_start and (epoch - self.cycle_start) % self.cycle_length == 0: if self.swa_model is None: self.swa_model = copy.deepcopy(self.core_model) self.num_avgd_models = torch.tensor(1) @@ -55,9 +52,7 @@ def update_wrapper(self, epoch: int) -> None: self.core_model.parameters(), strict=False, ): - swa_param.data += (param.data - swa_param.data) / ( - self.num_avgd_models + 1 - ) + swa_param.data += (param.data - swa_param.data) / (self.num_avgd_models + 1) self.num_avgd_models += 1 self.need_bn_update = True @@ -73,18 +68,12 @@ def forward(self, x: Tensor) -> Tensor: def bn_update(self, loader: DataLoader, device) -> None: if self.need_bn_update and self.swa_model is not None: - torch.optim.swa_utils.update_bn( - loader, self.swa_model, device=device - ) + torch.optim.swa_utils.update_bn(loader, self.swa_model, device=device) self.need_bn_update = False def _swa_checks(cycle_start: int, cycle_length: int) -> None: if cycle_start < 0: - raise ValueError( - f"`cycle_start` must be non-negative. Got {cycle_start}." - ) + raise ValueError(f"`cycle_start` must be non-negative. Got {cycle_start}.") if cycle_length <= 0: - raise ValueError( - f"`cycle_length` must be strictly positive. Got {cycle_length}." - ) + raise ValueError(f"`cycle_length` must be strictly positive. Got {cycle_length}.") diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index e96bbe87..3132439d 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -94,9 +94,7 @@ def initialize_stats(self) -> None: if not self.diag_covariance: covariance_sqrt = torch.zeros((0, param.numel()), device="cpu") - self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = ( - covariance_sqrt - ) + self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = covariance_sqrt @torch.no_grad() def update_wrapper(self, epoch: int) -> None: @@ -109,10 +107,7 @@ def update_wrapper(self, epoch: int) -> None: Args: epoch (int): Current epoch. """ - if not ( - epoch > self.cycle_start - and (epoch - self.cycle_start) % self.cycle_length == 0 - ): + if not (epoch > self.cycle_start and (epoch - self.cycle_start) % self.cycle_length == 0): return for name_p, param in self.core_model.named_parameters(): @@ -120,9 +115,9 @@ def update_wrapper(self, epoch: int) -> None: squared_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] new_param = param.data.detach().cpu() - mean = mean * self.num_avgd_models / ( + mean = mean * self.num_avgd_models / (self.num_avgd_models + 1) + new_param / ( self.num_avgd_models + 1 - ) + new_param / (self.num_avgd_models + 1) + ) squared_mean = squared_mean * self.num_avgd_models / ( self.num_avgd_models + 1 ) + new_param**2 / (self.num_avgd_models + 1) @@ -131,22 +126,17 @@ def update_wrapper(self, epoch: int) -> None: self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean if not self.diag_covariance: - covariance_sqrt = self.swag_stats[ - self.prfx + name_p + "_covariance_sqrt" - ] + covariance_sqrt = self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] dev = (new_param - mean).view(-1, 1).t() covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) if self.num_avgd_models + 1 > self.max_num_models: covariance_sqrt = covariance_sqrt[1:, :] - self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = ( - covariance_sqrt - ) + self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = covariance_sqrt self.num_avgd_models += 1 self.samples = [ - self.sample(self.scale, self.diag_covariance) - for _ in range(self.num_estimators) + self.sample(self.scale, self.diag_covariance) for _ in range(self.num_estimators) ] self.need_bn_update = True self.fit = True @@ -189,17 +179,13 @@ def sample( if diag_covariance is None: diag_covariance = self.diag_covariance if not diag_covariance and self.diag_covariance: - raise ValueError( - "Cannot sample full rank from diagonal covariance matrix." - ) + raise ValueError("Cannot sample full rank from diagonal covariance matrix.") if not block: return self._fullrank_sample(scale, diag_covariance) raise NotImplementedError("Raise an issue if you need this feature.") - def _fullrank_sample( - self, scale: float, diagonal_covariance: bool - ) -> nn.Module: + def _fullrank_sample(self, scale: float, diagonal_covariance: bool) -> nn.Module: new_sample = copy.deepcopy(self.core_model) for name_p, param in new_sample.named_parameters(): @@ -207,17 +193,13 @@ def _fullrank_sample( sq_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] if not diagonal_covariance: - cov_mat_sqrt = self.swag_stats[ - self.prfx + name_p + "_covariance_sqrt" - ] + cov_mat_sqrt = self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] var = torch.clamp(sq_mean - mean**2, self.var_clamp) var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) if not diagonal_covariance: - cov_sample = cov_mat_sqrt.t() @ torch.randn( - (cov_mat_sqrt.size(0),) - ) + cov_sample = cov_mat_sqrt.t() @ torch.randn((cov_mat_sqrt.size(0),)) cov_sample /= (self.max_num_models - 1) ** 0.5 var_sample += cov_sample.view_as(var_sample) @@ -230,9 +212,7 @@ def _save_to_state_dict(self, destination, prefix: str, keep_vars: bool): super()._save_to_state_dict(destination, prefix, keep_vars) destination |= self.swag_stats - def state_dict( - self, *args, destination=None, prefix="", keep_vars=False - ) -> Mapping: + def state_dict(self, *args, destination=None, prefix="", keep_vars=False) -> Mapping: """Add the SWAG statistics to the state dict.""" return self.swag_stats | super().state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars @@ -240,21 +220,16 @@ def state_dict( def _load_swag_stats(self, state_dict: Mapping): """Load the SWAG statistics from the state dict.""" - self.swag_stats = { - k: v for k, v in state_dict.items() if k in self.swag_stats - } + self.swag_stats = {k: v for k, v in state_dict.items() if k in self.swag_stats} for k in self.swag_stats: del state_dict[k] self.samples = [ - self.sample(self.scale, self.diag_covariance) - for _ in range(self.num_estimators) + self.sample(self.scale, self.diag_covariance) for _ in range(self.num_estimators) ] self.need_bn_update = True self.fit = True - def load_state_dict( - self, state_dict: Mapping, strict: bool = True, assign: bool = False - ): + def load_state_dict(self, state_dict: Mapping, strict: bool = True, assign: bool = False): self._load_swag_stats(state_dict) return super().load_state_dict(state_dict, strict, assign) @@ -269,8 +244,6 @@ def _swag_checks(scale: float, max_num_models: int, var_clamp: float) -> None: if scale < 0: raise ValueError(f"`scale` must be non-negative. Got {scale}.") if max_num_models < 0: - raise ValueError( - f"`max_num_models` must be non-negative. Got {max_num_models}." - ) + raise ValueError(f"`max_num_models` must be non-negative. Got {max_num_models}.") if var_clamp < 0: raise ValueError(f"`var_clamp` must be non-negative. Got {var_clamp}.") diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index 82b147c8..1479ce9c 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -197,9 +197,7 @@ def optim_imagenet_resnet50( } -def optim_imagenet_resnet50_a3( - model: nn.Module, effective_batch_size: int | None = None -) -> dict: +def optim_imagenet_resnet50_a3(model: nn.Module, effective_batch_size: int | None = None) -> dict: """Training procedure proposed in ResNet strikes back: An improved training procedure in timm. @@ -212,9 +210,7 @@ def optim_imagenet_resnet50_a3( dict: The optimizer and the scheduler for the training. """ if effective_batch_size is None: - logging.warning( - "Setting effective batch size to 2048 for steps computations !" - ) + logging.warning("Setting effective batch size to 2048 for steps computations !") effective_batch_size = 2048 optimizer = Lamb(model.parameters(), lr=0.008, weight_decay=0.02) @@ -350,8 +346,7 @@ def batch_ensemble_wrapper(model: nn.Module, optim_recipe: Callable) -> dict: ) param_core_tmp = list( filter( - lambda kv: (name_list[0] not in kv[0]) - and (name_list[1] not in kv[0]), + lambda kv: (name_list[0] not in kv[0]) and (name_list[1] not in kv[0]), model.named_parameters(), ) ) diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 9dd4e79d..0ec24375 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -85,9 +85,7 @@ def __init__( self.weights = [] for _ in range(num_models): weight = torch.ones([num_classes]) - weight[torch.randperm(num_classes)[:num_rp_classes]] += ( - random_prior - 1 - ) + weight[torch.randperm(num_classes)[:num_rp_classes]] += random_prior - 1 self.weights.append(weight) def fit(self, dataset: Dataset) -> None: @@ -104,9 +102,7 @@ def fit(self, dataset: Dataset) -> None: ClassificationRoutine( num_classes=self.num_classes, model=mod, - loss=nn.CrossEntropyLoss( - weight=self.weights[i].to(device=self.device) - ), + loss=nn.CrossEntropyLoss(weight=self.weights[i].to(device=self.device)), optim_recipe=optim_abnn(mod, lr=self.base_lr), eval_ood=True, ) @@ -133,9 +129,7 @@ def fit(self, dataset: Dataset) -> None: for baseline in baselines: model = copy.deepcopy(source_model) model.load_state_dict(baseline.model.state_dict()) - final_models.extend( - [copy.deepcopy(model) for _ in range(self.num_samples)] - ) + final_models.extend([copy.deepcopy(model) for _ in range(self.num_samples)]) self.final_model = deep_ensembles(final_models) @@ -161,31 +155,21 @@ def _abnn_checks( batch_size, ) -> None: if random_prior < 0: - raise ValueError( - f"random_prior must be greater than 0. Got {random_prior}." - ) + raise ValueError(f"random_prior must be greater than 0. Got {random_prior}.") if batch_size < 1: - raise ValueError( - f"batch_size must be greater than 0. Got {batch_size}." - ) + raise ValueError(f"batch_size must be greater than 0. Got {batch_size}.") if max_epochs < 1: raise ValueError(f"epoch must be greater than 0. Got {max_epochs}.") if num_models < 1: - raise ValueError( - f"num_models must be greater than 0. Got {num_models}." - ) + raise ValueError(f"num_models must be greater than 0. Got {num_models}.") if num_samples < 1: - raise ValueError( - f"num_samples must be greater than 0. Got {num_samples}." - ) + raise ValueError(f"num_samples must be greater than 0. Got {num_samples}.") if alpha < 0: raise ValueError(f"alpha must be greater than 0. Got {alpha}.") if base_lr < 0: raise ValueError(f"base_lr must be greater than 0. Got {base_lr}.") if num_classes < 1: - raise ValueError( - f"num_classes must be greater than 0. Got {num_classes}." - ) + raise ValueError(f"num_classes must be greater than 0. Got {num_classes}.") def _replace_bn_layers(model: nn.Module, alpha: float) -> None: diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 3dcf08e6..9a1a9ef6 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -62,9 +62,7 @@ def fit( """ logits_list = [] labels_list = [] - calibration_dl = DataLoader( - calibration_set, batch_size=32, shuffle=False, drop_last=False - ) + calibration_dl = DataLoader(calibration_set, batch_size=32, shuffle=False, drop_last=False) with torch.no_grad(): for inputs, labels in tqdm(calibration_dl, disable=not progress): logits = self.model(inputs.to(self.device)) @@ -73,9 +71,7 @@ def fit( all_logits = torch.cat(logits_list).detach().to(self.device) all_labels = torch.cat(labels_list).detach().to(self.device) - optimizer = optim.LBFGS( - self.temperature, lr=self.lr, max_iter=self.max_iter - ) + optimizer = optim.LBFGS(self.temperature, lr=self.lr, max_iter=self.max_iter) def calib_eval() -> float: optimizer.zero_grad() @@ -93,8 +89,7 @@ def calib_eval() -> float: def forward(self, inputs: Tensor) -> Tensor: if not self.trained: logging.error( - "TemperatureScaler has not been trained yet. Returning " - "manually tempered inputs." + "TemperatureScaler has not been trained yet. Returning " "manually tempered inputs." ) return self._scale(self.model(inputs)) diff --git a/torch_uncertainty/post_processing/calibration/temperature_scaler.py b/torch_uncertainty/post_processing/calibration/temperature_scaler.py index f334cbab..a9e938de 100644 --- a/torch_uncertainty/post_processing/calibration/temperature_scaler.py +++ b/torch_uncertainty/post_processing/calibration/temperature_scaler.py @@ -47,9 +47,7 @@ def set_temperature(self, val: float) -> None: if val <= 0: raise ValueError("Temperature value must be positive.") - self.temp = nn.Parameter( - torch.ones(1, device=self.device) * val, requires_grad=True - ) + self.temp = nn.Parameter(torch.ones(1, device=self.device) * val, requires_grad=True) def _scale(self, logits: Tensor) -> Tensor: """Scale the prediction with the optimal temperature. diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 23589f8c..582a4bd8 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -22,9 +22,7 @@ def __init__( weight_subset="last_layer", hessian_struct="kron", pred_type: Literal["glm", "nn"] = "glm", - link_approx: Literal[ - "mc", "probit", "bridge", "bridge_norm" - ] = "probit", + link_approx: Literal["mc", "probit", "bridge", "bridge_norm"] = "probit", batch_size: int = 256, optimize_prior_precision: bool = True, ) -> None: @@ -90,6 +88,4 @@ def forward( self, x: Tensor, ) -> Tensor: - return self.la( - x, pred_type=self.pred_type, link_approx=self.link_approx - ) + return self.la(x, pred_type=self.pred_type, link_approx=self.link_approx) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index b011a058..5357d954 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -49,17 +49,13 @@ def __init__( self._setup_model(model) def _setup_model(self, model): - _mcbn_checks( - model, self.num_estimators, self.mc_batch_size, self.convert - ) + _mcbn_checks(model, self.num_estimators, self.mc_batch_size, self.convert) self.model = deepcopy(model) # Is it necessary? self.model = self.model.eval() if self.convert: self._convert() if not has_mcbn(self.model): - raise ValueError( - "model does not contain any MCBatchNorm2d after conversion." - ) + raise ValueError("model does not contain any MCBatchNorm2d after conversion.") def set_model(self, model: nn.Module) -> None: self.model = model @@ -75,9 +71,7 @@ def fit(self, dataset: Dataset) -> None: This method is used to populate the MC BatchNorm layers. Use the training dataset. """ - self.dl = DataLoader( - dataset, batch_size=self.mc_batch_size, shuffle=True - ) + self.dl = DataLoader(dataset, batch_size=self.mc_batch_size, shuffle=True) self.counter = 0 self.reset_counters() self.set_accumulate(True) @@ -89,9 +83,7 @@ def fit(self, dataset: Dataset) -> None: self.set_accumulate(False) self.trained = True return - raise ValueError( - "The dataset is too small to populate the MC BatchNorm statistics." - ) + raise ValueError("The dataset is too small to populate the MC BatchNorm statistics.") def _est_forward(self, x: Tensor) -> Tensor: """Forward pass of a single estimator.""" @@ -106,13 +98,9 @@ def forward( if self.training: return self.model(x) if not self.trained: - raise RuntimeError( - "MCBatchNorm has not been trained. Call .fit() first." - ) + raise RuntimeError("MCBatchNorm has not been trained. Call .fit() first.") self.reset_counters() - return torch.cat( - [self._est_forward(x) for _ in range(self.num_estimators)], dim=0 - ) + return torch.cat([self._est_forward(x) for _ in range(self.num_estimators)], dim=0) def _convert(self) -> None: """Convert all BatchNorm2d layers to MCBatchNorm2d layers.""" @@ -176,15 +164,8 @@ def has_mcbn(model: nn.Module) -> bool: def _mcbn_checks(model, num_estimators, mc_batch_size, convert): if num_estimators < 1 or not isinstance(num_estimators, int): - raise ValueError( - f"num_estimators must be a positive integer, got {num_estimators}." - ) + raise ValueError(f"num_estimators must be a positive integer, got {num_estimators}.") if mc_batch_size < 1 or not isinstance(mc_batch_size, int): - raise ValueError( - f"mc_batch_size must be a positive integer, got {mc_batch_size}." - ) + raise ValueError(f"mc_batch_size must be a positive integer, got {mc_batch_size}.") if not convert and not has_mcbn(model): - raise ValueError( - "model does not contain any MCBatchNorm2d nor is not to be " - "converted." - ) + raise ValueError("model does not contain any MCBatchNorm2d nor is not to be " "converted.") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index b0c37385..8d131fa8 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -72,9 +72,7 @@ def __init__( eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", + ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", post_processing: PostProcessing | None = None, calibration_set: Literal["val", "test"] = "val", num_calibration_bins: int = 15, @@ -258,14 +256,10 @@ def _init_metrics(self) -> None: self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") if self.eval_shift: - self.test_shift_ens_metrics = ens_metrics.clone( - prefix="shift/ens_" - ) + self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift/ens_") if self.eval_grouping_loss: - grouping_loss = MetricCollection( - {"cls/grouping_loss": GroupingLoss()} - ) + grouping_loss = MetricCollection({"cls/grouping_loss": GroupingLoss()}) self.val_grouping_loss = grouping_loss.clone(prefix="val/") self.test_grouping_loss = grouping_loss.clone(prefix="test/") @@ -317,9 +311,7 @@ def _init_mixup(self, mixup_params: dict | None) -> Callable: ) return Identity() - def _apply_mixup( - self, batch: tuple[Tensor, Tensor] - ) -> tuple[Tensor, Tensor]: + def _apply_mixup(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: if not self.is_ensemble: if self.mixup_params["mixtype"] == "kernel_warping": if self.mixup_params["dist_sim"] == "emb": @@ -345,9 +337,7 @@ def on_validation_start(self) -> None: if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: if self.post_processing is not None: @@ -364,9 +354,7 @@ def on_test_start(self) -> None: self.ood_logit_storage = [] if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: """Forward pass of the model. @@ -390,9 +378,7 @@ def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: logits = self.model(inputs) return logits - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> STEP_OUTPUT: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: batch = self._apply_mixup(batch) inputs, target = self.format_batch_fn(batch) @@ -414,9 +400,7 @@ def training_step( self.log("train_loss", loss, prog_bar=True, logger=True) return loss - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) @@ -441,14 +425,8 @@ def test_step( inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) - - if self.binary_cls: - probs_per_est = torch.sigmoid(logits) - else: - probs_per_est = F.softmax(logits, dim=-1) - + probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) - confs = probs.max(-1)[0] if self.ood_criterion == "logit": @@ -456,9 +434,7 @@ def test_step( elif self.ood_criterion == "energy": ood_scores = -logits.mean(dim=1).logsumexp(dim=-1) elif self.ood_criterion == "entropy": - ood_scores = ( - torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) - ) + ood_scores = torch.special.entr(probs_per_est).sum(dim=-1).mean(dim=1) elif self.ood_criterion == "mi": mi_metric = MutualInformation(reduction="none") ood_scores = mi_metric(probs_per_est) @@ -477,9 +453,7 @@ def test_step( if self.eval_grouping_loss: self.test_grouping_loss.update(probs, targets, self.features) - self.log_dict( - self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False - ) + self.log_dict(self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False) self.test_id_entropy(probs) self.log( "test/cls/Entropy", @@ -492,9 +466,7 @@ def test_step( self.test_id_ens_metrics.update(probs_per_est) if self.eval_ood: - self.test_ood_metrics.update( - ood_scores, torch.zeros_like(targets) - ) + self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) if self.id_logit_storage is not None: self.id_logit_storage.append(logits.detach().cpu()) @@ -548,9 +520,7 @@ def on_test_epoch_end(self) -> None: result_dict = self.test_cls_metrics.compute() # already logged - result_dict.update( - {"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True - ) + result_dict.update({"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True) if self.post_processing is not None: tmp_metrics = self.post_cls_metrics.compute() @@ -639,9 +609,7 @@ def on_test_epoch_end(self) -> None: "Histogram of the likelihoods", )[0] self.logger.experiment.add_figure("Logit Histogram", logits_fig) - self.logger.experiment.add_figure( - "Likelihood Histogram", probs_fig - ) + self.logger.experiment.add_figure("Likelihood Histogram", probs_fig) if self.save_in_csv: self.save_results_to_csv(result_dict) @@ -680,8 +648,7 @@ def _classification_routine_checks( if not is_ensemble and ood_criterion in ["mi", "vr"]: raise ValueError( - "You cannot use mutual information or variation ratio with a single" - " model." + "You cannot use mutual information or variation ratio with a single" " model." ) if is_ensemble and eval_grouping_loss: @@ -691,14 +658,12 @@ def _classification_routine_checks( if num_classes < 1: raise ValueError( - "The number of classes must be a positive integer >= 1." - f"Got {num_classes}." + "The number of classes must be a positive integer >= 1." f"Got {num_classes}." ) if eval_grouping_loss and not hasattr(model, "feats_forward"): raise ValueError( - "Your model must have a `feats_forward` method to compute the " - "grouping loss." + "Your model must have a `feats_forward` method to compute the " "grouping loss." ) if eval_grouping_loss and not ( @@ -710,9 +675,7 @@ def _classification_routine_checks( ) if num_calibration_bins < 2: - raise ValueError( - f"num_calibration_bins must be at least 2, got {num_calibration_bins}." - ) + raise ValueError(f"num_calibration_bins must be at least 2, got {num_calibration_bins}.") if mixup_params is not None and isinstance(format_batch_fn, RepeatTarget): raise ValueError( diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index a6749bf0..f81ff762 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -83,8 +83,7 @@ def __init__( _depth_routine_checks(output_dim, num_image_plot, log_plots) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " - "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." ) self.model = model @@ -126,9 +125,7 @@ def __init__( self.test_metrics = depth_metrics.clone(prefix="test/") if self.probabilistic: - depth_prob_metrics = MetricCollection( - {"reg/NLL": DistributionNLL(reduction="mean")} - ) + depth_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")}) self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/") self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/") @@ -145,15 +142,11 @@ def on_validation_start(self) -> None: if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. @@ -176,21 +169,14 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = pred.squeeze(-1) return pred - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> STEP_OUTPUT: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: inputs, target = self.format_batch_fn(batch) if self.one_dim_depth: target = target.unsqueeze(1) dists = self.model(inputs) - if self.probabilistic: - out_shape = dist_size(dists)[-2:] - else: - out_shape = dists.shape[-2:] - target = F.resize( - target, out_shape, interpolation=F.InterpolationMode.NEAREST - ) + out_shape = dist_size(dists)[-2:] if self.probabilistic else dists.shape[-2:] + target = F.resize(target, out_shape, interpolation=F.InterpolationMode.NEAREST) padding_mask = torch.isnan(target) if self.probabilistic: loss = self.loss(dists, target, padding_mask) @@ -202,9 +188,7 @@ def training_step( self.log("train_loss", loss, prog_bar=True, logger=True) return loss - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) @@ -214,16 +198,10 @@ def validation_step( if self.probabilistic: ens_dist = Independent( - dist_rearrange( - preds, "(m b) c h w -> (b c h w) m", b=batch_size - ), + dist_rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size), 0, ) - mix = Categorical( - torch.ones( - (dist_size(preds)[0] // batch_size), device=self.device - ) - ) + mix = Categorical(torch.ones((dist_size(preds)[0] // batch_size), device=self.device)) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: @@ -251,8 +229,7 @@ def test_step( ) -> None: if dataloader_idx != 0: raise NotImplementedError( - "Depth OOD detection not implemented yet. Raise an issue " - "if needed." + "Depth OOD detection not implemented yet. Raise an issue " "if needed." ) inputs, targets = batch if self.one_dim_depth: @@ -262,14 +239,8 @@ def test_step( preds = self.model(inputs) if self.probabilistic: - ens_dist = dist_rearrange( - preds, "(m b) c h w -> (b c h w) m", b=batch_size - ) - mix = Categorical( - torch.ones( - (dist_size(preds)[0] // batch_size), device=self.device - ) - ) + ens_dist = dist_rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) + mix = Categorical(torch.ones((dist_size(preds)[0] // batch_size), device=self.device)) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: @@ -278,9 +249,7 @@ def test_step( if batch_idx == 0 and self.log_plots: num_images = ( - self.num_image_plot - if self.num_image_plot < inputs.size(0) - else inputs.size(0) + self.num_image_plot if self.num_image_plot < inputs.size(0) else inputs.size(0) ) self._plot_depth( inputs[:num_images, ...], @@ -340,12 +309,8 @@ def _plot_depth( all_imgs = [] for i in range(inputs.size(0)): img = F.normalize(inputs[i, ...].cpu(), **self.inv_norm_params) - pred = colorize( - preds[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth - ) - tgt = colorize( - target[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth - ) + pred = colorize(preds[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth) + tgt = colorize(target[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth) all_imgs.extend([img, pred, tgt]) self.logger.experiment.add_image( @@ -380,12 +345,8 @@ def colorize( return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0 -def _depth_routine_checks( - output_dim: int, num_image_plot: int, log_plots: bool -) -> None: +def _depth_routine_checks(output_dim: int, num_image_plot: int, log_plots: bool) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") if num_image_plot < 1 and log_plots: - raise ValueError( - f"num_image_plot must be positive, got {num_image_plot}." - ) + raise ValueError(f"num_image_plot must be positive, got {num_image_plot}.") diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index a90f08b2..e51a126b 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -72,8 +72,7 @@ def __init__( _regression_routine_checks(output_dim) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " - "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." ) self.model = model @@ -103,9 +102,7 @@ def __init__( self.test_metrics = reg_metrics.clone(prefix="test/") if self.probabilistic: - reg_prob_metrics = MetricCollection( - {"reg/NLL": DistributionNLL(reduction="mean")} - ) + reg_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")}) self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/") self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/") @@ -124,15 +121,11 @@ def on_validation_start(self) -> None: if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. @@ -159,9 +152,7 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = pred.squeeze(-1) return pred - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> STEP_OUTPUT: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: inputs, targets = self.format_batch_fn(batch) if self.one_dim_regression: @@ -178,9 +169,7 @@ def training_step( self.log("train_loss", loss, prog_bar=True, logger=True) return loss - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) @@ -190,11 +179,7 @@ def validation_step( if self.probabilistic: ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) - mix = Categorical( - torch.ones( - dist_size(preds)[0] // batch_size, device=self.device - ) - ) + mix = Categorical(torch.ones(dist_size(preds)[0] // batch_size, device=self.device)) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: @@ -213,8 +198,7 @@ def test_step( ) -> None: if dataloader_idx != 0: raise NotImplementedError( - "Regression OOD detection not implemented yet. Raise an issue " - "if needed." + "Regression OOD detection not implemented yet. Raise an issue " "if needed." ) inputs, targets = batch @@ -226,11 +210,7 @@ def test_step( if self.probabilistic: ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) - mix = Categorical( - torch.ones( - dist_size(preds)[0] // batch_size, device=self.device - ) - ) + mix = Categorical(torch.ones(dist_size(preds)[0] // batch_size, device=self.device)) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 185b004a..2ef10d73 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -76,8 +76,7 @@ def __init__( ) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " - "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." ) self.model = model @@ -103,14 +102,10 @@ def __init__( ) sbsmpl_seg_metrics = MetricCollection( { - "seg/mAcc": Accuracy( - task="multiclass", average="macro", num_classes=num_classes - ), + "seg/mAcc": Accuracy(task="multiclass", average="macro", num_classes=num_classes), "seg/Brier": BrierScore(num_classes=num_classes), "seg/NLL": CategoricalNLL(), - "seg/pixAcc": Accuracy( - task="multiclass", num_classes=num_classes - ), + "seg/pixAcc": Accuracy(task="multiclass", num_classes=num_classes), "cal/ECE": CalibrationError( task="multiclass", num_classes=num_classes, @@ -163,25 +158,17 @@ def on_validation_start(self) -> None: if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): - self.model.bn_update( - self.trainer.train_dataloader, device=self.device - ) + self.model.bn_update(self.trainer.train_dataloader, device=self.device) - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> STEP_OUTPUT: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: img, target = batch img, target = self.format_batch_fn((img, target)) logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST - ) + target = F.resize(target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST) logits = rearrange(logits, "b c h w -> (b h w) c") target = target.flatten() valid_mask = target != 255 @@ -191,9 +178,7 @@ def training_step( self.log("train_loss", loss, prog_bar=True, logger=True) return loss - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: img, targets = batch logits = self.forward(img) targets = F.resize( @@ -201,9 +186,7 @@ def validation_step( logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST, ) - logits = rearrange( - logits, "(m b) c h w -> (b h w) m c", b=targets.size(0) - ) + logits = rearrange(logits, "(m b) c h w -> (b h w) m c", b=targets.size(0)) probs_per_est = logits.softmax(dim=-1) probs = probs_per_est.mean(dim=1) targets = targets.flatten() @@ -221,20 +204,13 @@ def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: interpolation=F.InterpolationMode.NEAREST, ) - logits = rearrange( - logits, "(m b) c h w -> b m c h w", b=targets.size(0) - ) + logits = rearrange(logits, "(m b) c h w -> b m c h w", b=targets.size(0)) probs_per_est = logits.softmax(dim=2) probs = probs_per_est.mean(dim=1) - if ( - self.log_plots - and len(self.sample_buffer) < self.num_samples_to_plot - ): + if self.log_plots and len(self.sample_buffer) < self.num_samples_to_plot: max_count = self.num_samples_to_plot - len(self.sample_buffer) - for i, (_img, _prb, _tgt) in enumerate( - zip(img, probs, targets, strict=False) - ): + for i, (_img, _prb, _tgt) in enumerate(zip(img, probs, targets, strict=False)): if i >= max_count: break _pred = _prb.argmax(dim=0, keepdim=True) @@ -281,18 +257,8 @@ def on_test_epoch_end(self) -> None: def log_segmentation_plots(self) -> None: """Builds and logs examples of segmentation plots from the test set.""" for i, (img, pred, tgt) in enumerate(self.sample_buffer): - pred = ( - pred - == torch.arange(self.num_classes, device=pred.device)[ - :, None, None - ] - ) - tgt = ( - tgt - == torch.arange(self.num_classes, device=tgt.device)[ - :, None, None - ] - ) + pred = pred == torch.arange(self.num_classes, device=pred.device)[:, None, None] + tgt = tgt == torch.arange(self.num_classes, device=tgt.device)[:, None, None] # Undo normalization on the image and convert to uint8. mean = torch.tensor(self.trainer.datamodule.mean, device=img.device) @@ -301,17 +267,10 @@ def log_segmentation_plots(self) -> None: img = ToDtype(torch.uint8, scale=True)(img) dataset = self.trainer.datamodule.test - if hasattr(dataset, "color_palette"): - color_palette = dataset.color_palette - else: - color_palette = None + color_palette = dataset.color_palette if hasattr(dataset, "color_palette") else None - pred_mask = draw_segmentation_masks( - img, pred, alpha=0.7, colors=color_palette - ) - gt_mask = draw_segmentation_masks( - img, tgt, alpha=0.7, colors=color_palette - ) + pred_mask = draw_segmentation_masks(img, pred, alpha=0.7, colors=color_palette) + gt_mask = draw_segmentation_masks(img, tgt, alpha=0.7, colors=color_palette) self.logger.experiment.add_figure( f"Segmentation results/{i}", @@ -339,6 +298,4 @@ def _segmentation_routine_checks( ) if num_calibration_bins < 2: - raise ValueError( - f"num_calibration_bins must be at least 2, got {num_calibration_bins}." - ) + raise ValueError(f"num_calibration_bins must be at least 2, got {num_calibration_bins}.") diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index dd96bba7..1426992a 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -13,27 +13,19 @@ def __init__(self, num_repeats: int) -> None: super().__init__() if not isinstance(num_repeats, int): - raise TypeError( - f"num_repeats must be an integer. Got {num_repeats}." - ) + raise TypeError(f"num_repeats must be an integer. Got {num_repeats}.") if num_repeats <= 0: - raise ValueError( - f"num_repeats must be greater than 0. Got {num_repeats}." - ) + raise ValueError(f"num_repeats must be greater than 0. Got {num_repeats}.") self.num_repeats = num_repeats def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat( - self.num_repeats, *[1] * (targets.ndim - 1) - ) + return inputs, targets.repeat(self.num_repeats, *[1] * (targets.ndim - 1)) class MIMOBatchFormat(nn.Module): - def __init__( - self, num_estimators: int, rho: float = 0.0, batch_repeat: int = 1 - ) -> None: + def __init__(self, num_estimators: int, rho: float = 0.0, batch_repeat: int = 1) -> None: """Format the batch for MIMO training. Args: @@ -64,9 +56,9 @@ def shuffle(self, inputs: Tensor) -> Tensor: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - indexes = torch.arange( - 0, inputs.shape[0], device=inputs.device, dtype=torch.int64 - ).repeat(self.batch_repeat) + indexes = torch.arange(0, inputs.shape[0], device=inputs.device, dtype=torch.int64).repeat( + self.batch_repeat + ) main_shuffle = self.shuffle(indexes) threshold_shuffle = int(main_shuffle.shape[0] * (1.0 - self.rho)) shuffle_indices = [ @@ -80,21 +72,13 @@ def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: for _ in range(self.num_estimators) ] inputs = torch.stack( - [ - torch.index_select(inputs, dim=0, index=indices) - for indices in shuffle_indices - ], + [torch.index_select(inputs, dim=0, index=indices) for indices in shuffle_indices], dim=0, ) targets = torch.stack( - [ - torch.index_select(targets, dim=0, index=indices) - for indices in shuffle_indices - ], + [torch.index_select(targets, dim=0, index=indices) for indices in shuffle_indices], dim=0, ) - inputs = rearrange( - inputs, "m b c h w -> (m b) c h w", m=self.num_estimators - ) + inputs = rearrange(inputs, "m b c h w -> (m b) c h w", m=self.num_estimators) targets = rearrange(targets, "m b -> (m b)", m=self.num_estimators) return inputs, targets diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 70a59267..ffa5f491 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -210,20 +210,14 @@ def forward(self, img: Tensor) -> Tensor: img = torch.as_tensor(gaussian(img, sigma=self.sigma)) for _ in range(self.iterations): for h in range(img_size[0] - self.max_delta, self.max_delta, -1): - for w in range( - img_size[1] - self.max_delta, self.max_delta, -1 - ): - dx, dy = torch.randint( - -self.max_delta, self.max_delta, size=(2,) - ) + for w in range(img_size[1] - self.max_delta, self.max_delta, -1): + dx, dy = torch.randint(-self.max_delta, self.max_delta, size=(2,)) h_prime, w_prime = h + dy, w + dx img[h, w], img[h_prime, w_prime] = ( img[h_prime, w_prime], img[h, w], ) - return torch.clamp( - torch.as_tensor(gaussian(img, sigma=self.sigma)), 0, 1 - ) + return torch.clamp(torch.as_tensor(gaussian(img, sigma=self.sigma)), 0, 1) def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): @@ -267,9 +261,7 @@ def forward(self, img: Tensor) -> Tensor: sigma=self.sigma, angle=self.rng.uniform(-45, 45), ) - x = cv2.imdecode( - np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED - ) + x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED) x = np.clip(x[..., [2, 1, 0]], 0, 255) return self.to_tensor(x) @@ -342,9 +334,9 @@ def forward(self, img: Tensor) -> Tensor: return img _, height, width = img.shape x = img.numpy() - snow_layer = self.rng.normal( - size=x.shape[1:], loc=self.mix[0], scale=self.mix[1] - )[..., np.newaxis] + snow_layer = self.rng.normal(size=x.shape[1:], loc=self.mix[0], scale=self.mix[1])[ + ..., np.newaxis + ] snow_layer = clipped_zoom(snow_layer, self.mix[2]) snow_layer[snow_layer < self.mix[3]] = 0 snow_layer = Image.fromarray( @@ -369,27 +361,18 @@ def forward(self, img: Tensor) -> Tensor: snow_layer = snow_layer[np.newaxis, ...] x = self.mix[6] * x + (1 - self.mix[6]) * np.maximum( x, - cv2.cvtColor(x.transpose([1, 2, 0]), cv2.COLOR_RGB2GRAY).reshape( - 1, height, width - ) - * 1.5 + cv2.cvtColor(x.transpose([1, 2, 0]), cv2.COLOR_RGB2GRAY).reshape(1, height, width) * 1.5 + 0.5, ) - return torch.clamp( - torch.as_tensor(x + snow_layer + np.rot90(snow_layer, k=2)), 0, 1 - ) + return torch.clamp(torch.as_tensor(x + snow_layer + np.rot90(snow_layer, k=2)), 0, 1) class Frost(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) self.rng = np.random.default_rng() - self.mix = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][ - severity - 1 - ] - self.frost_ds = FrostImages( - "./data", download=True, transform=ToTensor() - ) + self.mix = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][severity - 1] + self.frost_ds = FrostImages("./data", download=True, transform=ToTensor()) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -436,15 +419,11 @@ def filldiamonds(): ldrsum = drgrid + np.roll(drgrid, 1, axis=0) lulsum = ulgrid + np.roll(ulgrid, -1, axis=1) ltsum = ldrsum + lulsum - maparray[0:mapsize:stepsize, stepsize // 2 : mapsize : stepsize] = ( - wibbledmean(ltsum) - ) + maparray[0:mapsize:stepsize, stepsize // 2 : mapsize : stepsize] = wibbledmean(ltsum) tdrsum = drgrid + np.roll(drgrid, 1, axis=1) tulsum = ulgrid + np.roll(ulgrid, -1, axis=0) ttsum = tdrsum + tulsum - maparray[stepsize // 2 : mapsize : stepsize, 0:mapsize:stepsize] = ( - wibbledmean(ttsum) - ) + maparray[stepsize // 2 : mapsize : stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum) while stepsize >= 2: fillsquares() @@ -464,9 +443,7 @@ def __init__(self, severity: int, size: int = 256) -> None: self.resize = Resize((size, size), InterpolationMode.BICUBIC) else: raise ValueError(f"Size must be a power of 2. Got {size}.") - self.mix = [(1.5, 2), (2, 2), (2.5, 1.7), (2.5, 1.5), (3, 1.4)][ - severity - 1 - ] + self.mix = [(1.5, 2), (2, 2), (2.5, 1.7), (2.5, 1.5), (3, 1.4)][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -478,13 +455,9 @@ def forward(self, img: Tensor) -> Tensor: max_val = img.max() fog = ( self.mix[0] - * plasma_fractal( - height=height, width=width, wibbledecay=self.mix[1] - )[:height, :width] - ) - final = torch.clamp( - (img + fog) * max_val / (max_val + self.mix[0]), 0, 1 + * plasma_fractal(height=height, width=width, wibbledecay=self.mix[1])[:height, :width] ) + final = torch.clamp((img + fog) * max_val / (max_val + self.mix[0]), 0, 1) return Resize((height, width), InterpolationMode.BICUBIC)(final) @@ -614,18 +587,14 @@ def forward(self, img: Tensor) -> Tensor: ).astype(np.float32) dx, dy = dx[..., np.newaxis], dy[..., np.newaxis] - x, y, z = np.meshgrid( - np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]) - ) + x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) indices = ( np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1)), ) img = np.clip( - map_coordinates(image, indices, order=1, mode="reflect").reshape( - shape - ), + map_coordinates(image, indices, order=1, mode="reflect").reshape(shape), 0, 1, ) diff --git a/torch_uncertainty/transforms/cutout.py b/torch_uncertainty/transforms/cutout.py index 3af0e477..d4bc22b5 100644 --- a/torch_uncertainty/transforms/cutout.py +++ b/torch_uncertainty/transforms/cutout.py @@ -17,9 +17,7 @@ def __init__(self, length: int, value: int = 0) -> None: self.length = length if value < 0 or value > 255: - raise ValueError( - f"Cutout value must be between 0 and 255. Got {value}." - ) + raise ValueError(f"Cutout value must be between 0 and 255. Got {value}.") self.value = value def __call__(self, img: torch.Tensor) -> torch.Tensor: diff --git a/torch_uncertainty/transforms/image.py b/torch_uncertainty/transforms/image.py index ba1eb838..b4695489 100644 --- a/torch_uncertainty/transforms/image.py +++ b/torch_uncertainty/transforms/image.py @@ -32,9 +32,7 @@ class Posterize(nn.Module): level_type = int corruption_overlap = False - def forward( - self, img: Tensor | Image.Image, level: int - ) -> Tensor | Image.Image: + def forward(self, img: Tensor | Image.Image, level: int) -> Tensor | Image.Image: if level >= self.max_level: raise ValueError(f"Level must be less than {self.max_level}.") if level < 0: @@ -48,9 +46,7 @@ class Solarize(nn.Module): level_type = int corruption_overlap = False - def forward( - self, img: Tensor | Image.Image, level: float - ) -> Tensor | Image.Image: + def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Image: if level >= self.max_level: raise ValueError(f"Level must be less than {self.max_level}.") if level < 0: @@ -78,12 +74,8 @@ def __init__( self.center = center self.fill = fill - def forward( - self, img: Tensor | Image.Image, level: float - ) -> Tensor | Image.Image: - if ( - self.random_direction and torch.rand(1).item() > 0.5 - ): # coverage: ignore + def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Image: + if self.random_direction and torch.rand(1).item() > 0.5: # coverage: ignore level = -level return F.rotate( img, @@ -117,12 +109,8 @@ def __init__( self.center = center self.fill = fill - def forward( - self, img: Tensor | Image.Image, level: int - ) -> Tensor | Image.Image: - if ( - self.random_direction and torch.rand(1).item() > 0.5 - ): # coverage: ignore + def forward(self, img: Tensor | Image.Image, level: int) -> Tensor | Image.Image: + if self.random_direction and torch.rand(1).item() > 0.5: # coverage: ignore level = -level shear = [0.0, 0.0] shear[self.axis] = level @@ -160,12 +148,8 @@ def __init__( self.center = center self.fill = fill - def forward( - self, img: Tensor | Image.Image, level: float - ) -> Tensor | Image.Image: - if ( - self.random_direction and torch.rand(1).item() > 0.5 - ): # coverage: ignore + def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Image: + if self.random_direction and torch.rand(1).item() > 0.5: # coverage: ignore level = -level translate = [0.0, 0.0] translate[self.axis] = level @@ -189,9 +173,7 @@ class Contrast(nn.Module): def __init__(self) -> None: super().__init__() - def forward( - self, img: Tensor | Image.Image, level: float - ) -> Tensor | Image.Image: + def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Image: if level < 0: raise ValueError("Level must be greater than 0.") return F.adjust_contrast(img, level) @@ -232,9 +214,7 @@ class Sharpen(nn.Module): def __init__(self) -> None: super().__init__() - def forward( - self, img: Tensor | Image.Image, level: float - ) -> Tensor | Image.Image: + def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Image: if level < 0: raise ValueError("Level must be greater than 0.") return F.adjust_sharpness(img, level) @@ -249,9 +229,7 @@ def __init__(self) -> None: """Color augmentation class.""" super().__init__() - def forward( - self, img: Tensor | Image.Image, level: float - ) -> Tensor | Image.Image: + def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Image: if level < 0: raise ValueError("Level must be greater than 0.") pil_img = F.to_pil_image(img) if isinstance(img, Tensor) else img diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 43c4283d..b020b8a8 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -18,9 +18,7 @@ def beta_warping(x, alpha_cdf: float = 1.0, eps: float = 1e-12) -> float: def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: - dist_rate = tau_max * np.exp( - -(dist - 1) / (np.mean(dist) * 2 * tau_std * tau_std) - ) + dist_rate = tau_max * np.exp(-(dist - 1) / (np.mean(dist) * 2 * tau_std * tau_std)) return 1 / (dist_rate + 1e-12) @@ -89,9 +87,7 @@ def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: # TODO: Should be a torchvision transform class AbstractMixup(nn.Module): - def __init__( - self, alpha: float = 1.0, mode: str = "batch", num_classes: int = 1000 - ) -> None: + def __init__(self, alpha: float = 1.0, mode: str = "batch", num_classes: int = 1000) -> None: super().__init__() self.rng = np.random.default_rng() self.alpha = alpha diff --git a/torch_uncertainty/transforms/pixmix.py b/torch_uncertainty/transforms/pixmix.py index 48f5e893..aff24c9e 100644 --- a/torch_uncertainty/transforms/pixmix.py +++ b/torch_uncertainty/transforms/pixmix.py @@ -69,9 +69,7 @@ def __init__( self.mixing_severity = mixing_severity if not all_ops: - allowed_augmentations = [ - aug for aug in augmentations if not aug.corruption_overlap - ] + allowed_augmentations = [aug for aug in augmentations if not aug.corruption_overlap] else: allowed_augmentations = augmentations @@ -95,9 +93,7 @@ def __call__(self, img: Image.Image) -> np.ndarray: # TODO: Fix mixed_op = self.rng.choice(mixings) - mixed = mixed_op( - np.array(mixed), np.array(aug_image_copy), self.mixing_severity - ) + mixed = mixed_op(np.array(mixed), np.array(aug_image_copy), self.mixing_severity) mixed = np.clip(mixed, 0, 1) return mixed diff --git a/torch_uncertainty/utils/checkpoints.py b/torch_uncertainty/utils/checkpoints.py index ac306005..5a69ee5e 100644 --- a/torch_uncertainty/utils/checkpoints.py +++ b/torch_uncertainty/utils/checkpoints.py @@ -1,9 +1,7 @@ from pathlib import Path -def get_version( - root: str | Path, version: int, checkpoint: int | None = None -) -> tuple[Path, Path]: +def get_version(root: str | Path, version: int, checkpoint: int | None = None) -> tuple[Path, Path]: """Find the path to the checkpoint corresponding to the version. Args: @@ -29,9 +27,7 @@ def get_version( else: ckpts = list(ckpt_folder.glob(f"epoch={checkpoint}-*.ckpt")) else: - raise FileNotFoundError( - f"The directory {root}/version_{version} does not exist." - ) + raise FileNotFoundError(f"The directory {root}/version_{version} does not exist.") file = ckpts[0] return (file.resolve(), (version_folder / "hparams.yaml").resolve()) diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 12a0f9b4..8cfd79ef 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -17,9 +17,7 @@ class TUSaveConfigCallback(SaveConfigCallback): @override - def setup( - self, trainer: Trainer, pl_module: LightningModule, stage: str - ) -> None: + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: if self.already_saved: return @@ -31,9 +29,7 @@ def setup( if not self.overwrite: # check if the file exists on rank 0 - file_exists = ( - fs.isfile(config_path) if trainer.is_global_zero else False - ) + file_exists = fs.isfile(config_path) if trainer.is_global_zero else False # broadcast whether to fail to all ranks file_exists = trainer.strategy.broadcast(file_exists) if file_exists: # coverage: ignore @@ -63,16 +59,11 @@ def setup( class TULightningCLI(LightningCLI): def __init__( self, - model_class: ( - type[LightningModule] | Callable[..., LightningModule] | None - ) = None, + model_class: (type[LightningModule] | Callable[..., LightningModule] | None) = None, datamodule_class: ( - type[LightningDataModule] - | Callable[..., LightningDataModule] - | None + type[LightningDataModule] | Callable[..., LightningDataModule] | None ) = None, - save_config_callback: type[SaveConfigCallback] - | None = TUSaveConfigCallback, + save_config_callback: type[SaveConfigCallback] | None = TUSaveConfigCallback, save_config_kwargs: dict[str, Any] | None = None, trainer_class: type[Trainer] | Callable[..., Trainer] = TUTrainer, trainer_defaults: dict[str, Any] | None = None, @@ -120,9 +111,7 @@ def __init__( auto_configure_optimizers, ) - def add_default_arguments_to_parser( - self, parser: LightningArgumentParser - ) -> None: + def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Adds default arguments to the parser.""" parser.add_argument( "--eval_after_fit", diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 1bf2e669..d3d161a8 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -24,8 +24,7 @@ def dist_size(distribution: Distribution) -> torch.Size: if isinstance(distribution, Normal | Laplace | NormalInverseGamma): return distribution.loc.size() raise NotImplementedError( - f"Size of {type(distribution)} distributions is not supported." - "Raise an issue if needed." + f"Size of {type(distribution)} distributions is not supported." "Raise an issue if needed." ) @@ -40,38 +39,21 @@ def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: Distribution: The concatenated distributions. """ dist_type = type(distributions[0]) - if not all( - isinstance(distribution, dist_type) for distribution in distributions - ): + if not all(isinstance(distribution, dist_type) for distribution in distributions): raise ValueError("All distributions must have the same type.") if isinstance(distributions[0], Normal | Laplace): - locs = torch.cat( - [distribution.loc for distribution in distributions], dim=dim - ) - scales = torch.cat( - [distribution.scale for distribution in distributions], dim=dim - ) + locs = torch.cat([distribution.loc for distribution in distributions], dim=dim) + scales = torch.cat([distribution.scale for distribution in distributions], dim=dim) return dist_type(loc=locs, scale=scales) if isinstance(distributions[0], NormalInverseGamma): - locs = torch.cat( - [distribution.loc for distribution in distributions], dim=dim - ) - lmbdas = torch.cat( - [distribution.lmbda for distribution in distributions], dim=dim - ) - alphas = torch.cat( - [distribution.alpha for distribution in distributions], dim=dim - ) - betas = torch.cat( - [distribution.beta for distribution in distributions], dim=dim - ) - return NormalInverseGamma( - loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas - ) + locs = torch.cat([distribution.loc for distribution in distributions], dim=dim) + lmbdas = torch.cat([distribution.lmbda for distribution in distributions], dim=dim) + alphas = torch.cat([distribution.alpha for distribution in distributions], dim=dim) + betas = torch.cat([distribution.beta for distribution in distributions], dim=dim) + return NormalInverseGamma(loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas) raise NotImplementedError( - f"Concatenation of {dist_type} distributions is not supported." - "Raise an issue if needed." + f"Concatenation of {dist_type} distributions is not supported." "Raise an issue if needed." ) @@ -97,14 +79,11 @@ def dist_squeeze(distribution: Distribution, dim: int) -> Distribution: beta = distribution.beta.squeeze(dim) return NormalInverseGamma(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( - f"Squeezing of {dist_type} distributions is not supported." - "Raise an issue if needed." + f"Squeezing of {dist_type} distributions is not supported." "Raise an issue if needed." ) -def dist_rearrange( - distribution: Distribution, pattern: str, **axes_lengths: int -) -> Distribution: +def dist_rearrange(distribution: Distribution, pattern: str, **axes_lengths: int) -> Distribution: dist_type = type(distribution) if isinstance(distribution, Normal | Laplace): loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) @@ -139,9 +118,7 @@ def __init__( beta: Number | Tensor, validate_args: bool | None = None, ) -> None: - self.loc, self.lmbda, self.alpha, self.beta = broadcast_all( - loc, lmbda, alpha, beta - ) + self.loc, self.lmbda, self.alpha, self.beta = broadcast_all(loc, lmbda, alpha, beta) if ( isinstance(loc, Number) and isinstance(lmbda, Number) @@ -163,19 +140,13 @@ def mean(self) -> Tensor: return self.loc def mode(self) -> None: - raise NotImplementedError( - "NormalInverseGamma distribution has no mode." - ) + raise NotImplementedError("NormalInverseGamma distribution has no mode.") def stddev(self) -> None: - raise NotImplementedError( - "NormalInverseGamma distribution has no stddev." - ) + raise NotImplementedError("NormalInverseGamma distribution has no stddev.") def variance(self) -> None: - raise NotImplementedError( - "NormalInverseGamma distribution has no variance." - ) + raise NotImplementedError("NormalInverseGamma distribution has no variance.") @property def mean_loc(self) -> Tensor: @@ -196,8 +167,7 @@ def log_prob(self, value: Tensor) -> Tensor: return ( -0.5 * torch.log(torch.pi / self.lmbda) + self.alpha * gam.log() - - (self.alpha + 0.5) - * torch.log(gam + self.lmbda * (value - self.loc) ** 2) + - (self.alpha + 0.5) * torch.log(gam + self.lmbda * (value - self.loc) ** 2) - torch.lgamma(self.alpha) + torch.lgamma(self.alpha + 0.5) ) diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index 5bc729c7..dac1138c 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -85,12 +85,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "cls" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) - table.add_column( - "Classification", justify="center", style="magenta", width=25 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) + table.add_column("Classification", justify="center", style="magenta", width=25) cls_metrics = OrderedDict(sorted(metrics["cls"].items())) for metric, value in cls_metrics.items(): if metric in percentage_metrics: @@ -102,12 +98,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "seg" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) - table.add_column( - "Segmentation", justify="center", style="magenta", width=25 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) + table.add_column("Segmentation", justify="center", style="magenta", width=25) seg_metrics = OrderedDict(sorted(metrics["seg"].items())) for metric, value in seg_metrics.items(): if metric in percentage_metrics: @@ -119,12 +111,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "reg" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) - table.add_column( - "Regression", justify="center", style="magenta", width=25 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) + table.add_column("Regression", justify="center", style="magenta", width=25) reg_metrics = OrderedDict(sorted(metrics["reg"].items())) for metric, value in reg_metrics.items(): if metric in percentage_metrics: # coverage: ignore @@ -136,12 +124,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "cal" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) - table.add_column( - "Calibration", justify="center", style="magenta", width=25 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) + table.add_column("Calibration", justify="center", style="magenta", width=25) cal_metrics = OrderedDict(sorted(metrics["cal"].items())) for metric, value in cal_metrics.items(): if metric in percentage_metrics: @@ -153,12 +137,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "ood" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) - table.add_column( - "OOD Detection", justify="center", style="magenta", width=25 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) + table.add_column("OOD Detection", justify="center", style="magenta", width=25) ood_metrics = OrderedDict(sorted(metrics["ood"].items())) for metric, value in ood_metrics.items(): if metric in percentage_metrics: @@ -170,9 +150,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "sc" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) table.add_column( "Selective Classification", justify="center", @@ -190,12 +168,8 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "post" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) - table.add_column( - "Post-Processing", justify="center", style="magenta", width=25 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) + table.add_column("Post-Processing", justify="center", style="magenta", width=25) post_metrics = OrderedDict(sorted(metrics["post"].items())) for metric, value in post_metrics.items(): if metric in percentage_metrics: @@ -207,9 +181,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: if "shift" in metrics: table = Table() - table.add_column( - first_col_name, justify="center", style="cyan", width=12 - ) + table.add_column(first_col_name, justify="center", style="cyan", width=12) shift_severity = int(metrics["shift"]["shift_severity"]) table.add_column( f"Distribution Shift lvl{shift_severity}", diff --git a/torch_uncertainty/utils/hub.py b/torch_uncertainty/utils/hub.py index 73bd0f7f..fd522dc2 100644 --- a/torch_uncertainty/utils/hub.py +++ b/torch_uncertainty/utils/hub.py @@ -20,9 +20,7 @@ huggingface_hub_installed = False -def load_hf( - weight_id: str, version: int = 0 -) -> tuple[dict[str, torch.Tensor], dict[str, str]]: +def load_hf(weight_id: str, version: int = 0) -> tuple[dict[str, torch.Tensor], dict[str, str]]: """Load a model from the HuggingFace hub. Args: @@ -36,9 +34,7 @@ def load_hf( TorchUncertainty's weights are released under the Apache 2.0 license. """ if not huggingface_hub_installed: - raise ImportError( - "Please install huggingface_hub to use this function." - ) + raise ImportError("Please install huggingface_hub to use this function.") if not safetensors_installed: raise ImportError("Please install safetensors to use this function.") repo_id = f"torch-uncertainty/{weight_id}" @@ -57,9 +53,7 @@ def load_hf( try: weight_path = hf_hub_download(repo_id=repo_id, filename=filename) except EntryNotFoundError: - raise ValueError( - f"Model {weight_id}_{version} not found on HuggingFace." - ) from not_pt + raise ValueError(f"Model {weight_id}_{version} not found on HuggingFace.") from not_pt if pickle: weight = torch.load(weight_path, map_location=torch.device("cpu")) diff --git a/torch_uncertainty/utils/learning_rate.py b/torch_uncertainty/utils/learning_rate.py index 1e6aef85..b30602ce 100644 --- a/torch_uncertainty/utils/learning_rate.py +++ b/torch_uncertainty/utils/learning_rate.py @@ -22,8 +22,7 @@ def get_lr(self) -> list[float]: def _get_closed_form_lr(self) -> list[float]: return [ max( - base_lr - * (1 - self.last_epoch / self.total_iters) ** self.power, + base_lr * (1 - self.last_epoch / self.total_iters) ** self.power, self.min_lr, ) for base_lr in self.base_lrs diff --git a/torch_uncertainty/utils/to_hub_format.py b/torch_uncertainty/utils/to_hub_format.py index 1c3c24f9..1e1136d3 100644 --- a/torch_uncertainty/utils/to_hub_format.py +++ b/torch_uncertainty/utils/to_hub_format.py @@ -12,21 +12,11 @@ prog="to_hub_format", description="Post-process the checkpoints before the upload to HuggingFace", ) -parser.add_argument( - "--name", type=str, required=True, help="name of the checkpoint" -) -parser.add_argument( - "--path", type=Path, required=True, help="path to the checkpoint" -) -parser.add_argument( - "--version", type=int, default=0, help="version of the checkpoint" -) -parser.add_argument( - "--safe", action="store_true", help="whether to use safetensors" -) -parser.add_argument( - "--fp16", action="store_true", help="whether to use fp16 for the checkpoint" -) +parser.add_argument("--name", type=str, required=True, help="name of the checkpoint") +parser.add_argument("--path", type=Path, required=True, help="path to the checkpoint") +parser.add_argument("--version", type=int, default=0, help="version of the checkpoint") +parser.add_argument("--safe", action="store_true", help="whether to use safetensors") +parser.add_argument("--fp16", action="store_true", help="whether to use fp16 for the checkpoint") args = parser.parse_args() @@ -35,10 +25,7 @@ dtype = torch.float16 if args.fp16 else torch.float32 model = torch.load(args.path)["state_dict"] -model = { - key.replace("model.", ""): val.to(device="cpu", dtype=dtype) - for key, val in model.items() -} +model = {key.replace("model.", ""): val.to(device="cpu", dtype=dtype) for key, val in model.items()} output_name = args.name if args.version != 0: