Skip to content

Commit 06a1f42

Browse files
authored
Fix ViViT Transformer not passing use_flash_attn to Attention and duplicate mask reshape (#360)
Two related bugs in vivit.py: 1. Transformer.__init__ accepted use_flash_attn but never forwarded it to the Attention modules it creates. Since Attention defaults to use_flash_attn=True, setting use_flash_attn=False on ViViT had no effect on the factorized_encoder variant's spatial and temporal transformers. 2. Attention.forward reshaped the mask from 2D to 4D before the flash/non-flash branch (line 82), then attempted to reshape it again inside the non-flash branch (line 92). When the non-flash code path is actually reached with a mask, einops raises an error because the mask is already 4D. These bugs masked each other: bug #1 prevented bug #2 from triggering because the non-flash path was never taken even when requested. Fix: pass use_flash_attn through to Attention in Transformer.__init__, and remove the redundant second mask rearrange in the non-flash branch.
1 parent 6ae6a3a commit 06a1f42

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

vit_pytorch/vivit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def forward(self, x, mask = None):
8989
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
9090

9191
if exists(mask):
92-
mask = rearrange(mask, 'b j -> b 1 1 j')
9392
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
9493

9594
attn = self.attend(dots)
@@ -109,7 +108,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash
109108
self.layers = ModuleList([])
110109
for _ in range(depth):
111110
self.layers.append(nn.ModuleList([
112-
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
111+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
113112
FeedForward(dim, mlp_dim, dropout = dropout)
114113
]))
115114

0 commit comments

Comments
 (0)