prune pointpillar with mmdetection3d #5188
Description
I want to prune pointpillar on mmdetection3d framework, but the model has more than one inputdef forward(self, img, img_metas, return_loss=True, **kwargs):
, which makes it not suitable for speedup: ModelSpeedup(model, dummy_input=torch.rand([10, 3, 32, 32]).to(device), masks_file=masks).speedup_model()
so I just prune and speedup the backbone: model.backbone
follow steps:
step1 : load origin model
step2: prune model, speedup model, save the speeduped model
step3 : load speeduped model and finetune it
the reason I save model in step2 is I can not finetune it in mmdetection3d immediately if I prune the model in mmdetection3d.
the sparse ratio is 0.9(cut 10% weight with L1Norm), accuracy drops about 2%
question is, am I right?
ps: finetune loss