forked from JunwookHeo/YOLO-OT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathListContainer.py
146 lines (117 loc) · 4.8 KB
/
ListContainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from VideoDataset import *
from YotDataset import *
from YotGtDataset import *
from RoloDataset import *
class VideoLoader:
@staticmethod
def getDataset(path, label, seq_num, img_size, mode):
return VideoDataset(path, label, seq_num, img_size, mode)
class YotLoader:
@staticmethod
def getDataset(path, label, seq_num, img_size, mode):
return YotDataset(path, label, seq_num, img_size, mode)
class YotGtLoader:
@staticmethod
def getDataset(path, label, seq_num, img_size, mode):
return YotGtDataset(path, label, seq_num, img_size, mode)
class RoloLoader:
@staticmethod
def getDataset(path, label, seq_num, img_size, mode):
return RoloDataset(path, label, seq_num, img_size, mode)
class ListContainer:
""" Loading folders which contain datasets """
def __init__(self, datatype, path, batch_size, seq_num, img_size, mode):
self.pos = 0
self.path = path
self.batch_size = batch_size
self.seq_num = seq_num
self.img_size = img_size
self.mode = mode
paths = [os.path.join(path,fn) for fn in next(os.walk(path))[1]]
paths = sorted(paths)
if datatype == 'video':
if len(paths) == 2:
l = paths[0].split(os.sep)[-1]
v = paths[1].split(os.sep)[-1]
if l.lower() == 'labels' and v.lower() == 'videos' :
self.load_videos(paths)
return
elif datatype == 'rolo':
self.load_rolo(paths)
elif datatype == 'yot':
self.load_yot(paths)
elif datatype == 'yotgt':
self.load_yotgt(paths)
else :
raise ValueError
def __iter__(self):
return self
def __next__(self):
if self.pos >= len(self.lists):
raise StopIteration
pos = self.pos
self.pos += 1
if len(self.labels) == 0:
label = None
else:
label = self.labels[pos]
dataset = self.loader.getDataset(self.lists[pos], label, self.seq_num, self.img_size, self.mode)
dataLoader = torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True
)
return dataLoader
def load_videos(self, paths):
self.labels = [os.path.join(paths[0],fn) for fn in next(os.walk(paths[0]))[2]]
self.lists = [os.path.join(paths[1],fn) for fn in next(os.walk(paths[1]))[2]]
self.loader = VideoLoader
def load_yot(self, paths):
self.labels = []
self.lists = []
for path in paths:
if os.path.exists(os.path.join(os.path.dirname(path),'yot_out')):
self.labels.append(os.path.join(os.path.dirname(path),"groundtruth_rect.txt"))
self.lists.append(os.path.dirname(path))
break
if os.path.exists(os.path.join(path,'yot_out')):
#if path.endswith('MotorRolling') or path.endswith('Singer1'):
self.labels.append(os.path.join(path,"groundtruth_rect.txt"))
self.lists.append(path)
#break
self.loader = YotLoader
def load_yotgt(self, paths):
self.labels = []
self.lists = []
for path in paths:
if os.path.exists(os.path.join(os.path.dirname(path),'images')):
self.labels.append(os.path.join(os.path.dirname(path),"groundtruth_rect.txt"))
self.lists.append(os.path.dirname(path))
break
if os.path.exists(os.path.join(path,'images')):
#if path.endswith('MotorRolling') or path.endswith('Singer1'):
self.labels.append(os.path.join(path,"groundtruth_rect.txt"))
self.lists.append(path)
#break
self.loader = YotGtLoader
def load_rolo(self, paths):
self.labels = []
self.lists = []
for path in paths:
if os.path.exists(os.path.join(os.path.dirname(path),'yolo_out')):
self.labels.append(os.path.join(os.path.dirname(path),"groundtruth_rect.txt"))
self.lists.append(os.path.dirname(path))
break
if os.path.exists(os.path.join(path,'yolo_out')):
#if path.endswith('MotorRolling') or path.endswith('Singer1'):
self.labels.append(os.path.join(path,"groundtruth_rect.txt"))
self.lists.append(path)
#break
self.loader = RoloLoader
def get_list_info(self, pos):
name = os.path.dirname(self.labels[pos])
return os.path.split(name)[-1]