Skip to content

Commit d7188dd

Browse files
committed
Select flash attention automatically when available
1 parent 58d7c8f commit d7188dd

2 files changed

Lines changed: 10 additions & 20 deletions

File tree

src/nets/common_former.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ def forward(self, x):
3737

3838
use_dropout = 0.0 if not self.training else self.dropout
3939

40-
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
40+
sdp_backend = SDPBackend.DEFAULT
41+
if SDPBackend.is_supported(SDPBackend.FLASH_ATTENTION):
42+
sdp_backend = SDPBackend.FLASH_ATTENTION
43+
44+
with sdpa_kernel(sdp_backend):
4145
context_vec = nn.functional.scaled_dot_product_attention(
4246
queries,
4347
keys,

src/nets/conformer.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def __init__(
1919
qkv_bias=False,
2020
use_rope=False,
2121
max_len=10000,
22-
use_flash_attention=True,
2322
):
2423
super().__init__()
2524

@@ -31,7 +30,6 @@ def __init__(
3130
self.d_in = d_in
3231
self.use_rope = use_rope
3332
self.rope_dim = self.head_dim
34-
self.use_flash_attention = use_flash_attention
3533

3634
self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
3735
self.proj = nn.Linear(d_in, d_out)
@@ -74,17 +72,11 @@ def forward(self, x):
7472

7573
use_dropout = 0.0 if not self.training else self.dropout
7674

77-
if self.use_flash_attention:
78-
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
79-
context_vec = nn.functional.scaled_dot_product_attention(
80-
queries,
81-
keys,
82-
values,
83-
attn_mask=None,
84-
dropout_p=use_dropout,
85-
is_causal=True,
86-
)
87-
else:
75+
sdp_backend = SDPBackend.DEFAULT
76+
if SDPBackend.is_supported(SDPBackend.FLASH_ATTENTION):
77+
sdp_backend = SDPBackend.FLASH_ATTENTION
78+
79+
with sdpa_kernel(sdp_backend):
8880
context_vec = nn.functional.scaled_dot_product_attention(
8981
queries,
9082
keys,
@@ -290,15 +282,13 @@ def __init__(
290282
alpha=0.1,
291283
beta=0.1,
292284
use_rope=False,
293-
use_flash_attention=True,
294285
):
295286
super(ConformerBlock, self).__init__()
296287
self.feed_forward_residual_factor = feed_forward_residual_factor
297288
self.use_deepnorm = use_deepnorm
298289
self.alpha = alpha
299290
self.beta = beta
300291
self.use_rope = use_rope
301-
self.use_flash_attention = use_flash_attention
302292

303293
self.ff1 = FeedForwardBlock(embed_dim, feed_forward_expansion_factor, dropout)
304294
self.attention = MHAPyTorchScaledDotProduct(
@@ -307,7 +297,6 @@ def __init__(
307297
num_heads=num_heads,
308298
dropout=dropout,
309299
use_rope=use_rope,
310-
use_flash_attention=self.use_flash_attention,
311300
)
312301
self.conv_block = ConvBlock(embed_dim, conv_kernel_size, dropout)
313302
self.ff2 = FeedForwardBlock(embed_dim, feed_forward_expansion_factor, dropout)
@@ -399,7 +388,6 @@ def __init__(
399388
use_rope: bool,
400389
num_patches: int,
401390
patch_size: Tuple[int, int] | None = None,
402-
use_flash_attention: bool = True,
403391
):
404392
super(Conformer, self).__init__()
405393
self.embed_dim = embed_dim
@@ -414,7 +402,6 @@ def __init__(
414402
self.use_deepnorm = use_deepnorm
415403
self.use_rope = use_rope
416404
self.num_patches = num_patches
417-
self.use_flash_attention = use_flash_attention
418405

419406
self.input_dropout = nn.Dropout(input_dropout)
420407

@@ -437,7 +424,6 @@ def __init__(
437424
alpha=self.alpha_deepnorm,
438425
beta=self.beta_deepnorm,
439426
use_rope=self.use_rope,
440-
use_flash_attention=self.use_flash_attention,
441427
)
442428
for _ in range(depth)
443429
]

0 commit comments

Comments
 (0)