Skip to content

Commit 69c3b51

Browse files
committed
Upgrade sdpa_kernel
1 parent b8c6f59 commit 69c3b51

2 files changed

Lines changed: 5 additions & 9 deletions

File tree

src/nets/common_former.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5-
import gin.torch
6-
7-
from .net import Net
5+
from torch.nn.attention import SDPBackend, sdpa_kernel
86

97

108
class MHAPyTorchScaledDotProduct(nn.Module):
@@ -38,9 +36,8 @@ def forward(self, x):
3836
queries, keys, values = qkv
3937

4038
use_dropout = 0.0 if not self.training else self.dropout
41-
with torch.backends.cuda.sdp_kernel(
42-
enable_flash=True, enable_math=False, enable_mem_efficient=False
43-
):
39+
40+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
4441
context_vec = nn.functional.scaled_dot_product_attention(
4542
queries,
4643
keys,

src/nets/conformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gin
55
import torch
66
import torch.nn as nn
7+
from torch.nn.attention import SDPBackend, sdpa_kernel
78
from .common_former import DeepNorm
89
from .rope import RotaryEmbedding
910

@@ -74,9 +75,7 @@ def forward(self, x):
7475
use_dropout = 0.0 if not self.training else self.dropout
7576

7677
if self.use_flash_attention:
77-
with torch.backends.cuda.sdp_kernel(
78-
enable_flash=True, enable_math=False, enable_mem_efficient=False
79-
):
78+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
8079
context_vec = nn.functional.scaled_dot_product_attention(
8180
queries,
8281
keys,

0 commit comments

Comments
 (0)