-
Notifications
You must be signed in to change notification settings - Fork 211
Description
Describe the bug
Build video swin transformer in torch.bfloat16 and run with following error:
.........
File "/data4/Projects/video_captioning/my_project/experiments/modeling/swin_transformer.py", line 284, in forward
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
RuntimeError: expected scalar type BFloat16 but found Float
As I haved change something of this code in my project, the corresponding line in swin_transformer.py of this repository is in https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L166
and related function with bug is https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L317.
Bug fix
One way to fix is shown below:
- def compute_mask(D, H, W, window_size, shift_size, device):
+ def compute_mask(D, H, W, window_size, shift_size, device, dtype):
- img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
+ img_mask = torch.zeros((1, D, H, W, 1), device=device, dtype=dtype) # 1 Dp Hp Wp 1
And change the line in https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L405 to
attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device, x.dtype)