Skip to content

Commit cc44cd1

Browse files
committed
added mmdetection3d and openpcdet integrations.
1 parent 9eb3b88 commit cc44cd1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+7677
-0
lines changed

examples/README.md

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Containers
2+
A docker image is created with all the required environment installed: `ioeddk/torchsparse_plugin_demo:latest`, including MMDetection3D, OpenPCDet, TorchSparse, plugins, and PyTorch based on the NVIDIA CUDA 12.1 image.
3+
The dataset is not included in the image and need to be bind mounted to the container when starting. Specifically with the following command:
4+
```bash
5+
docker run -it --gpus all --mount type=bind,source=<kitti_dataset_root>,target=/root/data/kitti --mount type=bind,source=<nuscenes_dataset_root>,target=/root/data/nuscenes ioeddk/torchsparse_plugin_demo:latest
6+
```
7+
The above is an example to mount the kitti dataset when starting the container.
8+
9+
Using this container is the simplest way to start the demo of this plugin since the all the dependencies are installed and the paths are configured. You can simply open `/root/repo/torchsparse-dev/examples/mmdetection3d/demo.ipynb` or `/root/repo/torchsparse-dev/examples/openpcdet/demo.ipynb` and run all cells to run the demo. The helper functions in the demo are defined to automatically load the pretrained checkpoints, do the conversions, and run the evaluation.
10+
11+
If not using the container, then please follow the tutorial below to run the demo. The same copy of demo is also in the demo notebook.
12+
13+
# Convert the Module Weights
14+
The dimensions of TorchSparse differs from the SpConv, so the parameter dimension conversion is required to use the TorchSparse backend. The conversion script can be found in `examples/converter.py`. The `convert_weighs` function has the header `def convert_weights(ckpt_before: str, ckpt_after: str, cfg_path: str, v_spconv: int = 1, framework: str = "mmdet3d")`:
15+
- `ckpt_before`: the pretrained checkpoint of your module, typically downloaded from the MMDetection3d and OpenPCDet model Zoo.
16+
- `ckpt_after`: the output path for the converted checkpoint.
17+
- `cfg_path`: the path to the config file of the MMdet3d or OPC model to be converted. It is requried since the converter create an instance of the model, find all the Sparse Convolution layers, and convert the weights of thay layer.
18+
- `v_spconv`: the version of the SpConv that the original model is build upon. Valud versions are 1 or 2.
19+
- `framework`: choose between `mmdet3d` and `openpc`.
20+
21+
## Example Conversion Commands
22+
parser.add_argument("--ckpt_before", help="Path to the SpConv checkpoint")
23+
parser.add_argument("--ckpt_after", help="Path to the output folder of the converted checkpoint.")
24+
parser.add_argument("--cfg_path", help="Path to the config file of the model")
25+
parser.add_argument("--v_spconv", default="1", help="SpConv version used for the weights. Can be one of 1 or 2")
26+
parser.add_argument("--framework", default="mmdet3d", help="From which framework does the model weight comes from, choose one of mmdet3d or openpc")
27+
### MMDetection3D
28+
```bash
29+
python examples/converter.py --ckpt_before ../mmdetection3d/models/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --cfg_path ../mmdetection3d/pv_rcnn/pv_rcnn_8xb2-80e_kitti-3d-3class.py --ckpt_after ./converted/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --v_spconv 1 --framework mmdet3d
30+
```
31+
32+
### OpenPCDet
33+
```bash
34+
python examples/converter.py --ckpt_before ../OpenPCDet/models/SECOND/second_7862.pth --cfg_path ../OpenPCDet/tools/cfgs/kitti_models/second.yaml --ckpt_after ./converted/SECOND/second_7862.pth --v_spconv 1 --framework openpc
35+
```
36+
37+
# Run evaluation.
38+
Use the `test.py` that comes with the MMDet3D or OPC to run the evaluation. Provide the converted checkpoint as the model weights. For MMDet3D models, you need to provide extra arguments to replace certain layers to be torchsparse's (see how to replace them in `examples/mmdetection3d/demo.ipynb`). For OpenPCDet, the config file with those layers replaced is in the `examples/openpcdet/cfgs`; to use them, see `examples/openpcdet/demo.ipynb`. An additional step is to add `import ts_plugin` in `mmdetection3d/tools/test.py` and add `import pcdet_plugin` in `OpenPCDet/tools/test.py` to activate the plugins before running the evaluation.
39+
40+
# Details
41+
Please see `examples/mmdetection3d/demo.ipynb` and `examples/openpcdet/demo.ipynb` for more details.

examples/converter.py

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
"""This is the model converter to convert a SpConv model to TorchSparse model.
2+
"""
3+
import argparse
4+
import torch
5+
import re
6+
import logging
7+
import spconv.pytorch as spconv
8+
import logging
9+
10+
# Disable JIT because running OpenPCDet with JIT enabled will cause some import issue.
11+
torch.jit._state.disable()
12+
13+
# Works for SECOND
14+
def convert_weights_v2(key, model):
15+
"""Convert model weights for models build with SpConv v2.
16+
17+
:param key: _description_
18+
:type key: _type_
19+
:param model: _description_
20+
:type model: _type_
21+
:return: _description_
22+
:rtype: _type_
23+
"""
24+
new_key = key.replace(".weight", ".kernel")
25+
weights = model[key]
26+
oc, kx, ky, kz, ic = weights.shape
27+
28+
converted_weights = weights.reshape(oc, -1, ic)
29+
30+
converted_weights = converted_weights.permute(1, 0, 2)
31+
32+
if converted_weights.shape[0] == 1:
33+
converted_weights = converted_weights[0]
34+
elif converted_weights.shape[0] == 27:
35+
offsets = [list(range(kz)), list(range(ky)), list(range(kx))]
36+
kykx = ky * kx
37+
offsets = [
38+
(x * kykx + y * kx + z)
39+
for z in offsets[0]
40+
for y in offsets[1]
41+
for x in offsets[2]
42+
]
43+
offsets = torch.tensor(
44+
offsets, dtype=torch.int64, device=converted_weights.device
45+
)
46+
converted_weights = converted_weights[offsets]
47+
48+
converted_weights = converted_weights.permute(0,2,1)
49+
50+
return new_key, converted_weights
51+
52+
# Order for CenterPoint, PV-RCNN, and default, legacy SpConv
53+
def convert_weights_v1(key, model):
54+
"""Convert model weights for models implemented with SpConv v1
55+
56+
:param key: _description_
57+
:type key: _type_
58+
:param model: _description_
59+
:type model: _type_
60+
:return: _description_
61+
:rtype: _type_
62+
"""
63+
new_key = key.replace(".weight", ".kernel")
64+
weights = model[key]
65+
66+
kx, ky, kz, ic, oc = weights.shape
67+
68+
converted_weights = weights.reshape(-1, ic, oc)
69+
if converted_weights.shape[0] == 1:
70+
converted_weights = converted_weights[0]
71+
72+
elif converted_weights.shape[0] == 27:
73+
offsets = [list(range(kz)), list(range(ky)), list(range(kx))]
74+
kykx = ky * kx
75+
offsets = [
76+
(x * kykx + y * kx + z)
77+
for z in offsets[0]
78+
for y in offsets[1]
79+
for x in offsets[2]
80+
]
81+
offsets = torch.tensor(
82+
offsets, dtype=torch.int64, device=converted_weights.device
83+
)
84+
converted_weights = converted_weights[offsets]
85+
elif converted_weights.shape[0] == 3: # 3 is the case in PartA2.
86+
pass
87+
# offsets = torch.tensor(
88+
# [2, 1, 0], dtype=torch.int64, device=converted_weights.device
89+
# )
90+
# converted_weights = converted_weights[offsets]
91+
return new_key, converted_weights
92+
93+
def build_mmdet_model_from_cfg(cfg_path, ckpt_path):
94+
try:
95+
from mmdet3d.apis import init_model
96+
from mmengine.config import Config
97+
except:
98+
print("MMDetection3D is not installed. Please install MMDetection3D to use this function.")
99+
cfg = Config.fromfile(cfg_path)
100+
model = init_model(cfg, ckpt_path)
101+
return model
102+
103+
def build_opc_model_from_cfg(cfg_path):
104+
try:
105+
from pcdet.config import cfg, cfg_from_yaml_file
106+
from pcdet.datasets import build_dataloader
107+
from pcdet.models import build_network
108+
except Exception as e:
109+
print(e)
110+
raise ImportError("Failed to import OpenPCDet")
111+
cfg_from_yaml_file(cfg_path, cfg)
112+
test_set, test_loader, sampler = build_dataloader(
113+
dataset_cfg=cfg.DATA_CONFIG,
114+
class_names=cfg.CLASS_NAMES,
115+
batch_size=1,
116+
dist=False,
117+
training=False,
118+
logger=logging.Logger("Build Dataloader"),
119+
)
120+
121+
model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set)
122+
return model
123+
124+
# Allow use the API to convert based on a passed in model.
125+
def convert_model_weights(ckpt_before, ckpt_after, model, legacy=False):
126+
127+
model_modules = {}
128+
for key, value in model.named_modules():
129+
model_modules[key] = value
130+
131+
cp_old = torch.load(ckpt_before, map_location="cpu")
132+
try:
133+
opc = False
134+
old_state_dict = cp_old["state_dict"]
135+
except:
136+
opc = True
137+
old_state_dict = cp_old["model_state"]
138+
139+
new_model = dict()
140+
141+
for state_dict_key in old_state_dict.keys():
142+
is_sparseconv_weight = False
143+
if state_dict_key.endswith(".weight"):
144+
if state_dict_key[:-len(".weight")] in model_modules.keys():
145+
if isinstance(model_modules[state_dict_key[:-len(".weight")]], (spconv.SparseConv3d, spconv.SubMConv3d, spconv.SparseInverseConv3d)):
146+
is_sparseconv_weight = True
147+
148+
if is_sparseconv_weight:
149+
# print(f"{state_dict_key} is a sparseconv weight")
150+
pass
151+
152+
if is_sparseconv_weight:
153+
if len(old_state_dict[state_dict_key].shape) == 5:
154+
if legacy:
155+
new_key, converted_weights = convert_weights_v1(state_dict_key, old_state_dict)
156+
else:
157+
new_key, converted_weights = convert_weights_v2(state_dict_key, old_state_dict)
158+
else:
159+
new_key = state_dict_key
160+
converted_weights = old_state_dict[state_dict_key]
161+
162+
new_model[new_key] = converted_weights
163+
164+
if opc:
165+
cp_old["model_state"] = new_model
166+
else:
167+
cp_old["state_dict"] = new_model
168+
torch.save(cp_old, ckpt_after)
169+
170+
171+
def convert_weights_cmd():
172+
"""Convert the weights of a model from SpConv to TorchSparse.
173+
174+
:param ckpt_before: Path to the SpConv checkpoint
175+
:type ckpt_before: str
176+
:param ckpt_after: Path to the output folder of the converted checkpoint.
177+
:type ckpt_after: str
178+
:param v_spconv: SpConv version used for the weights. Can be one of 1 or 2, defaults to "1"
179+
:type v_spconv: str, optional
180+
:param framework: From which framework does the model weight comes from, choose one of mmdet3d or openpc, defaults to "mmdet3d"
181+
:type framework: str, optional
182+
"""
183+
# ckpt_before, ckpt_after, v_spconv="1", framework="mmdet3d"
184+
185+
# argument parser
186+
parser = argparse.ArgumentParser(description="Convert SpConv model to TorchSparse model")
187+
parser.add_argument("--ckpt_before", help="Path to the SpConv checkpoint")
188+
parser.add_argument("--ckpt_after", help="Path to the output folder of the converted checkpoint.")
189+
parser.add_argument("--cfg_path", help="Path to the config file of the model")
190+
parser.add_argument("--v_spconv", default="1", help="SpConv version used for the weights. Can be one of 1 or 2")
191+
parser.add_argument("--framework", default="mmdet3d", help="From which framework does the model weight comes from, choose one of mmdet3d or openpc")
192+
args = parser.parse_args()
193+
194+
# Check the plugin argument
195+
assert args.framework in ['mmdet3d', 'openpc'], "plugin argument can only be mmdet3d or openpcdet"
196+
assert args.v_spconv in ['1', '2'], "v_spconv argument can only be 1 or 2"
197+
198+
legacy = True if args.v_spconv == "1" else False
199+
cfg_path = args.cfg_path
200+
201+
model = build_mmdet_model_from_cfg(cfg_path, args.ckpt_before) if args.framework == "mmdet3d" else build_opc_model_from_cfg(cfg_path)
202+
convert_model_weights(
203+
ckpt_before=args.ckpt_before,
204+
ckpt_after=args.ckpt_after,
205+
model=model,
206+
legacy=legacy)
207+
208+
209+
def convert_weights(ckpt_before: str, ckpt_after: str, cfg_path: str, v_spconv: int = 1, framework: str = "mmdet3d"):
210+
"""Convert the weights of a model from SpConv to TorchSparse.
211+
212+
:param ckpt_before: _description_
213+
:type ckpt_before: str
214+
:param ckpt_after: _description_
215+
:type ckpt_after: str
216+
:param cfg_path: _description_
217+
:type cfg_path: str
218+
:param v_spconv: _description_, defaults to 1
219+
:type v_spconv: int, optional
220+
:param framework: _description_, defaults to "mmdet3d"
221+
:type framework: str, optional
222+
"""
223+
224+
# Check the plugin argument
225+
assert framework in ['mmdet3d', 'openpc'], "plugin argument can only be mmdet3d or openpcdet"
226+
assert v_spconv in [1, 2], "v_spconv argument can only be 1 or 2"
227+
228+
legacy = True if v_spconv == 1 else False
229+
230+
model = build_mmdet_model_from_cfg(cfg_path, ckpt_before) if framework == "mmdet3d" else build_opc_model_from_cfg(cfg_path)
231+
convert_model_weights(
232+
ckpt_before=ckpt_before,
233+
ckpt_after=ckpt_after,
234+
model=model,
235+
legacy=legacy)
236+
237+
238+
if __name__ == "__main__":
239+
convert_weights_cmd()
240+
print("Conversion completed")

