Skip to content

Commit

Permalink
[Fix]: fix label type bug when using dbsampler (#111)
Browse files Browse the repository at this point in the history
* [Fix]: fix label type bug when using dbsampler

* Unify cat_id for more general usage

* fix CI bugs

* keep astype np.long
  • Loading branch information
ZwwWayne authored Sep 19, 2020
1 parent ee80116 commit 62ce67c
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
coverage report -m
# Only upload coverage report for python3.7 && pytorch1.5
- name: Upload coverage to Codecov
if: ${{matrix.torch == '1.5.0' && matrix.python-version == '3.7'}}
if: ${{matrix.torch == '1.5.0+cu101' && matrix.python-version == '3.7'}}
uses: codecov/codecov-action@v1.0.10
with:
file: ./coverage.xml
Expand Down
5 changes: 3 additions & 2 deletions mmdet3d/datasets/custom_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self,
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)

self.CLASSES = self.get_classes(classes)
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.data_infos = self.load_annotations(self.ann_file)

if pipeline is not None:
Expand Down Expand Up @@ -300,7 +301,7 @@ def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus
are all zeros.
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
16 changes: 9 additions & 7 deletions mmdet3d/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class CBGSDataset(object):
def __init__(self, dataset):
self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.sample_indices = self._get_sample_indices()
# self.dataset.data_infos = self.data_infos
if hasattr(self.dataset, 'flag'):
Expand All @@ -34,22 +35,23 @@ def _get_sample_indices(self):
Returns:
list[dict]: List of annotations after class sampling.
"""
class_sample_idxs = {name: [] for name in self.CLASSES}
class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()}
for idx in range(len(self.dataset)):
class_sample_idx = self.dataset.get_cat_ids(idx)
for key in class_sample_idxs.keys():
class_sample_idxs[key] += class_sample_idx[key]
duplicated_samples = sum([len(v) for _, v in class_sample_idx.items()])
sample_cat_ids = self.dataset.get_cat_ids(idx)
for cat_id in sample_cat_ids:
class_sample_idxs[cat_id].append(idx)
duplicated_samples = sum(
[len(v) for _, v in class_sample_idxs.items()])
class_distribution = {
k: len(v) / duplicated_samples
for k, v in class_sample_idx.items()
for k, v in class_sample_idxs.items()
}

sample_indices = []

frac = 1.0 / len(self.CLASSES)
ratios = [frac / v for v in class_distribution.values()]
for cls_inds, ratio in zip(list(class_sample_idx.values()), ratios):
for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios):
sample_indices += np.random.choice(cls_inds,
int(len(cls_inds) *
ratio)).tolist()
Expand Down
7 changes: 4 additions & 3 deletions mmdet3d/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,18 @@ def get_cat_ids(self, idx):
contains such boxes, store a list containing idx,
otherwise, store empty list.
"""
class_sample_idx = {name: [] for name in self.CLASSES}
info = self.data_infos[idx]
if self.use_valid_flag:
mask = info['valid_flag']
gt_names = set(info['gt_names'][mask])
else:
gt_names = set(info['gt_names'])

cat_ids = []
for name in gt_names:
if name in self.CLASSES:
class_sample_idx[name].append(idx)
return class_sample_idx
cat_ids.append(self.cat2id[name])
return cat_ids

def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Expand Down
6 changes: 3 additions & 3 deletions mmdet3d/datasets/pipelines/dbsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def sample_all(self, gt_bboxes, gt_labels, img=None):
count += 1

s_points_list.append(s_points)
# gt_names = np.array([s['name'] for s in sampled]),
# gt_labels = np.array([self.cat2label(s) for s in gt_names])
gt_labels = np.array([self.cat2label[s['name']] for s in sampled])

gt_labels = np.array([self.cat2label[s['name']] for s in sampled],
dtype=np.long)
ret = {
'gt_labels_3d':
gt_labels,
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __call__(self, input_dict):
input_dict['img'] = sampled_dict['img']

input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['gt_labels_3d'] = gt_labels_3d
input_dict['gt_labels_3d'] = gt_labels_3d.astype(np.long)
input_dict['points'] = points

return input_dict
Expand Down
12 changes: 6 additions & 6 deletions tests/test_dataset/test_dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def test_getitem():
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR'))
nus_dataset = build_dataset(dataset_cfg)
assert len(nus_dataset) == 10
assert len(nus_dataset) == 20
data = nus_dataset[0]
assert data['img_metas'].data['flip'] is True
assert data['img_metas'].data['pcd_horizontal_flip'] is True
assert data['points']._data.shape == (537, 5)

data = nus_dataset[1]
assert data['img_metas'].data['flip'] is False
assert data['img_metas'].data['pcd_horizontal_flip'] is False
assert data['points']._data.shape == (901, 5)

data = nus_dataset[1]
assert data['img_metas'].data['flip'] is True
assert data['img_metas'].data['pcd_horizontal_flip'] is True
assert data['points']._data.shape == (537, 5)
1 change: 1 addition & 0 deletions tests/test_pipeline/test_transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_object_sample():
gt_labels.append(CLASSES.index(cat))
else:
gt_labels.append(-1)
gt_labels = np.array(gt_labels, dtype=np.long)
input_dict = dict(
points=points, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels)
input_dict = object_sample(input_dict)
Expand Down

0 comments on commit 62ce67c

Please sign in to comment.