Custom segmentation dataset class for torchvision
. Applies data augmentation to both images and segmentations.
Can be used with torchvision.transforms
:
from utils import SegmentationDataset
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomAffine(
degrees=15,
translate=(0.05, 0.05),
scale=(0.95, 1.05),
resample=2,
fillcolor=0,
),
transforms.ColorJitter(
brightness=0.15,
contrast=0.15,
saturation=0.15,
hue=0.05
),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
dataset = SegmentationDataset(
dir_images="./my_dataset/images/",
dir_masks="./my_dataset/masks/",
transform=transform,
)
Normalize
,Lambda
,Pad
,ColorJitter
andRandomErasing
won't be applied to masks by default- Images from: https://www.ntu.edu.sg/home/asjfcai/Benchmark_Website/benchmark_index.html