Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/natten/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
SUPPORTED_DTYPES = ["fp32", "bf16", "fp16"]
NATTEN_BACKENDS = ["cutlass-fna", "blackwell-fna", "hopper-fna", "flex-fna"]
NATTEN_FMHA_BACKENDS = ["cutlass-fmha", "blackwell-fmha", "hopper-fmha", "flex-fmha"]
SDPA_BACKENDS = ["xformers", "cudnn", "fav2"]
SDPA_BACKENDS = ["xformers", "cudnn", "fav2", "fa"]

SCHEDULE_MAP = {
"non": KernelSchedule.NonPersistent,
Expand Down
14 changes: 14 additions & 0 deletions src/natten/profiling_utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,24 @@

from torch.nn.functional import scaled_dot_product_attention

try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
from flash_attn_interface import flash_attn_func
HAS_FLASH_ATTN = True
except:
HAS_FLASH_ATTN = False

def sdpa(q: Tensor, k: Tensor, v: Tensor, backend: str) -> Tensor:
backends = []

if backend == "fa":
if not HAS_FLASH_ATTN:
raise ValueError("Please install flash-attention before using `fa` backend.")

return flash_attn_func(q, k, v)

if backend == "xformers":
backends = [SDPBackend.EFFICIENT_ATTENTION]

Expand Down
19 changes: 18 additions & 1 deletion src/natten/profiling_utils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,18 @@ def profile_na_with_torch(

return out

def check_fa_installed():
try:
import flash_attn
fa_version = 2
logger.info(f"Using FAv{fa_version}.")
except ImportError:
import flash_attn_interface
fa_version = 3
logger.info(f"Using FAv{fa_version}.")
except:
raise ValueError("Please install flash-attention before using `fa` backend.")


def _profile_fmha_with_torch(
problem: Problem,
Expand All @@ -448,8 +460,13 @@ def _profile_fmha_with_torch(
not problem.has_additional_kv
), "Profiling SDPA with additional KV is not supported."

is_flash_attn = backend == "fa"

if is_flash_attn:
check_fa_installed()

query, key, value, d_out, additional_kv = init_tensors(
problem, flatten_sequence=True, heads_last=False # torch SDPA is heads first :(
problem, flatten_sequence=True, heads_last=is_flash_attn # heads first except if flash attn
)

def run_ops(query, key, value, d_out, backend):
Expand Down