Skip to content

Commit f219c89

Browse files
authored
[Cute,Sm100,Fwd] add MLA 64/512 with topk sparsity for MQA 128 heads (#2441)
* add mla 2cta with topk sparsity support * add tma store O * add clc option; performs worse than single tile * enable clc for topk gather * add producer tails * add mla dsa to interface * ruff format * use tma store for varlen * decouple sm stats from scale for smem * add varlen tests * credit monellz for kernel dump attributes utility * add docstring for optional args, change default value of topk_indices_maybe_oob to None * give default vals for new args in interface * more rigorous tests; fix race condition on smem for rowmax * add bandwidth calc and qv to benchmark script * refactor interface per suggestions * return more Nones for gradients
1 parent 09c93ea commit f219c89

13 files changed

Lines changed: 4702 additions & 119 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ repos:
1010
flash_bwd|
1111
flash_fwd|
1212
flash_fwd_sm100|
13+
flash_fwd_mla_sm100|
1314
interface|
1415
)\.py$
1516
- id: ruff-format

benchmarks/benchmark_attn.py

Lines changed: 146 additions & 29 deletions
Large diffs are not rendered by default.

flash_attn/cute/bench_utils.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313

1414

1515
def flops(
16-
batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)
16+
batch,
17+
nheads,
18+
seqlen_q,
19+
seqlen_k,
20+
headdim,
21+
headdim_v,
22+
causal=False,
23+
window_size=(None, None),
24+
has_qv=False,
1725
):
1826
if causal:
1927
avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
@@ -35,7 +43,37 @@ def flops(
3543
else torch.full_like(row_idx, seqlen_k - 1)
3644
)
3745
avg_seqlen = (col_right - col_left + 1).float().mean().item()
38-
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)
46+
eff_headdim = headdim + headdim_v if has_qv else headdim
47+
return batch * nheads * 2 * seqlen_q * avg_seqlen * (eff_headdim + headdim_v)
48+
49+
50+
# ── Bandwidth calculation ────────────────────────────────────────────────────
51+
52+
53+
def bandwidth_fwd_bytes(
54+
batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2, has_qv=False
55+
):
56+
"""HBM traffic for one attention pass: read Q,K,V + write O."""
57+
q = batch * nheads * seqlen_q * headdim
58+
qv = batch * nheads * seqlen_q * headdim_v if has_qv else 0
59+
k = batch * nheads_kv * seqlen_k * headdim
60+
v = batch * nheads_kv * seqlen_k * headdim_v
61+
o = batch * nheads * seqlen_q * headdim_v
62+
return (q + qv + k + v + o) * dtype_bytes
63+
64+
65+
def bandwidth_bwd_bytes(
66+
batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2
67+
):
68+
"""HBM traffic for one attention pass: read Q,K,V,dO + write dQ,dK,dV."""
69+
q = batch * nheads * seqlen_q * headdim
70+
k = batch * nheads_kv * seqlen_k * headdim
71+
v = batch * nheads_kv * seqlen_k * headdim_v
72+
do = batch * nheads * seqlen_q * headdim_v
73+
dq = q
74+
dk = k
75+
dv = v
76+
return (q + k + v + do + dq + dk + dv) * dtype_bytes
3977

4078

4179
# ── Reference attention ─────────────────────────────────────────────────────

flash_attn/cute/cute_dsl_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,38 @@ def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
104104
patterns are not interchangeable.
105105
"""
106106
return tuple(s == 0 for s in tensor.stride())
107+
108+
109+
# credit: monellz (https://github.com/NVIDIA/cutlass/issues/2658#issuecomment-3630564264)
110+
def dump_kernel_attributes(compiled_kernel):
111+
from cuda.bindings import driver
112+
from cutlass.utils import HardwareInfo
113+
import torch
114+
115+
device_id = torch.cuda.current_device()
116+
hardware_info = HardwareInfo(device_id=device_id)
117+
cubin_data = compiled_kernel.artifacts.CUBIN
118+
assert cubin_data is not None, "cubin_data is None, need '--keep-cubin' option when compiling"
119+
cuda_library = hardware_info._checkCudaErrors(
120+
driver.cuLibraryLoadData(cubin_data, None, None, 0, None, None, 0)
121+
)
122+
kernels = hardware_info._checkCudaErrors(driver.cuLibraryEnumerateKernels(1, cuda_library))
123+
kernel = hardware_info._checkCudaErrors(driver.cuKernelGetFunction(kernels[0]))
124+
# more metrics: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b
125+
local_size_bytes = hardware_info._checkCudaErrors(
126+
driver.cuFuncGetAttribute(
127+
driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES,
128+
kernel,
129+
)
130+
)
131+
num_regs = hardware_info._checkCudaErrors(
132+
driver.cuFuncGetAttribute(
133+
driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS,
134+
kernel,
135+
)
136+
)
137+
138+
print("--- Kernel Info ---")
139+
print(f"local_size_bytes: {local_size_bytes}")
140+
print(f"num_regs: {num_regs}")
141+
print("--- End Kernel Info ---")

0 commit comments

Comments
 (0)