Skip to content

Steps to convert mmaction's video-swin-transformer to ONNX successfully #89

@gigasurgeon

Description

@gigasurgeon

Hello all. I have been trying to export mmaction's video-swin transformer model to ONNX. However, the script tools/deployment/pytorch2onnx.py provided in this repo was giving me following errors:

error 1) Floating point exception (core dumped)
error 2) RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":513, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.
error 3) other issues

I tried other another repo's reimplementation i.e https://github.com/haofanwang/video-swin-transformer-pytorch , but found same set of issues.
No luck. The default model was able to infer properly, but during onnx export it was failing. Even torch.jit.script() was failing.

So, after days of effort, I was able to come up with a way to successfully export my trained video-swin model to ONNX. Here's the code. I hope this will help.

from torchvision.models.video.swin_transformer import SwinTransformer3d
import torch
from collections import OrderedDict


torchvision_model = SwinTransformer3d(
        patch_size=[2, 4, 4],
        embed_dim= 128,
        depths= [2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        window_size=[16, 7, 7],
        mlp_ratio=4.0,
        dropout=0.0,
        attention_dropout= 0.0,
        stochastic_depth_prob=0.1,
        num_classes=5)


mmaction_weights = torch.load('../dl_model_ckpt_swin/frames/swin_last.pth')

assert len(torchvision_model.state_dict())==len(mmaction_weights['state_dict']), "mamction video-swin weight's length doesn't match with torchvision video-swin model's architecture"

# print(torchvision_model)

############################
######## printing pytorch torchvision's swin state_dict without loading checkpoint
# for k, i in enumerate(torchvision_model.state_dict()):
#     print(i, torchvision_model.state_dict()[i].shape)

# print('*'*50)
# print()

######## printing mmaction's swin checkpoints state_dict
# for k, i in enumerate(mmaction_weights['state_dict']):
#     print(i, mmaction_weights['state_dict'][i].shape)

############################


########## asserting shape of state dicts

torchvision_model_keys = [i for i in torchvision_model.state_dict()]
mmaction_weight_keys = [i for i in mmaction_weights['state_dict']]


for i in range(len(torchvision_model_keys)):
    shape_1 = torchvision_model.state_dict()[torchvision_model_keys[i]].shape
    shape_2 = mmaction_weights['state_dict'][mmaction_weight_keys[i]].shape

    if shape_1!=shape_2:
        print('shapes not matching')
        break
print('done')


############################ changing actual weight values in the torchvision swin

new_torchvision_state_dict = OrderedDict()

for i in range(len(torchvision_model_keys)):
    new_torchvision_state_dict[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]

torchvision_model.load_state_dict(new_torchvision_state_dict)


# for i in range(len(torchvision_model_keys)):
# #     print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
# #     print('a')
# #     torchvision_model.state_dict()[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]
#     print(mmaction_weights['state_dict'][mmaction_weight_keys[i]][[-1]])
#     print('a')
#     print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
#     print('b')
#     print(new_torchvision_state_dict[torchvision_model_keys[i]][-1])
#     exit()



print('done')

input_shape = [1, 3, 8, 224, 224]
input_tensor = torch.randn(input_shape)
a = torchvision_model(input_tensor)
# torch.jit.script(torchvision_model, (input_tensor))
# torchvision_model = torch.compile(torchvision_model)


torch.onnx.export(
        torchvision_model,
        input_tensor,
        'video_swin.onnx',
        export_params=True,
        keep_initializers_as_inputs=True,
        verbose=True,
        opset_version=15)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions