Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/natten/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
SUPPORTED_DTYPES = ["fp32", "bf16", "fp16"]
NATTEN_BACKENDS = ["cutlass-fna", "blackwell-fna", "hopper-fna", "flex-fna"]
NATTEN_FMHA_BACKENDS = ["cutlass-fmha", "blackwell-fmha", "hopper-fmha", "flex-fmha"]
SDPA_BACKENDS = ["xformers", "cudnn", "fav2"]
SDPA_BACKENDS = ["xformers", "cudnn", "fav2", "fa"]

SCHEDULE_MAP = {
"non": KernelSchedule.NonPersistent,
Expand Down
17 changes: 15 additions & 2 deletions src/natten/profiling_utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,23 @@

from torch.nn.functional import scaled_dot_product_attention


def sdpa(q: Tensor, k: Tensor, v: Tensor, backend: str) -> Tensor:
HAS_FLASH_ATTN = False
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
from flash_attn_interface import flash_attn_func
HAS_FLASH_ATTN = True

def fmha(q: Tensor, k: Tensor, v: Tensor, backend: str) -> Tensor:
backends = []

if backend == "fa":
if not HAS_FLASH_ATTN:
raise ValueError("Please install flash-attention before using `fa` backend.")

return flash_attn_func(q, k, v)

if backend == "xformers":
backends = [SDPBackend.EFFICIENT_ATTENTION]

Expand Down
10 changes: 7 additions & 3 deletions src/natten/profiling_utils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from natten.utils.checks import check_all_args

from .formatting import convert_to_natten_profiler_ops, Result
from .ops import sdpa
from .ops import fmha

from .problem import Problem

Expand Down Expand Up @@ -448,15 +448,19 @@ def _profile_fmha_with_torch(
not problem.has_additional_kv
), "Profiling SDPA with additional KV is not supported."

is_flash_attn = backend == "fa"

# PyTorch SDPA expects heads first, like it's still 2021. Flash expects heads last, like NATTEN.
heads_last_layout = is_flash_attn
query, key, value, d_out, additional_kv = init_tensors(
problem, flatten_sequence=True, heads_last=False # torch SDPA is heads first :(
problem, flatten_sequence=True, heads_last=heads_last_layout
)

def run_ops(query, key, value, d_out, backend):
query.requires_grad = not disable_backward
key.requires_grad = not disable_backward
value.requires_grad = not disable_backward
out = sdpa(query, key, value, backend=backend)
out = fmha(query, key, value, backend=backend)
if not disable_backward:
out.backward(d_out)

Expand Down