Skip to content
This repository was archived by the owner on Apr 11, 2023. It is now read-only.

Latest commit

 

History

History
51 lines (41 loc) · 1.29 KB

README.md

File metadata and controls

51 lines (41 loc) · 1.29 KB

PyTorch Segmentation Dataset Loader

Custom segmentation dataset class for torchvision. Applies data augmentation to both images and segmentations.

Usage

Can be used with torchvision.transforms:

from utils import SegmentationDataset

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomAffine(
            degrees=15,
            translate=(0.05, 0.05),
            scale=(0.95, 1.05),
            resample=2,
            fillcolor=0,
        ),
        transforms.ColorJitter(
            brightness=0.15,
            contrast=0.15,
            saturation=0.15,
            hue=0.05
        ),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

dataset = SegmentationDataset(
    dir_images="./my_dataset/images/",
    dir_masks="./my_dataset/masks/",
    transform=transform,
)

Note

Helpful Links