@@ -19,7 +19,6 @@ def __init__(
1919 qkv_bias = False ,
2020 use_rope = False ,
2121 max_len = 10000 ,
22- use_flash_attention = True ,
2322 ):
2423 super ().__init__ ()
2524
@@ -31,7 +30,6 @@ def __init__(
3130 self .d_in = d_in
3231 self .use_rope = use_rope
3332 self .rope_dim = self .head_dim
34- self .use_flash_attention = use_flash_attention
3533
3634 self .qkv = nn .Linear (d_in , 3 * d_out , bias = qkv_bias )
3735 self .proj = nn .Linear (d_in , d_out )
@@ -74,17 +72,11 @@ def forward(self, x):
7472
7573 use_dropout = 0.0 if not self .training else self .dropout
7674
77- if self .use_flash_attention :
78- with sdpa_kernel (SDPBackend .FLASH_ATTENTION ):
79- context_vec = nn .functional .scaled_dot_product_attention (
80- queries ,
81- keys ,
82- values ,
83- attn_mask = None ,
84- dropout_p = use_dropout ,
85- is_causal = True ,
86- )
87- else :
75+ sdp_backend = SDPBackend .DEFAULT
76+ if SDPBackend .is_supported (SDPBackend .FLASH_ATTENTION ):
77+ sdp_backend = SDPBackend .FLASH_ATTENTION
78+
79+ with sdpa_kernel (sdp_backend ):
8880 context_vec = nn .functional .scaled_dot_product_attention (
8981 queries ,
9082 keys ,
@@ -290,15 +282,13 @@ def __init__(
290282 alpha = 0.1 ,
291283 beta = 0.1 ,
292284 use_rope = False ,
293- use_flash_attention = True ,
294285 ):
295286 super (ConformerBlock , self ).__init__ ()
296287 self .feed_forward_residual_factor = feed_forward_residual_factor
297288 self .use_deepnorm = use_deepnorm
298289 self .alpha = alpha
299290 self .beta = beta
300291 self .use_rope = use_rope
301- self .use_flash_attention = use_flash_attention
302292
303293 self .ff1 = FeedForwardBlock (embed_dim , feed_forward_expansion_factor , dropout )
304294 self .attention = MHAPyTorchScaledDotProduct (
@@ -307,7 +297,6 @@ def __init__(
307297 num_heads = num_heads ,
308298 dropout = dropout ,
309299 use_rope = use_rope ,
310- use_flash_attention = self .use_flash_attention ,
311300 )
312301 self .conv_block = ConvBlock (embed_dim , conv_kernel_size , dropout )
313302 self .ff2 = FeedForwardBlock (embed_dim , feed_forward_expansion_factor , dropout )
@@ -399,7 +388,6 @@ def __init__(
399388 use_rope : bool ,
400389 num_patches : int ,
401390 patch_size : Tuple [int , int ] | None = None ,
402- use_flash_attention : bool = True ,
403391 ):
404392 super (Conformer , self ).__init__ ()
405393 self .embed_dim = embed_dim
@@ -414,7 +402,6 @@ def __init__(
414402 self .use_deepnorm = use_deepnorm
415403 self .use_rope = use_rope
416404 self .num_patches = num_patches
417- self .use_flash_attention = use_flash_attention
418405
419406 self .input_dropout = nn .Dropout (input_dropout )
420407
@@ -437,7 +424,6 @@ def __init__(
437424 alpha = self .alpha_deepnorm ,
438425 beta = self .beta_deepnorm ,
439426 use_rope = self .use_rope ,
440- use_flash_attention = self .use_flash_attention ,
441427 )
442428 for _ in range (depth )
443429 ]
0 commit comments