File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 22
33import torch
44import torch .nn as nn
5- import gin .torch
6-
7- from .net import Net
5+ from torch .nn .attention import SDPBackend , sdpa_kernel
86
97
108class MHAPyTorchScaledDotProduct (nn .Module ):
@@ -38,9 +36,8 @@ def forward(self, x):
3836 queries , keys , values = qkv
3937
4038 use_dropout = 0.0 if not self .training else self .dropout
41- with torch .backends .cuda .sdp_kernel (
42- enable_flash = True , enable_math = False , enable_mem_efficient = False
43- ):
39+
40+ with sdpa_kernel (SDPBackend .FLASH_ATTENTION ):
4441 context_vec = nn .functional .scaled_dot_product_attention (
4542 queries ,
4643 keys ,
Original file line number Diff line number Diff line change 44import gin
55import torch
66import torch .nn as nn
7+ from torch .nn .attention import SDPBackend , sdpa_kernel
78from .common_former import DeepNorm
89from .rope import RotaryEmbedding
910
@@ -74,9 +75,7 @@ def forward(self, x):
7475 use_dropout = 0.0 if not self .training else self .dropout
7576
7677 if self .use_flash_attention :
77- with torch .backends .cuda .sdp_kernel (
78- enable_flash = True , enable_math = False , enable_mem_efficient = False
79- ):
78+ with sdpa_kernel (SDPBackend .FLASH_ATTENTION ):
8079 context_vec = nn .functional .scaled_dot_product_attention (
8180 queries ,
8281 keys ,
You can’t perform that action at this time.
0 commit comments