Commit 06a1f42
authored
Fix ViViT Transformer not passing use_flash_attn to Attention and duplicate mask reshape (#360)
Two related bugs in vivit.py:
1. Transformer.__init__ accepted use_flash_attn but never forwarded it to the
Attention modules it creates. Since Attention defaults to use_flash_attn=True,
setting use_flash_attn=False on ViViT had no effect on the factorized_encoder
variant's spatial and temporal transformers.
2. Attention.forward reshaped the mask from 2D to 4D before the flash/non-flash
branch (line 82), then attempted to reshape it again inside the non-flash
branch (line 92). When the non-flash code path is actually reached with a
mask, einops raises an error because the mask is already 4D.
These bugs masked each other: bug #1 prevented bug #2 from triggering because
the non-flash path was never taken even when requested.
Fix: pass use_flash_attn through to Attention in Transformer.__init__, and
remove the redundant second mask rearrange in the non-flash branch.1 parent 6ae6a3a commit 06a1f42
1 file changed
Lines changed: 1 addition & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
89 | 89 | | |
90 | 90 | | |
91 | 91 | | |
92 | | - | |
93 | 92 | | |
94 | 93 | | |
95 | 94 | | |
| |||
109 | 108 | | |
110 | 109 | | |
111 | 110 | | |
112 | | - | |
| 111 | + | |
113 | 112 | | |
114 | 113 | | |
115 | 114 | | |
| |||
0 commit comments