-
Notifications
You must be signed in to change notification settings - Fork 211
Description
Hello all. I have been trying to export mmaction's video-swin transformer model to ONNX. However, the script tools/deployment/pytorch2onnx.py provided in this repo was giving me following errors:
error 1) Floating point exception (core dumped)
error 2) RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":513, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.
error 3) other issues
I tried other another repo's reimplementation i.e https://github.com/haofanwang/video-swin-transformer-pytorch , but found same set of issues.
No luck. The default model was able to infer properly, but during onnx export it was failing. Even torch.jit.script() was failing.
So, after days of effort, I was able to come up with a way to successfully export my trained video-swin model to ONNX. Here's the code. I hope this will help.
from torchvision.models.video.swin_transformer import SwinTransformer3d
import torch
from collections import OrderedDict
torchvision_model = SwinTransformer3d(
patch_size=[2, 4, 4],
embed_dim= 128,
depths= [2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[16, 7, 7],
mlp_ratio=4.0,
dropout=0.0,
attention_dropout= 0.0,
stochastic_depth_prob=0.1,
num_classes=5)
mmaction_weights = torch.load('../dl_model_ckpt_swin/frames/swin_last.pth')
assert len(torchvision_model.state_dict())==len(mmaction_weights['state_dict']), "mamction video-swin weight's length doesn't match with torchvision video-swin model's architecture"
# print(torchvision_model)
############################
######## printing pytorch torchvision's swin state_dict without loading checkpoint
# for k, i in enumerate(torchvision_model.state_dict()):
# print(i, torchvision_model.state_dict()[i].shape)
# print('*'*50)
# print()
######## printing mmaction's swin checkpoints state_dict
# for k, i in enumerate(mmaction_weights['state_dict']):
# print(i, mmaction_weights['state_dict'][i].shape)
############################
########## asserting shape of state dicts
torchvision_model_keys = [i for i in torchvision_model.state_dict()]
mmaction_weight_keys = [i for i in mmaction_weights['state_dict']]
for i in range(len(torchvision_model_keys)):
shape_1 = torchvision_model.state_dict()[torchvision_model_keys[i]].shape
shape_2 = mmaction_weights['state_dict'][mmaction_weight_keys[i]].shape
if shape_1!=shape_2:
print('shapes not matching')
break
print('done')
############################ changing actual weight values in the torchvision swin
new_torchvision_state_dict = OrderedDict()
for i in range(len(torchvision_model_keys)):
new_torchvision_state_dict[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]
torchvision_model.load_state_dict(new_torchvision_state_dict)
# for i in range(len(torchvision_model_keys)):
# # print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
# # print('a')
# # torchvision_model.state_dict()[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]
# print(mmaction_weights['state_dict'][mmaction_weight_keys[i]][[-1]])
# print('a')
# print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
# print('b')
# print(new_torchvision_state_dict[torchvision_model_keys[i]][-1])
# exit()
print('done')
input_shape = [1, 3, 8, 224, 224]
input_tensor = torch.randn(input_shape)
a = torchvision_model(input_tensor)
# torch.jit.script(torchvision_model, (input_tensor))
# torchvision_model = torch.compile(torchvision_model)
torch.onnx.export(
torchvision_model,
input_tensor,
'video_swin.onnx',
export_params=True,
keep_initializers_as_inputs=True,
verbose=True,
opset_version=15)