Skip to content
434 changes: 434 additions & 0 deletions flash_attn/cute/benchmark_flash_attention_fp8.py

Large diffs are not rendered by default.

52 changes: 37 additions & 15 deletions flash_attn/cute/blackwell_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
from flash_attn.cute.utils import parse_swizzle_from_pointer


def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str:
if isinstance(op, tcgen05.mma.MmaF16BF16Op):
return "f16"
if isinstance(op, tcgen05.mma.MmaTF32Op):
return "tf32"
if isinstance(op, tcgen05.mma.MmaI8Op):
return "i8"
if isinstance(op, tcgen05.mma.MmaFP8Op):
return "f8f6f4"
if isinstance(op, tcgen05.mma.MmaMXF8Op):
return "mxf8f6f4"
if isinstance(op, tcgen05.mma.MmaMXF4Op):
return "mxf4"
if isinstance(op, tcgen05.mma.MmaMXF4NVF4Op):
return "mxf4nvf4"
raise TypeError(f"Unsupported tcgen05 MMA op kind: {type(op).__name__}")


@cute.jit
def gemm_w_idx(
tiled_mma: cute.TiledMma,
Expand Down Expand Up @@ -96,6 +114,7 @@ def gemm_ptx(
sA_layout = sA.layout if sA is not None else None
sB_layout = sB.layout
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
kind = _tcgen05_mma_kind(op)
if const_expr(not is_ts):
sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
smem_desc_base_a: int = const_expr(
Expand Down Expand Up @@ -165,7 +184,7 @@ def gemm_ptx(
f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t"
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
"setp.ne.b32 p, $3, 0;\n\t"
f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
f"tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
"}\n",
"r,r,r,r",
has_side_effects=True,
Expand All @@ -186,7 +205,7 @@ def gemm_ptx(
".reg .b64 smem_desc_b;\n\t"
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
"setp.ne.b32 p, $3, 0;\n\t"
f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
f"tcgen05.mma.cta_group::1.kind::{kind} [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
"}\n",
"r,r,r,r",
has_side_effects=True,
Expand All @@ -211,6 +230,7 @@ def gemm_ptx_loop(
sA_layout = sA.layout if sA is not None else tCrA.layout
sB_layout = sB.layout
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
kind = _tcgen05_mma_kind(op)
if const_expr(not is_ts):
sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
smem_desc_base_a: int = const_expr(
Expand Down Expand Up @@ -298,14 +318,14 @@ def gemm_ptx_loop(
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $3, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
+ "".join(
(
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
)
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
)
Expand Down Expand Up @@ -339,14 +359,14 @@ def gemm_ptx_loop(
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $3, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
+ "".join(
(
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
)
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
)
Expand Down Expand Up @@ -380,6 +400,7 @@ def gemm_ptx_partial(
sA_layout = sA.layout if sA is not None else tCrA.layout
sB_layout = sB.layout
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
kind = _tcgen05_mma_kind(op)
if const_expr(not is_ts):
sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
smem_desc_base_a: int = const_expr(
Expand Down Expand Up @@ -463,7 +484,7 @@ def gemm_ptx_partial(
f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $2, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
+ "".join(
(
# f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
Expand All @@ -472,7 +493,7 @@ def gemm_ptx_partial(
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
)
for k in range(1, cute.size(tCrA.shape[2]))
)
Expand Down Expand Up @@ -536,15 +557,15 @@ def gemm_ptx_partial(
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $2, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
+ "".join(
(
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
)
for k in range(
1,
Expand All @@ -559,7 +580,7 @@ def gemm_ptx_partial(
(
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
)
for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2]))
)
Expand Down Expand Up @@ -597,6 +618,7 @@ def gemm_ptx_partial1(
assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM"
assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM"
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
kind = _tcgen05_mma_kind(op)
if const_expr(not is_ts):
smem_desc_base_a: int = const_expr(
sm100_desc.make_smem_desc_base(
Expand Down Expand Up @@ -690,14 +712,14 @@ def gemm_ptx_partial1(
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $4, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
+ "".join(
(
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
)
for k in range(1, cute.size(tCrA.shape[2]))
)
Expand Down Expand Up @@ -735,13 +757,13 @@ def gemm_ptx_partial1(
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $3, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
+ "".join(
(
f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
)
for k in range(1, cute.size(tCrA.shape[2]))
)
Expand Down
7 changes: 5 additions & 2 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,8 @@ def handle_block_sparse_empty_tile_correction_sm100(
o_corr_consumer_phase: Int32,
corr_epi_producer_phase: Int32,
softmax_scale_log2: Float32,
max_offset: Float32,
max_offset_scale: Float32,
mO_cur: Optional[cute.Tensor] = None,
gO: Optional[cute.Tensor] = None,
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
Expand Down Expand Up @@ -695,10 +697,11 @@ def handle_block_sparse_empty_tile_correction_sm100(
if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0):
if row_max_value == -Float32.inf:
row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
row_sum_value = Float32(1.0)
row_sum_value = max_offset_scale
else:
row_sum_value = row_sum_value + cute.math.exp2(
sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True
sink_val * LOG2_E - row_max_value * softmax_scale_log2 + max_offset,
fastmath=True,
)
if tidx < m_block_size:
scale_row_idx = tidx + stage * m_block_size
Expand Down
17 changes: 16 additions & 1 deletion flash_attn/cute/cute_dsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.float32: cutlass.Float32,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
torch.float8_e5m2: cutlass.Float8E5M2,
}


Expand Down Expand Up @@ -144,7 +146,20 @@ def assume_tensor_aligned(t):

def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
# NOTE: torch 2.9.1 doesn't support fp8 via DLPack but 2.11.0 nightly does
# currently export raw bytes as uint8 and tell cutlass correct type
# can directly export as fp8 when torch supports it
if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
tensor = from_dlpack(
t.view(torch.uint8).detach(),
assumed_align=assumed_align,
enable_tvm_ffi=enable_tvm_ffi,
)
tensor.element_type = (
cutlass.Float8E4M3FN if t.dtype == torch.float8_e4m3fn else cutlass.Float8E5M2
)
else:
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
if fully_dynamic:
return tensor.mark_layout_dynamic()
if leading_dim == -1:
Expand Down
Loading