From 11c53a046ea7335cf3b132d61c22af81290748d5 Mon Sep 17 00:00:00 2001 From: Joseph Paul Cohen Date: Sun, 22 Jan 2023 11:55:31 -0800 Subject: [PATCH] pep8 cleanup --- torchxrayvision/autoencoders.py | 6 +++--- .../baseline_models/chestx_det/__init__.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchxrayvision/autoencoders.py b/torchxrayvision/autoencoders.py index 5299ae4..08b34f2 100644 --- a/torchxrayvision/autoencoders.py +++ b/torchxrayvision/autoencoders.py @@ -15,6 +15,7 @@ "class": "ResNetAE101" } + class Bottleneck(nn.Module): expansion = 4 @@ -171,13 +172,12 @@ def _make_up_block(self, block, init_channels, num_layer, stride=1): return nn.Sequential(*layers) def encode(self, x, check_resolution=True): - + if check_resolution and hasattr(self, 'weights_metadata'): resolution = self.weights_metadata['resolution'] if (x.shape[2] != resolution) | (x.shape[3] != resolution): raise ValueError("Input size ({}x{}) is not the native resolution ({}x{}) for this model. Set check_resolution=False on the encode function to override this error.".format(x.shape[2], x.shape[3], resolution, resolution)) - - + x = self.conv1(x) x = self.bn1(x) x = self.relu(x) diff --git a/torchxrayvision/baseline_models/chestx_det/__init__.py b/torchxrayvision/baseline_models/chestx_det/__init__.py index e935c55..1fdd7ff 100644 --- a/torchxrayvision/baseline_models/chestx_det/__init__.py +++ b/torchxrayvision/baseline_models/chestx_det/__init__.py @@ -41,7 +41,7 @@ class PSPNet(nn.Module): def __init__(self): super(PSPNet, self).__init__() - + self.transform = torchvision.transforms.Compose([ torchvision.transforms.Normalize( [0.485, 0.456, 0.406], @@ -50,9 +50,9 @@ def __init__(self): ]) self._targets = ['Left Clavicle', 'Right Clavicle', 'Left Scapula', 'Right Scapula', - 'Left Lung', 'Right Lung', 'Left Hilus Pulmonis', 'Right Hilus Pulmonis', - 'Heart', 'Aorta', 'Facies Diaphragmatica', 'Mediastinum', 'Weasand', 'Spine'] - + 'Left Lung', 'Right Lung', 'Left Hilus Pulmonis', 'Right Hilus Pulmonis', + 'Heart', 'Aorta', 'Facies Diaphragmatica', 'Mediastinum', 'Weasand', 'Spine'] + model = pspnet(len(self.targets)) url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/pspnet_chestxray_best_model_4.pth" @@ -74,7 +74,7 @@ def __init__(self): except Exception as e: print("Loading failure. Check weights file:", self.weights_filename_local) raise (e) - + model.eval() self.model = model self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False) @@ -90,7 +90,7 @@ def forward(self, x): # expecting values between [-1024,1024] x = (x + 1024) / 2048 - + # now between [0,1] for this model preprocessing x = self.transform(x) y = self.model(x)