examples/mmdetection3d/README.md

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# TorchSparse for MMDetection3D Plugin Demo
2+
3+
This tutorial demonstrates how to evaluate TorchSparse integrated MMDetection3D models. Follow the steps below to install dependencies, configure paths, convert model weights, and run the demo.
4+
5+
## Dependencies
6+
7+
1. **MMDetection3D Installation**: Follow the [MMDetection3D documentation](https://mmdetection3d.readthedocs.io/en/latest/get_started.html).
8+
2. **Dataset Preparation**: Pre-process the datasets as described [here](https://mmdetection3d.readthedocs.io/en/latest/user_guides/dataset_prepare.html).
9+
3. **TorchSparse Installation**: Install [TorchSparse](https://github.com/mit-han-lab/torchsparse).
10+
4. **Install TorchSparse Plugin for MMDetection3D**:
11+
1. Clone this repository.
12+
2. Navigate to `examples/mmdetection3d` and run `pip install -v -e .`.
13+
14+
## Notes
15+
16+
- For model evaluation, change the data root in the original MMDetection3D's model config to the full path of the corresponding dataset root.
17+
18+
## Steps
19+
20+
1. Install the dependencies.
21+
2. Specify the base paths and model registry.
22+
3. **IMPORTANT,** Activate the plugin: In `mmdetection3d/tools/test.py`, add `import ts_plugin` as the last import statement to activate the plugin.
23+
4. Run the evaluation.
24+
25+
## Supported Models
26+
27+
- SECOND
28+
- PV-RCNN
29+
- CenterPoint
30+
- Part-A2
31+
32+
## Convert Module Weights
33+
The dimensions of TorchSparse differ from SpConv, so parameter dimension conversion is required. You can use `convert_weights_cmd()` in converter.py as a command line tool or use `convert_weights()` as an API. Both functions have four parameters:
34+
35+
1. `ckpt_before`: Path to the input SpConv checkpoint file.
36+
2. `ckpt_after`: Path where the converted TorchSparse checkpoint will be saved.
37+
3. `cfg_path`: Path to the configuration mmdet3d file of the model.
38+
4. `v_spconv`: Version of SpConv used in the original model (1 or 2).
39+
5. `framework`: Choose between `'openpc'` and `'mmdet3d'`, default to `'mmdet3d'`.
40+
41+
These parameters allow the converter to locate the input model, specify the output location, understand the model's architecture, and apply the appropriate conversion method based for specific Sparse Conv layers.
42+
43+
Example conversion commands:
44+
```bash
45+
python examples/converter.py --ckpt_before ../mmdetection3d/models/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --cfg_path ../mmdetection3d/pv_rcnn/pv_rcnn_8xb2-80e_kitti-3d-3class.py --ckpt_after ./converted/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --v_spconv 1 --framework mmdet3d
46+
```
47+
48+
49+
# Run a demo
50+
In your Conda environment, run:
51+
```bash
52+
python <test_file_path> <cfg_path> <torchsparse_model_path> <cfg_options> --task lidar_det
53+
```
54+
55+
- `test_file_path`: The `tools/test.py` file in mmdet3d repository.
56+
- `cfg_path`: The path to the mmdet3d's model config for your model.
57+
- `torchsparse_model_path`: the path to the converted TorchSparse model checkpoint.
58+
- `cfg_options`: The plugin requires the use of MMDet3D cfg_options to tweak certain model layers to be the plugin layers. `cfg_options` examples are below:
59+
60+
## SECOND
61+
`cfg_options`:
62+
```bash
63+
"--cfg-options test_evaluator.pklfile_prefix=outputs/torchsparse/second --cfg-options model.middle_encoder.type=SparseEncoderTS"
64+
```
65+
66+
## PV-RCNN
67+
`cfg_options`:
68+
```bash
69+
"--cfg-options test_evaluator.pklfile_prefix=outputs/torchsparse/pv_rcnn --cfg-options model.middle_encoder.type=SparseEncoderTS --cfg-options model.points_encoder.type=VoxelSetAbstractionTS"
70+
```
71+
72+
### CenterPoint Voxel 0.1 Circular NMS
73+
74+
Update the path of the NuScenes dataset in the MMDetection3D dataset config `configs/_base_/datasets/nus-3d.py`.
75+
76+
`cfg_options`:
77+
```bash
78+
"--cfg-options model.pts_middle_encoder.type=SparseEncoderTS"
79+
```
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This folder contains the configs to carry out the demo in mmdetectino3d.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Default model conversion base folder for the demo. Please create the relative path to each specific model under this directory.

0 commit comments

Comments
 (0)