diff --git a/pep8.sh b/pep8.sh new file mode 100755 index 0000000..0418328 --- /dev/null +++ b/pep8.sh @@ -0,0 +1 @@ +autopep8 --in-place --select=E1,E2,E3,W1,W2 torchxrayvision/**.py torchxrayvision/**/**.py torchxrayvision/**/**/**.py diff --git a/torchxrayvision/autoencoders.py b/torchxrayvision/autoencoders.py index 5305e06..fab3c7c 100644 --- a/torchxrayvision/autoencoders.py +++ b/torchxrayvision/autoencoders.py @@ -1,21 +1,19 @@ import torch import torch.nn as nn -import torchvision -from torch.nn import Module -import urllib import pathlib -import torch.nn.functional as F import os -import numpy as np +import sys +import requests -model_urls = {} +model_urls = {} model_urls['101-elastic'] = { "description": 'This model was trained on the datasets: nih pc rsna mimic_ch chex datasets.', "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/nihpcrsnamimic_ch-resnet101-2-ae-test2-elastic-e250.pt', - "image_range": [-1024,1024], - "class":"ResNetAE101" - } + "image_range": [-1024, 1024], + "class": "ResNetAE101" +} + class Bottleneck(nn.Module): expansion = 4 @@ -104,11 +102,12 @@ def forward(self, x): return out + # source: https://github.com/ycszen/pytorch-segmentation/blob/master/resnet.py class _ResNetAE(nn.Module): def __init__(self, downblock, upblock, num_layers, n_classes): super(_ResNetAE, self).__init__() - + self.in_channels = 64 self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) @@ -121,7 +120,7 @@ def __init__(self, downblock, upblock, num_layers, n_classes): self.layer3 = self._make_downlayer(downblock, 256, num_layers[2], stride=2) self.layer4 = self._make_downlayer(downblock, 128, num_layers[3], stride=6) - self.uplayer1 = self._make_up_block(upblock, 128, num_layers[3], stride=6) + self.uplayer1 = self._make_up_block(upblock, 128, num_layers[3], stride=6) self.uplayer2 = self._make_up_block(upblock, 64, num_layers[2], stride=2) self.uplayer3 = self._make_up_block(upblock, 32, num_layers[1], stride=2) self.uplayer4 = self._make_up_block(upblock, 16, num_layers[0], stride=2) @@ -133,13 +132,13 @@ def __init__(self, downblock, upblock, num_layers, n_classes): self.uplayer_top = DeconvBottleneck(self.in_channels, 64, 1, 2, upsample) self.conv1_1 = nn.ConvTranspose2d(64, n_classes, kernel_size=1, stride=1, bias=False) - + def __repr__(self): if self.weights != None: return "XRV-ResNetAE-{}".format(self.weights) else: return "XRV-ResNetAE" - + def _make_downlayer(self, block, init_channels, num_layer, stride=1): downsample = None if stride != 1 or self.in_channels != init_channels * block.expansion: @@ -160,13 +159,13 @@ def _make_up_block(self, block, init_channels, num_layer, stride=1): # expansion = block.expansion if stride != 1 or self.in_channels != init_channels * 2: upsample = nn.Sequential( - nn.ConvTranspose2d(self.in_channels, init_channels * 2,kernel_size=1, stride=stride, bias=False, output_padding=1), + nn.ConvTranspose2d(self.in_channels, init_channels * 2, kernel_size=1, stride=stride, bias=False, output_padding=1), nn.BatchNorm2d(init_channels * 2), ) layers = [] for i in range(1, num_layer): layers.append(block(self.in_channels, init_channels, 4)) - + layers.append(block(self.in_channels, init_channels, 2, stride, upsample)) self.in_channels = init_channels * 2 return nn.Sequential(*layers) @@ -182,11 +181,11 @@ def encode(self, x): x = self.layer3(x) x = self.layer4(x) return x - + def features(self, x): return self.encode(x) - def decode(self, x, image_size=[1,1,224,224]): + def decode(self, x, image_size=[1, 1, 224, 224]): x = self.uplayer1(x) x = self.uplayer2(x) x = self.uplayer3(x) @@ -195,14 +194,15 @@ def decode(self, x, image_size=[1,1,224,224]): x = self.conv1_1(x, output_size=image_size) return x - + def forward(self, x): ret = {} ret["z"] = z = self.encode(x) ret["out"] = self.decode(z, x.size()) return ret - + + def ResNetAE50(**kwargs): return _ResNetAE(Bottleneck, DeconvBottleneck, [3, 4, 6, 3], 1, **kwargs) @@ -212,21 +212,21 @@ def ResNetAE101(**kwargs): def ResNetAE(weights=None): - + if weights == None: return ResNetAE101() - + if not weights in model_urls.keys(): raise Exception("weights value must be in {}".format(list(model_urls.keys()))) - + method_to_call = globals()[model_urls[weights]["class"]] ae = method_to_call() - - ## load pretrained models + + # load pretrained models url = model_urls[weights]["weights_url"] weights_filename = os.path.basename(url) - weights_storage_folder = os.path.expanduser(os.path.join("~",".torchxrayvision","models_data")) - weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder,weights_filename)) + weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data")) + weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename)) if not os.path.isfile(weights_filename_local): print("Downloading weights...") @@ -239,17 +239,15 @@ def ResNetAE(weights=None): ae.load_state_dict(state_dict) except Exception as e: print("Loading failure. Check weights file:", weights_filename_local) - raise(e) - + raise (e) + ae = ae.eval() - + ae.weights = weights ae.description = model_urls[weights]["description"] - + return ae - -import sys -import requests + # from here https://sumit-ghosh.com/articles/python-download-progress-bar/ def download(url, filename): @@ -262,11 +260,10 @@ def download(url, filename): else: downloaded = 0 total = int(total) - for data in response.iter_content(chunk_size=max(int(total/1000), 1024*1024)): + for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)): downloaded += len(data) f.write(data) - done = int(50*downloaded/total) - sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50-done))) + done = int(50 * downloaded / total) + sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done))) sys.stdout.flush() sys.stdout.write('\n') - \ No newline at end of file diff --git a/torchxrayvision/baseline_models/chexpert/__init__.py b/torchxrayvision/baseline_models/chexpert/__init__.py index e655c0f..163a0eb 100644 --- a/torchxrayvision/baseline_models/chexpert/__init__.py +++ b/torchxrayvision/baseline_models/chexpert/__init__.py @@ -1,17 +1,8 @@ import sys, os thisfolder = os.path.dirname(__file__) -sys.path.insert(0,thisfolder) -import torch -import csv -import numpy as np -import json -import argparse -import urllib -import pathlib +sys.path.insert(0, thisfolder) import torch import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms from .model import Tasks2Models @@ -21,52 +12,50 @@ class DenseNet(nn.Module): CheXpert: A Large Chest Radiograph Dataset with Uncertainty Labels and Expert Comparison. AAAI Conference on Artificial Intelligence. http://arxiv.org/abs/1901.07031 - + Setting num_models less than 30 will load a subset of the ensemble. - + Modified for torchxrayvision to maintain the pytorch gradient tape and also to provide the features() argument. - + Weights can be found: https://academictorrents.com/details/5c7ee21e6770308f2d2b4bd829e896dbd9d3ee87 https://archive.org/download/torchxrayvision_chexpert_weights/chexpert_weights.zip """ def __init__(self, weights_zip="", num_models=30): - + super(DenseNet, self).__init__() url = "https://academictorrents.com/details/5c7ee21e6770308f2d2b4bd829e896dbd9d3ee87" self.weights_zip = weights_zip self.num_models = num_models - + if self.weights_zip == "": raise Exception("Need to specify weights_zip file location. You can download them from {}".format(url)) - + self.use_gpu = torch.cuda.is_available() dirname = os.path.dirname(os.path.realpath(__file__)) - self.model = Tasks2Models(os.path.join(dirname, 'predict_configs.json'), - weights_zip=self.weights_zip, - num_models=self.num_models, - dynamic=False, - use_gpu=self.use_gpu) + self.model = Tasks2Models(os.path.join(dirname, 'predict_configs.json'), + weights_zip=self.weights_zip, + num_models=self.num_models, + dynamic=False, + use_gpu=self.use_gpu) self.upsample = nn.Upsample(size=(320, 320), mode='bilinear', align_corners=False) - self.pathologies = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion"] - - + def forward(self, x): x = x.repeat(1, 3, 1, 1) x = self.upsample(x) - - #expecting values between [-1024,1024] - x = x/512 - #now between [-2,2] for this model - + + # expecting values between [-1024,1024] + x = x / 512 + # now between [-2,2] for this model + outputs = [] - for sample in x: #sorry hard to make parallel + for sample in x: # sorry hard to make parallel all_task2prob = {} for tasks in self.model: task2prob = self.model.infer(sample.unsqueeze(0), tasks) @@ -76,30 +65,29 @@ def forward(self, x): output = [all_task2prob[patho] for patho in ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]] output = torch.stack(output) outputs.append(output) - + return torch.stack(outputs) - + def features(self, x): x = x.repeat(1, 3, 1, 1) x = self.upsample(x) - - #expecting values between [-1024,1024] - x = x/512 - #now between [-2,2] for this model - + + # expecting values between [-1024,1024] + x = x / 512 + # now between [-2,2] for this model + outputs = [] - for sample in x: #sorry hard to make parallel + for sample in x: # sorry hard to make parallel all_feats = [] for tasks in self.model: task2prob = self.model.features(sample.unsqueeze(0), tasks) all_feats.append(task2prob) feats = torch.stack(all_feats) outputs.append(feats.flatten()) - + return torch.stack(outputs) - - + def __repr__(self): return "CheXpert-DenseNet121-ensemble" - - + + diff --git a/torchxrayvision/baseline_models/chexpert/model.py b/torchxrayvision/baseline_models/chexpert/model.py index b25dcfc..23d1788 100644 --- a/torchxrayvision/baseline_models/chexpert/model.py +++ b/torchxrayvision/baseline_models/chexpert/model.py @@ -1,17 +1,13 @@ -import os import json import torch import torch.nn as nn import torch.nn.functional as F - -import numpy as np - -from collections import OrderedDict from torchvision import models import zipfile import io import tqdm + def uncertain_logits_to_probs(logits): """Convert explicit uncertainty modeling logits to probabilities P(is_abnormal). @@ -39,6 +35,7 @@ class Model(nn.Module): """Models from TorchVision's GitHub page of pretrained neural networks: https://github.com/pytorch/vision/tree/master/torchvision/models """ + def __init__(self, model_fn, task_sequence, model_uncertainty, use_gpu): super(Model, self).__init__() @@ -47,7 +44,7 @@ def __init__(self, model_fn, task_sequence, model_uncertainty, use_gpu): self.use_gpu = use_gpu # Set pretrained to False to avoid loading weights which will be overwritten - self.model = model_fn(pretrained=False) + self.model = model_fn(pretrained=False) self.pool = nn.AdaptiveAvgPool2d(1) @@ -66,20 +63,18 @@ def forward(self, x): x = self.model.classifier(x) return x - + def features2(self, x): features = self.model.features(x) out = F.relu(features, inplace=True) out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) return out - def infer(self, x, tasks): preds = self(x) - probs = self.get_probs(preds)[0] - + task2results = {} for task in tasks: @@ -98,15 +93,13 @@ def __init__(self, task_sequence, model_uncertainty, use_gpu): def load_individual(weights_zip, ckpt_path, model_uncertainty, use_gpu=False): - #print(ckpt_path) with weights_zip.open(ckpt_path) as file: - + stream = io.BytesIO(file.read()) ckpt_dict = torch.load(stream, map_location="cpu") - + device = 'cuda:0' if use_gpu else 'cpu' - #ckpt_path = os.path.join(os.path.dirname(__file__), ckpt_path) - + # Build model, load parameters task_sequence = ckpt_dict['task_sequence'] model = DenseNet121(task_sequence, model_uncertainty, use_gpu) @@ -121,9 +114,10 @@ class Tasks2Models(object): Main attribute is a (task tuple) -> {iterator, list} dictionary, which loads models iteratively depending on the specified task. - """ + """ + def __init__(self, config_path, weights_zip, num_models=1, dynamic=True, use_gpu=False): - + super(Tasks2Models).__init__() self.get_config(config_path) @@ -187,7 +181,7 @@ def get_config(self, config_path): raise ValueError('Invalid configuration: {} = {} (expected "max" or "mean")'.format('aggregation_method', agg_method)) def model_iterator(self, model_dicts, num_models, desc=""): - + def iterator(): for model_dict in model_dicts[:num_models]: @@ -201,7 +195,7 @@ def iterator(): return iterator def model_list(self, model_dicts, num_models, desc=""): - + loaded_models = [] toiter = tqdm.tqdm(model_dicts[:num_models]) toiter.set_description(desc) @@ -218,7 +212,7 @@ def iterator(): return iterator def infer(self, img, tasks): - + ensemble_probs = [] model_iterable = self.tasks2models[tasks] @@ -239,11 +233,11 @@ def infer(self, img, tasks): for task in tasks: ensemble_probs = task2ensemble_results[task] task2results[task] = self.aggregation_fn(torch.stack(ensemble_probs), dim=0) - + assert all([task in task2results for task in tasks]), "Not all tasks in task2results" return task2results - + def features(self, img, tasks): """ Return shape is [3, 30, 1, 1024] diff --git a/torchxrayvision/baseline_models/jfhealthcare/__init__.py b/torchxrayvision/baseline_models/jfhealthcare/__init__.py index 78dfeac..050d758 100644 --- a/torchxrayvision/baseline_models/jfhealthcare/__init__.py +++ b/torchxrayvision/baseline_models/jfhealthcare/__init__.py @@ -1,18 +1,12 @@ import sys, os thisfolder = os.path.dirname(__file__) -sys.path.insert(0,thisfolder) -import torch -import csv -import numpy as np +sys.path.insert(0, thisfolder) from .model import classifier import json -import argparse -import urllib import pathlib import torch import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms + class DenseNet(nn.Module): """ @@ -27,74 +21,76 @@ class DenseNet(nn.Module): archivePrefix={arXiv}, primaryClass={cs.CV} } - + """ def __init__(self, apply_sigmoid=True): - + super(DenseNet, self).__init__() self.apply_sigmoid = apply_sigmoid - + with open(os.path.join(thisfolder, 'config/example.json')) as f: self.cfg = json.load(f) - + class Struct: def __init__(self, **entries): self.__dict__.update(entries) - + self.cfg = Struct(**self.cfg) - + model = classifier.Classifier(self.cfg) model = nn.DataParallel(model).eval() - + url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/baseline_models_jfhealthcare-DenseNet121_pre_train.pth" - + weights_filename = os.path.basename(url) - weights_storage_folder = os.path.expanduser(os.path.join("~",".torchxrayvision","models_data")) - self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder,weights_filename)) + weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data")) + self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename)) if not os.path.isfile(self.weights_filename_local): print("Downloading weights...") print("If this fails you can run `wget {} -O {}`".format(url, self.weights_filename_local)) pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True) download(url, self.weights_filename_local) - + try: - ckpt = torch.load(self.weights_filename_local , map_location="cpu") + ckpt = torch.load(self.weights_filename_local, map_location="cpu") model.module.load_state_dict(ckpt) except Exception as e: print("Loading failure. Check weights file:", self.weights_filename_local) - raise(e) - + raise (e) + self.model = model self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False) - + self.pathologies = ["Cardiomegaly", 'Edema', 'Consolidation', 'Atelectasis', 'Effusion'] - - + def forward(self, x): x = x.repeat(1, 3, 1, 1) x = self.upsample(x) - - #expecting values between [-1024,1024] - x = x/512 - #now between [-2,2] for this model - + + # expecting values between [-1024,1024] + x = x / 512 + # now between [-2,2] for this model + y, _ = self.model(x) - y = torch.cat(y,1) - + y = torch.cat(y, 1) + if self.apply_sigmoid: y = torch.sigmoid(y) - + return y - + def __repr__(self): return "jfhealthcare-DenseNet121" - + + import sys import requests # from here https://sumit-ghosh.com/articles/python-download-progress-bar/ + + def download(url, filename): with open(filename, 'wb') as f: response = requests.get(url, stream=True) @@ -105,12 +101,11 @@ def download(url, filename): else: downloaded = 0 total = int(total) - for data in response.iter_content(chunk_size=max(int(total/1000), 1024*1024)): + for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)): downloaded += len(data) f.write(data) - done = int(50*downloaded/total) - sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50-done))) + done = int(50 * downloaded / total) + sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done))) sys.stdout.flush() sys.stdout.write('\n') - - \ No newline at end of file + diff --git a/torchxrayvision/datasets.py b/torchxrayvision/datasets.py index f25de71..1e30416 100644 --- a/torchxrayvision/datasets.py +++ b/torchxrayvision/datasets.py @@ -74,11 +74,11 @@ def apply_transforms(sample, transform, seed=None) -> Dict: This way data augmentation will work for segmentation or other tasks which use masks information. """ - + if seed is None: MAX_RAND_VAL = 2147483647 seed = np.random.randint(MAX_RAND_VAL) - + if transform is not None: random.seed(seed) torch.random.manual_seed(seed) @@ -159,7 +159,7 @@ def limit_to_selected_views(self, views): self.csv.view.fillna("UNKNOWN", inplace=True) if "*" not in views: - self.csv = self.csv[self.csv["view"].isin(self.views)] # Select the view + self.csv = self.csv[self.csv["view"].isin(self.views)] # Select the view class MergeDataset(Dataset): @@ -173,9 +173,9 @@ def __init__(self, datasets, seed=0, label_concat=False): self.offset = np.zeros(0) currentoffset = 0 for i, dataset in enumerate(datasets): - self.which_dataset = np.concatenate([self.which_dataset, np.zeros(len(dataset))+i]) + self.which_dataset = np.concatenate([self.which_dataset, np.zeros(len(dataset)) + i]) self.length += len(dataset) - self.offset = np.concatenate([self.offset, np.zeros(len(dataset))+currentoffset]) + self.offset = np.concatenate([self.offset, np.zeros(len(dataset)) + currentoffset]) currentoffset += len(dataset) if dataset.pathologies != self.pathologies: raise Exception("incorrect pathology alignment") @@ -188,10 +188,10 @@ def __init__(self, datasets, seed=0, label_concat=False): self.which_dataset = self.which_dataset.astype(int) if label_concat: - new_labels = np.zeros([self.labels.shape[0], self.labels.shape[1]*len(datasets)])*np.nan + new_labels = np.zeros([self.labels.shape[0], self.labels.shape[1] * len(datasets)]) * np.nan for i, shift in enumerate(self.which_dataset): size = self.labels.shape[1] - new_labels[i, shift*size:shift*size+size] = self.labels[i] + new_labels[i, shift * size:shift * size + size] = self.labels[i] self.labels = new_labels try: @@ -209,11 +209,10 @@ def __setattr__(self, name, value): object.__setattr__(self, name, value) - def string(self): s = self.__class__.__name__ + " num_samples={}\n".format(len(self)) for i, d in enumerate(self.datasets): - if i < len(self.datasets)-1: + if i < len(self.datasets) - 1: s += "├{} ".format(i) + d.string().replace("\n", "\n| ") + "\n" else: s += "└{} ".format(i) + d.string().replace("\n", "\n ") + "\n" @@ -223,7 +222,7 @@ def __len__(self): return self.length def __getitem__(self, idx): - item = self.datasets[int(self.which_dataset[idx])][idx - int(self.offset[idx])] + item = self.datasets[int(self.which_dataset[idx])][idx - int(self.offset[idx])] item["lab"] = self.labels[idx] item["source"] = self.which_dataset[idx] return item @@ -303,6 +302,7 @@ class NIH_Dataset(Dataset): Download resized (224x224) images here: https://academictorrents.com/details/e615d3aebce373f1dc8bd9d11064da55bdadede0 """ + def __init__(self, imgpath, csvpath=os.path.join(datapath, "Data_Entry_2017_v2020.csv.gz"), @@ -314,7 +314,7 @@ def __init__(self, seed=0, unique_patients=True, pathology_masks=False - ): + ): super(NIH_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -347,8 +347,8 @@ def __init__(self, ####### pathology masks ######## # load nih pathology masks self.pathology_maskscsv = pd.read_csv(bbox_list_path, - names=["Image Index","Finding Label","x","y","w","h","_1","_2","_3"], - skiprows=1) + names=["Image Index", "Finding Label", "x", "y", "w", "h", "_1", "_2", "_3"], + skiprows=1) # change label name to match self.pathology_maskscsv.loc[self.pathology_maskscsv["Finding Label"] == "Infiltrate", "Finding Label"] = "Infiltration" @@ -363,17 +363,17 @@ def __init__(self, self.labels = np.asarray(self.labels).T self.labels = self.labels.astype(np.float32) - ########## add consistent csv values + # add consistent csv values # offset_day_int - #self.csv["offset_day_int"] = + # self.csv["offset_day_int"] = # patientid self.csv["patientid"] = self.csv["Patient ID"].astype(str) - + # age - self.csv['age_years'] = self.csv['Patient Age']*1.0 - + self.csv['age_years'] = self.csv['Patient Age'] * 1.0 + # sex self.csv['sex_male'] = self.csv['Patient Gender'] == 'M' self.csv['sex_female'] = self.csv['Patient Gender'] == 'F' @@ -405,7 +405,7 @@ def __getitem__(self, idx): def get_mask_dict(self, image_name, this_size): base_size = 1024 - scale = this_size/base_size + scale = this_size / base_size images_with_masks = self.pathology_maskscsv[self.pathology_maskscsv["Image Index"] == image_name] path_mask = {} @@ -415,11 +415,11 @@ def get_mask_dict(self, image_name, this_size): # Don't add masks for labels we don't have if row["Finding Label"] in self.pathologies: - mask = np.zeros([this_size,this_size]) - xywh = np.asarray([row.x,row.y,row.w,row.h]) - xywh = xywh*scale + mask = np.zeros([this_size, this_size]) + xywh = np.asarray([row.x, row.y, row.w, row.h]) + xywh = xywh * scale xywh = xywh.astype(int) - mask[xywh[1]:xywh[1]+xywh[3],xywh[0]:xywh[0]+xywh[2]] = 1 + mask[xywh[1]:xywh[1] + xywh[3], xywh[0]:xywh[0] + xywh[2]] = 1 # Resize so image resizing works mask = mask[None, :, :] @@ -447,6 +447,7 @@ class RSNA_Pneumonia_Dataset(Dataset): JPG files stored here: https://academictorrents.com/details/95588a735c9ae4d123f3ca408e56570409bcf2a9 """ + def __init__(self, imgpath, csvpath=os.path.join(datapath, "kaggle_stage_2_train_labels.csv.zip"), @@ -459,7 +460,7 @@ def __init__(self, unique_patients=True, pathology_masks=False, extension=".jpg" - ): + ): super(RSNA_Pneumonia_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -473,8 +474,7 @@ def __init__(self, self.pathologies = sorted(self.pathologies) self.extension = extension - self.use_pydicom=( extension == ".dcm" ) - + self.use_pydicom = (extension == ".dcm") # Load data self.csvpath = csvpath @@ -498,7 +498,7 @@ def __init__(self, # Get our classes. self.labels = [] self.labels.append(self.csv["Target"].values) - self.labels.append(self.csv["Target"].values) #same labels for both + self.labels.append(self.csv["Target"].values) # same labels for both # set if we have masks self.csv["has_masks"] = ~np.isnan(self.csv["x"]) @@ -506,7 +506,7 @@ def __init__(self, self.labels = np.asarray(self.labels).T self.labels = self.labels.astype(np.float32) - ########## add consistent csv values + # add consistent csv values # offset_day_int # TODO: merge with NIH metadata to get dates for images @@ -534,7 +534,7 @@ def __getitem__(self, idx): except ImportError as e: raise Exception("Please install pydicom to work with this dataset") - img=pydicom.filereader.dcmread(img_path).pixel_array + img = pydicom.filereader.dcmread(img_path).pixel_array else: img = imread(img_path) @@ -551,24 +551,24 @@ def __getitem__(self, idx): def get_mask_dict(self, image_name, this_size): base_size = 1024 - scale = this_size/base_size + scale = this_size / base_size images_with_masks = self.raw_csv[self.raw_csv["patientId"] == image_name] path_mask = {} # All masks are for both pathologies for patho in ["Pneumonia", "Lung Opacity"]: - mask = np.zeros([this_size,this_size]) + mask = np.zeros([this_size, this_size]) # Don't add masks for labels we don't have if patho in self.pathologies: for i in range(len(images_with_masks)): row = images_with_masks.iloc[i] - xywh = np.asarray([row.x,row.y,row.width,row.height]) - xywh = xywh*scale + xywh = np.asarray([row.x, row.y, row.width, row.height]) + xywh = xywh * scale xywh = xywh.astype(int) - mask[xywh[1]:xywh[1]+xywh[3],xywh[0]:xywh[0]+xywh[2]] = 1 + mask[xywh[1]:xywh[1] + xywh[3], xywh[0]:xywh[0] + xywh[2]] = 1 # Resize so image resizing works mask = mask[None, :, :] @@ -595,6 +595,7 @@ class NIH_Google_Dataset(Dataset): NIH data can be downloaded here: https://academictorrents.com/details/e615d3aebce373f1dc8bd9d11064da55bdadede0 """ + def __init__(self, imgpath, csvpath=os.path.join(datapath, "google2019_nih-chest-xray-labels.csv.gz"), @@ -604,7 +605,7 @@ def __init__(self, nrows=None, seed=0, unique_patients=True - ): + ): super(NIH_Google_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -688,6 +689,7 @@ class PC_Dataset(Dataset): Download resized (224x224) images here (recropped): https://academictorrents.com/details/96ebb4f92b85929eadfb16761f310a6d04105797 """ + def __init__(self, imgpath, csvpath=os.path.join(datapath, "PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv.gz"), @@ -697,7 +699,7 @@ def __init__(self, flat_dir=True, seed=0, unique_patients=True - ): + ): super(PC_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -705,13 +707,13 @@ def __init__(self, self.pathologies = ["Atelectasis", "Consolidation", "Infiltration", "Pneumothorax", "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia", "Pleural_Thickening", - "Cardiomegaly", "Nodule", "Mass", "Hernia","Fracture", + "Cardiomegaly", "Nodule", "Mass", "Hernia", "Fracture", "Granuloma", "Flattened Diaphragm", "Bronchiectasis", "Aortic Elongation", "Scoliosis", "Hilar Enlargement", "Tuberculosis", "Air Trapping", "Costophrenic Angle Blunting", "Aortic Atheromatosis", "Hemidiaphragm Elevation", - "Support Devices", "Tube'"] # the Tube' is intentional + "Support Devices", "Tube'"] # the Tube' is intentional self.pathologies = sorted(self.pathologies) @@ -731,7 +733,7 @@ def __init__(self, "pulmonary artery enlargement"] mapping["Support Devices"] = ["device", "pacemaker"] - mapping["Tube'"] = ["stent'"] ## the ' is to select findings which end in that word + mapping["Tube'"] = ["stent'"] # the ' is to select findings which end in that word self.imgpath = imgpath self.transform = transform @@ -743,7 +745,7 @@ def __init__(self, self.csv = pd.read_csv(self.csvpath, low_memory=False) # Standardize view names - self.csv.loc[self.csv["Projection"].isin(["AP_horizontal"]),"Projection"] = "AP Supine" + self.csv.loc[self.csv["Projection"].isin(["AP_horizontal"]), "Projection"] = "AP Supine" self.csv["view"] = self.csv['Projection'] self.limit_to_selected_views(views) @@ -769,7 +771,7 @@ def __init__(self, self.csv = self.csv.groupby("PatientID").first().reset_index() # Filter out age < 10 (paper published 2019) - self.csv = self.csv[(2019-self.csv.PatientBirth > 10)] + self.csv = self.csv[(2019 - self.csv.PatientBirth > 10)] # Get our classes. self.labels = [] @@ -785,18 +787,18 @@ def __init__(self, self.pathologies[self.pathologies.index("Tube'")] = "Tube" - ########## add consistent csv values + # add consistent csv values # offset_day_int dt = pd.to_datetime(self.csv["StudyDate_DICOM"], format="%Y%m%d") - self.csv["offset_day_int"] = dt.astype(int)// 10**9 // 86400 + self.csv["offset_day_int"] = dt.astype(int) // 10**9 // 86400 # patientid self.csv["patientid"] = self.csv["PatientID"].astype(str) - + # age self.csv['age_years'] = (2017 - self.csv['PatientBirth']) - + # sex self.csv['sex_male'] = self.csv['PatientSex_DICOM'] == 'M' self.csv['sex_female'] = self.csv['PatientSex_DICOM'] == 'F' @@ -813,7 +815,7 @@ def __getitem__(self, idx): sample["lab"] = self.labels[idx] imgid = self.csv['ImageID'].iloc[idx] - img_path = os.path.join(self.imgpath,imgid) + img_path = os.path.join(self.imgpath, imgid) img = imread(img_path) sample["img"] = normalize(img, maxval=65535, reshape=True) @@ -823,6 +825,7 @@ def __getitem__(self, idx): return sample + class CheX_Dataset(Dataset): """CheXpert Dataset @@ -838,6 +841,7 @@ class CheX_Dataset(Dataset): A small validation set is provided with the data as well, but is so tiny, it not included here. """ + def __init__(self, imgpath, csvpath=os.path.join(datapath, "chexpert_train.csv.gz"), @@ -847,7 +851,7 @@ def __init__(self, flat_dir=True, seed=0, unique_patients=True - ): + ): super(CheX_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -875,14 +879,14 @@ def __init__(self, self.csv = pd.read_csv(self.csvpath) self.views = views - self.csv["view"] = self.csv["Frontal/Lateral"] # Assign view column - self.csv.loc[(self.csv["view"] == "Frontal"), "view"] = self.csv["AP/PA"] # If Frontal change with the corresponding value in the AP/PA column otherwise remains Lateral - self.csv["view"] = self.csv["view"].replace({'Lateral': "L"}) # Rename Lateral with L + self.csv["view"] = self.csv["Frontal/Lateral"] # Assign view column + self.csv.loc[(self.csv["view"] == "Frontal"), "view"] = self.csv["AP/PA"] # If Frontal change with the corresponding value in the AP/PA column otherwise remains Lateral + self.csv["view"] = self.csv["view"].replace({'Lateral': "L"}) # Rename Lateral with L self.limit_to_selected_views(views) if unique_patients: - self.csv["PatientID"] = self.csv["Path"].str.extract(pat = r'(patient\d+)') + self.csv["PatientID"] = self.csv["Path"].str.extract(pat=r'(patient\d+)') self.csv = self.csv.groupby("PatientID").first().reset_index() # Get our classes. @@ -904,7 +908,7 @@ def __init__(self, # Rename pathologies self.pathologies = list(np.char.replace(self.pathologies, "Pleural Effusion", "Effusion")) - ########## add consistent csv values + # add consistent csv values # offset_day_int @@ -917,20 +921,18 @@ def __init__(self, raise NotImplemented patientid = patientid.str.split("/study", expand=True)[0] - patientid = patientid.str.replace("patient","") - + patientid = patientid.str.replace("patient", "") + # patientid self.csv["patientid"] = patientid - + # age - self.csv['age_years'] = self.csv['Age']*1.0 + self.csv['age_years'] = self.csv['Age'] * 1.0 self.csv['Age'][(self.csv['Age'] == 0)] = None - + # sex self.csv['sex_male'] = self.csv['Sex'] == 'Male' self.csv['sex_female'] = self.csv['Sex'] == 'Female' - - def string(self): return self.__class__.__name__ + " num_samples={} views={} data_aug={}".format(len(self), self.views, self.data_aug) @@ -944,7 +946,7 @@ def __getitem__(self, idx): sample["lab"] = self.labels[idx] imgid = self.csv['Path'].iloc[idx] - imgid = imgid.replace("CheXpert-v1.0-small/","") + imgid = imgid.replace("CheXpert-v1.0-small/", "") img_path = os.path.join(self.imgpath, imgid) img = imread(img_path) @@ -955,6 +957,7 @@ def __getitem__(self, idx): return sample + class MIMIC_Dataset(Dataset): """MIMIC-CXR Dataset @@ -967,6 +970,7 @@ class MIMIC_Dataset(Dataset): Dataset website here: https://physionet.org/content/mimic-cxr-jpg/2.0.0/ """ + def __init__(self, imgpath, csvpath, @@ -977,7 +981,7 @@ def __init__(self, flat_dir=True, seed=0, unique_patients=True - ): + ): super(MIMIC_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -1036,7 +1040,7 @@ def __init__(self, # Rename pathologies self.pathologies = np.char.replace(self.pathologies, "Pleural Effusion", "Effusion") - ########## add consistent csv values + # add consistent csv values # offset_day_int self.csv["offset_day_int"] = self.csv["StudyDate"] @@ -1069,6 +1073,7 @@ def __getitem__(self, idx): return sample + class Openi_Dataset(Dataset): """OpenI Dataset @@ -1086,6 +1091,7 @@ class Openi_Dataset(Dataset): Download images: https://academictorrents.com/details/5a3a439df24931f410fac269b87b050203d9467d """ + def __init__(self, imgpath, xmlpath=os.path.join(datapath, "NLMCXR_reports.tgz"), dicomcsv_path=os.path.join(datapath, "nlmcxr_dicom_metadata.csv.gz"), @@ -1097,7 +1103,7 @@ def __init__(self, imgpath, nrows=None, seed=0, unique_patients=True - ): + ): super(Openi_Dataset, self).__init__() import xml @@ -1184,10 +1190,10 @@ def __init__(self, imgpath, self.pathologies = np.char.replace(self.pathologies, "Opacity", "Lung Opacity") self.pathologies = np.char.replace(self.pathologies, "Lesion", "Lung Lesion") - ########## add consistent csv values + # add consistent csv values # offset_day_int - #self.csv["offset_day_int"] = + # self.csv["offset_day_int"] = # patientid self.csv["patientid"] = self.csv["uid"].astype(str) @@ -1204,7 +1210,7 @@ def __getitem__(self, idx): sample["lab"] = self.labels[idx] imageid = self.csv.iloc[idx].imageid - img_path = os.path.join(self.imgpath,imageid + ".png") + img_path = os.path.join(self.imgpath, imageid + ".png") img = imread(img_path) sample["img"] = normalize(img, maxval=255, reshape=True) @@ -1232,6 +1238,7 @@ class COVID19_Dataset(Dataset): """ dataset_url = "https://github.com/ieee8023/covid-chestxray-dataset" + def __init__(self, imgpath=os.path.join(thispath, "covid-chestxray-dataset", "images"), csvpath=os.path.join(thispath, "covid-chestxray-dataset", "metadata.csv"), @@ -1243,7 +1250,7 @@ def __init__(self, seed=0, unique_patients=True, semantic_masks=False - ): + ): super(COVID19_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -1285,7 +1292,7 @@ def __init__(self, temp = zipfile.ZipFile(self.semantic_masks_v7labs_lungs_path) self.semantic_masks_v7labs_lungs_namelist = temp.namelist() - ########## add consistent csv values + # add consistent csv values # offset_day_int self.csv["offset_day_int"] = self.csv["offset"] @@ -1362,7 +1369,7 @@ def __init__(self, data_aug=None, seed=0, views=["PA"] - ): + ): """ Args: img_path (str): Path to `MontgomerySet` or `ChinaSet_AllFiles` folder @@ -1389,7 +1396,7 @@ def __init__(self, self.csv["view"] = "PA" self.limit_to_selected_views(views) - self.labels = self.csv["label"].values.reshape(-1,1) + self.labels = self.csv["label"].values.reshape(-1, 1) self.pathologies = ["Tuberculosis"] def string(self): @@ -1415,6 +1422,7 @@ def __getitem__(self, idx): return sample + class SIIM_Pneumothorax_Dataset(Dataset): """SIIM Pneumothorax Dataset @@ -1433,7 +1441,7 @@ def __init__(self, seed=0, unique_patients=True, pathology_masks=False - ): + ): super(SIIM_Pneumothorax_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. self.imgpath = imgpath @@ -1462,7 +1470,7 @@ def __init__(self, file_map = {} for root, directories, files in os.walk(self.imgpath, followlinks=False): for filename in files: - filePath = os.path.join(root,filename) + filePath = os.path.join(root, filename) file_map[filename] = filePath _cache_dict["siim_file_map"] = file_map self.file_map = _cache_dict["siim_file_map"] @@ -1506,7 +1514,7 @@ def get_pathology_mask_dict(self, image_name, this_size): # From kaggle code def rle2mask(rle, width, height): - mask= np.zeros(width* height) + mask = np.zeros(width * height) array = np.asarray([int(x) for x in rle.split()]) starts = array[0::2] lengths = array[1::2] @@ -1514,7 +1522,7 @@ def rle2mask(rle, width, height): current_position = 0 for index, start in enumerate(starts): current_position += start - mask[current_position:current_position+lengths[index]] = 1 + mask[current_position:current_position + lengths[index]] = 1 current_position += lengths[index] return mask.reshape(width, height) @@ -1522,17 +1530,17 @@ def rle2mask(rle, width, height): if len(images_with_masks) > 0: # Using a for loop so it is consistent with the other code for patho in ["Pneumothorax"]: - mask = np.zeros([this_size,this_size]) + mask = np.zeros([this_size, this_size]) # don't add masks for labels we don't have if patho in self.pathologies: for i in range(len(images_with_masks)): row = images_with_masks.iloc[i] - mask = rle2mask(row[" EncodedPixels"],base_size,base_size) + mask = rle2mask(row[" EncodedPixels"], base_size, base_size) mask = mask.T mask = skimage.transform.resize(mask, (this_size, this_size), mode='constant', order=0) - mask = mask.round() #make 0,1 + mask = mask.round() # make 0,1 # reshape so image resizing works mask = mask[None, :, :] @@ -1550,6 +1558,7 @@ class VinBrain_Dataset(Dataset): https://www.kaggle.com/c/vinbigdata-chest-xray-abnormalities-detection """ + def __init__(self, imgpath, csvpath=os.path.join(datapath, "vinbigdata-train.csv.gz"), @@ -1558,7 +1567,7 @@ def __init__(self, data_aug=None, seed=0, pathology_masks=False - ): + ): super(VinBrain_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -1641,7 +1650,7 @@ def __getitem__(self, idx): bitdepth = 8 if mode == "MONOCHROME1": - img = -1*img + 2**float(bitdepth) + img = -1 * img + 2**float(bitdepth) elif mode == "MONOCHROME2": pass else: @@ -1686,13 +1695,13 @@ class StonyBrookCOVID_Dataset(Dataset): """ def __init__(self, - imgpath, # path to CXR_images_scored - csvpath, # path to ralo-dataset-metadata.csv + imgpath, # path to CXR_images_scored + csvpath, # path to ralo-dataset-metadata.csv transform=None, data_aug=None, views=["AP"], seed=0 - ): + ): super(StonyBrookCOVID_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -1705,10 +1714,10 @@ def __init__(self, self.csv = pd.read_csv(self.csvpath, skiprows=1) self.MAXVAL = 255 # Range [0 255] - self.pathologies = ["Geographic Extent","Lung Opacity"] + self.pathologies = ["Geographic Extent", "Lung Opacity"] - self.csv["Geographic Extent"] = (self.csv["Total GEOGRAPHIC"] + self.csv["Total GEOGRAPHIC.1"])/2 - self.csv["Lung Opacity"] = (self.csv["Total OPACITY"] + self.csv["Total OPACITY.1"])/2 + self.csv["Geographic Extent"] = (self.csv["Total GEOGRAPHIC"] + self.csv["Total GEOGRAPHIC.1"]) / 2 + self.csv["Lung Opacity"] = (self.csv["Total OPACITY"] + self.csv["Total OPACITY.1"]) / 2 self.labels = [] self.labels.append(self.csv["Geographic Extent"]) @@ -1717,13 +1726,13 @@ def __init__(self, self.labels = np.asarray(self.labels).T self.labels = self.labels.astype(np.float32) - ########## add consistent csv values + # add consistent csv values # offset_day_int - date_col = self.csv["Exam_DateTime"].str.split("_",expand=True)[0] + date_col = self.csv["Exam_DateTime"].str.split("_", expand=True)[0] dt = pd.to_datetime(date_col, format="%Y%m%d") - self.csv["offset_day_int"] = dt.astype(int)// 10**9 // 86400 + self.csv["offset_day_int"] = dt.astype(int) // 10**9 // 86400 # patientid self.csv["patientid"] = self.csv["Subject_ID"].astype(str) @@ -1753,6 +1762,7 @@ def __getitem__(self, idx): return sample + class ObjectCXR_Dataset(Dataset): """ObjectCXR Dataset @@ -1762,13 +1772,14 @@ class ObjectCXR_Dataset(Dataset): https://academictorrents.com/details/fdc91f11d7010f7259a05403fc9d00079a09f5d5 """ + def __init__(self, imgzippath, csvpath, transform=None, data_aug=None, seed=0 - ): + ): super(ObjectCXR_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -1838,11 +1849,11 @@ def __call__(self, img): warnings.simplefilter("ignore") return skimage.transform.resize(img, (1, self.size, self.size), mode='constant', preserve_range=True).astype(np.float32) elif self.engine == "cv2": - import cv2 # pip install opencv-python - return cv2.resize(img[0,:,:], + import cv2 # pip install opencv-python + return cv2.resize(img[0, :, :], (self.size, self.size), - interpolation = cv2.INTER_AREA - ).reshape(1,self.size,self.size).astype(np.float32) + interpolation=cv2.INTER_AREA + ).reshape(1, self.size, self.size).astype(np.float32) else: raise Exception("Unknown engine, Must be skimage (default) or cv2.") @@ -1850,7 +1861,7 @@ def __call__(self, img): class XRayCenterCrop(object): def crop_center(self, img): _, y, x = img.shape - crop_size = np.min([y,x]) + crop_size = np.min([y, x]) startx = x // 2 - (crop_size // 2) starty = y // 2 - (crop_size // 2) return img[:, starty:starty + crop_size, startx:startx + crop_size] @@ -1865,6 +1876,7 @@ class CovariateDataset(Dataset): Viviano et al. Saliency is a Possible Red Herring When Diagnosing Poor Generalization https://arxiv.org/abs/1910.00199 """ + def __init__(self, d1, d1_target, d2, d2_target, @@ -1874,7 +1886,7 @@ def __init__(self, nsamples=None, splits=[0.5, 0.25, 0.25], verbose=False - ): + ): super(CovariateDataset, self).__init__() self.splits = np.array(splits) @@ -1889,8 +1901,8 @@ def __init__(self, np.random.seed(seed) # Reset the seed so all runs are the same. all_imageids = np.concatenate([np.arange(len(self.d1)), - np.arange(len(self.d2))]).astype(int) - + np.arange(len(self.d2))]).astype(int) + all_idx = np.arange(len(all_imageids)).astype(int) all_labels = np.concatenate([d1_target, @@ -1900,11 +1912,11 @@ def __init__(self, np.ones(len(self.d2))]).astype(int) idx_sick = all_labels == 1 - n_per_category = np.min([sum(idx_sick[all_site==0]), - sum(idx_sick[all_site==1]), - sum(~idx_sick[all_site==0]), - sum(~idx_sick[all_site==1])]) - + n_per_category = np.min([sum(idx_sick[all_site == 0]), + sum(idx_sick[all_site == 1]), + sum(~idx_sick[all_site == 0]), + sum(~idx_sick[all_site == 1])]) + all_csv = pd.concat([d1.csv, d2.csv]) all_csv['site'] = all_site all_csv['label'] = all_labels @@ -1912,24 +1924,24 @@ def __init__(self, if verbose: print("n_per_category={}".format(n_per_category)) - all_0_neg = all_idx[np.where((all_site==0) & (all_labels==0))] + all_0_neg = all_idx[np.where((all_site == 0) & (all_labels == 0))] all_0_neg = np.random.choice(all_0_neg, n_per_category, replace=False) - all_0_pos = all_idx[np.where((all_site==0) & (all_labels==1))] + all_0_pos = all_idx[np.where((all_site == 0) & (all_labels == 1))] all_0_pos = np.random.choice(all_0_pos, n_per_category, replace=False) - all_1_neg = all_idx[np.where((all_site==1) & (all_labels==0))] + all_1_neg = all_idx[np.where((all_site == 1) & (all_labels == 0))] all_1_neg = np.random.choice(all_1_neg, n_per_category, replace=False) - all_1_pos = all_idx[np.where((all_site==1) & (all_labels==1))] + all_1_pos = all_idx[np.where((all_site == 1) & (all_labels == 1))] all_1_pos = np.random.choice(all_1_pos, n_per_category, replace=False) # TRAIN train_0_neg = np.random.choice( - all_0_neg, int(n_per_category*ratio*splits[0]*2), replace=False) + all_0_neg, int(n_per_category * ratio * splits[0] * 2), replace=False) train_0_pos = np.random.choice( - all_0_pos, int(n_per_category*(1-ratio)*splits[0]*2), replace=False) + all_0_pos, int(n_per_category * (1 - ratio) * splits[0] * 2), replace=False) train_1_neg = np.random.choice( - all_1_neg, int(n_per_category*(1-ratio)*splits[0]*2), replace=False) + all_1_neg, int(n_per_category * (1 - ratio) * splits[0] * 2), replace=False) train_1_pos = np.random.choice( - all_1_pos, int(n_per_category*ratio*splits[0]*2), replace=False) + all_1_pos, int(n_per_category * ratio * splits[0] * 2), replace=False) # REDUCE POST-TRAIN all_0_neg = np.setdiff1d(all_0_neg, train_0_neg) @@ -1939,23 +1951,23 @@ def __init__(self, if verbose: print("TRAIN (ratio={:.2}): neg={}, pos={}, d1_pos/neg={}/{}, d2_pos/neg={}/{}".format( - ratio, - len(train_0_neg)+len(train_1_neg), - len(train_0_pos)+len(train_1_pos), - len(train_0_pos), - len(train_0_neg), - len(train_1_pos), - len(train_1_neg))) + ratio, + len(train_0_neg) + len(train_1_neg), + len(train_0_pos) + len(train_1_pos), + len(train_0_pos), + len(train_0_neg), + len(train_1_pos), + len(train_1_neg))) # VALID valid_0_neg = np.random.choice( - all_0_neg, int(n_per_category*(1-ratio)*splits[1]*2), replace=False) + all_0_neg, int(n_per_category * (1 - ratio) * splits[1] * 2), replace=False) valid_0_pos = np.random.choice( - all_0_pos, int(n_per_category*ratio*splits[1]*2), replace=False) + all_0_pos, int(n_per_category * ratio * splits[1] * 2), replace=False) valid_1_neg = np.random.choice( - all_1_neg, int(n_per_category*ratio*splits[1]*2), replace=False) + all_1_neg, int(n_per_category * ratio * splits[1] * 2), replace=False) valid_1_pos = np.random.choice( - all_1_pos, int(n_per_category*(1-ratio)*splits[1]*2), replace=False) + all_1_pos, int(n_per_category * (1 - ratio) * splits[1] * 2), replace=False) # REDUCE POST-VALID all_0_neg = np.setdiff1d(all_0_neg, valid_0_neg) @@ -1965,13 +1977,13 @@ def __init__(self, if verbose: print("VALID (ratio={:.2}): neg={}, pos={}, d1_pos/neg={}/{}, d2_pos/neg={}/{}".format( - 1-ratio, - len(valid_0_neg)+len(valid_1_neg), - len(valid_0_pos)+len(valid_1_pos), - len(valid_0_pos), - len(valid_0_neg), - len(valid_1_pos), - len(valid_1_neg))) + 1 - ratio, + len(valid_0_neg) + len(valid_1_neg), + len(valid_0_pos) + len(valid_1_pos), + len(valid_0_pos), + len(valid_0_neg), + len(valid_1_pos), + len(valid_1_neg))) # TEST test_0_neg = all_0_neg @@ -1981,21 +1993,20 @@ def __init__(self, if verbose: print("TEST (ratio={:.2}): neg={}, pos={}, d1_pos/neg={}/{}, d2_pos/neg={}/{}".format( - 1-ratio, - len(test_0_neg)+len(test_1_neg), - len(test_0_pos)+len(test_1_pos), - len(test_0_pos), - len(test_0_neg), - len(test_1_pos), - len(test_1_neg))) - + 1 - ratio, + len(test_0_neg) + len(test_1_neg), + len(test_0_pos) + len(test_1_pos), + len(test_0_pos), + len(test_0_neg), + len(test_1_pos), + len(test_1_neg))) def _reduce_nsamples(nsamples, a, b, c, d): if nsamples: - a = a[:int(np.floor(nsamples/4))] - b = b[:int(np.ceil(nsamples/4))] - c = c[:int(np.ceil(nsamples/4))] - d = d[:int(np.floor(nsamples/4))] + a = a[:int(np.floor(nsamples / 4))] + b = b[:int(np.ceil(nsamples / 4))] + c = c[:int(np.ceil(nsamples / 4))] + d = d[:int(np.floor(nsamples / 4))] return (a, b, c, d) @@ -2014,10 +2025,9 @@ def _reduce_nsamples(nsamples, a, b, c, d): self.select_idx = np.concatenate([a, b, c, d]) self.imageids = all_imageids[self.select_idx] self.pathologies = ["Custom"] - self.labels = all_labels[self.select_idx].reshape(-1,1) + self.labels = all_labels[self.select_idx].reshape(-1, 1) self.site = all_site[self.select_idx] self.csv = all_csv.iloc[self.select_idx] - def __repr__(self): pprint.pprint(self.totals()) diff --git a/torchxrayvision/models.py b/torchxrayvision/models.py index 3de4819..d5a90fa 100644 --- a/torchxrayvision/models.py +++ b/torchxrayvision/models.py @@ -9,7 +9,8 @@ import numpy as np from collections import OrderedDict from . import datasets -import warnings; warnings.filterwarnings("ignore") +import warnings +warnings.filterwarnings("ignore") model_urls = {} @@ -17,53 +18,53 @@ model_urls['all'] = { "description": 'This model was trained on the datasets: nih-pc-chex-mimic_ch-google-openi-rsna and is described here: https://arxiv.org/abs/2002.02497', "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', - "labels":[ 'Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum' ], - "op_threshs":[0.07422872, 0.038290843, 0.09814756, 0.0098118475, 0.023601074, 0.0022490358, 0.010060724, 0.103246614, 0.056810737, 0.026791653, 0.050318155, 0.023985857, 0.01939503, 0.042889766, 0.053369623, 0.035975814, 0.20204692, 0.05015312], - "ppv80_thres":[0.72715247, 0.8885005, 0.92493945, 0.6527224, 0.68707734, 0.46127197, 0.7272054, 0.6127343, 0.9878492, 0.61979693, 0.66309816, 0.7853459, 0.930661, 0.93645346, 0.6788558, 0.6547198, 0.61614525, 0.8489876] + "labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], + "op_threshs": [0.07422872, 0.038290843, 0.09814756, 0.0098118475, 0.023601074, 0.0022490358, 0.010060724, 0.103246614, 0.056810737, 0.026791653, 0.050318155, 0.023985857, 0.01939503, 0.042889766, 0.053369623, 0.035975814, 0.20204692, 0.05015312], + "ppv80_thres": [0.72715247, 0.8885005, 0.92493945, 0.6527224, 0.68707734, 0.46127197, 0.7272054, 0.6127343, 0.9878492, 0.61979693, 0.66309816, 0.7853459, 0.930661, 0.93645346, 0.6788558, 0.6547198, 0.61614525, 0.8489876] } model_urls['densenet121-res224-all'] = model_urls['all'] model_urls['nih'] = { - "weights_url":'https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', - "labels":[ 'Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', '', '', '' ], - "op_threshs":[0.039117552, 0.0034529066, 0.11396341, 0.0057298196, 0.00045666535, 0.0018880932, 0.012037827, 0.038744126, 0.0037213727, 0.014730946, 0.016149804, 0.054241467, 0.037198864, 0.0004403434, np.nan, np.nan, np.nan, np.nan], + "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', + "labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', '', '', ''], + "op_threshs": [0.039117552, 0.0034529066, 0.11396341, 0.0057298196, 0.00045666535, 0.0018880932, 0.012037827, 0.038744126, 0.0037213727, 0.014730946, 0.016149804, 0.054241467, 0.037198864, 0.0004403434, np.nan, np.nan, np.nan, np.nan], } model_urls['densenet121-res224-nih'] = model_urls['nih'] model_urls['pc'] = { - "weights_url":'https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', - "labels":[ 'Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', 'Fracture', '', '' ], + "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', + "labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', 'Fracture', '', ''], "op_threshs": [0.031012505, 0.013347598, 0.081435576, 0.001262615, 0.002587246, 0.0035944257, 0.0023071, 0.055412333, 0.044385884, 0.042766232, 0.043258056, 0.037629247, 0.005658899, 0.0091741895, np.nan, 0.026507627, np.nan, np.nan] } model_urls['densenet121-res224-pc'] = model_urls['pc'] model_urls['chex'] = { - "weights_url":'https://github.com/mlmed/torchxrayvision/releases/download/v1/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', - "labels":[ 'Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum' ], + "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', + "labels": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], "op_threshs": [0.1988969, 0.05710573, np.nan, 0.0531293, 0.1435217, np.nan, np.nan, 0.27212676, 0.07749717, np.nan, 0.19712369, np.nan, np.nan, np.nan, 0.09932402, 0.09273402, 0.3270967, 0.10888247], } model_urls['densenet121-res224-chex'] = model_urls['chex'] model_urls['rsna'] = { - "weights_url":'https://github.com/mlmed/torchxrayvision/releases/download/v1/kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', - "labels":[ '', '', '', '', '', '', '', '', 'Pneumonia', '', '', '', '', '', '', '', 'Lung Opacity', '' ], + "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', + "labels": ['', '', '', '', '', '', '', '', 'Pneumonia', '', '', '', '', '', '', '', 'Lung Opacity', ''], "op_threshs": [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 0.13486601, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 0.13511065, np.nan] } model_urls['densenet121-res224-rsna'] = model_urls['rsna'] model_urls['mimic_nb'] = { - "weights_url":'https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_nb-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', - "labels":[ 'Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum' ], + "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_nb-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', + "labels": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], "op_threshs": [0.08558747, 0.011884617, np.nan, 0.0040595434, 0.010733786, np.nan, np.nan, 0.118761964, 0.022924708, np.nan, 0.06358637, np.nan, np.nan, np.nan, 0.022143636, 0.017476924, 0.1258702, 0.014020768], } model_urls['densenet121-res224-mimic_nb'] = model_urls['mimic_nb'] model_urls['mimic_ch'] = { - "weights_url":'https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_ch-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', - "labels":[ 'Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum' ], + "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_ch-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt', + "labels": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], "op_threshs": [0.09121389, 0.010573786, np.nan, 0.005023008, 0.003698257, np.nan, np.nan, 0.08001232, 0.037242252, np.nan, 0.05006329, np.nan, np.nan, np.nan, 0.019866971, 0.03823637, 0.11303808, 0.0069147074], } model_urls['densenet121-res224-mimic_ch'] = model_urls['mimic_ch'] @@ -71,7 +72,7 @@ model_urls['resnet50-res512-all'] = { "description": 'This model was trained on the datasets pc-nih-rsna-siim-vin at a 512x512 resolution.', "weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt', - "labels":[ 'Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum' ], + "labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], "op_threshs": [0.51570356, 0.50444704, 0.53787947, 0.50723547, 0.5025118, 0.5035252, 0.5038076, 0.51862943, 0.5078151, 0.50724894, 0.5056339, 0.510706, 0.5053923, 0.5020846, np.nan, 0.5080557, 0.5138526, np.nan], "ppv80_thres": [0.690908, 0.720028, 0.7303882, 0.7235838, 0.6787441, 0.7304924, 0.73105824, 0.6839408, 0.7241559, 0.7219969, 0.6346738, 0.72764945, 0.7285066, 0.5735704, np.nan, 0.69684714, 0.7135549, np.nan] } @@ -131,39 +132,38 @@ class DenseNet(nn.Module): num_classes (int) - number of classification classes """ - def __init__(self, - growth_rate=32, - block_config=(6, 12, 24, 16), - num_init_features=64, + def __init__(self, + growth_rate=32, + block_config=(6, 12, 24, 16), + num_init_features=64, bn_size=4, - drop_rate=0, - num_classes=len(datasets.default_pathologies), - in_channels=1, - weights=None, - op_threshs=None, + drop_rate=0, + num_classes=len(datasets.default_pathologies), + in_channels=1, + weights=None, + op_threshs=None, apply_sigmoid=False - ): + ): + + super(DenseNet, self).__init__() - super(DenseNet, self).__init__() - self.apply_sigmoid = apply_sigmoid self.weights = weights - + if self.weights is not None: if not self.weights in model_urls.keys(): possible_weights = [k for k in model_urls.keys() if k.startswith("densenet")] raise Exception("Weights value must be in {}".format(possible_weights)) - + # set to be what this model is trained to predict self.pathologies = model_urls[weights]["labels"] - + # if different from default number of classes - if num_classes != len(datasets.default_pathologies): + if num_classes != len(datasets.default_pathologies): raise ValueError("num_classes and weights cannot both be specified. The weights loaded will define the own number of output classes.") - + num_classes = len(self.pathologies) - - + # First convolution self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(in_channels, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), @@ -199,13 +199,13 @@ def __init__(self, nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) - + # needs to be register_buffer here so it will go to cuda/cpu easily self.register_buffer('op_threshs', op_threshs) - + if self.weights != None: self.weights_filename_local = get_weights(weights) - + try: savedmodel = torch.load(self.weights_filename_local, map_location='cpu') # patch to load old models https://github.com/pytorch/pytorch/issues/42242 @@ -217,12 +217,12 @@ def __init__(self, except Exception as e: print("Loading failure. Check weights file:", self.weights_filename_local) raise e - + self.eval() - + if "op_threshs" in model_urls[weights]: self.op_threshs = torch.tensor(model_urls[weights]["op_threshs"]) - + self.upsample = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False) def __repr__(self): @@ -230,44 +230,44 @@ def __repr__(self): return "XRV-DenseNet121-{}".format(self.weights) else: return "XRV-DenseNet" - + def features2(self, x): x = fix_resolution(x, 224, self) warn_normalization(x) - + features = self.features(x) out = F.relu(features, inplace=True) out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) return out - + def forward(self, x): x = fix_resolution(x, 224, self) - + features = self.features2(x) out = self.classifier(features) - + if hasattr(self, 'apply_sigmoid') and self.apply_sigmoid: out = torch.sigmoid(out) - - if hasattr(self,"op_threshs") and (self.op_threshs != None): + + if hasattr(self, "op_threshs") and (self.op_threshs != None): out = torch.sigmoid(out) out = op_norm(out, self.op_threshs) return out - + ########################## class ResNet(nn.Module): def __init__(self, weights: str = None, apply_sigmoid: bool = False): - super(ResNet, self).__init__() - + super(ResNet, self).__init__() + self.weights = weights self.apply_sigmoid = apply_sigmoid - + if not self.weights in model_urls.keys(): possible_weights = [k for k in model_urls.keys() if k.startswith("resnet")] raise Exception("Weights value must be in {}".format(possible_weights)) - + self.weights_filename_local = get_weights(weights) self.weights_dict = model_urls[weights] self.pathologies = model_urls[weights]["labels"] @@ -281,18 +281,18 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False): self.model = torchvision.models.resnet50(num_classes=len(self.weights_dict["labels"]), pretrained=False) # patch for single channel self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) - + try: self.model.load_state_dict(torch.load(self.weights_filename_local)) except Exception as e: print("Loading failure. Check weights file:", self.weights_filename_local) raise e - + if "op_threshs" in model_urls[weights]: self.register_buffer('op_threshs', torch.tensor(model_urls[weights]["op_threshs"])) - + self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False) - + self.eval() def __repr__(self): @@ -300,11 +300,11 @@ def __repr__(self): return "XRV-ResNet-{}".format(self.weights) else: return "XRV-ResNet" - + def features(self, x): x = fix_resolution(x, 512, self) warn_normalization(x) - + x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) @@ -318,30 +318,32 @@ def features(self, x): x = self.model.avgpool(x) x = torch.flatten(x, 1) return x - + def forward(self, x): x = fix_resolution(x, 512, self) warn_normalization(x) - + out = self.model(x) - + if hasattr(self, 'apply_sigmoid') and self.apply_sigmoid: out = torch.sigmoid(out) - - if hasattr(self,"op_threshs") and (self.op_threshs != None): + + if hasattr(self, "op_threshs") and (self.op_threshs != None): out = torch.sigmoid(out) out = op_norm(out, self.op_threshs) return out - + warning_log = {} + + def fix_resolution(x, resolution: int, model: nn.Module): """Check resolution of input and resize to match requested.""" - + # just skip it if upsample was removed somehow if not hasattr(model, 'upsample') or (model.upsample == None): return x - + if (x.shape[2] != resolution) | (x.shape[3] != resolution): if not hash(model) in warning_log: print("Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance.".format(x.shape[2], x.shape[3], resolution, resolution)) @@ -349,13 +351,14 @@ def fix_resolution(x, resolution: int, model: nn.Module): return model.upsample(x) return x + def warn_normalization(x): """Check normalization of input and warn if possibly wrong. When processing an image that may likely not have the correct normalization we can issue a warning. But running min and max on every image/batch is costly so we only do it on the first image/batch. """ - + # Only run this check on the first image so we don't hurt performance. if not "norm_check" in warning_log: x_min = x.min() @@ -365,11 +368,10 @@ def warn_normalization(x): warning_log["norm_correct"] = False else: warning_log["norm_correct"] = True - + warning_log["norm_check"] = True - - - + + def op_norm(outputs, op_threshs): """Normalize outputs according to operating points for a given model. Args: @@ -379,20 +381,20 @@ def op_norm(outputs, op_threshs): outputs_new: normalized outputs, torch.Size(batch_size, num_tasks) """ # expand to batch size so we can do parallel comp - op_threshs = op_threshs.expand(outputs.shape[0],-1) - + op_threshs = op_threshs.expand(outputs.shape[0], -1) + # initial values will be 0.5 - outputs_new = torch.zeros(outputs.shape, device = outputs.device)+0.5 - + outputs_new = torch.zeros(outputs.shape, device=outputs.device) + 0.5 + # only select non-nan elements otherwise the gradient breaks - mask_leq = (outputs