From 579b0799da2357ce075b02b7271c9fdf27a5abf6 Mon Sep 17 00:00:00 2001 From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Date: Sun, 20 Sep 2020 21:04:09 +0800 Subject: [PATCH] Bump to V0.6.0 (#118) * Add gitlab CI back * clean isort * Update gitlab CI version * Update mmcv install * fix unit test bug * waymo * Use new flake8 * Update mmdet3d/core/evaluation/waymo_utils/prediction_kitti_to_waymo.py, tools/data_converter/waymo_converter.py files * Add baseline configs for waymo * fix linting * yapf reformat * update waymo results * Update waymo model zoo and docs * Bump v0.6.0 * Fix a minor bug when converting waymo data * Fix cmds in the waymo doc * Fix setup.cfg to pass isort test * Fix waymo configs * Update model zoo link & doc * update version date * clean ci Co-authored-by: wangtai Co-authored-by: Tai-Wang --- .dev_scripts/gather_models.py | 5 +- README.md | 5 +- configs/_base_/datasets/waymoD5-3d-3class.py | 127 +++++ configs/_base_/datasets/waymoD5-3d-car.py | 125 +++++ .../models/hv_pointpillars_secfpn_waymo.py | 109 ++++ configs/centerpoint/README.md | 2 +- configs/h3dnet/README.md | 2 +- configs/pointpillars/README.md | 18 + ...rs_secfpn_sbn_2x16_2x_waymoD5-3d-3class.py | 6 + ...llars_secfpn_sbn_2x16_2x_waymoD5-3d-car.py | 35 ++ data/scannet/scannet_utils.py | 3 +- docs/changelog.md | 50 ++ docs/getting_started.md | 63 ++- docs/install.md | 17 +- docs/model_zoo.md | 5 +- docs/waymo.md | 153 +++++ mmdet3d/core/evaluation/kitti_utils/eval.py | 43 +- .../waymo_utils/prediction_kitti_to_waymo.py | 255 +++++++++ mmdet3d/datasets/__init__.py | 3 +- mmdet3d/datasets/kitti_dataset.py | 21 +- mmdet3d/datasets/pipelines/dbsampler.py | 1 - mmdet3d/datasets/waymo_dataset.py | 525 ++++++++++++++++++ mmdet3d/models/detectors/base.py | 4 +- mmdet3d/models/detectors/two_stage.py | 4 +- mmdet3d/models/roi_heads/base_3droi_head.py | 4 +- .../furthest_point_sample.py | 4 +- mmdet3d/ops/spconv/structure.py | 5 +- requirements/runtime.txt | 2 + setup.cfg | 2 +- tests/test_detectors.py | 8 +- tests/test_forward.py | 4 +- tests/test_heads.py | 24 +- tools/create_data.py | 51 ++ tools/data_converter/create_gt_database.py | 25 + tools/data_converter/kitti_converter.py | 139 ++++- tools/data_converter/kitti_data_utils.py | 229 +++++++- tools/data_converter/waymo_converter.py | 510 +++++++++++++++++ tools/fuse_conv_bn.py | 1 - 38 files changed, 2470 insertions(+), 119 deletions(-) create mode 100644 configs/_base_/datasets/waymoD5-3d-3class.py create mode 100644 configs/_base_/datasets/waymoD5-3d-car.py create mode 100644 configs/_base_/models/hv_pointpillars_secfpn_waymo.py create mode 100644 configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class.py create mode 100644 configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car.py create mode 100644 docs/waymo.md create mode 100644 mmdet3d/core/evaluation/waymo_utils/prediction_kitti_to_waymo.py create mode 100644 mmdet3d/datasets/waymo_dataset.py create mode 100644 tools/data_converter/waymo_converter.py diff --git a/.dev_scripts/gather_models.py b/.dev_scripts/gather_models.py index 2141b77afa..44aff49b23 100644 --- a/.dev_scripts/gather_models.py +++ b/.dev_scripts/gather_models.py @@ -1,12 +1,11 @@ import argparse import glob import json -import os.path as osp +import mmcv import shutil import subprocess - -import mmcv import torch +from os import path as osp # build schedule look-up table to automatically find the final model SCHEDULES_LUT = { diff --git a/README.md b/README.md index 0ace7ff922..375071d579 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![license](https://img.shields.io/github/license/open-mmlab/mmdetection3d.svg)](https://github.com/open-mmlab/mmdetection3d/blob/master/LICENSE) -**News**: We released the codebase v0.1.0. +**News**: We released the codebase v0.6.0. Documentation: https://mmdetection3d.readthedocs.io/ @@ -56,7 +56,7 @@ This project is released under the [Apache 2.0 license](LICENSE). ## Changelog -v0.1.0 was released in 9/7/2020. +v0.6.0 was released in 20/9/2020. Please refer to [changelog.md](docs/changelog.md) for details and release history. ## Benchmark and model zoo @@ -74,6 +74,7 @@ Results and models are available in the [model zoo](docs/model_zoo.md). | 3DSSD | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | | Part-A2 | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | | MVXNet | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | +| CenterPoint | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | Other features - [x] [Dynamic Voxelization](configs/carafe/README.md) diff --git a/configs/_base_/datasets/waymoD5-3d-3class.py b/configs/_base_/datasets/waymoD5-3d-3class.py new file mode 100644 index 0000000000..4595bdebf9 --- /dev/null +++ b/configs/_base_/datasets/waymoD5-3d-3class.py @@ -0,0 +1,127 @@ +# dataset settings +# D5 in the config name means the whole dataset is divided into 5 folds +# We only use one fold for efficient experiments +dataset_type = 'WaymoDataset' +data_root = 'data/waymo/kitti_format/' +file_client_args = dict(backend='disk') +# Uncomment the following if use ceph or other file clients. +# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient +# for more details. +# file_client_args = dict( +# backend='petrel', path_mapping=dict(data='s3://waymo_data/')) + +class_names = ['Car', 'Pedestrian', 'Cyclist'] +point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4] +input_modality = dict(use_lidar=True, use_camera=False) +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'waymo_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), + classes=class_names, + sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10), + points_loader=dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args)) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=6, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + file_client_args=file_client_args), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=6, + use_dim=5, + file_client_args=file_client_args), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + samples_per_gpu=2, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=2, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'waymo_infos_train.pkl', + split='training', + pipeline=train_pipeline, + modality=input_modality, + classes=class_names, + test_mode=False, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR', + # load one frame every five frames + load_interval=5)), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'waymo_infos_val.pkl', + split='training', + pipeline=test_pipeline, + modality=input_modality, + classes=class_names, + test_mode=True, + box_type_3d='LiDAR'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'waymo_infos_val.pkl', + split='training', + pipeline=test_pipeline, + modality=input_modality, + classes=class_names, + test_mode=True, + box_type_3d='LiDAR')) + +evaluation = dict(interval=24) diff --git a/configs/_base_/datasets/waymoD5-3d-car.py b/configs/_base_/datasets/waymoD5-3d-car.py new file mode 100644 index 0000000000..16b30c86d1 --- /dev/null +++ b/configs/_base_/datasets/waymoD5-3d-car.py @@ -0,0 +1,125 @@ +# dataset settings +# D5 in the config name means the whole dataset is divided into 5 folds +# We only use one fold for efficient experiments +dataset_type = 'WaymoDataset' +data_root = 'data/waymo/kitti_format/' +file_client_args = dict(backend='disk') +# Uncomment the following if use ceph or other file clients. +# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient +# for more details. +# file_client_args = dict( +# backend='petrel', path_mapping=dict(data='s3://waymo_data/')) + +class_names = ['Car'] +point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4] +input_modality = dict(use_lidar=True, use_camera=False) +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'waymo_dbinfos_train.pkl', + rate=1.0, + prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)), + classes=class_names, + sample_groups=dict(Car=15), + points_loader=dict( + type='LoadPointsFromFile', + load_dim=5, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args)) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=6, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + file_client_args=file_client_args), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + load_dim=6, + use_dim=5, + file_client_args=file_client_args), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + samples_per_gpu=2, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=2, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'waymo_infos_train.pkl', + split='training', + pipeline=train_pipeline, + modality=input_modality, + classes=class_names, + test_mode=False, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR', + # load one frame every five frames + load_interval=5)), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'waymo_infos_val.pkl', + split='training', + pipeline=test_pipeline, + modality=input_modality, + classes=class_names, + test_mode=True, + box_type_3d='LiDAR'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'waymo_infos_val.pkl', + split='training', + pipeline=test_pipeline, + modality=input_modality, + classes=class_names, + test_mode=True, + box_type_3d='LiDAR')) + +evaluation = dict(interval=24) diff --git a/configs/_base_/models/hv_pointpillars_secfpn_waymo.py b/configs/_base_/models/hv_pointpillars_secfpn_waymo.py new file mode 100644 index 0000000000..066a36ac56 --- /dev/null +++ b/configs/_base_/models/hv_pointpillars_secfpn_waymo.py @@ -0,0 +1,109 @@ +# model settings +# Voxel size for voxel encoder +# Usually voxel size is changed consistently with the point cloud range +# If point cloud range is modified, do remember to change all related +# keys in the config. +voxel_size = [0.32, 0.32, 6] +model = dict( + type='MVXFasterRCNN', + pts_voxel_layer=dict( + max_num_points=20, + point_cloud_range=[-74.88, -74.88, -2, 74.88, 74.88, 4], + voxel_size=voxel_size, + max_voxels=(32000, 32000)), + pts_voxel_encoder=dict( + type='HardVFE', + in_channels=5, + feat_channels=[64], + with_distance=False, + voxel_size=voxel_size, + with_cluster_center=True, + with_voxel_center=True, + point_cloud_range=[-74.88, -74.88, -2, 74.88, 74.88, 4], + norm_cfg=dict(type='naiveSyncBN1d', eps=1e-3, momentum=0.01)), + pts_middle_encoder=dict( + type='PointPillarsScatter', in_channels=64, output_shape=[468, 468]), + pts_backbone=dict( + type='SECOND', + in_channels=64, + norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01), + layer_nums=[3, 5, 5], + layer_strides=[1, 2, 2], + out_channels=[64, 128, 256]), + pts_neck=dict( + type='SECONDFPN', + norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01), + in_channels=[64, 128, 256], + upsample_strides=[1, 2, 4], + out_channels=[128, 128, 128]), + pts_bbox_head=dict( + type='Anchor3DHead', + num_classes=3, + in_channels=384, + feat_channels=384, + use_direction_classifier=True, + anchor_generator=dict( + type='AlignedAnchor3DRangeGenerator', + ranges=[[-74.88, -74.88, -0.0345, 74.88, 74.88, -0.0345], + [-74.88, -74.88, -0.1188, 74.88, 74.88, -0.1188], + [-74.88, -74.88, 0, 74.88, 74.88, 0]], + sizes=[ + [2.08, 4.73, 1.77], # car + [0.84, 1.81, 1.77], # cyclist + [0.84, 0.91, 1.74] # pedestrian + ], + rotations=[0, 1.57], + reshape_out=False), + diff_rad_by_sin=True, + dir_offset=0.7854, # pi/4 + dir_limit_offset=0, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=7), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_dir=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2))) +# model training and testing settings +train_cfg = dict( + pts=dict( + assigner=[ + dict( # car + type='MaxIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.55, + neg_iou_thr=0.4, + min_pos_iou=0.4, + ignore_iof_thr=-1), + dict( # cyclist + type='MaxIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.5, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + dict( # pedestrian + type='MaxIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.5, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + ], + allowed_border=0, + code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + pos_weight=-1, + debug=False)) + +test_cfg = dict( + pts=dict( + use_rotate_nms=True, + nms_across_levels=False, + nms_pre=4096, + nms_thr=0.25, + score_thr=0.1, + min_bbox_size=0, + max_num=500)) diff --git a/configs/centerpoint/README.md b/configs/centerpoint/README.md index 56a1e06951..153bfb3955 100644 --- a/configs/centerpoint/README.md +++ b/configs/centerpoint/README.md @@ -48,6 +48,6 @@ We follow the below style to name config files. Contributors are advised to foll |[SECFPN](./centerpoint_0075voxel_second_secfpn_dcn_4x8_cyclic_20e_nus.py)|voxel (0.075)|✓|✗|||||| |[SECFPN](./centerpoint_0075voxel_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py)|voxel (0.075)|✓|✓|||||| |[SECFPN](./centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py)|pillar (0.2)|✗|✗|||||| -|[SECFPN](./centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py)|pillar (0.2)|✗|✓|||||| +|[SECFPN](./centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py)|pillar (0.2)|✗|✓|||48.72|59.40|| |[SECFPN](./centerpoint_02pillar_second_secfpn_dcn_4x8_cyclic_20e_nus.py)|pillar (0.2)|✓|✗|||||| |[SECFPN](./centerpoint_02pillar_second_secfpn_dcn_circlenms_4x8_cyclic_20e_nus.py)|pillar (0.2)|✓|✓|||||| diff --git a/configs/h3dnet/README.md b/configs/h3dnet/README.md index c3023bcb92..33ffd63681 100644 --- a/configs/h3dnet/README.md +++ b/configs/h3dnet/README.md @@ -16,4 +16,4 @@ We implement H3DNet and provide the result and checkpoints on ScanNet datasets. ### ScanNet | Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download | | :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | -| [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01|| +| [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01|[model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/h3dnet/h3dnet_scannet-3d-18class/h3dnet_scannet-3d-18class_20200830_000136-02e36246.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/h3dnet/h3dnet_scannet-3d-18class/h3dnet_scannet-3d-18class_20200830_000136.log.json) | diff --git a/configs/pointpillars/README.md b/configs/pointpillars/README.md index 5852bbb962..0b2fd35ae3 100644 --- a/configs/pointpillars/README.md +++ b/configs/pointpillars/README.md @@ -37,3 +37,21 @@ We implement PointPillars and provide the results and checkpoints on KITTI and n | :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | |[SECFPN](./hv_pointpillars_secfpn_sbn-all_4x8_2x_lyft-3d.py)|2x|||13.4|13.4|| |[FPN](./hv_pointpillars_fpn_sbn-all_4x8_2x_lyft-3d.py)|2x|||14.0|14.2|| + +### Waymo + +| Backbone | Load Interval | Class | Lr schd | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | +| :-------: | :-----------: |:-----:| :------:| :------: | :------------: | :----: | :-----: | :-----: | :-----: | :------: | +| [SECFPN](./hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car.py)|5|Car|2x|7.76||70.2|69.6|62.6|62.1|[model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car_20200901_204315-302fc3e7.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car_20200901_204315.log.json)| +| [SECFPN](./hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class.py)|5|3 Class|2x|8.12||64.7|57.6|58.4|52.1|[model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class_20200831_204144-d1a706b1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class_20200831_204144.log.json)| +| above @ Car|||2x|8.12||68.5|67.9|60.1|59.6| | +| above @ Pedestrian|||2x|8.12||67.8|50.6|59.6|44.3| | +| above @ Cyclist|||2x|8.12||57.7|54.4|55.5|52.4| | + +Note: + +- **Metric**: For model trained with 3 classes, the average APH@L2 (mAPH@L2) of all the categories is reported and used to rank the model. For model trained with only 1 class, the APH@L2 is reported and used to rank the model. +- **Data Split**: Here we provide several baselines for waymo dataset, which are adapted from the original pointpillars implementation. Specifically, we divide the dataset into 5 folds (denoted as D5 in the config names) for efficient experiments. +Using the complete dataset can boost the performance a lot, especially for the detection of cyclist and pedestrian, where more than 5 mAP or mAPH improvement can be expected. A more complete benchmark with more models and methods is coming soon. +- **Implementation Details**: We basically follow the implementation in the [paper](https://arxiv.org/pdf/1912.04838.pdf) in terms of the network architecture (having a +stride of 1 for the first convolutional block). Different settings of voxelization, data augmentation and hyper parameters make these baselines outperform those in the paper by about 7 mAP for car and 4 mAP for pedestrian with only a subset of the whole dataset. diff --git a/configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class.py b/configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class.py new file mode 100644 index 0000000000..e4f1ce5cda --- /dev/null +++ b/configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class.py @@ -0,0 +1,6 @@ +_base_ = [ + '../_base_/models/hv_pointpillars_secfpn_waymo.py', + '../_base_/datasets/waymoD5-3d-3class.py', + '../_base_/schedules/schedule_2x.py', + '../_base_/default_runtime.py', +] diff --git a/configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car.py b/configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car.py new file mode 100644 index 0000000000..21e267d0c3 --- /dev/null +++ b/configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car.py @@ -0,0 +1,35 @@ +_base_ = [ + '../_base_/models/hv_pointpillars_secfpn_waymo.py', + '../_base_/datasets/waymoD5-3d-car.py', + '../_base_/schedules/schedule_2x.py', + '../_base_/default_runtime.py', +] + +# model settings +model = dict( + type='MVXFasterRCNN', + pts_bbox_head=dict( + type='Anchor3DHead', + num_classes=1, + anchor_generator=dict( + type='AlignedAnchor3DRangeGenerator', + ranges=[[-74.88, -74.88, -0.0345, 74.88, 74.88, -0.0345]], + sizes=[[2.08, 4.73, 1.77]], + rotations=[0, 1.57], + reshape_out=True))) + +# model training and testing settings +train_cfg = dict( + _delete_=True, + pts=dict( + assigner=dict( + type='MaxIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.55, + neg_iou_thr=0.4, + min_pos_iou=0.4, + ignore_iof_thr=-1), + allowed_border=0, + code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + pos_weight=-1, + debug=False)) diff --git a/data/scannet/scannet_utils.py b/data/scannet/scannet_utils.py index 46e160b496..5813098f89 100644 --- a/data/scannet/scannet_utils.py +++ b/data/scannet/scannet_utils.py @@ -8,9 +8,8 @@ """ import csv -import os - import numpy as np +import os from plyfile import PlyData diff --git a/docs/changelog.md b/docs/changelog.md index 33f720129e..612380328c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,4 +1,54 @@ ## Changelog +### v0.6.0 (20/9/2020) + +#### Highlights + +- Support new methods [H3DNet](https://arxiv.org/abs/2006.05682), [3DSSD](https://arxiv.org/abs/2002.10187), [CenterPoint](https://arxiv.org/abs/2006.11275). +- Support new dataset [Waymo](https://waymo.com/open/) (with PointPillars baselines) and [nuImages](https://www.nuscenes.org/nuimages) (with Mask R-CNN and Cascade Mask R-CNN baselines). +- Support Batch Inference +- Support Pytorch 1.6 +- Start to publish `mmdet3d` package to PyPI since v0.5.0. You can use mmdet3d through `pip install mmdet3d`. + +#### Backwards Incompatible Changes + +- Support Batch Inference (#95, #103, #116): MMDetection3D v0.6.0 migrates to support batch inference based on MMDetection >= v2.4.0. This change influences all the test APIs in MMDetection3D and downstream codebases. +- Start to use collect environment function from MMCV (#113): MMDetection3D v0.6.0 migrates to use `collect_env` function in MMCV. +`get_compiler_version` and `get_compiling_cuda_version` compiled in `mmdet3d.ops.utils` are removed. Please import these two functions from `mmcv.ops`. + +#### Bug Fixes + +- Rename CosineAnealing to CosineAnnealing (#57) +- Fix device inconsistant bug in 3D IoU computation (#69) +- Fix a minor bug in json2csv of lyft dataset (#78) +- Add missed test data for pointnet modules (#85) +- Fix `use_valid_flag` bug in `CustomDataset` (#106) + +#### New Features + +- Support [nuImages](https://www.nuscenes.org/nuimages) dataset by converting them into coco format and release Mask R-CNN and Cascade Mask R-CNN baseline models (#91, #94) +- Support to publish to PyPI in github-action (#17, #19, #25, #39, #40) +- Support CBGSDataset and make it generally applicable to all the supported datasets (#75, #94) +- Support [H3DNet](https://arxiv.org/abs/2006.05682) and release models on ScanNet dataset (#53, #58, #105) +- Support Fusion Point Sampling used in [3DSSD](https://arxiv.org/abs/2002.10187) (#66) +- Add `BackgroundPointsFilter` to filter background points in data pipeline (#84) +- Support pointnet2 with multi-scale grouping in backbone and refactor pointnets (#82) +- Support dilated ball query used in [3DSSD](https://arxiv.org/abs/2002.10187) (#96) +- Support [3DSSD](https://arxiv.org/abs/2002.10187) and release models on KITTI dataset (#83, #100, #104) +- Support [CenterPoint](https://arxiv.org/abs/2006.11275) and release models on nuScenes dataset (#49, #92) +- Support [Waymo](https://waymo.com/open/) dataset and release PointPillars baseline models (#118) +- Allow `LoadPointsFromMultiSweeps` to pad empty sweeps and select multiple sweeps randomly (#67) + +#### Improvements + +- Fix all warnings and bugs in Pytorch 1.6.0 (#70, #72) +- Update issue templates (#43) +- Update unit tests (#20, #24, #30) +- Update documentation for using `ply` format point cloud data (#41) +- Use points loader to load point cloud data in ground truth (GT) samplers (#87) +- Unify version file of OpenMMLab projects by using `version.py` (#112) +- Remove unnecessary data preprocessing commands of SUN RGB-D dataset (#110) + ### v0.5.0 (9/7/2020) + MMDetection3D is released. diff --git a/docs/getting_started.md b/docs/getting_started.md index da8143708d..00036e60b3 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -1,6 +1,6 @@ # Getting Started -This page provides basic tutorials about the usage of MMDetection. +This page provides basic tutorials about the usage of MMDetection3D. For installation instructions, please see [install.md](install.md). ## Prepare datasets @@ -31,6 +31,14 @@ mmdetection3d │ │ │ ├── image_2 │ │ │ ├── label_2 │ │ │ ├── velodyne +│ ├── waymo +│ │ ├── waymo_format +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── testing +│ │ │ ├── gt.bin +│ │ ├── kitti_format +│ │ │ ├── ImageSets │ ├── lyft │ │ ├── v1.01-train │ │ │ ├── v1.01-train (train_data) @@ -62,12 +70,6 @@ mmdetection3d ``` -Download nuScenes V1.0 full dataset data [HERE]( https://www.nuscenes.org/download). Prepare nuscenes data by running - -```bash -python tools/create_data.py nuscenes --root-path ./data/nuscenes --out-dir ./data/nuscenes --extra-tag nuscenes -``` - Download KITTI 3D detection data [HERE](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d). Prepare kitti data by running ```bash @@ -82,6 +84,20 @@ wget -c https://raw.githubusercontent.com/traveller59/second.pytorch/master/sec python tools/create_data.py kitti --root-path ./data/kitti --out-dir ./data/kitti --extra-tag kitti ``` +Download Waymo open dataset V1.2 [HERE](https://waymo.com/open/download/) and its data split [HERE](https://drive.google.com/drive/folders/18BVuF_RYJF0NjZpt8SnfzANiakoRMf0o?usp=sharing). Then put tfrecord files into corresponding folders in `data/waymo/waymo_format/` and put the data split txt files into `data/waymo/kitti_format/ImageSets`. Download ground truth bin file for validation set [HERE](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_1_2_0/validation/ground_truth_objects) and put it into `data/waymo/waymo_format/`. A tip is that you can use `gsutil` to download the large-scale dataset with commands. You can take this [tool](https://github.com/RalphMao/Waymo-Dataset-Tool) as an example for more details. Subsequently, prepare waymo data by running + +```bash +python tools/create_data.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo +``` + +Note that if your local disk does not have enough space for saving converted data, you can change the `out-dir` to anywhere else. Just remember to create folders and prepare data there in advance and link them back to `data/waymo/kitti_format` after the data conversion. + +Download nuScenes V1.0 full dataset data [HERE]( https://www.nuscenes.org/download). Prepare nuscenes data by running + +```bash +python tools/create_data.py nuscenes --root-path ./data/nuscenes --out-dir ./data/nuscenes --extra-tag nuscenes +``` + Download Lyft 3D detection data [HERE](https://www.kaggle.com/c/3d-object-detection-for-autonomous-vehicles/data). Prepare Lyft data by running ```bash @@ -180,6 +196,39 @@ Assume that you have already downloaded the checkpoints to the directory `checkp The generated results be under `./second_kitti_results` directory. +7. Test PointPillars on Lyft with 8 GPUs, generate the pkl files and make a submission to the leaderboard. + + ```shell + ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/pointpillars/hv_pointpillars_fpn_sbn-2x8_2x_lyft-3d.py \ + checkpoints/hv_pointpillars_fpn_sbn-2x8_2x_lyft-3d_latest.pth --out results/pp_lyft/results_challenge.pkl \ + --format-only --options 'jsonfile_prefix=results/pp_lyft/results_challenge' \ + 'csv_path=results/pp_lyft/results_challenge.csv' + ``` + + **Notice**: To generate submissions on Lyft, `csv_path` must be given in the options. After generating the csv file, you can make a submission with kaggle commands given on the [website](https://www.kaggle.com/c/3d-object-detection-for-autonomous-vehicles/submit). + +7. Test PointPillars on waymo with 8 GPUs, and evaluate the mAP with waymo metrics. + + ```shell + ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/pointpillars/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car.py \ + checkpoints/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car_latest.pth --out results/waymo-car/results_eval.pkl \ + --eval waymo --options 'pklfile_prefix=results/waymo-car/kitti_results' \ + 'submission_prefix=results/waymo-car/kitti_results' + ``` + + **Notice**: For evaluation on waymo, please follow the [instruction](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md) to build the binary file `compute_detection_metrics_main` for metrics computation and put it into `mmdet3d/core/evaluation/waymo_utils/`.(Sometimes when using bazel to build `compute_detection_metrics_main`, an error `'round' is not a member of 'std'` may appear. We just need to remove the `std::` before `round` in that file.) `pklfile_prefix` should be given in the options for the bin file generation. For metrics, `waymo` is the recommended official evaluation prototype. Currently, evaluating with choice `kitti` is adapted from KITTI and the results for each difficulty are not exactly the same as the definition of KITTI. Instead, most of objects are marked with difficulty 0 currently, which will be fixed in the future. The reasons of its instability include the large computation for evalution, the lack of occlusion and truncation in the converted data, different definition of difficulty and different methods of computing average precision. + +8. Test PointPillars on waymo with 8 GPUs, generate the bin files and make a submission to the leaderboard. + + ```shell + ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/pointpillars/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car.py \ + checkpoints/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car_latest.pth --out results/waymo-car/results_eval.pkl \ + --format-only --options 'pklfile_prefix=results/waymo-car/kitti_results' \ + 'submission_prefix=results/waymo-car/kitti_results' + ``` + + **Notice**: After generating the bin file, you can simply build the binary file `create_submission` and use them to create a submission file by following the [instruction](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md). For evaluation on the validation set with the eval server, you can also use the same way to generate a submission. + ### Visualization To see the SUNRGBD, ScanNet or KITTI points and detection results, you can run the following command diff --git a/docs/install.md b/docs/install.md index bd4a43269b..c4880701a8 100644 --- a/docs/install.md +++ b/docs/install.md @@ -49,6 +49,7 @@ c. Install [MMCV](https://mmcv.readthedocs.io/en/latest/). *mmcv-full* is necessary since MMDetection3D relies on MMDetection, CUDA ops in *mmcv-full* are required. The pre-build *mmcv-full* could be installed by running: (available versions could be found [here](https://mmcv.readthedocs.io/en/latest/#install-with-pip)) + ```shell pip install mmcv-full==latest+torch1.5.0+cu101 -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html ``` @@ -74,11 +75,16 @@ pip install -r requirements/build.txt pip install -v -e . # or "python setup.py develop" ``` -If you build mmdetection on macOS, replace the last command with +**Important**: + +1. The required versions of MMCV and MMDetection for different versions of MMDetection3D are as below. Please install the correct version of MMCV and MMDetection to avoid installation issues. + +| MMDetection3D version | MMDetection version | MMCV version | +|:-------------------:|:-------------------:|:-------------------:| +| master | mmdet>=2.4.0 | mmcv-full>=1.1.1, <=1.2| +| 0.6.0 | mmdet>=2.4.0 | mmcv-full>=1.1.1, <=1.2| +| 0.5.0 | 2.3.0 | mmcv-full==1.0.5| -``` -CC=clang CXX=clang++ CFLAGS='-stdlib=libc++' pip install -e . -``` e. Clone the MMDetection3D repository. @@ -100,7 +106,7 @@ It is recommended that you run step d each time you pull some updates from githu > Important: Be sure to remove the `./build` folder if you reinstall mmdet with a different CUDA/PyTorch version. - ``` + ```shell pip uninstall mmdet3d rm -rf ./build find . -name "*.so" | xargs rm @@ -115,7 +121,6 @@ you can install it before installing MMCV. 5. The code can not be built for CPU only environment (where CUDA isn't available) for now. - ### A from-scratch setup script Here is a full script for setting up mmdetection with conda. diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 0713b0756f..2f2a9e6ab4 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -6,7 +6,6 @@ - For fair comparison with other codebases, we report the GPU memory as the maximum value of `torch.cuda.max_memory_allocated()` for all 8 GPUs. Note that this value is usually less than what `nvidia-smi` shows. - We report the inference time as the total time of network forwarding and post-processing, excluding the data loading time. Results are obtained with the script [benchmark.py](https://github.com/open-mmlab/mmdetection/blob/master/tools/benchmark.py) which computes the average time on 2000 images. - ## Baselines ### SECOND @@ -46,3 +45,7 @@ Please refer to [H3DNet](https://github.com/open-mmlab/mmdetection3d/blob/master ### 3DSSD Please refer to [3DSSD](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/3dssd) for details. + +### CenterPoint + +Please refer to [CenterPoint](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) for details. diff --git a/docs/waymo.md b/docs/waymo.md new file mode 100644 index 0000000000..db24588dfe --- /dev/null +++ b/docs/waymo.md @@ -0,0 +1,153 @@ +# A Brief Tutorial for Waymo Dataset + +This page provides specific tutorials about the usage of MMDetection3D for waymo dataset. + +## Prepare datasets + +Like the general way to prepare dataset, it is recommended to symlink the dataset root to `$MMDETECTION3D/data`. +Due to the original waymo data format is based on `tfrecord`, we need to preprocess the raw data for convenient usage in the training and evaluation procedure. Our approach is to convert them into KITTI format. + +The folder structure should be organized as follows before our processing. + +``` +mmdetection3d +├── mmdet3d +├── tools +├── configs +├── data +│ ├── waymo +│ │ ├── waymo_format +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── testing +│ │ │ ├── gt.bin +│ │ ├── kitti_format +│ │ │ ├── ImageSets + +``` + +You can download Waymo open dataset V1.2 [HERE](https://waymo.com/open/download/) and its data split [HERE](https://drive.google.com/drive/folders/18BVuF_RYJF0NjZpt8SnfzANiakoRMf0o?usp=sharing). Then put tfrecord files into corresponding folders in `data/waymo/waymo_format/` and put the data split txt files into `data/waymo/kitti_format/ImageSets`. Download ground truth bin file for validation set [HERE](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_1_2_0/validation/ground_truth_objects) and put it into `data/waymo/waymo_format/`. A tip is that you can use `gsutil` to download the large-scale dataset with commands. You can take this [tool](https://github.com/RalphMao/Waymo-Dataset-Tool) as an example for more details. Subsequently, prepare waymo data by running + +```bash +python tools/create_data.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo +``` + +Note that if your local disk does not have enough space for saving converted data, you can change the `out-dir` to anywhere else. Just remember to create folders and prepare data there in advance and link them back to `data/waymo/kitti_format` after the data conversion. + +After the data conversion, the folder structure and info files should be organized as below. + +``` +mmdetection3d +├── mmdet3d +├── tools +├── configs +├── data +│ ├── waymo +│ │ ├── waymo_format +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── testing +│ │ │ ├── gt.bin +│ │ ├── kitti_format +│ │ │ ├── ImageSets +│ │ │ ├── training +│ │ │ │ ├── calib +│ │ │ │ ├── image_0 +│ │ │ │ ├── image_1 +│ │ │ │ ├── image_2 +│ │ │ │ ├── image_3 +│ │ │ │ ├── image_4 +│ │ │ │ ├── label_0 +│ │ │ │ ├── label_1 +│ │ │ │ ├── label_2 +│ │ │ │ ├── label_3 +│ │ │ │ ├── label_4 +│ │ │ │ ├── label_all +│ │ │ │ ├── pose +│ │ │ │ ├── velodyne +│ │ │ ├── testing +│ │ │ │ ├── (the same as training) +│ │ │ ├── waymo_gt_database +│ │ │ ├── waymo_infos_trainval.pkl +│ │ │ ├── waymo_infos_train.pkl +│ │ │ ├── waymo_infos_val.pkl +│ │ │ ├── waymo_infos_test.pkl +│ │ │ ├── waymo_dbinfos_train.pkl + +``` + +Here because there are several cameras, we store the corresponding image and labels that can be projected to that camera respectively and save pose for further usage of consecutive frames point clouds. We use a coding way `{a}{bbb}{ccc}` to name the data for each frame, where `a` is the prefix for different split (`0` for training, `1` for validation and `2` for testing), `bbb` for segment index and `ccc` for frame index. You can easily locate the required frame according to this naming rule. We gather the data for training and validation together as KITTI and store the indices for different set in the ImageSet files. + +## Training + +Considering there are many similar frames in the original dataset, we can basically use a subset to train our model primarily. In our preliminary baselines, we load one frame every five frames, and thanks to our hyper parameters settings and data augmentation, we obtain a better result compared with the performance given in the original dataset [paper](https://arxiv.org/pdf/1912.04838.pdf). For more details about the configuration and performance, please refer to README.md in the `configs/pointpillars/`. A more complete benchmark based on other settings and methods is coming soon. + +## Evaluation + +For evaluation on waymo, please follow the [instruction](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md) to build the binary file `compute_detection_metrics_main` for metrics computation and put it into `mmdet3d/core/evaluation/waymo_utils/`. Basically, you can follow the commands below to install bazel and build the file. + + ```shell + git clone https://github.com/waymo-research/waymo-open-dataset.git waymo-od + cd waymo-od + git checkout remotes/origin/master + + sudo apt-get install --assume-yes pkg-config zip g++ zlib1g-dev unzip python3 python3-pip + wget https://github.com/bazelbuild/bazel/releases/download/0.28.0/bazel-0.28.0-installer-linux-x86_64.sh + sudo bash bazel-0.28.0-installer-linux-x86_64.sh + sudo apt install build-essential + + ./configure.sh + bazel clean + + bazel build waymo_open_dataset/metrics/tools/compute_detection_metrics_main + cp bazel-bin/waymo_open_dataset/metrics/tools/compute_detection_metrics_main ../mmdetection3d/mmdet3d/core/evaluation/waymo_utils/ + ``` + +Then you can evaluate your models on waymo. An example to evaluate PointPillars on waymo with 8 GPUs with waymo metrics is as follows. + + ```shell + ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/pointpillars/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car.py \ + checkpoints/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car_latest.pth --out results/waymo-car/results_eval.pkl \ + --eval waymo --options 'pklfile_prefix=results/waymo-car/kitti_results' \ + 'submission_prefix=results/waymo-car/kitti_results' + ``` + +`pklfile_prefix` should be given in the options if the bin file is needed to be generated. For metrics, `waymo` is the recommended official evaluation prototype. Currently, evaluating with choice `kitti` is adapted from KITTI and the results for each difficulty are not exactly the same as the definition of KITTI. Instead, most of objects are marked with difficulty 0 currently, which will be fixed in the future. The reasons of its instability include the large computation for evalution, the lack of occlusion and truncation in the converted data, different definition of difficulty and different methods of computing average precision. + +**Notice**: + +1. Sometimes when using bazel to build `compute_detection_metrics_main`, an error `'round' is not a member of 'std'` may appear. We just need to remove the `std::` before `round` in that file. + +2. Considering it takes a little long time to evaluate once, we recommend to evaluate only once at the end of model training. + +3. To use tensorflow with cuda9, it is recommended to compile it from source. Apart from official tutorials, you can refer to this [link](https://github.com/SmileTM/Tensorflow2.X-GPU-CUDA9.0) for possibly suitable precompiled packages and useful information for compiling it from source. + +## Testing and make a submission + +An example to test PointPillars on waymo with 8 GPUs, generate the bin files and make a submission to the leaderboard. + + ```shell + ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/pointpillars/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car.py \ + checkpoints/hv_pointpillars_secfpn_sbn-2x16_2x_waymo-3d-car_latest.pth --out results/waymo-car/results_eval.pkl \ + --format-only --options 'pklfile_prefix=results/waymo-car/kitti_results' \ + 'submission_prefix=results/waymo-car/kitti_results' + ``` + +After generating the bin file, you can simply build the binary file `create_submission` and use them to create a submission file by following the [instruction](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md). Basically, here are some example commands. + + ```shell + cd ../waymo-od/ + bazel build waymo_open_dataset/metrics/tools/create_submission + cp bazel-bin/waymo_open_dataset/metrics/tools/create_submission ../mmdetection3d/mmdet3d/core/evaluation/waymo_utils/ + vim waymo_open_dataset/metrics/tools/submission.txtpb # set the metadata information + cp waymo_open_dataset/metrics/tools/submission.txtpb ../mmdetection3d/mmdet3d/core/evaluation/waymo_utils/ + + cd ../mmdetection3d + # suppose the result bin is in `results/waymo-car/submission` + mmdet3d/core/evaluation/waymo_utils/create_submission --input_filenames='results/waymo-car/kitti_results_test.bin' --output_filename='results/waymo-car/submission/model' --submission_filename='mmdet3d/core/evaluation/waymo_utils/submission.txtpb' + + tar cvf results/waymo-car/submission/my_model.tar results/waymo-car/submission/my_model/ + gzip results/waymo-car/submission/my_model.tar + ``` + +For evaluation on the validation set with the eval server, you can also use the same way to generate a submission. Make sure you change the fields in submission.txtpb before running the command above. diff --git a/mmdet3d/core/evaluation/kitti_utils/eval.py b/mmdet3d/core/evaluation/kitti_utils/eval.py index 505414794f..7980446b92 100644 --- a/mmdet3d/core/evaluation/kitti_utils/eval.py +++ b/mmdet3d/core/evaluation/kitti_utils/eval.py @@ -591,19 +591,21 @@ def do_eval(gt_annos, eval_types=['bbox', 'bev', '3d']): # min_overlaps: [num_minoverlap, metric, num_class] difficultys = [0, 1, 2] - ret = eval_class( - gt_annos, - dt_annos, - current_classes, - difficultys, - 0, - min_overlaps, - compute_aos=('aos' in eval_types)) - # ret: [num_class, num_diff, num_minoverlap, num_sample_points] - mAP_bbox = get_mAP(ret['precision']) + mAP_bbox = None mAP_aos = None - if 'aos' in eval_types: - mAP_aos = get_mAP(ret['orientation']) + if 'bbox' in eval_types: + ret = eval_class( + gt_annos, + dt_annos, + current_classes, + difficultys, + 0, + min_overlaps, + compute_aos=('aos' in eval_types)) + # ret: [num_class, num_diff, num_minoverlap, num_sample_points] + mAP_bbox = get_mAP(ret['precision']) + if 'aos' in eval_types: + mAP_aos = get_mAP(ret['orientation']) mAP_bev = None if 'bev' in eval_types: @@ -654,7 +656,9 @@ def kitti_eval(gt_annos, Returns: tuple: String and dict of evaluation results. """ - assert 'bbox' in eval_types, 'must evaluate bbox at least' + assert len(eval_types) > 0, 'must contain at least one evaluation type' + if 'aos' in eval_types: + assert 'bbox' in eval_types, 'must evaluate bbox when evaluating aos' overlap_0_7 = np.array([[0.7, 0.5, 0.5, 0.7, 0.5], [0.7, 0.5, 0.5, 0.7, 0.5], [0.7, 0.5, 0.5, 0.7, 0.5]]) @@ -683,12 +687,19 @@ def kitti_eval(gt_annos, result = '' # check whether alpha is valid compute_aos = False + pred_alpha = False + valid_alpha_gt = False for anno in dt_annos: if anno['alpha'].shape[0] != 0: - if anno['alpha'][0] != -10: - compute_aos = True - eval_types.append('aos') + pred_alpha = True + break + for anno in gt_annos: + if anno['alpha'][0] != -10: + valid_alpha_gt = True break + compute_aos = (pred_alpha and valid_alpha_gt) + if compute_aos: + eval_types.append('aos') mAPbbox, mAPbev, mAP3d, mAPaos = do_eval(gt_annos, dt_annos, current_classes, min_overlaps, diff --git a/mmdet3d/core/evaluation/waymo_utils/prediction_kitti_to_waymo.py b/mmdet3d/core/evaluation/waymo_utils/prediction_kitti_to_waymo.py new file mode 100644 index 0000000000..0f851dda6b --- /dev/null +++ b/mmdet3d/core/evaluation/waymo_utils/prediction_kitti_to_waymo.py @@ -0,0 +1,255 @@ +r"""Adapted from `Waymo to KITTI converter + `_. +""" + +import mmcv +import numpy as np +import tensorflow as tf +from glob import glob +from os.path import join +from waymo_open_dataset import dataset_pb2 as open_dataset +from waymo_open_dataset import label_pb2 +from waymo_open_dataset.protos import metrics_pb2 + + +class KITTI2Waymo(object): + """KITTI predictions to Waymo converter. + + This class serves as the converter to change predictions from KITTI to + Waymo format. + + Args: + kitti_result_files (list[dict]): Predictions in KITTI format. + waymo_tfrecords_dir (str): Directory to load waymo raw data. + waymo_results_save_dir (str): Directory to save converted predictions + in waymo format (.bin files). + waymo_results_final_path (str): Path to save combined + predictions in waymo format (.bin file), like 'a/b/c.bin'. + prefix (str): Prefix of filename. In general, 0 for training, 1 for + validation and 2 for testing. + workers (str): Number of parallel processes. + """ + + def __init__(self, + kitti_result_files, + waymo_tfrecords_dir, + waymo_results_save_dir, + waymo_results_final_path, + prefix, + workers=64): + + self.kitti_result_files = kitti_result_files + self.waymo_tfrecords_dir = waymo_tfrecords_dir + self.waymo_results_save_dir = waymo_results_save_dir + self.waymo_results_final_path = waymo_results_final_path + self.prefix = prefix + self.workers = int(workers) + self.name2idx = {} + for idx, result in enumerate(kitti_result_files): + if len(result['sample_idx']) > 0: + self.name2idx[str(result['sample_idx'][0])] = idx + + # turn on eager execution for older tensorflow versions + if int(tf.__version__.split('.')[0]) < 2: + tf.enable_eager_execution() + + self.k2w_cls_map = { + 'Car': label_pb2.Label.TYPE_VEHICLE, + 'Pedestrian': label_pb2.Label.TYPE_PEDESTRIAN, + 'Sign': label_pb2.Label.TYPE_SIGN, + 'Cyclist': label_pb2.Label.TYPE_CYCLIST, + } + + self.T_ref_to_front_cam = np.array([[0.0, 0.0, 1.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + + self.get_file_names() + self.create_folder() + + def get_file_names(self): + """Get file names of waymo raw data.""" + self.waymo_tfrecord_pathnames = sorted( + glob(join(self.waymo_tfrecords_dir, '*.tfrecord'))) + print(len(self.waymo_tfrecord_pathnames), 'tfrecords found.') + + def create_folder(self): + """Create folder for data conversion.""" + mmcv.mkdir_or_exist(self.waymo_results_save_dir) + + def parse_objects(self, kitti_result, T_k2w, context_name, + frame_timestamp_micros): + """Parse one prediction with several instances in kitti format and + convert them to `Object` proto. + + Args: + kitti_result (dict): Predictions in kitti format. + + - name (np.ndarray): Class labels of predictions. + - dimensions (np.ndarray): Height, width, length of boxes. + - location (np.ndarray): Bottom center of boxes (x, y, z). + - rotation_y (np.ndarray): Orientation of boxes. + - score (np.ndarray): Scores of predictions. + T_k2w (np.ndarray): Transformation matrix from kitti to waymo. + context_name (str): Context name of the frame. + frame_timestamp_micros (int): Frame timestamp. + + Returns: + :obj:`Object`: Predictions in waymo dataset Object proto. + """ + + def parse_one_object(instance_idx): + """Parse one instance in kitti format and convert them to + `Object` proto. + + Args: + instance_idx (int): Index of the instance to be converted. + + Returns: + :obj:`Object`: Predicted instance in waymo dataset \ + Object proto. + """ + cls = kitti_result['name'][instance_idx] + length = round(kitti_result['dimensions'][instance_idx, 0], 4) + height = round(kitti_result['dimensions'][instance_idx, 1], 4) + width = round(kitti_result['dimensions'][instance_idx, 2], 4) + x = round(kitti_result['location'][instance_idx, 0], 4) + y = round(kitti_result['location'][instance_idx, 1], 4) + z = round(kitti_result['location'][instance_idx, 2], 4) + rotation_y = round(kitti_result['rotation_y'][instance_idx], 4) + score = round(kitti_result['score'][instance_idx], 4) + + # y: downwards; move box origin from bottom center (kitti) to + # true center (waymo) + y -= height / 2 + # frame transformation: kitti -> waymo + x, y, z = self.transform(T_k2w, x, y, z) + + # different conventions + heading = -(rotation_y + np.pi / 2) + while heading < -np.pi: + heading += 2 * np.pi + while heading > np.pi: + heading -= 2 * np.pi + + box = label_pb2.Label.Box() + box.center_x = x + box.center_y = y + box.center_z = z + box.length = length + box.width = width + box.height = height + box.heading = heading + + o = metrics_pb2.Object() + o.object.box.CopyFrom(box) + o.object.type = self.k2w_cls_map[cls] + o.score = score + + o.context_name = context_name + o.frame_timestamp_micros = frame_timestamp_micros + + return o + + objects = metrics_pb2.Objects() + + for instance_idx in range(len(kitti_result['name'])): + o = parse_one_object(instance_idx) + objects.objects.append(o) + + return objects + + def convert_one(self, file_idx): + """Convert action for single file. + + Args: + file_idx (int): Index of the file to be converted. + """ + file_pathname = self.waymo_tfrecord_pathnames[file_idx] + file_data = tf.data.TFRecordDataset(file_pathname, compression_type='') + + for frame_num, frame_data in enumerate(file_data): + frame = open_dataset.Frame() + frame.ParseFromString(bytearray(frame_data.numpy())) + + filename = f'{self.prefix}{file_idx:03d}{frame_num:03d}' + + for camera in frame.context.camera_calibrations: + # FRONT = 1, see dataset.proto for details + if camera.name == 1: + T_front_cam_to_vehicle = np.array( + camera.extrinsic.transform).reshape(4, 4) + + T_k2w = T_front_cam_to_vehicle @ self.T_ref_to_front_cam + + context_name = frame.context.name + frame_timestamp_micros = frame.timestamp_micros + + if filename in self.name2idx: + kitti_result = \ + self.kitti_result_files[self.name2idx[filename]] + objects = self.parse_objects(kitti_result, T_k2w, context_name, + frame_timestamp_micros) + else: + print(filename, 'not found.') + objects = metrics_pb2.Objects() + + with open( + join(self.waymo_results_save_dir, f'{filename}.bin'), + 'wb') as f: + f.write(objects.SerializeToString()) + + def convert(self): + """Convert action.""" + print('Start converting ...') + mmcv.track_parallel_progress(self.convert_one, range(len(self)), + self.workers) + print('\nFinished ...') + + # combine all files into one .bin + pathnames = sorted(glob(join(self.waymo_results_save_dir, '*.bin'))) + combined = self.combine(pathnames) + + with open(self.waymo_results_final_path, 'wb') as f: + f.write(combined.SerializeToString()) + + def __len__(self): + """Length of the filename list.""" + return len(self.waymo_tfrecord_pathnames) + + def transform(self, T, x, y, z): + """Transform the coordinates with matrix T. + + Args: + T (np.ndarray): Transformation matrix. + x(float): Coordinate in x axis. + y(float): Coordinate in y axis. + z(float): Coordinate in z axis. + + Returns: + list: Coordinates after transformation. + """ + pt_bef = np.array([x, y, z, 1.0]).reshape(4, 1) + pt_aft = np.matmul(T, pt_bef) + return pt_aft[:3].flatten().tolist() + + def combine(self, pathnames): + """Combine predictions in waymo format for each sample together. + + Args: + pathnames (str): Paths to save predictions. + + Returns: + :obj:`Objects`: Combined predictions in Objects proto. + """ + combined = metrics_pb2.Objects() + + for pathname in pathnames: + objects = metrics_pb2.Objects() + with open(pathname, 'rb') as f: + objects.ParseFromString(f.read()) + for o in objects.objects: + combined.objects.append(o) + + return combined diff --git a/mmdet3d/datasets/__init__.py b/mmdet3d/datasets/__init__.py index e9649545da..f1976c9a9f 100644 --- a/mmdet3d/datasets/__init__.py +++ b/mmdet3d/datasets/__init__.py @@ -12,6 +12,7 @@ RandomFlip3D) from .scannet_dataset import ScanNetDataset from .sunrgbd_dataset import SUNRGBDDataset +from .waymo_dataset import WaymoDataset __all__ = [ 'KittiDataset', 'GroupSampler', 'DistributedGroupSampler', @@ -21,5 +22,5 @@ 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample', 'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset', - 'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter' + 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter' ] diff --git a/mmdet3d/datasets/kitti_dataset.py b/mmdet3d/datasets/kitti_dataset.py index d5ac1d799c..d7c57c6f32 100644 --- a/mmdet3d/datasets/kitti_dataset.py +++ b/mmdet3d/datasets/kitti_dataset.py @@ -44,6 +44,8 @@ class KittiDataset(Custom3DDataset): Defaults to True. test_mode (bool, optional): Whether the dataset is in test mode. Defaults to False. + pcd_limit_range (list): The range of point cloud used to filter + invalid predicted boxes. Default: [0, -40, -3, 70.4, 40, 0.0]. """ CLASSES = ('car', 'pedestrian', 'cyclist') @@ -57,7 +59,8 @@ def __init__(self, modality=None, box_type_3d='LiDAR', filter_empty_gt=True, - test_mode=False): + test_mode=False, + pcd_limit_range=[0, -40, -3, 70.4, 40, 0.0]): super().__init__( data_root=data_root, ann_file=ann_file, @@ -68,9 +71,10 @@ def __init__(self, filter_empty_gt=filter_empty_gt, test_mode=test_mode) + self.split = split self.root_split = os.path.join(self.data_root, split) assert self.modality is not None - self.pcd_limit_range = [0, -40, -3, 70.4, 40, 0.0] + self.pcd_limit_range = pcd_limit_range self.pts_prefix = pts_prefix def _get_pts_filename(self, idx): @@ -157,7 +161,6 @@ def get_ann_info(self, index): dims = annos['dimensions'] rots = annos['rotation_y'] gt_names = annos['name'] - # print(gt_names, len(loc)) gt_bboxes_3d = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1).astype(np.float32) @@ -167,7 +170,6 @@ def get_ann_info(self, index): gt_bboxes = annos['bbox'] selected = self.drop_arrays_by_name(gt_names, ['DontCare']) - # gt_bboxes_3d = gt_bboxes_3d[selected].astype('float32') gt_bboxes = gt_bboxes[selected].astype('float32') gt_names = gt_names[selected] @@ -177,7 +179,7 @@ def get_ann_info(self, index): gt_labels.append(self.CLASSES.index(cat)) else: gt_labels.append(-1) - gt_labels = np.array(gt_labels) + gt_labels = np.array(gt_labels).astype(np.int64) gt_labels_3d = copy.deepcopy(gt_labels) anns_results = dict( @@ -372,7 +374,8 @@ def bbox2result_kitti(self, Returns: list[dict]: A list of dictionaries with the kitti format. """ - assert len(net_outputs) == len(self.data_infos) + assert len(net_outputs) == len(self.data_infos), \ + 'invalid list length of network outputs' if submission_prefix is not None: mmcv.mkdir_or_exist(submission_prefix) @@ -465,7 +468,7 @@ def bbox2result_kitti(self, if not pklfile_prefix.endswith(('.pkl', '.pickle')): out = f'{pklfile_prefix}.pkl' mmcv.dump(det_annos, out) - print('Result is saved to %s' % out) + print(f'Result is saved to {out}.') return det_annos @@ -487,8 +490,8 @@ def bbox2result_kitti2d(self, Returns: list[dict]: A list of dictionaries have the kitti format """ - assert len(net_outputs) == len(self.data_infos) - + assert len(net_outputs) == len(self.data_infos), \ + 'invalid list length of network outputs' det_annos = [] print('\nConverting prediction to KITTI format') for i, bboxes_per_sample in enumerate( diff --git a/mmdet3d/datasets/pipelines/dbsampler.py b/mmdet3d/datasets/pipelines/dbsampler.py index f4f86c5a91..24bf402f09 100644 --- a/mmdet3d/datasets/pipelines/dbsampler.py +++ b/mmdet3d/datasets/pipelines/dbsampler.py @@ -251,7 +251,6 @@ def sample_all(self, gt_bboxes, gt_labels, img=None): file_path = os.path.join( self.data_root, info['path']) if self.data_root else info['path'] - results = dict(pts_filename=file_path) s_points = self.points_loader(results)['points'] s_points[:, :3] += info['box3d_lidar'][:3] diff --git a/mmdet3d/datasets/waymo_dataset.py b/mmdet3d/datasets/waymo_dataset.py new file mode 100644 index 0000000000..9bc7ea5ffe --- /dev/null +++ b/mmdet3d/datasets/waymo_dataset.py @@ -0,0 +1,525 @@ +import mmcv +import numpy as np +import os +import tempfile +import torch +from mmcv.utils import print_log +from os import path as osp + +from mmdet.datasets import DATASETS +from ..core.bbox import Box3DMode, points_cam2img +from ..core.evaluation.waymo_utils.prediction_kitti_to_waymo import KITTI2Waymo +from .kitti_dataset import KittiDataset + + +@DATASETS.register_module() +class WaymoDataset(KittiDataset): + """Waymo Dataset. + + This class serves as the API for experiments on the Waymo Dataset. + + Please refer to ``_for data downloading. + It is recommended to symlink the dataset root to $MMDETECTION3D/data and + organize them as the doc shows. + + Args: + data_root (str): Path of dataset root. + ann_file (str): Path of annotation file. + split (str): Split of input data. + pts_prefix (str, optional): Prefix of points files. + Defaults to 'velodyne'. + pipeline (list[dict], optional): Pipeline used for data processing. + Defaults to None. + classes (tuple[str], optional): Classes used in the dataset. + Defaults to None. + modality (dict, optional): Modality to specify the sensor data used + as input. Defaults to None. + box_type_3d (str, optional): Type of 3D box of this dataset. + Based on the `box_type_3d`, the dataset will encapsulate the box + to its original format then converted them to `box_type_3d`. + Defaults to 'LiDAR' in this dataset. Available options includes + + - 'LiDAR': box in LiDAR coordinates + - 'Depth': box in depth coordinates, usually for indoor dataset + - 'Camera': box in camera coordinates + filter_empty_gt (bool, optional): Whether to filter empty GT. + Defaults to True. + test_mode (bool, optional): Whether the dataset is in test mode. + Defaults to False. + pcd_limit_range (list): The range of point cloud used to filter + invalid predicted boxes. Default: [-85, -85, -5, 85, 85, 5]. + """ + + CLASSES = ('Car', 'Cyclist', 'Pedestrian') + + def __init__(self, + data_root, + ann_file, + split, + pts_prefix='velodyne', + pipeline=None, + classes=None, + modality=None, + box_type_3d='LiDAR', + filter_empty_gt=True, + test_mode=False, + load_interval=1, + pcd_limit_range=[-85, -85, -5, 85, 85, 5]): + super().__init__( + data_root=data_root, + ann_file=ann_file, + split=split, + pts_prefix=pts_prefix, + pipeline=pipeline, + classes=classes, + modality=modality, + box_type_3d=box_type_3d, + filter_empty_gt=filter_empty_gt, + test_mode=test_mode, + pcd_limit_range=pcd_limit_range) + + # to load a subset, just set the load_interval in the dataset config + self.data_infos = self.data_infos[::load_interval] + if hasattr(self, 'flag'): + self.flag = self.flag[::load_interval] + + def _get_pts_filename(self, idx): + pts_filename = osp.join(self.root_split, self.pts_prefix, + f'{idx:07d}.bin') + return pts_filename + + def get_data_info(self, index): + """Get data info according to the given index. + + Args: + index (int): Index of the sample data to get. + + Returns: + dict: Standard input_dict consists of the + data information. + + - sample_idx (str): sample index + - pts_filename (str): filename of point clouds + - img_prefix (str | None): prefix of image files + - img_info (dict): image info + - lidar2img (list[np.ndarray], optional): transformations from + lidar to different cameras + - ann_info (dict): annotation info + """ + info = self.data_infos[index] + sample_idx = info['image']['image_idx'] + img_filename = os.path.join(self.data_root, + info['image']['image_path']) + + # TODO: consider use torch.Tensor only + rect = info['calib']['R0_rect'].astype(np.float32) + Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32) + P0 = info['calib']['P0'].astype(np.float32) + lidar2img = P0 @ rect @ Trv2c + + pts_filename = self._get_pts_filename(sample_idx) + input_dict = dict( + sample_idx=sample_idx, + pts_filename=pts_filename, + img_prefix=None, + img_info=dict(filename=img_filename), + lidar2img=lidar2img) + + if not self.test_mode: + annos = self.get_ann_info(index) + input_dict['ann_info'] = annos + + return input_dict + + def format_results(self, + outputs, + pklfile_prefix=None, + submission_prefix=None, + data_format='waymo'): + """Format the results to pkl file. + + Args: + outputs (list[dict]): Testing results of the dataset. + pklfile_prefix (str | None): The prefix of pkl files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Default: None. + submission_prefix (str | None): The prefix of submitted files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". If not specified, a temp file will be created. + Default: None. + data_format (str | None): Output data format. Default: 'waymo'. + Another supported choice is 'kitti'. + + Returns: + tuple: (result_files, tmp_dir), result_files is a dict containing + the json filepaths, tmp_dir is the temporal directory created + for saving json files when jsonfile_prefix is not specified. + """ + if pklfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + pklfile_prefix = osp.join(tmp_dir.name, 'results') + else: + tmp_dir = None + + assert ('waymo' in data_format or 'kitti' in data_format), \ + f'invalid data_format {data_format}' + + if (not isinstance(outputs[0], dict)) or 'img_bbox' in outputs[0]: + raise TypeError('Not supported type for reformat results.') + elif 'pts_bbox' in outputs[0]: + result_files = dict() + for name in outputs[0]: + results_ = [out[name] for out in outputs] + pklfile_prefix_ = pklfile_prefix + name + if submission_prefix is not None: + submission_prefix_ = f'{submission_prefix}_{name}' + else: + submission_prefix_ = None + result_files_ = self.bbox2result_kitti(results_, self.CLASSES, + pklfile_prefix_, + submission_prefix_) + result_files[name] = result_files_ + else: + result_files = self.bbox2result_kitti(outputs, self.CLASSES, + pklfile_prefix, + submission_prefix) + if 'waymo' in data_format: + waymo_root = osp.join( + self.data_root.split('kitti_format')[0], 'waymo_format') + if self.split == 'training': + waymo_tfrecords_dir = osp.join(waymo_root, 'validation') + prefix = '1' + elif self.split == 'testing': + waymo_tfrecords_dir = osp.join(waymo_root, 'testing') + prefix = '2' + else: + raise ValueError('Not supported split value.') + save_tmp_dir = tempfile.TemporaryDirectory() + waymo_results_save_dir = save_tmp_dir.name + waymo_results_final_path = f'{pklfile_prefix}.bin' + if 'pts_bbox' in result_files: + converter = KITTI2Waymo(result_files['pts_bbox'], + waymo_tfrecords_dir, + waymo_results_save_dir, + waymo_results_final_path, prefix) + else: + converter = KITTI2Waymo(result_files, waymo_tfrecords_dir, + waymo_results_save_dir, + waymo_results_final_path, prefix) + converter.convert() + save_tmp_dir.cleanup() + + return result_files, tmp_dir + + def evaluate(self, + results, + metric='waymo', + logger=None, + pklfile_prefix=None, + submission_prefix=None, + show=False, + out_dir=None): + """Evaluation in KITTI protocol. + + Args: + results (list[dict]): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + Default: 'waymo'. Another supported metric is 'kitti'. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + pklfile_prefix (str | None): The prefix of pkl files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Default: None. + submission_prefix (str | None): The prefix of submission datas. + If not specified, the submission data will not be generated. + show (bool): Whether to visualize. + Default: False. + out_dir (str): Path to save the visualization results. + Default: None. + + Returns: + dict[str: float]: results of each evaluation metric + """ + assert ('waymo' in metric or 'kitti' in metric), \ + f'invalid metric {metric}' + if 'kitti' in metric: + result_files, tmp_dir = self.format_results( + results, + pklfile_prefix, + submission_prefix, + data_format='kitti') + from mmdet3d.core.evaluation import kitti_eval + gt_annos = [info['annos'] for info in self.data_infos] + + if isinstance(result_files, dict): + ap_dict = dict() + for name, result_files_ in result_files.items(): + eval_types = ['bev', '3d'] + ap_result_str, ap_dict_ = kitti_eval( + gt_annos, + result_files_, + self.CLASSES, + eval_types=eval_types) + for ap_type, ap in ap_dict_.items(): + ap_dict[f'{name}/{ap_type}'] = float( + '{:.4f}'.format(ap)) + + print_log( + f'Results of {name}:\n' + ap_result_str, logger=logger) + + else: + ap_result_str, ap_dict = kitti_eval( + gt_annos, + result_files, + self.CLASSES, + eval_types=['bev', '3d']) + print_log('\n' + ap_result_str, logger=logger) + if 'waymo' in metric: + waymo_root = osp.join( + self.data_root.split('kitti_format')[0], 'waymo_format') + if pklfile_prefix is None: + eval_tmp_dir = tempfile.TemporaryDirectory() + pklfile_prefix = osp.join(eval_tmp_dir.name, 'results') + else: + eval_tmp_dir = None + result_files, tmp_dir = self.format_results( + results, + pklfile_prefix, + submission_prefix, + data_format='waymo') + import subprocess + ret_bytes = subprocess.check_output( + 'mmdet3d/core/evaluation/waymo_utils/' + + f'compute_detection_metrics_main {pklfile_prefix}.bin ' + + f'{waymo_root}/gt.bin', + shell=True) + ret_texts = ret_bytes.decode('utf-8') + print_log(ret_texts) + # parse the text to get ap_dict + ap_dict = { + 'Vehicle/L1 mAP': 0, + 'Vehicle/L1 mAPH': 0, + 'Vehicle/L2 mAP': 0, + 'Vehicle/L2 mAPH': 0, + 'Pedestrian/L1 mAP': 0, + 'Pedestrian/L1 mAPH': 0, + 'Pedestrian/L2 mAP': 0, + 'Pedestrian/L2 mAPH': 0, + 'Sign/L1 mAP': 0, + 'Sign/L1 mAPH': 0, + 'Sign/L2 mAP': 0, + 'Sign/L2 mAPH': 0, + 'Cyclist/L1 mAP': 0, + 'Cyclist/L1 mAPH': 0, + 'Cyclist/L2 mAP': 0, + 'Cyclist/L2 mAPH': 0 + } + mAP_splits = ret_texts.split('mAP ') + mAPH_splits = ret_texts.split('mAPH ') + for idx, key in enumerate(ap_dict.keys()): + split_idx = int(idx / 2) + 1 + if idx % 2 == 0: # mAP + ap_dict[key] = float(mAP_splits[split_idx].split(']')[0]) + else: # mAPH + ap_dict[key] = float(mAPH_splits[split_idx].split(']')[0]) + if eval_tmp_dir is not None: + eval_tmp_dir.cleanup() + + if tmp_dir is not None: + tmp_dir.cleanup() + + if show: + self.show(results, out_dir) + return ap_dict + + def bbox2result_kitti(self, + net_outputs, + class_names, + pklfile_prefix=None, + submission_prefix=None): + """Convert results to kitti format for evaluation and test submission. + + Args: + net_outputs (List[np.ndarray]): list of array storing the + bbox and score + class_nanes (List[String]): A list of class names + pklfile_prefix (str | None): The prefix of pkl file. + submission_prefix (str | None): The prefix of submission file. + + Returns: + List[dict]: A list of dict have the kitti 3d format + """ + assert len(net_outputs) == len(self.data_infos), \ + 'invalid list length of network outputs' + if submission_prefix is not None: + mmcv.mkdir_or_exist(submission_prefix) + + det_annos = [] + print('\nConverting prediction to KITTI format') + for idx, pred_dicts in enumerate( + mmcv.track_iter_progress(net_outputs)): + annos = [] + info = self.data_infos[idx] + sample_idx = info['image']['image_idx'] + image_shape = info['image']['image_shape'][:2] + + box_dict = self.convert_valid_bboxes(pred_dicts, info) + if len(box_dict['bbox']) > 0: + box_2d_preds = box_dict['bbox'] + box_preds = box_dict['box3d_camera'] + scores = box_dict['scores'] + box_preds_lidar = box_dict['box3d_lidar'] + label_preds = box_dict['label_preds'] + + anno = { + 'name': [], + 'truncated': [], + 'occluded': [], + 'alpha': [], + 'bbox': [], + 'dimensions': [], + 'location': [], + 'rotation_y': [], + 'score': [] + } + + for box, box_lidar, bbox, score, label in zip( + box_preds, box_preds_lidar, box_2d_preds, scores, + label_preds): + bbox[2:] = np.minimum(bbox[2:], image_shape[::-1]) + bbox[:2] = np.maximum(bbox[:2], [0, 0]) + anno['name'].append(class_names[int(label)]) + anno['truncated'].append(0.0) + anno['occluded'].append(0) + anno['alpha'].append( + -np.arctan2(-box_lidar[1], box_lidar[0]) + box[6]) + anno['bbox'].append(bbox) + anno['dimensions'].append(box[3:6]) + anno['location'].append(box[:3]) + anno['rotation_y'].append(box[6]) + anno['score'].append(score) + + anno = {k: np.stack(v) for k, v in anno.items()} + annos.append(anno) + + if submission_prefix is not None: + curr_file = f'{submission_prefix}/{sample_idx:07d}.txt' + with open(curr_file, 'w') as f: + bbox = anno['bbox'] + loc = anno['location'] + dims = anno['dimensions'] # lhw -> hwl + + for idx in range(len(bbox)): + print( + '{} -1 -1 {:.4f} {:.4f} {:.4f} {:.4f} ' + '{:.4f} {:.4f} {:.4f} ' + '{:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'. + format(anno['name'][idx], anno['alpha'][idx], + bbox[idx][0], bbox[idx][1], + bbox[idx][2], bbox[idx][3], + dims[idx][1], dims[idx][2], + dims[idx][0], loc[idx][0], loc[idx][1], + loc[idx][2], anno['rotation_y'][idx], + anno['score'][idx]), + file=f) + else: + annos.append({ + 'name': np.array([]), + 'truncated': np.array([]), + 'occluded': np.array([]), + 'alpha': np.array([]), + 'bbox': np.zeros([0, 4]), + 'dimensions': np.zeros([0, 3]), + 'location': np.zeros([0, 3]), + 'rotation_y': np.array([]), + 'score': np.array([]), + }) + annos[-1]['sample_idx'] = np.array( + [sample_idx] * len(annos[-1]['score']), dtype=np.int64) + + det_annos += annos + + if pklfile_prefix is not None: + if not pklfile_prefix.endswith(('.pkl', '.pickle')): + out = f'{pklfile_prefix}.pkl' + mmcv.dump(det_annos, out) + print(f'Result is saved to {out}.') + + return det_annos + + def convert_valid_bboxes(self, box_dict, info): + """Convert the boxes into valid format. + + Args: + box_dict (dict): Bounding boxes to be converted. + + - boxes_3d (:obj:``LiDARInstance3DBoxes``): 3D bounding boxes. + - scores_3d (np.ndarray): Scores of predicted boxes. + - labels_3d (np.ndarray): Class labels of predicted boxes. + info (dict): Dataset information dictionary. + + Returns: + dict: Valid boxes after conversion. + + - bbox (np.ndarray): 2D bounding boxes (in camera 0). + - box3d_camera (np.ndarray): 3D boxes in camera coordinates. + - box3d_lidar (np.ndarray): 3D boxes in lidar coordinates. + - scores (np.ndarray): Scores of predicted boxes. + - label_preds (np.ndarray): Class labels of predicted boxes. + - sample_idx (np.ndarray): Sample index. + """ + # TODO: refactor this function + box_preds = box_dict['boxes_3d'] + scores = box_dict['scores_3d'] + labels = box_dict['labels_3d'] + sample_idx = info['image']['image_idx'] + # TODO: remove the hack of yaw + box_preds.limit_yaw(offset=0.5, period=np.pi * 2) + + if len(box_preds) == 0: + return dict( + bbox=np.zeros([0, 4]), + box3d_camera=np.zeros([0, 7]), + box3d_lidar=np.zeros([0, 7]), + scores=np.zeros([0]), + label_preds=np.zeros([0, 4]), + sample_idx=sample_idx) + + rect = info['calib']['R0_rect'].astype(np.float32) + Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32) + P0 = info['calib']['P0'].astype(np.float32) + P0 = box_preds.tensor.new_tensor(P0) + + box_preds_camera = box_preds.convert_to(Box3DMode.CAM, rect @ Trv2c) + + box_corners = box_preds_camera.corners + box_corners_in_image = points_cam2img(box_corners, P0) + # box_corners_in_image: [N, 8, 2] + minxy = torch.min(box_corners_in_image, dim=1)[0] + maxxy = torch.max(box_corners_in_image, dim=1)[0] + box_2d_preds = torch.cat([minxy, maxxy], dim=1) + # Post-processing + # check box_preds + limit_range = box_preds.tensor.new_tensor(self.pcd_limit_range) + valid_pcd_inds = ((box_preds.center > limit_range[:3]) & + (box_preds.center < limit_range[3:])) + valid_inds = valid_pcd_inds.all(-1) + + if valid_inds.sum() > 0: + return dict( + bbox=box_2d_preds[valid_inds, :].numpy(), + box3d_camera=box_preds_camera[valid_inds].tensor.numpy(), + box3d_lidar=box_preds[valid_inds].tensor.numpy(), + scores=scores[valid_inds].numpy(), + label_preds=labels[valid_inds].numpy(), + sample_idx=sample_idx, + ) + else: + return dict( + bbox=np.zeros([0, 4]), + box3d_camera=np.zeros([0, 7]), + box3d_lidar=np.zeros([0, 7]), + scores=np.zeros([0]), + label_preds=np.zeros([0, 4]), + sample_idx=sample_idx, + ) diff --git a/mmdet3d/models/detectors/base.py b/mmdet3d/models/detectors/base.py index 7e2655ba58..2c033c10ee 100644 --- a/mmdet3d/models/detectors/base.py +++ b/mmdet3d/models/detectors/base.py @@ -48,8 +48,8 @@ def forward(self, return_loss=True, **kwargs): Note this setting will change the expected inputs. When `return_loss=True`, img and img_metas are single-nested (i.e. - torch.Tensor and list[dict]), and when `resturn_loss=False`, img and - img_metas should be double nested (i.e. list[torch.Tensor], + torch.Tensor and list[dict]), and when `resturn_loss=False`, img + and img_metas should be double nested (i.e. list[torch.Tensor], list[list[dict]]), with the outer list indicating test time augmentations. """ diff --git a/mmdet3d/models/detectors/two_stage.py b/mmdet3d/models/detectors/two_stage.py index fe8760d0e9..9284b67727 100644 --- a/mmdet3d/models/detectors/two_stage.py +++ b/mmdet3d/models/detectors/two_stage.py @@ -7,8 +7,8 @@ class TwoStage3DDetector(Base3DDetector, TwoStageDetector): """Base class of two-stage 3D detector. It inherits original ``:class:TwoStageDetector`` and - ``:class:Base3DDetector``. This class could serve as a base class for all - two-stage 3D detectors. + ``:class:Base3DDetector``. This class could serve as a base class + for all two-stage 3D detectors. """ def __init__(self, **kwargs): diff --git a/mmdet3d/models/roi_heads/base_3droi_head.py b/mmdet3d/models/roi_heads/base_3droi_head.py index 21809a7f32..be088c7c0b 100644 --- a/mmdet3d/models/roi_heads/base_3droi_head.py +++ b/mmdet3d/models/roi_heads/base_3droi_head.py @@ -93,7 +93,7 @@ def simple_test(self, def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs): """Test with augmentations. - If rescale is False, then returned bboxes and masks will fit the scale - of imgs[0]. + If rescale is False, then returned bboxes and masks will fit the + scale of imgs[0]. """ pass diff --git a/mmdet3d/ops/furthest_point_sample/furthest_point_sample.py b/mmdet3d/ops/furthest_point_sample/furthest_point_sample.py index d03b57c7ac..6d7b2b5974 100644 --- a/mmdet3d/ops/furthest_point_sample/furthest_point_sample.py +++ b/mmdet3d/ops/furthest_point_sample/furthest_point_sample.py @@ -7,8 +7,8 @@ class FurthestPointSampling(Function): """Furthest Point Sampling. - Uses iterative furthest point sampling to select a set of features whose - corresponding points have the furthest distance. + Uses iterative furthest point sampling to select a set of features + whose corresponding points have the furthest distance. """ @staticmethod diff --git a/mmdet3d/ops/spconv/structure.py b/mmdet3d/ops/spconv/structure.py index 6e59af28ee..31e224b4e0 100644 --- a/mmdet3d/ops/spconv/structure.py +++ b/mmdet3d/ops/spconv/structure.py @@ -5,8 +5,9 @@ def scatter_nd(indices, updates, shape): """pytorch edition of tensorflow scatter_nd. - this function don't contain except handle code. so use this carefully when - indice repeats, don't support repeat add which is supported in tensorflow. + this function don't contain except handle code. so use this + carefully when indice repeats, don't support repeat add which is + supported in tensorflow. """ ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device) ndim = indices.shape[-1] diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 02d680cbab..59cc56ae75 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -7,3 +7,5 @@ plyfile # by default we also use tensorboard to log results tensorboard trimesh>=2.35.39,<2.35.40 +scikit-image +waymo-open-dataset-tf-2-1-0==1.2.0 diff --git a/setup.cfg b/setup.cfg index 3f017ece7d..a0521648c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet,mmdet3d -known_third_party = load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,recommonmark,scannet_utils,scipy,seaborn,shapely,skimage,terminaltables,torch,trimesh +known_third_party = load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,recommonmark,scannet_utils,scipy,seaborn,shapely,skimage,tensorflow,terminaltables,torch,trimesh,waymo_open_dataset no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_detectors.py b/tests/test_detectors.py index 5382f20500..73f2379d0e 100644 --- a/tests/test_detectors.py +++ b/tests/test_detectors.py @@ -44,8 +44,8 @@ def _get_config_module(fname): def _get_model_cfg(fname): """Grab configs necessary to create a model. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ config = _get_config_module(fname) model = copy.deepcopy(config.model) @@ -56,8 +56,8 @@ def _get_model_cfg(fname): def _get_detector_cfg(fname): """Grab configs necessary to create a detector. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ import mmcv config = _get_config_module(fname) diff --git a/tests/test_forward.py b/tests/test_forward.py index 9e416cae54..9766e5fb4a 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -37,8 +37,8 @@ def _get_config_module(fname): def _get_detector_cfg(fname): """Grab configs necessary to create a detector. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ import mmcv config = _get_config_module(fname) diff --git a/tests/test_heads.py b/tests/test_heads.py index 561e28e00e..1bd8a50d7c 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -46,8 +46,8 @@ def _get_config_module(fname): def _get_head_cfg(fname): """Grab configs necessary to create a bbox_head. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ import mmcv config = _get_config_module(fname) @@ -64,8 +64,8 @@ def _get_head_cfg(fname): def _get_rpn_head_cfg(fname): """Grab configs necessary to create a rpn_head. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ import mmcv config = _get_config_module(fname) @@ -82,8 +82,8 @@ def _get_rpn_head_cfg(fname): def _get_roi_head_cfg(fname): """Grab configs necessary to create a roi_head. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ import mmcv config = _get_config_module(fname) @@ -100,8 +100,8 @@ def _get_roi_head_cfg(fname): def _get_pts_bbox_head_cfg(fname): """Grab configs necessary to create a pts_bbox_head. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ import mmcv config = _get_config_module(fname) @@ -118,8 +118,8 @@ def _get_pts_bbox_head_cfg(fname): def _get_vote_head_cfg(fname): """Grab configs necessary to create a vote_head. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ import mmcv config = _get_config_module(fname) @@ -136,8 +136,8 @@ def _get_vote_head_cfg(fname): def _get_parta2_bbox_head_cfg(fname): """Grab configs necessary to create a parta2_bbox_head. - These are deep copied to allow for safe modification of parameters without - influencing other tests. + These are deep copied to allow for safe modification of parameters + without influencing other tests. """ config = _get_config_module(fname) model = copy.deepcopy(config.model) diff --git a/tools/create_data.py b/tools/create_data.py index 7e78b328d1..9a71b16d8d 100644 --- a/tools/create_data.py +++ b/tools/create_data.py @@ -5,6 +5,7 @@ from tools.data_converter import kitti_converter as kitti from tools.data_converter import lyft_converter as lyft_converter from tools.data_converter import nuscenes_converter as nuscenes_converter +from tools.data_converter import waymo_converter as waymo from tools.data_converter.create_gt_database import create_groundtruth_database @@ -133,6 +134,48 @@ def sunrgbd_data_prep(root_path, info_prefix, out_dir, workers): root_path, info_prefix, out_dir, workers=workers) +def waymo_data_prep(root_path, + info_prefix, + version, + out_dir, + workers, + max_sweeps=5): + """Prepare the info file for waymo dataset. + + Args: + root_path (str): Path of dataset root. + info_prefix (str): The prefix of info filenames. + out_dir (str): Output directory of the generated info file. + workers (int): Number of threads to be used. + max_sweeps (int): Number of input consecutive frames. Default: 5 \ + Here we store pose information of these frames for later use. + """ + splits = ['training', 'validation', 'testing'] + for i, split in enumerate(splits): + load_dir = osp.join(root_path, 'waymo_format', split) + if split == 'validation': + save_dir = osp.join(out_dir, 'kitti_format', 'training') + else: + save_dir = osp.join(out_dir, 'kitti_format', split) + converter = waymo.Waymo2KITTI( + load_dir, + save_dir, + prefix=str(i), + workers=workers, + test_mode=(split == 'test')) + converter.convert() + # Generate waymo infos + out_dir = osp.join(out_dir, 'kitti_format') + kitti.create_waymo_info_file(out_dir, info_prefix, max_sweeps=max_sweeps) + create_groundtruth_database( + 'WaymoDataset', + out_dir, + info_prefix, + f'{out_dir}/{info_prefix}_infos_train.pkl', + relative_path=False, + with_mask=False) + + parser = argparse.ArgumentParser(description='Data converter arg parser') parser.add_argument('dataset', metavar='kitti', help='name of the dataset') parser.add_argument( @@ -213,6 +256,14 @@ def sunrgbd_data_prep(root_path, info_prefix, out_dir, workers): dataset_name='LyftDataset', out_dir=args.out_dir, max_sweeps=args.max_sweeps) + elif args.dataset == 'waymo': + waymo_data_prep( + root_path=args.root_path, + info_prefix=args.extra_tag, + version=args.version, + out_dir=args.out_dir, + workers=args.workers, + max_sweeps=args.max_sweeps) elif args.dataset == 'scannet': scannet_data_prep( root_path=args.root_path, diff --git a/tools/data_converter/create_gt_database.py b/tools/data_converter/create_gt_database.py index 2150d20c02..b540ea4ca9 100644 --- a/tools/data_converter/create_gt_database.py +++ b/tools/data_converter/create_gt_database.py @@ -183,6 +183,31 @@ def create_groundtruth_database(dataset_class_name, with_bbox_3d=True, with_label_3d=True) ]) + + elif dataset_class_name == 'WaymoDataset': + file_client_args = dict(backend='disk') + dataset_cfg.update( + test_mode=False, + split='training', + modality=dict( + use_lidar=True, + use_depth=False, + use_lidar_intensity=True, + use_camera=False, + ), + pipeline=[ + dict( + type='LoadPointsFromFile', + load_dim=6, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + file_client_args=file_client_args) + ]) + dataset = build_dataset(dataset_cfg) if database_save_path is None: diff --git a/tools/data_converter/kitti_converter.py b/tools/data_converter/kitti_converter.py index 3b9163e08d..94a6770e5f 100644 --- a/tools/data_converter/kitti_converter.py +++ b/tools/data_converter/kitti_converter.py @@ -1,10 +1,9 @@ +import mmcv import numpy as np -import pickle -from mmcv import track_iter_progress from pathlib import Path from mmdet3d.core.bbox import box_np_ops -from .kitti_data_utils import get_kitti_image_info +from .kitti_data_utils import get_kitti_image_info, get_waymo_image_info def convert_to_kitti_info_version2(info): @@ -43,7 +42,7 @@ def _calculate_num_points_in_gt(data_path, relative_path, remove_outside=True, num_features=4): - for info in track_iter_progress(infos): + for info in mmcv.track_iter_progress(infos): pc_info = info['point_cloud'] image_info = info['image'] calib = info['calib'] @@ -80,7 +79,7 @@ def _calculate_num_points_in_gt(data_path, def create_kitti_info_file(data_path, - pkl_prefix='kitti_', + pkl_prefix='kitti', save_path=None, relative_path=True): """Create info file of KITTI dataset. @@ -113,8 +112,7 @@ def create_kitti_info_file(data_path, _calculate_num_points_in_gt(data_path, kitti_infos_train, relative_path) filename = save_path / f'{pkl_prefix}_infos_train.pkl' print(f'Kitti info train file is saved to {filename}') - with open(filename, 'wb') as f: - pickle.dump(kitti_infos_train, f) + mmcv.dump(kitti_infos_train, filename) kitti_infos_val = get_kitti_image_info( data_path, training=True, @@ -125,12 +123,10 @@ def create_kitti_info_file(data_path, _calculate_num_points_in_gt(data_path, kitti_infos_val, relative_path) filename = save_path / f'{pkl_prefix}_infos_val.pkl' print(f'Kitti info val file is saved to {filename}') - with open(filename, 'wb') as f: - pickle.dump(kitti_infos_val, f) + mmcv.dump(kitti_infos_val, filename) filename = save_path / f'{pkl_prefix}_infos_trainval.pkl' print(f'Kitti info trainval file is saved to {filename}') - with open(filename, 'wb') as f: - pickle.dump(kitti_infos_train + kitti_infos_val, f) + mmcv.dump(kitti_infos_train + kitti_infos_val, filename) kitti_infos_test = get_kitti_image_info( data_path, @@ -142,18 +138,109 @@ def create_kitti_info_file(data_path, relative_path=relative_path) filename = save_path / f'{pkl_prefix}_infos_test.pkl' print(f'Kitti info test file is saved to {filename}') - with open(filename, 'wb') as f: - pickle.dump(kitti_infos_test, f) + mmcv.dump(kitti_infos_test, filename) + + +def create_waymo_info_file(data_path, + pkl_prefix='waymo', + save_path=None, + relative_path=True, + max_sweeps=5): + """Create info file of waymo dataset. + + Given the raw data, generate its related info file in pkl format. + + Args: + data_path (str): Path of the data root. + pkl_prefix (str): Prefix of the info file to be generated. + save_path (str | None): Path to save the info file. + relative_path (bool): Whether to use relative path. + max_sweeps (int): Max sweeps before the detection frame to be used. + """ + imageset_folder = Path(data_path) / 'ImageSets' + train_img_ids = _read_imageset_file(str(imageset_folder / 'train.txt')) + val_img_ids = _read_imageset_file(str(imageset_folder / 'val.txt')) + test_img_ids = _read_imageset_file(str(imageset_folder / 'test.txt')) + + print('Generate info. this may take several minutes.') + if save_path is None: + save_path = Path(data_path) + else: + save_path = Path(save_path) + waymo_infos_train = get_waymo_image_info( + data_path, + training=True, + velodyne=True, + calib=True, + pose=True, + image_ids=train_img_ids, + relative_path=relative_path, + max_sweeps=max_sweeps) + _calculate_num_points_in_gt( + data_path, + waymo_infos_train, + relative_path, + num_features=6, + remove_outside=False) + filename = save_path / f'{pkl_prefix}_infos_train.pkl' + print(f'Waymo info train file is saved to {filename}') + mmcv.dump(waymo_infos_train, filename) + waymo_infos_val = get_waymo_image_info( + data_path, + training=True, + velodyne=True, + calib=True, + pose=True, + image_ids=val_img_ids, + relative_path=relative_path, + max_sweeps=max_sweeps) + _calculate_num_points_in_gt( + data_path, + waymo_infos_val, + relative_path, + num_features=6, + remove_outside=False) + filename = save_path / f'{pkl_prefix}_infos_val.pkl' + print(f'Waymo info val file is saved to {filename}') + mmcv.dump(waymo_infos_val, filename) + filename = save_path / f'{pkl_prefix}_infos_trainval.pkl' + print(f'Waymo info trainval file is saved to {filename}') + mmcv.dump(waymo_infos_train + waymo_infos_val, filename) + waymo_infos_test = get_waymo_image_info( + data_path, + training=False, + label_info=False, + velodyne=True, + calib=True, + pose=True, + image_ids=test_img_ids, + relative_path=relative_path, + max_sweeps=max_sweeps) + filename = save_path / f'{pkl_prefix}_infos_test.pkl' + print(f'Waymo info test file is saved to {filename}') + mmcv.dump(waymo_infos_test, filename) def _create_reduced_point_cloud(data_path, info_path, save_path=None, - back=False): - with open(info_path, 'rb') as f: - kitti_infos = pickle.load(f) + back=False, + num_features=4, + front_camera_id=2): + """Create reduced point clouds for given info. - for info in track_iter_progress(kitti_infos): + Args: + data_path (str): Path of original data. + info_path (str): Path of data info. + save_path (str | None): Path to save reduced point cloud data. + Default: None. + back (bool): Whether to flip the points to back. + num_features (int): Number of point features. Default: 4. + front_camera_id (int): The referenced/front camera ID. Default: 2. + """ + kitti_infos = mmcv.load(info_path) + + for info in mmcv.track_iter_progress(kitti_infos): pc_info = info['point_cloud'] image_info = info['image'] calib = info['calib'] @@ -161,9 +248,13 @@ def _create_reduced_point_cloud(data_path, v_path = pc_info['velodyne_path'] v_path = Path(data_path) / v_path points_v = np.fromfile( - str(v_path), dtype=np.float32, count=-1).reshape([-1, 4]) + str(v_path), dtype=np.float32, + count=-1).reshape([-1, num_features]) rect = calib['R0_rect'] - P2 = calib['P2'] + if front_camera_id == 2: + P2 = calib['P2'] + else: + P2 = calib[f'P{str(front_camera_id)}'] Trv2c = calib['Tr_velo_to_cam'] # first remove z < 0 points # keep = points_v[:, -1] > 0 @@ -196,10 +287,10 @@ def create_reduced_point_cloud(data_path, test_info_path=None, save_path=None, with_back=False): - """Create reduced point cloud info file. + """Create reduced point clouds for training/validation/testing. Args: - data_path (str): Path of original infos. + data_path (str): Path of original data. pkl_prefix (str): Prefix of info files. train_info_path (str | None): Path of training set info. Default: None. @@ -207,8 +298,8 @@ def create_reduced_point_cloud(data_path, Default: None. test_info_path (str | None): Path of test set info. Default: None. - save_path (str | None): Path to save reduced info. - with_back (bool | None): Whether to create backup info. + save_path (str | None): Path to save reduced point cloud data. + with_back (bool): Whether to flip the points to back. """ if train_info_path is None: train_info_path = Path(data_path) / f'{pkl_prefix}_infos_train.pkl' @@ -219,7 +310,7 @@ def create_reduced_point_cloud(data_path, print('create reduced point cloud for training set') _create_reduced_point_cloud(data_path, train_info_path, save_path) - print('create reduced point cloud for validatin set') + print('create reduced point cloud for validation set') _create_reduced_point_cloud(data_path, val_info_path, save_path) print('create reduced point cloud for testing set') _create_reduced_point_cloud(data_path, test_info_path, save_path) diff --git a/tools/data_converter/kitti_data_utils.py b/tools/data_converter/kitti_data_utils.py index ba95c2eec8..e456260f2d 100644 --- a/tools/data_converter/kitti_data_utils.py +++ b/tools/data_converter/kitti_data_utils.py @@ -1,12 +1,16 @@ import numpy as np from collections import OrderedDict from concurrent import futures as futures +from os import path as osp from pathlib import Path from skimage import io -def get_image_index_str(img_idx): - return '{:06d}'.format(img_idx) +def get_image_index_str(img_idx, use_prefix_id=False): + if use_prefix_id: + return '{:07d}'.format(img_idx) + else: + return '{:06d}'.format(img_idx) def get_kitti_info_path(idx, @@ -15,8 +19,9 @@ def get_kitti_info_path(idx, file_tail='.png', training=True, relative_path=True, - exist_check=True): - img_idx_str = get_image_index_str(idx) + exist_check=True, + use_prefix_id=False): + img_idx_str = get_image_index_str(idx, use_prefix_id) img_idx_str += file_tail prefix = Path(prefix) if training: @@ -35,36 +40,52 @@ def get_image_path(idx, prefix, training=True, relative_path=True, - exist_check=True): - return get_kitti_info_path(idx, prefix, 'image_2', '.png', training, - relative_path, exist_check) + exist_check=True, + info_type='image_2', + use_prefix_id=False): + return get_kitti_info_path(idx, prefix, info_type, '.png', training, + relative_path, exist_check, use_prefix_id) def get_label_path(idx, prefix, training=True, relative_path=True, - exist_check=True): - return get_kitti_info_path(idx, prefix, 'label_2', '.txt', training, - relative_path, exist_check) + exist_check=True, + info_type='label_2', + use_prefix_id=False): + return get_kitti_info_path(idx, prefix, info_type, '.txt', training, + relative_path, exist_check, use_prefix_id) def get_velodyne_path(idx, prefix, training=True, relative_path=True, - exist_check=True): + exist_check=True, + use_prefix_id=False): return get_kitti_info_path(idx, prefix, 'velodyne', '.bin', training, - relative_path, exist_check) + relative_path, exist_check, use_prefix_id) def get_calib_path(idx, prefix, training=True, relative_path=True, - exist_check=True): + exist_check=True, + use_prefix_id=False): return get_kitti_info_path(idx, prefix, 'calib', '.txt', training, - relative_path, exist_check) + relative_path, exist_check, use_prefix_id) + + +def get_pose_path(idx, + prefix, + training=True, + relative_path=True, + exist_check=True, + use_prefix_id=False): + return get_kitti_info_path(idx, prefix, 'pose', '.txt', training, + relative_path, exist_check, use_prefix_id) def get_label_anno(label_path): @@ -126,7 +147,6 @@ def get_kitti_image_info(path, num_worker=8, relative_path=True, with_imageshape=True): - # image_infos = [] """ KITTI annotation format version 2: { @@ -241,6 +261,185 @@ def map_func(idx): return list(image_infos) +def get_waymo_image_info(path, + training=True, + label_info=True, + velodyne=False, + calib=False, + pose=False, + image_ids=7481, + extend_matrix=True, + num_worker=8, + relative_path=True, + with_imageshape=True, + max_sweeps=5): + """ + Waymo annotation format version like KITTI: + { + [optional]points: [N, 3+] point cloud + [optional, for kitti]image: { + image_idx: ... + image_path: ... + image_shape: ... + } + point_cloud: { + num_features: 6 + velodyne_path: ... + } + [optional, for kitti]calib: { + R0_rect: ... + Tr_velo_to_cam0: ... + P0: ... + } + annos: { + location: [num_gt, 3] array + dimensions: [num_gt, 3] array + rotation_y: [num_gt] angle array + name: [num_gt] ground truth name array + [optional]difficulty: kitti difficulty + [optional]group_ids: used for multi-part object + } + } + """ + root_path = Path(path) + if not isinstance(image_ids, list): + image_ids = list(range(image_ids)) + + def map_func(idx): + info = {} + pc_info = {'num_features': 6} + calib_info = {} + + image_info = {'image_idx': idx} + annotations = None + if velodyne: + pc_info['velodyne_path'] = get_velodyne_path( + idx, path, training, relative_path, use_prefix_id=True) + points = np.fromfile( + Path(path) / pc_info['velodyne_path'], dtype=np.float32) + points = np.copy(points).reshape(-1, pc_info['num_features']) + info['timestamp'] = np.int64(points[0, -1]) + # values of the last dim are all the timestamp + image_info['image_path'] = get_image_path( + idx, + path, + training, + relative_path, + info_type='image_0', + use_prefix_id=True) + if with_imageshape: + img_path = image_info['image_path'] + if relative_path: + img_path = str(root_path / img_path) + image_info['image_shape'] = np.array( + io.imread(img_path).shape[:2], dtype=np.int32) + if label_info: + label_path = get_label_path( + idx, + path, + training, + relative_path, + info_type='label_all', + use_prefix_id=True) + if relative_path: + label_path = str(root_path / label_path) + annotations = get_label_anno(label_path) + info['image'] = image_info + info['point_cloud'] = pc_info + if calib: + calib_path = get_calib_path( + idx, path, training, relative_path=False, use_prefix_id=True) + with open(calib_path, 'r') as f: + lines = f.readlines() + P0 = np.array([float(info) for info in lines[0].split(' ')[1:13] + ]).reshape([3, 4]) + P1 = np.array([float(info) for info in lines[1].split(' ')[1:13] + ]).reshape([3, 4]) + P2 = np.array([float(info) for info in lines[2].split(' ')[1:13] + ]).reshape([3, 4]) + P3 = np.array([float(info) for info in lines[3].split(' ')[1:13] + ]).reshape([3, 4]) + P4 = np.array([float(info) for info in lines[4].split(' ')[1:13] + ]).reshape([3, 4]) + if extend_matrix: + P0 = _extend_matrix(P0) + P1 = _extend_matrix(P1) + P2 = _extend_matrix(P2) + P3 = _extend_matrix(P3) + P4 = _extend_matrix(P4) + R0_rect = np.array([ + float(info) for info in lines[5].split(' ')[1:10] + ]).reshape([3, 3]) + if extend_matrix: + rect_4x4 = np.zeros([4, 4], dtype=R0_rect.dtype) + rect_4x4[3, 3] = 1. + rect_4x4[:3, :3] = R0_rect + else: + rect_4x4 = R0_rect + + Tr_velo_to_cam = np.array([ + float(info) for info in lines[6].split(' ')[1:13] + ]).reshape([3, 4]) + if extend_matrix: + Tr_velo_to_cam = _extend_matrix(Tr_velo_to_cam) + calib_info['P0'] = P0 + calib_info['P1'] = P1 + calib_info['P2'] = P2 + calib_info['P3'] = P3 + calib_info['P4'] = P4 + calib_info['R0_rect'] = rect_4x4 + calib_info['Tr_velo_to_cam'] = Tr_velo_to_cam + info['calib'] = calib_info + if pose: + pose_path = get_pose_path( + idx, path, training, relative_path=False, use_prefix_id=True) + info['pose'] = np.loadtxt(pose_path) + + if annotations is not None: + info['annos'] = annotations + info['annos']['camera_id'] = info['annos'].pop('score') + add_difficulty_to_annos(info) + + sweeps = [] + prev_idx = idx + while len(sweeps) < max_sweeps: + prev_info = {} + prev_idx -= 1 + prev_info['velodyne_path'] = get_velodyne_path( + prev_idx, + path, + training, + relative_path, + exist_check=False, + use_prefix_id=True) + if_prev_exists = osp.exists( + Path(path) / prev_info['velodyne_path']) + if if_prev_exists: + prev_points = np.fromfile( + Path(path) / prev_info['velodyne_path'], dtype=np.float32) + prev_points = np.copy(prev_points).reshape( + -1, pc_info['num_features']) + prev_info['timestamp'] = np.int64(prev_points[0, -1]) + prev_pose_path = get_pose_path( + prev_idx, + path, + training, + relative_path=False, + use_prefix_id=True) + prev_info['pose'] = np.loadtxt(prev_pose_path) + sweeps.append(prev_info) + else: + break + info['sweeps'] = sweeps + + return info + + with futures.ThreadPoolExecutor(num_worker) as executor: + image_infos = executor.map(map_func, image_ids) + + return list(image_infos) + + def kitti_anno_to_label_file(annos, folder): folder = Path(folder) for anno in annos: diff --git a/tools/data_converter/waymo_converter.py b/tools/data_converter/waymo_converter.py new file mode 100644 index 0000000000..30e39ac287 --- /dev/null +++ b/tools/data_converter/waymo_converter.py @@ -0,0 +1,510 @@ +r"""Adapted from `Waymo to KITTI converter + `_. +""" + +import mmcv +import numpy as np +import tensorflow as tf +from glob import glob +from os.path import join +from waymo_open_dataset import dataset_pb2 +from waymo_open_dataset import dataset_pb2 as open_dataset +from waymo_open_dataset.utils import range_image_utils, transform_utils +from waymo_open_dataset.utils.frame_utils import \ + parse_range_image_and_camera_projection + + +class Waymo2KITTI(object): + """Waymo to KITTI converter. + + This class serves as the converter to change the waymo raw data to KITTI + format. + + Args: + load_dir (str): Directory to load waymo raw data. + save_dir (str): Directory to save data in KITTI format. + prefix (str): Prefix of filename. In general, 0 for training, 1 for + validation and 2 for testing. + workers (str): Number of workers for the parallel process. + test_mode (bool): Whether in the test_mode. Default: False. + """ + + def __init__(self, + load_dir, + save_dir, + prefix, + workers=64, + test_mode=False): + self.filter_empty_3dboxes = True + self.filter_no_label_zone_points = True + + self.selected_waymo_classes = ['VEHICLE', 'PEDESTRIAN', 'CYCLIST'] + + # Only data collected in specific locations will be converted + # If set None, this filter is disabled + # Available options: location_sf (main dataset) + self.selected_waymo_locations = None + self.save_track_id = False + + # turn on eager execution for older tensorflow versions + if int(tf.__version__.split('.')[0]) < 2: + tf.enable_eager_execution() + + self.lidar_list = [ + '_FRONT', '_FRONT_RIGHT', '_FRONT_LEFT', '_SIDE_RIGHT', + '_SIDE_LEFT' + ] + self.type_list = [ + 'UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST' + ] + self.waymo_to_kitti_class_map = { + 'UNKNOWN': 'DontCare', + 'PEDESTRIAN': 'Pedestrian', + 'VEHICLE': 'Car', + 'CYCLIST': 'Cyclist', + 'SIGN': 'Sign' # not in kitti + } + + self.load_dir = load_dir + self.save_dir = save_dir + self.prefix = prefix + self.workers = int(workers) + self.test_mode = test_mode + + self.tfrecord_pathnames = sorted( + glob(join(self.load_dir, '*.tfrecord'))) + + self.label_save_dir = f'{self.save_dir}/label_' + self.label_all_save_dir = f'{self.save_dir}/label_all' + self.image_save_dir = f'{self.save_dir}/image_' + self.calib_save_dir = f'{self.save_dir}/calib' + self.point_cloud_save_dir = f'{self.save_dir}/velodyne' + self.pose_save_dir = f'{self.save_dir}/pose' + + self.create_folder() + + def convert(self): + """Convert action.""" + print('Start converting ...') + mmcv.track_parallel_progress(self.convert_one, range(len(self)), + self.workers) + print('\nFinished ...') + + def convert_one(self, file_idx): + """Convert action for single file. + + Args: + file_idx (int): Index of the file to be converted. + """ + pathname = self.tfrecord_pathnames[file_idx] + dataset = tf.data.TFRecordDataset(pathname, compression_type='') + + for frame_idx, data in enumerate(dataset): + + frame = open_dataset.Frame() + frame.ParseFromString(bytearray(data.numpy())) + if (self.selected_waymo_locations is not None + and frame.context.stats.location + not in self.selected_waymo_locations): + continue + + self.save_image(frame, file_idx, frame_idx) + self.save_calib(frame, file_idx, frame_idx) + self.save_lidar(frame, file_idx, frame_idx) + self.save_pose(frame, file_idx, frame_idx) + + if not self.test_mode: + self.save_label(frame, file_idx, frame_idx) + + def __len__(self): + """Length of the filename list.""" + return len(self.tfrecord_pathnames) + + def save_image(self, frame, file_idx, frame_idx): + """Parse and save the images in png format. + + Args: + frame (:obj:`Frame`): Open dataset frame proto. + file_idx (int): Current file index. + frame_idx (int): Current frame index. + """ + for img in frame.images: + img_path = f'{self.image_save_dir}{str(img.name - 1)}/' + \ + f'{self.prefix}{str(file_idx).zfill(3)}' + \ + f'{str(frame_idx).zfill(3)}.png' + img = mmcv.imfrombytes(img.image) + mmcv.imwrite(img, img_path) + + def save_calib(self, frame, file_idx, frame_idx): + """Parse and save the calibration data. + + Args: + frame (:obj:`Frame`): Open dataset frame proto. + file_idx (int): Current file index. + frame_idx (int): Current frame index. + """ + # waymo front camera to kitti reference camera + T_front_cam_to_ref = np.array([[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], + [1.0, 0.0, 0.0]]) + camera_calibs = [] + R0_rect = [f'{i:e}' for i in np.eye(3).flatten()] + Tr_velo_to_cams = [] + calib_context = '' + + for camera in frame.context.camera_calibrations: + # extrinsic parameters + T_cam_to_vehicle = np.array(camera.extrinsic.transform).reshape( + 4, 4) + T_vehicle_to_cam = np.linalg.inv(T_cam_to_vehicle) + Tr_velo_to_cam = \ + self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam + if camera.name == 1: # FRONT = 1, see dataset.proto for details + self.T_velo_to_front_cam = Tr_velo_to_cam.copy() + Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12, )) + Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam]) + + # intrinsic parameters + camera_calib = np.zeros((3, 4)) + camera_calib[0, 0] = camera.intrinsic[0] + camera_calib[1, 1] = camera.intrinsic[1] + camera_calib[0, 2] = camera.intrinsic[2] + camera_calib[1, 2] = camera.intrinsic[3] + camera_calib[2, 2] = 1 + camera_calib = list(camera_calib.reshape(12)) + camera_calib = [f'{i:e}' for i in camera_calib] + camera_calibs.append(camera_calib) + + # all camera ids are saved as id-1 in the result because + # camera 0 is unknown in the proto + for i in range(5): + calib_context += 'P' + str(i) + ': ' + \ + ' '.join(camera_calibs[i]) + '\n' + calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n' + for i in range(5): + calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \ + ' '.join(Tr_velo_to_cams[i]) + '\n' + + with open( + f'{self.calib_save_dir}/{self.prefix}' + + f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', + 'w+') as fp_calib: + fp_calib.write(calib_context) + fp_calib.close() + + def save_lidar(self, frame, file_idx, frame_idx): + """Parse and save the lidar data in psd format. + + Args: + frame (:obj:`Frame`): Open dataset frame proto. + file_idx (int): Current file index. + frame_idx (int): Current frame index. + """ + range_images, camera_projections, range_image_top_pose = \ + parse_range_image_and_camera_projection(frame) + + # First return + points_0, cp_points_0, intensity_0, elongation_0 = \ + self.convert_range_image_to_point_cloud( + frame, + range_images, + camera_projections, + range_image_top_pose, + ri_index=0 + ) + points_0 = np.concatenate(points_0, axis=0) + intensity_0 = np.concatenate(intensity_0, axis=0) + elongation_0 = np.concatenate(elongation_0, axis=0) + + # Second return + points_1, cp_points_1, intensity_1, elongation_1 = \ + self.convert_range_image_to_point_cloud( + frame, + range_images, + camera_projections, + range_image_top_pose, + ri_index=1 + ) + points_1 = np.concatenate(points_1, axis=0) + intensity_1 = np.concatenate(intensity_1, axis=0) + elongation_1 = np.concatenate(elongation_1, axis=0) + + points = np.concatenate([points_0, points_1], axis=0) + intensity = np.concatenate([intensity_0, intensity_1], axis=0) + elongation = np.concatenate([elongation_0, elongation_1], axis=0) + timestamp = frame.timestamp_micros * np.ones_like(intensity) + + # concatenate x,y,z, intensity, elongation, timestamp (6-dim) + point_cloud = np.column_stack( + (points, intensity, elongation, timestamp)) + + pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \ + f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin' + point_cloud.astype(np.float32).tofile(pc_path) + + def save_label(self, frame, file_idx, frame_idx): + """Parse and save the label data in txt format. + The relation between waymo and kitti coordinates is noteworthy: + 1. x, y, z correspond to l, w, h (waymo) -> l, h, w (kitti) + 2. x-y-z: front-left-up (waymo) -> right-down-front(kitti) + 3. bbox origin at volumetric center (waymo) -> bottom center (kitti) + 4. rotation: +x around y-axis (kitti) -> +x around z-axis (waymo) + + Args: + frame (:obj:`Frame`): Open dataset frame proto. + file_idx (int): Current file index. + frame_idx (int): Current frame index. + """ + fp_label_all = open( + f'{self.label_all_save_dir}/{self.prefix}' + + f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') + id_to_bbox = dict() + id_to_name = dict() + for labels in frame.projected_lidar_labels: + name = labels.name + for label in labels.labels: + # TODO: need a workaround as bbox may not belong to front cam + bbox = [ + label.box.center_x - label.box.length / 2, + label.box.center_y - label.box.width / 2, + label.box.center_x + label.box.length / 2, + label.box.center_y + label.box.width / 2 + ] + id_to_bbox[label.id] = bbox + id_to_name[label.id] = name - 1 + + for obj in frame.laser_labels: + bounding_box = None + name = None + id = obj.id + for lidar in self.lidar_list: + if id + lidar in id_to_bbox: + bounding_box = id_to_bbox.get(id + lidar) + name = str(id_to_name.get(id + lidar)) + break + + if bounding_box is None or name is None: + name = '0' + bounding_box = (0, 0, 0, 0) + + my_type = self.type_list[obj.type] + + if my_type not in self.selected_waymo_classes: + continue + + if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1: + continue + + my_type = self.waymo_to_kitti_class_map[my_type] + + height = obj.box.height + width = obj.box.width + length = obj.box.length + + x = obj.box.center_x + y = obj.box.center_y + z = obj.box.center_z - height / 2 + + # project bounding box to the virtual reference frame + pt_ref = self.T_velo_to_front_cam @ \ + np.array([x, y, z, 1]).reshape((4, 1)) + x, y, z, _ = pt_ref.flatten().tolist() + + rotation_y = -obj.box.heading - np.pi / 2 + track_id = obj.id + + # not available + truncated = 0 + occluded = 0 + alpha = -10 + + line = my_type + \ + ' {} {} {} {} {} {} {} {} {} {} {} {} {} {}\n'.format( + round(truncated, 2), occluded, round(alpha, 2), + round(bounding_box[0], 2), round(bounding_box[1], 2), + round(bounding_box[2], 2), round(bounding_box[3], 2), + round(height, 2), round(width, 2), round(length, 2), + round(x, 2), round(y, 2), round(z, 2), + round(rotation_y, 2)) + + if self.save_track_id: + line_all = line[:-1] + ' ' + name + ' ' + track_id + '\n' + else: + line_all = line[:-1] + ' ' + name + '\n' + + fp_label = open( + f'{self.label_save_dir}{name}/{self.prefix}' + + f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'a') + fp_label.write(line) + fp_label.close() + + fp_label_all.write(line_all) + + fp_label_all.close() + + def save_pose(self, frame, file_idx, frame_idx): + """Parse and save the pose data. + + Note that SDC's own pose is not included in the regular training + of KITTI dataset. KITTI raw dataset contains ego motion files + but are not often used. Pose is important for algorithms that + take advantage of the temporal information. + + Args: + frame (:obj:`Frame`): Open dataset frame proto. + file_idx (int): Current file index. + frame_idx (int): Current frame index. + """ + pose = np.array(frame.pose.transform).reshape(4, 4) + np.savetxt( + join(f'{self.pose_save_dir}/{self.prefix}' + + f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'), + pose) + + def create_folder(self): + """Create folder for data preprocessing.""" + if not self.test_mode: + dir_list1 = [ + self.label_all_save_dir, self.calib_save_dir, + self.point_cloud_save_dir, self.pose_save_dir + ] + dir_list2 = [self.label_save_dir, self.image_save_dir] + else: + dir_list1 = [ + self.calib_save_dir, self.point_cloud_save_dir, + self.pose_save_dir + ] + dir_list2 = [self.image_save_dir] + for d in dir_list1: + mmcv.mkdir_or_exist(d) + for d in dir_list2: + for i in range(5): + mmcv.mkdir_or_exist(f'{d}{str(i)}') + + def convert_range_image_to_point_cloud(self, + frame, + range_images, + camera_projections, + range_image_top_pose, + ri_index=0): + """Convert range images to point cloud. + + Args: + frame (:obj:`Frame`): Open dataset frame. + range_images (dict): Mapping from laser_name to list of two + range images corresponding with two returns. + camera_projections (dict): Mapping from laser_name to list of two + camera projections corresponding with two returns. + range_image_top_pose (:obj:`Transform`): Range image pixel pose for + top lidar. + ri_index (int): 0 for the first return, 1 for the second return. + Default: 0. + + Returns: + tuple[list[np.ndarray]]: (List of points with shape [N, 3], + camera projections of points with shape [N, 6], intensity + with shape [N, 1], elongation with shape [N, 1]). All the + lists have the length of lidar numbers (5). + """ + calibrations = sorted( + frame.context.laser_calibrations, key=lambda c: c.name) + points = [] + cp_points = [] + intensity = [] + elongation = [] + + frame_pose = tf.convert_to_tensor( + value=np.reshape(np.array(frame.pose.transform), [4, 4])) + # [H, W, 6] + range_image_top_pose_tensor = tf.reshape( + tf.convert_to_tensor(value=range_image_top_pose.data), + range_image_top_pose.shape.dims) + # [H, W, 3, 3] + range_image_top_pose_tensor_rotation = \ + transform_utils.get_rotation_matrix( + range_image_top_pose_tensor[..., 0], + range_image_top_pose_tensor[..., 1], + range_image_top_pose_tensor[..., 2]) + range_image_top_pose_tensor_translation = \ + range_image_top_pose_tensor[..., 3:] + range_image_top_pose_tensor = transform_utils.get_transform( + range_image_top_pose_tensor_rotation, + range_image_top_pose_tensor_translation) + for c in calibrations: + range_image = range_images[c.name][ri_index] + if len(c.beam_inclinations) == 0: + beam_inclinations = range_image_utils.compute_inclination( + tf.constant( + [c.beam_inclination_min, c.beam_inclination_max]), + height=range_image.shape.dims[0]) + else: + beam_inclinations = tf.constant(c.beam_inclinations) + + beam_inclinations = tf.reverse(beam_inclinations, axis=[-1]) + extrinsic = np.reshape(np.array(c.extrinsic.transform), [4, 4]) + + range_image_tensor = tf.reshape( + tf.convert_to_tensor(value=range_image.data), + range_image.shape.dims) + pixel_pose_local = None + frame_pose_local = None + if c.name == dataset_pb2.LaserName.TOP: + pixel_pose_local = range_image_top_pose_tensor + pixel_pose_local = tf.expand_dims(pixel_pose_local, axis=0) + frame_pose_local = tf.expand_dims(frame_pose, axis=0) + range_image_mask = range_image_tensor[..., 0] > 0 + + if self.filter_no_label_zone_points: + nlz_mask = range_image_tensor[..., 3] != 1.0 # 1.0: in NLZ + range_image_mask = range_image_mask & nlz_mask + + range_image_cartesian = \ + range_image_utils.extract_point_cloud_from_range_image( + tf.expand_dims(range_image_tensor[..., 0], axis=0), + tf.expand_dims(extrinsic, axis=0), + tf.expand_dims(tf.convert_to_tensor( + value=beam_inclinations), axis=0), + pixel_pose=pixel_pose_local, + frame_pose=frame_pose_local) + + range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0) + points_tensor = tf.gather_nd(range_image_cartesian, + tf.compat.v1.where(range_image_mask)) + + cp = camera_projections[c.name][ri_index] + cp_tensor = tf.reshape( + tf.convert_to_tensor(value=cp.data), cp.shape.dims) + cp_points_tensor = tf.gather_nd( + cp_tensor, tf.compat.v1.where(range_image_mask)) + points.append(points_tensor.numpy()) + cp_points.append(cp_points_tensor.numpy()) + + intensity_tensor = tf.gather_nd(range_image_tensor[..., 1], + tf.where(range_image_mask)) + intensity.append(intensity_tensor.numpy()) + + elongation_tensor = tf.gather_nd(range_image_tensor[..., 2], + tf.where(range_image_mask)) + elongation.append(elongation_tensor.numpy()) + + return points, cp_points, intensity, elongation + + def cart_to_homo(self, mat): + """Convert transformation matrix in Cartesian coordinates to + homogeneous format. + + Args: + mat (np.ndarray): Transformation matrix in Cartesian. + The input matrix shape is 3x3 or 3x4. + + Returns: + np.ndarray: Transformation matrix in homogeneous format. + The matrix shape is 4x4. + """ + ret = np.eye(4) + if mat.shape == (3, 3): + ret[:3, :3] = mat + elif mat.shape == (3, 4): + ret[:3, :] = mat + else: + raise ValueError(mat.shape) + return ret diff --git a/tools/fuse_conv_bn.py b/tools/fuse_conv_bn.py index c00219b7b9..d393672a2d 100644 --- a/tools/fuse_conv_bn.py +++ b/tools/fuse_conv_bn.py @@ -1,5 +1,4 @@ import argparse - import torch import torch.nn as nn from mmcv.runner import save_checkpoint