Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added mmdetection3d and openpcdet integrations. #334

Merged
merged 6 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ TorchSparse is a high-performance neural network library for point cloud process
Point cloud computation has become an increasingly more important workload for autonomous driving and other applications. Unlike dense 2D computation, point cloud convolution has **sparse** and **irregular** computation patterns and thus requires dedicated inference system support with specialized high-performance kernels. While existing point cloud deep learning libraries have developed different dataflows for convolution on point clouds, they assume a single dataflow throughout the execution of the entire model. In this work, we systematically analyze and improve existing dataflows. Our resulting system, TorchSparse, achieves **2.9x**, **3.3x**, **2.2x** and **1.7x** measured end-to-end speedup on an NVIDIA A100 GPU over the state-of-the-art MinkowskiEngine, SpConv 1.2, TorchSparse (MLSys) and SpConv v2 in inference respectively.

## News
**\[2024/11\]** TorchSparse++ is now supporting [MMDetection3D](https://github.com/open-mmlab/mmdetection3d) and [OpenPCDet](https://github.com/open-mmlab/OpenPCDet) via plugins! [A full demo](./examples/) is available.

**\[2023/11\]** TorchSparse++ has been adopted by [One-2-3-45++](https://arxiv.org/abs/2311.07885) from Prof. Hao Su's lab (UCSD) for 3D object generation!

Expand Down
36 changes: 36 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Containers
A docker image is created with all the required environment installed: `ioeddk/torchsparse_plugin_demo:latest`, including MMDetection3D, OpenPCDet, TorchSparse, plugins, and PyTorch based on the NVIDIA CUDA 12.1 image.
The dataset is not included in the image and need to be bind mounted to the container when starting. Specifically with the following command:
```bash
docker run -it --gpus all --mount type=bind,source=<kitti_dataset_root>,target=/root/data/kitti --mount type=bind,source=<nuscenes_dataset_root>,target=/root/data/nuscenes ioeddk/torchsparse_plugin_demo:latest
```
The above is an example to mount the kitti dataset when starting the container.

Using this container is the simplest way to start the demo of this plugin since the all the dependencies are installed and the paths are configured. You can simply open `/root/repo/torchsparse-dev/examples/mmdetection3d/demo.ipynb` or `/root/repo/torchsparse-dev/examples/openpcdet/demo.ipynb` and run all cells to run the demo. The helper functions in the demo are defined to automatically load the pretrained checkpoints, do the conversions, and run the evaluation.

If not using the container, then please follow the tutorial below to run the demo. The same copy of demo is also in the demo notebook.

# Convert the Module Weights
The dimensions of TorchSparse differs from the SpConv, so the parameter dimension conversion is required to use the TorchSparse backend. The conversion script can be found in `examples/converter.py`. The `convert_weighs` function has the header `def convert_weights(ckpt_before: str, ckpt_after: str, cfg_path: str, v_spconv: int = 1, framework: str = "mmdet3d")`:
- `ckpt_before`: the pretrained checkpoint of your module, typically downloaded from the MMDetection3d and OpenPCDet model Zoo.
- `ckpt_after`: the output path for the converted checkpoint.
- `cfg_path`: the path to the config file of the MMdet3d or OPC model to be converted. It is requried since the converter create an instance of the model, find all the Sparse Convolution layers, and convert the weights of thay layer.
- `v_spconv`: the version of the SpConv that the original model is build upon. Valud versions are 1 or 2.
- `framework`: choose between `mmdet3d` and `openpc`.

## Example Conversion Commands
### MMDetection3D
```bash
python examples/converter.py --ckpt_before ../mmdetection3d/models/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --cfg_path ../mmdetection3d/pv_rcnn/pv_rcnn_8xb2-80e_kitti-3d-3class.py --ckpt_after ./converted/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --v_spconv 1 --framework mmdet3d
```

### OpenPCDet
```bash
python examples/converter.py --ckpt_before ../OpenPCDet/models/SECOND/second_7862.pth --cfg_path ../OpenPCDet/tools/cfgs/kitti_models/second.yaml --ckpt_after ./converted/SECOND/second_7862.pth --v_spconv 1 --framework openpc
```

# Run evaluation.
Use the `test.py` that comes with the MMDet3D or OPC to run the evaluation. Provide the converted checkpoint as the model weights. For MMDet3D models, you need to provide extra arguments to replace certain layers to be torchsparse's (see how to replace them in `examples/mmdetection3d/demo.ipynb`). For OpenPCDet, the config file with those layers replaced is in the `examples/openpcdet/cfgs`; to use them, see `examples/openpcdet/demo.ipynb`. An additional step is to add `import ts_plugin` in `mmdetection3d/tools/test.py` and add `import pcdet_plugin` in `OpenPCDet/tools/test.py` to activate the plugins before running the evaluation.

# Details
Please see `examples/mmdetection3d/demo.ipynb` and `examples/openpcdet/demo.ipynb` for more details.
240 changes: 240 additions & 0 deletions examples/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""This is the model converter to convert a SpConv model to TorchSparse model.
"""
import argparse
import torch
import re
import logging
import spconv.pytorch as spconv
import logging

# Disable JIT because running OpenPCDet with JIT enabled will cause some import issue.
torch.jit._state.disable()

# Works for SECOND
def convert_weights_v2(key, model):
"""Convert model weights for models build with SpConv v2.

:param key: _description_
:type key: _type_
:param model: _description_
:type model: _type_
:return: _description_
:rtype: _type_
"""
new_key = key.replace(".weight", ".kernel")
weights = model[key]
oc, kx, ky, kz, ic = weights.shape

converted_weights = weights.reshape(oc, -1, ic)

converted_weights = converted_weights.permute(1, 0, 2)

if converted_weights.shape[0] == 1:
converted_weights = converted_weights[0]
elif converted_weights.shape[0] == 27:
offsets = [list(range(kz)), list(range(ky)), list(range(kx))]
kykx = ky * kx
offsets = [
(x * kykx + y * kx + z)
for z in offsets[0]
for y in offsets[1]
for x in offsets[2]
]
offsets = torch.tensor(
offsets, dtype=torch.int64, device=converted_weights.device
)
converted_weights = converted_weights[offsets]

converted_weights = converted_weights.permute(0,2,1)

return new_key, converted_weights

# Order for CenterPoint, PV-RCNN, and default, legacy SpConv
def convert_weights_v1(key, model):
"""Convert model weights for models implemented with SpConv v1

:param key: _description_
:type key: _type_
:param model: _description_
:type model: _type_
:return: _description_
:rtype: _type_
"""
new_key = key.replace(".weight", ".kernel")
weights = model[key]

kx, ky, kz, ic, oc = weights.shape

converted_weights = weights.reshape(-1, ic, oc)
if converted_weights.shape[0] == 1:
converted_weights = converted_weights[0]

elif converted_weights.shape[0] == 27:
offsets = [list(range(kz)), list(range(ky)), list(range(kx))]
kykx = ky * kx
offsets = [
(x * kykx + y * kx + z)
for z in offsets[0]
for y in offsets[1]
for x in offsets[2]
]
offsets = torch.tensor(
offsets, dtype=torch.int64, device=converted_weights.device
)
converted_weights = converted_weights[offsets]
elif converted_weights.shape[0] == 3: # 3 is the case in PartA2.
pass
# offsets = torch.tensor(
# [2, 1, 0], dtype=torch.int64, device=converted_weights.device
# )
# converted_weights = converted_weights[offsets]
return new_key, converted_weights

def build_mmdet_model_from_cfg(cfg_path, ckpt_path):
try:
from mmdet3d.apis import init_model
from mmengine.config import Config
except:
print("MMDetection3D is not installed. Please install MMDetection3D to use this function.")
cfg = Config.fromfile(cfg_path)
model = init_model(cfg, ckpt_path)
return model

def build_opc_model_from_cfg(cfg_path):
try:
from pcdet.config import cfg, cfg_from_yaml_file
from pcdet.datasets import build_dataloader
from pcdet.models import build_network
except Exception as e:
print(e)
raise ImportError("Failed to import OpenPCDet")
cfg_from_yaml_file(cfg_path, cfg)
test_set, test_loader, sampler = build_dataloader(
dataset_cfg=cfg.DATA_CONFIG,
class_names=cfg.CLASS_NAMES,
batch_size=1,
dist=False,
training=False,
logger=logging.Logger("Build Dataloader"),
)

model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set)
return model

# Allow use the API to convert based on a passed in model.
def convert_model_weights(ckpt_before, ckpt_after, model, legacy=False):

model_modules = {}
for key, value in model.named_modules():
model_modules[key] = value

cp_old = torch.load(ckpt_before, map_location="cpu")
try:
opc = False
old_state_dict = cp_old["state_dict"]
except:
opc = True
old_state_dict = cp_old["model_state"]

new_model = dict()

for state_dict_key in old_state_dict.keys():
is_sparseconv_weight = False
if state_dict_key.endswith(".weight"):
if state_dict_key[:-len(".weight")] in model_modules.keys():
if isinstance(model_modules[state_dict_key[:-len(".weight")]], (spconv.SparseConv3d, spconv.SubMConv3d, spconv.SparseInverseConv3d)):
is_sparseconv_weight = True

if is_sparseconv_weight:
# print(f"{state_dict_key} is a sparseconv weight")
pass

if is_sparseconv_weight:
if len(old_state_dict[state_dict_key].shape) == 5:
if legacy:
new_key, converted_weights = convert_weights_v1(state_dict_key, old_state_dict)
else:
new_key, converted_weights = convert_weights_v2(state_dict_key, old_state_dict)
else:
new_key = state_dict_key
converted_weights = old_state_dict[state_dict_key]

new_model[new_key] = converted_weights

if opc:
cp_old["model_state"] = new_model
else:
cp_old["state_dict"] = new_model
torch.save(cp_old, ckpt_after)


def convert_weights_cmd():
"""Convert the weights of a model from SpConv to TorchSparse.

:param ckpt_before: Path to the SpConv checkpoint
:type ckpt_before: str
:param ckpt_after: Path to the output folder of the converted checkpoint.
:type ckpt_after: str
:param v_spconv: SpConv version used for the weights. Can be one of 1 or 2, defaults to "1"
:type v_spconv: str, optional
:param framework: From which framework does the model weight comes from, choose one of mmdet3d or openpc, defaults to "mmdet3d"
:type framework: str, optional
"""
# ckpt_before, ckpt_after, v_spconv="1", framework="mmdet3d"

# argument parser
parser = argparse.ArgumentParser(description="Convert SpConv model to TorchSparse model")
parser.add_argument("--ckpt_before", help="Path to the SpConv checkpoint")
parser.add_argument("--ckpt_after", help="Path to the output folder of the converted checkpoint.")
parser.add_argument("--cfg_path", help="Path to the config file of the model")
parser.add_argument("--v_spconv", default="1", help="SpConv version used for the weights. Can be one of 1 or 2")
parser.add_argument("--framework", default="mmdet3d", help="From which framework does the model weight comes from, choose one of mmdet3d or openpc")
args = parser.parse_args()

# Check the plugin argument
assert args.framework in ['mmdet3d', 'openpc'], "plugin argument can only be mmdet3d or openpcdet"
assert args.v_spconv in ['1', '2'], "v_spconv argument can only be 1 or 2"

legacy = True if args.v_spconv == "1" else False
cfg_path = args.cfg_path

model = build_mmdet_model_from_cfg(cfg_path, args.ckpt_before) if args.framework == "mmdet3d" else build_opc_model_from_cfg(cfg_path)
convert_model_weights(
ckpt_before=args.ckpt_before,
ckpt_after=args.ckpt_after,
model=model,
legacy=legacy)


def convert_weights(ckpt_before: str, ckpt_after: str, cfg_path: str, v_spconv: int = 1, framework: str = "mmdet3d"):
"""Convert the weights of a model from SpConv to TorchSparse.

:param ckpt_before: _description_
:type ckpt_before: str
:param ckpt_after: _description_
:type ckpt_after: str
:param cfg_path: _description_
:type cfg_path: str
:param v_spconv: _description_, defaults to 1
:type v_spconv: int, optional
:param framework: _description_, defaults to "mmdet3d"
:type framework: str, optional
"""

# Check the plugin argument
assert framework in ['mmdet3d', 'openpc'], "plugin argument can only be mmdet3d or openpcdet"
assert v_spconv in [1, 2], "v_spconv argument can only be 1 or 2"

legacy = True if v_spconv == 1 else False

model = build_mmdet_model_from_cfg(cfg_path, ckpt_before) if framework == "mmdet3d" else build_opc_model_from_cfg(cfg_path)
convert_model_weights(
ckpt_before=ckpt_before,
ckpt_after=ckpt_after,
model=model,
legacy=legacy)


if __name__ == "__main__":
convert_weights_cmd()
print("Conversion completed")
79 changes: 79 additions & 0 deletions examples/mmdetection3d/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# TorchSparse for MMDetection3D Plugin Demo

This tutorial demonstrates how to evaluate TorchSparse integrated MMDetection3D models. Follow the steps below to install dependencies, configure paths, convert model weights, and run the demo.

## Dependencies

1. **MMDetection3D Installation**: Follow the [MMDetection3D documentation](https://mmdetection3d.readthedocs.io/en/latest/get_started.html).
2. **Dataset Preparation**: Pre-process the datasets as described [here](https://mmdetection3d.readthedocs.io/en/latest/user_guides/dataset_prepare.html).
3. **TorchSparse Installation**: Install [TorchSparse](https://github.com/mit-han-lab/torchsparse).
4. **Install TorchSparse Plugin for MMDetection3D**:
1. Clone this repository.
2. Navigate to `examples/mmdetection3d` and run `pip install -v -e .`.

## Notes

- For model evaluation, change the data root in the original MMDetection3D's model config to the full path of the corresponding dataset root.

## Steps

1. Install the dependencies.
2. Specify the base paths and model registry.
3. **IMPORTANT,** Activate the plugin: In `mmdetection3d/tools/test.py`, add `import ts_plugin` as the last import statement to activate the plugin.
4. Run the evaluation.

## Supported Models

- SECOND
- PV-RCNN
- CenterPoint
- Part-A2

## Convert Module Weights
The dimensions of TorchSparse differ from SpConv, so parameter dimension conversion is required. You can use `convert_weights_cmd()` in converter.py as a command line tool or use `convert_weights()` as an API. Both functions have four parameters:

1. `ckpt_before`: Path to the input SpConv checkpoint file.
2. `ckpt_after`: Path where the converted TorchSparse checkpoint will be saved.
3. `cfg_path`: Path to the configuration mmdet3d file of the model.
4. `v_spconv`: Version of SpConv used in the original model (1 or 2).
5. `framework`: Choose between `'openpc'` and `'mmdet3d'`, default to `'mmdet3d'`.

These parameters allow the converter to locate the input model, specify the output location, understand the model's architecture, and apply the appropriate conversion method based for specific Sparse Conv layers.

Example conversion commands:
```bash
python examples/converter.py --ckpt_before ../mmdetection3d/models/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --cfg_path ../mmdetection3d/pv_rcnn/pv_rcnn_8xb2-80e_kitti-3d-3class.py --ckpt_after ./converted/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --v_spconv 1 --framework mmdet3d
```


# Run a demo
In your Conda environment, run:
```bash
python <test_file_path> <cfg_path> <torchsparse_model_path> <cfg_options> --task lidar_det
```

- `test_file_path`: The `tools/test.py` file in mmdet3d repository.
- `cfg_path`: The path to the mmdet3d's model config for your model.
- `torchsparse_model_path`: the path to the converted TorchSparse model checkpoint.
- `cfg_options`: The plugin requires the use of MMDet3D cfg_options to tweak certain model layers to be the plugin layers. `cfg_options` examples are below:

## SECOND
`cfg_options`:
```bash
"--cfg-options test_evaluator.pklfile_prefix=outputs/torchsparse/second --cfg-options model.middle_encoder.type=SparseEncoderTS"
```

## PV-RCNN
`cfg_options`:
```bash
"--cfg-options test_evaluator.pklfile_prefix=outputs/torchsparse/pv_rcnn --cfg-options model.middle_encoder.type=SparseEncoderTS --cfg-options model.points_encoder.type=VoxelSetAbstractionTS"
```

### CenterPoint Voxel 0.1 Circular NMS

Update the path of the NuScenes dataset in the MMDetection3D dataset config `configs/_base_/datasets/nus-3d.py`.

`cfg_options`:
```bash
"--cfg-options model.pts_middle_encoder.type=SparseEncoderTS"
```
1 change: 1 addition & 0 deletions examples/mmdetection3d/configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This folder contains the configs to carry out the demo in mmdetectino3d.
1 change: 1 addition & 0 deletions examples/mmdetection3d/converted_models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Default model conversion base folder for the demo. Please create the relative path to each specific model under this directory.
Loading
Loading