-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
38 lines (26 loc) · 1.06 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from PIL import Image
import os
from torch.utils.data import Dataset
class myDataset(Dataset):
def __init__(self, root_A, root_B, transform=None):
self.root_A = root_A
self.root_B = root_B
self.transform = transform
self.A_images = os.listdir(root_A)
self.B_images = os.listdir(root_B)
self.length_dataset = max(len(self.A_images), len(self.B_images))
self.A_len = len(self.A_images)
self.B_len = len(self.B_images)
def __len__(self):
return self.length_dataset
def __getitem__(self, index):
A_img = self.A_images[index % self.A_len]
B_img = self.B_images[index % self.B_len]
A_path = os.path.join(self.root_A, A_img)
B_path = os.path.join(self.root_B, B_img)
A_img = Image.open(A_path).convert("RGB")
B_img = Image.open(B_path).convert("RGB")
if self.transform:
A_img = self.transform(A_img)
B_img = self.transform(B_img)
return A_img, B_img