From b38110bbf3fdeeb689b8814039cd2314e8e48ce0 Mon Sep 17 00:00:00 2001 From: Farnaz <44510760+farnaznouraei@users.noreply.github.com> Date: Wed, 7 Apr 2021 14:49:01 -0400 Subject: [PATCH] Some Changes to Pascal Loader The current version has deprecated functions such as scipy.misc.imsave, and the resizing part generates unwanted values in the label masks. This triggers cuda to stop the training process due to dataloader issues. Edits include: - changing misc functions to cv2.imwrite and cv2.imread - use of nearest neighbor resize method for preserving label integers - adding test set functionality (which could be similar to val set or contain different images) - using aug_with_sbd flag to make augmentation with SBD optional - added test set functionalities (a separate test.txt file should be made in the directory) - some other minor changes Most important reason for proposal: It works!!! --- ptsemseg/loader/pascal_voc_loader.py | 316 +++++++++++++++++---------- 1 file changed, 198 insertions(+), 118 deletions(-) diff --git a/ptsemseg/loader/pascal_voc_loader.py b/ptsemseg/loader/pascal_voc_loader.py index 2f610964..8afec524 100644 --- a/ptsemseg/loader/pascal_voc_loader.py +++ b/ptsemseg/loader/pascal_voc_loader.py @@ -1,21 +1,23 @@ +# DATASET + import os from os.path import join as pjoin import collections import json import torch import numpy as np -import scipy.misc as m import scipy.io as io import matplotlib.pyplot as plt import glob +import cv2 from PIL import Image from tqdm import tqdm -from torch.utils import data +from torch.utils import data as dt from torchvision import transforms +import warnings - -class pascalVOCLoader(data.Dataset): +class pascalVOCLoader(dt.Dataset): """Data loader for the Pascal VOC semantic segmentation dataset. Annotations from both the original VOC data (which consist of RGB images @@ -28,79 +30,90 @@ class pascalVOCLoader(data.Dataset): is added as a subdirectory of the `SegmentationClass` folder in the original Pascal VOC data layout. - A total of five data splits are provided for working with the VOC data: + A total of four data splits are provided for working with the VOC data: train: The original VOC 2012 training data - 1464 images val: The original VOC 2012 validation data - 1449 images - trainval: The combination of `train` and `val` - 2913 images train_aug: The unique images present in both the train split and training images from SBD: - 8829 images (the unique members of the result of combining lists of length 1464 and 8498) - train_aug_val: The original VOC 2012 validation data minus the images - present in `train_aug` (This is done with the same logic as - the validation set used in FCN PAMI paper, but with VOC 2012 - rather than VOC 2011) - 904 images + test: an arbitrary set of images from the "val" set, written in a text file test.txt in the same directory """ def __init__( self, - root, - sbd_path=None, - split="train_aug", - is_transform=False, - img_size=512, + root="/gpfs/home/fnouraei/data/fnouraei/VOC/VOCdevkit/VOC2012/", # choose your main path as default + sbd_path="/gpfs/home/fnouraei/data/fnouraei/VOC/benchmark_RELEASE", # choose your SBD dir as default + split="train", # loader split {"train", "val", "test", "train_aug"} + is_transform=True, + img_size=(128,128), # choose default input size augmentations=None, img_norm=True, - test_mode=False, + aug_with_sbd = False, # choose whether or not to add Berkeley images to val and train sets by default ): self.root = root self.sbd_path = sbd_path + self.aug_with_sbd = aug_with_sbd self.split = split self.is_transform = is_transform self.augmentations = augmentations self.img_norm = img_norm - self.test_mode = test_mode self.n_classes = 21 self.mean = np.array([104.00699, 116.66877, 122.67892]) self.files = collections.defaultdict(list) self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) - if not self.test_mode: - for split in ["train", "val", "trainval"]: - path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt") - file_list = tuple(open(path, "r")) - file_list = [id_.rstrip() for id_ in file_list] - self.files[split] = file_list - self.setup_annotations() - - self.tf = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) + + for split in ["train", "val", "test"]: + path = pjoin(self.root, "ImageSets","Segmentation", split + ".txt") + file_list = tuple(open(path, "r")) + file_list = [id_.rstrip() for id_ in file_list] + self.files[split] = file_list + self.setup_annotations() + + if not self.img_norm: + self.tf = transforms.ToTensor() + else: + self.tf = transforms.Compose( + [ + + transforms.ToTensor() + ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + def __len__(self): return len(self.files[self.split]) def __getitem__(self, index): + im_name = self.files[self.split][index] im_path = pjoin(self.root, "JPEGImages", im_name + ".jpg") lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded", im_name + ".png") - im = Image.open(im_path) - lbl = Image.open(lbl_path) + if self.aug_with_sbd: + lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded_aug", im_name + ".png") + im = cv2.imread(im_path, cv2.IMREAD_UNCHANGED) + im = cv2.cvtColor(np.array(im), cv2.COLOR_BGR2RGB) + lbl = cv2.imread(lbl_path, cv2.IMREAD_UNCHANGED) + + if self.augmentations is not None: im, lbl = self.augmentations(im, lbl) if self.is_transform: im, lbl = self.transform(im, lbl) + return im, lbl def transform(self, img, lbl): if self.img_size == ("same", "same"): pass else: - img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode - lbl = lbl.resize((self.img_size[0], self.img_size[1])) + img = cv2.resize(img, (self.img_size[0], self.img_size[1]), interpolation =cv2.INTER_NEAREST) + lbl = cv2.resize(lbl, (self.img_size[0], self.img_size[1]), interpolation =cv2.INTER_NEAREST) + + img = self.tf(img) + lbl = torch.from_numpy(np.array(lbl)).long() lbl[lbl == 255] = 0 return img, lbl @@ -111,32 +124,29 @@ def get_pascal_labels(self): Returns: np.ndarray with dimensions (21, 3) """ - return np.asarray( - [ - [0, 0, 0], - [128, 0, 0], - [0, 128, 0], - [128, 128, 0], - [0, 0, 128], - [128, 0, 128], - [0, 128, 128], - [128, 128, 128], - [64, 0, 0], - [192, 0, 0], - [64, 128, 0], - [192, 128, 0], - [64, 0, 128], - [192, 0, 128], - [64, 128, 128], - [192, 128, 128], - [0, 64, 0], - [128, 64, 0], - [0, 192, 0], - [128, 192, 0], - [0, 64, 128], - ] - ) - + return np.asarray( [ + [0, 0, 0], #0: bg + [128, 0, 0], #1: aeroplane + [0, 128, 0], #2: bicycle + [128, 128, 0], #3: bird + [0, 0, 128], #4: boat + [128, 0, 128], #5: bottle + [0, 128, 128], #6: bus + [128, 128, 128], #7: car + [64, 0, 0], #8: cat + [192, 0, 0], #9: chair + [64, 128, 0], #10: cow + [192, 128, 0], #11: dining table + [64, 0, 128], #12: dog + [192, 0, 128], #13: horse + [64, 128, 128], #14: motorbike + [192, 128, 128], #15: person + [0, 64, 0], #16: potted plant + [128, 64, 0], #17: sheep + [0, 192, 0], #18: sofa + [128, 192, 0], #19: train + [0, 64, 128], #20: tv/monitor + ] ) def encode_segmap(self, mask): """Encode segmentation label images as pascal classes @@ -153,6 +163,7 @@ def encode_segmap(self, mask): for ii, label in enumerate(self.get_pascal_labels()): label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii label_mask = label_mask.astype(int) + return label_mask def decode_segmap(self, label_mask, plot=False): @@ -186,68 +197,137 @@ def decode_segmap(self, label_mask, plot=False): return rgb def setup_annotations(self): - """Sets up Berkley annotations by adding image indices to the - `train_aug` split and pre-encode all segmentation labels into the - common label_mask format (if this has not already been done). This - function also defines the `train_aug` and `train_aug_val` data splits - according to the description in the class docstring + """ + pre-encode all segmentation labels into the + common label_mask format (if this has not already been done). """ sbd_path = self.sbd_path - target_path = pjoin(self.root, "SegmentationClass/pre_encoded") + if not self.aug_with_sbd: + target_path = pjoin(self.root, "SegmentationClass/pre_encoded") + if self.aug_with_sbd: + target_path = pjoin(self.root, "SegmentationClass/pre_encoded_aug") + if not os.path.exists(target_path): os.makedirs(target_path) - path = pjoin(sbd_path, "dataset/train.txt") + + path = pjoin(sbd_path, "dataset", "train.txt") sbd_train_list = tuple(open(path, "r")) sbd_train_list = [id_.rstrip() for id_ in sbd_train_list] - train_aug = self.files["train"] + sbd_train_list + + if not self.aug_with_sbd: + train = self.files["train"] + + if self.aug_with_sbd: + train_aug = self.files["train"] + sbd_train_list + + val = self.files["val"] + test = self.files["test"] + + if self.aug_with_sbd: + # keep unique elements (stable) + train_aug = [train_aug[i] for i in sorted(np.unique(train_aug, return_index=True)[1])] + self.files["train_aug"] = train_aug + + pre_encoded = glob.glob(pjoin(target_path, "*.png")) + + if not self.aug_with_sbd: + expected = len(list(set().union(val,test,train))) + # make sure validation set is held out + #print("[Debug] val and train intersection: ", len(list(set().intersection(val,train)))) + else: + expected = len(list(set().union(val,test,train_aug))) + # make sure validation set is held out + #print("[Debug] val and train intersection: ", len(list(set().intersection(val,train_aug)))) + + + + if len(pre_encoded) < expected: + + if self.aug_with_sbd: + for ii in tqdm(sbd_train_list, desc="pre-encode SBD"): + fname = ii + ".png" + lbl_path = pjoin(sbd_path, "dataset/cls", ii + ".mat") + data = io.loadmat(lbl_path) + lbl = data["GTcls"][0]["Segmentation"][0].astype(np.int32) + cv2.imwrite(pjoin(target_path, fname), lbl) + + for split in ["train", "val", "test"]: + + for ii in tqdm(self.files[split], desc="pre-encode "+split): + fname = ii + ".png" + lbl_path = pjoin(self.root, "SegmentationClass", fname) + lbl = cv2.imread(lbl_path, cv2.IMREAD_UNCHANGED) + lbl = cv2.cvtColor(np.array(lbl), cv2.COLOR_BGR2RGB) + lbl = self.encode_segmap(lbl) + #print("[Debug] pre-encoded label before saving as png: ", np.unique(lbl)) + cv2.imwrite(pjoin(target_path, fname), lbl) + + pre_encoded = glob.glob(pjoin(target_path, "*.png")) + print("[info] num expected labels: {} num pre-encoded labels: {}".format(expected,len(pre_encoded))) - # keep unique elements (stable) - train_aug = [train_aug[i] for i in sorted(np.unique(train_aug, return_index=True)[1])] - self.files["train_aug"] = train_aug - set_diff = set(self.files["val"]) - set(train_aug) # remove overlap - self.files["train_aug_val"] = list(set_diff) + +""" +# Test Pascal Dataloader (JUPYTER NOTEBOOK) - pre_encoded = glob.glob(pjoin(target_path, "*.png")) - expected = np.unique(self.files["train_aug"] + self.files["val"]).size - - if len(pre_encoded) != expected: - print("Pre-encoding segmentation masks...") - for ii in tqdm(sbd_train_list): - lbl_path = pjoin(sbd_path, "dataset/cls", ii + ".mat") - data = io.loadmat(lbl_path) - lbl = data["GTcls"][0]["Segmentation"][0].astype(np.int32) - lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) - m.imsave(pjoin(target_path, ii + ".png"), lbl) - - for ii in tqdm(self.files["trainval"]): - fname = ii + ".png" - lbl_path = pjoin(self.root, "SegmentationClass", fname) - lbl = self.encode_segmap(m.imread(lbl_path)) - lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) - m.imsave(pjoin(target_path, fname), lbl) - - assert expected == 9733, "unexpected dataset sizes" - - -# Leave code for debugging purposes -# import ptsemseg.augmentations as aug -# if __name__ == '__main__': -# # local_path = '/home/meetshah1995/datasets/VOCdevkit/VOC2012/' -# bs = 4 -# augs = aug.Compose([aug.RandomRotate(10), aug.RandomHorizontallyFlip()]) -# dst = pascalVOCLoader(root=local_path, is_transform=True, augmentations=augs) -# trainloader = data.DataLoader(dst, batch_size=bs) -# for i, data in enumerate(trainloader): -# imgs, labels = data -# imgs = imgs.numpy()[:, ::-1, :, :] -# imgs = np.transpose(imgs, [0,2,3,1]) -# f, axarr = plt.subplots(bs, 2) -# for j in range(bs): -# axarr[j][0].imshow(imgs[j]) -# axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) -# plt.show() -# a = raw_input() -# if a == 'ex': -# break -# else: -# plt.close() +%matplotlib inline +import matplotlib.pyplot as plt +from ptsemseg.augmentations.augmentations import ( + AdjustContrast, + AdjustGamma, + AdjustBrightness, + AdjustSaturation, + AdjustHue, + RandomCrop, + RandomHorizontallyFlip, + RandomVerticallyFlip, + Scale, + RandomSized, + RandomSizedCrop, + RandomRotate, + RandomTranslate, + CenterCrop, + Compose, +) + +if __name__ == '__main__': + + local_path = '/gpfs/home/fnouraei/data/fnouraei/VOC/VOCdevkit/VOC2012/' + bs = BATCH_SIZE + if OVERFIT: + dset = pascalVOCLoader_reduced(root=local_path, split = 'train',img_size=(INPUT_SIZE,INPUT_SIZE) + , is_transform=True , aug_with_sbd=False + , augmentations= None) + else: + dset = pascalVOCLoader(root=local_path, split = 'train',img_size=(INPUT_SIZE,INPUT_SIZE) + , is_transform=True , aug_with_sbd=True + , augmentations= Compose([RandomRotate(10), RandomHorizontallyFlip(p=0.5) + ,RandomVerticallyFlip(p=0.5),RandomSizedCrop(125) + ,AdjustGamma(gamma = 2.2),AdjustContrast(cf=0.4)])) + + + print("[Debug] length of dataset: ",len(dset)) + + trainloader = dt.DataLoader(dset, batch_size=bs, shuffle=True) + + for i, data in enumerate(trainloader): # i = batch idx , data = (img , label) (as tensors - first dim is batch size) + img, label = data + img = img.numpy() + img = np.transpose(img, [0,2,3,1]) + print("[Debug] label values:",label.unique()) + + f, axarr = plt.subplots(bs, 2, figsize=(15, 15), dpi=80) + + for j in range(bs): + print("[Debug] batch item {} img shape: {} label shape: {}".format(bs,img[j].shape,label.numpy()[j].shape)) + axarr[j,0].imshow(img[j]) + axarr[j,1].imshow(dset.decode_segmap((label[j].numpy()))) + plt.show() + + plt.close() +""" + + + + + +