Skip to content

Commit 2ea7730

Browse files
committed
fp8 unit tests, test refactor
1 parent da8852e commit 2ea7730

6 files changed

Lines changed: 257 additions & 570 deletions

File tree

csrc/fmha_v2_jit_binding.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ void fmha_v2_run(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, ffi::T
2828
ffi::TensorView workspace_buffer, size_t workspace_buffer_size_in_bytes,
2929
Optional<ffi::TensorView> maybe_block_tables, int page_size,
3030
ffi::TensorView seq_lens, ffi::TensorView cum_seq_lens_q,
31-
ffi::TensorView cum_seq_lens_kv, const std::string& input_layout_str, int max_q_len,
32-
int max_kv_len, int batch_size, int total_q_tokens, int total_kv_tokens,
33-
const std::string& mask_mode_str, float scale_softmax, float scale_bmm1, float scale_bmm2,
34-
int window_left, int chunked_attention_size, bool has_alibi,
35-
float softcapping_scale, ffi::TensorView scale_bmm2_d,
31+
ffi::TensorView cum_seq_lens_kv, const std::string& input_layout_str,
32+
int max_q_len, int max_kv_len, int batch_size, int total_q_tokens,
33+
int total_kv_tokens, const std::string& mask_mode_str, float scale_softmax,
34+
float scale_bmm1, float scale_bmm2, int window_left, int chunked_attention_size,
35+
bool has_alibi, float softcapping_scale, ffi::TensorView scale_bmm2_d,
3636
Optional<ffi::TensorView> softmax_stats, Optional<ffi::TensorView> sinks);
3737

3838
// FMHAv2 attention operator

csrc/fmha_v2_run.cu

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ static inline Data_type dltype_to_data_type(DLDataType dtype) {
291291
return DATA_TYPE_FP16;
292292
}
293293

294-
295294
static inline Attention_mask_type string_to_mask_type(const std::string& s) {
296295
if (s == "padding") return Attention_mask_type::PADDING;
297296
if (s == "causal") return Attention_mask_type::CAUSAL;
@@ -309,7 +308,6 @@ static inline Attention_input_layout string_to_input_layout(const std::string& s
309308
return Attention_input_layout::Q_PAGED_KV; // default
310309
}
311310

312-
313311
void fmha_v2_run(
314312
ffi::TensorView q, // [batch, s_q, num_heads, head_dim]
315313
ffi::TensorView k, // [batch, s_kv, num_kv_heads, head_dim]
@@ -321,12 +319,11 @@ void fmha_v2_run(
321319
ffi::TensorView seq_lens, // [batch]
322320
ffi::TensorView cum_seq_lens_q, // [batch + 1]
323321
ffi::TensorView cum_seq_lens_kv, // [batch + 1]
324-
const std::string& input_layout_str,
325-
int max_q_len, int max_kv_len, int batch_size, int total_q_tokens,
326-
int total_kv_tokens, // Totals from cum_seq_lens (computed in Python)
327-
const std::string& mask_mode_str,
328-
float scale_softmax, float scale_bmm1, float scale_bmm2, int window_left,
329-
int chunked_attention_size, bool has_alibi, float softcapping_scale,
322+
const std::string& input_layout_str, int max_q_len, int max_kv_len, int batch_size,
323+
int total_q_tokens,
324+
int total_kv_tokens, // Totals from cum_seq_lens (computed in Python)
325+
const std::string& mask_mode_str, float scale_softmax, float scale_bmm1, float scale_bmm2,
326+
int window_left, int chunked_attention_size, bool has_alibi, float softcapping_scale,
330327
ffi::TensorView scale_bmm2_d, // Pre-populated scale_bmm2 on device [1] int32
331328
Optional<ffi::TensorView> softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum)
332329
Optional<ffi::TensorView> sinks) {
@@ -473,7 +470,8 @@ void fmha_v2_run(
473470
std::tie(warps_m, warps_n, warps_k) = get_warps(launch_params, sm, data_type, s, b, d, 2);
474471

475472
// Debug output for warps
476-
printf("DEBUG: get_warps returned warps_m=%zu, warps_n=%zu, warps_k=%zu\n", warps_m, warps_n, warps_k);
473+
printf("DEBUG: get_warps returned warps_m=%zu, warps_n=%zu, warps_k=%zu\n", warps_m, warps_n,
474+
warps_k);
477475
printf("DEBUG: launch_params: flash_attention=%d, warp_specialization=%d, use_tma=%d\n",
478476
launch_params.flash_attention, launch_params.warp_specialization, launch_params.use_tma);
479477
printf("DEBUG: data_type=%d, sm=%d, s=%zu, d=%zu\n", int(data_type), sm, s, d);

flashinfer/jit/attention/fmha_v2/fmha_library.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,15 +1129,13 @@ def gen_cta_spec(spec):
11291129
return api_code
11301130

11311131

1132-
def generate_jit_sources(input_layout: str) -> list:
1133-
uri = "trtllm_fmha_v2"
1132+
def generate_jit_sources(uri: str, input_layout: str, input_dtype: str, output_dtype: str) -> list:
11341133
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
11351134
source_paths = []
11361135
specs_names = []
1137-
dtype_values = ["fp16", "bf16", "e4m3"]
11381136
head_size_qk_values = [16, 32, 64, 128, 256, 512]
11391137
head_size_qk_warpspec_values = [32, 40, 48, 64, 80, 96, 104, 128, 160, 192, 256]
1140-
1138+
11411139
# 0 means head_size_v = head_size_qk (required for flash_valid)
11421140
head_size_v_values = [0]
11431141
map_input_layout = {
@@ -1146,8 +1144,11 @@ def generate_jit_sources(input_layout: str) -> list:
11461144
"separate_q_k_v": InputLayout.SEPARATE_Q_K_V,
11471145
"contiguous_q_kv": InputLayout.CONTIGUOUS_Q_KV,
11481146
}
1147+
11491148
input_layout_values = [map_input_layout[input_layout.lower()]]
1150-
output_dtype_values = ["fp16", "bf16"]
1149+
dtype_values = [input_dtype]
1150+
output_dtype_values = [output_dtype] if output_dtype is not None else [None]
1151+
11511152
is_mla_values = [False]
11521153

11531154
enable_attn_logit_softcapping_values = [True, False]

flashinfer/jit/attention/modules.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,17 +1936,28 @@ def gen_trtllm_fmha_v2_sm120_module(device: torch.device) -> JitSpec:
19361936
)
19371937

19381938

1939-
def gen_fmha_v2_module(input_layout: str) -> JitSpec:
1940-
uri = "trtllm_fmha_v2"
1941-
1939+
def gen_fmha_v2_module(input_layout: str, input_dtype: torch.dtype, output_dtype: torch.dtype = None) -> JitSpec:
19421940
# Setup generated source directory
1941+
if output_dtype is None:
1942+
output_dtype = input_dtype
1943+
1944+
dtype_map = {
1945+
torch.float16: "fp16",
1946+
torch.bfloat16: "bf16",
1947+
torch.float8_e4m3fn: "e4m3",
1948+
}
1949+
input_dtype_str = dtype_map[input_dtype]
1950+
output_dtype_str = dtype_map[output_dtype] if output_dtype is not None else None
1951+
1952+
uri = f"trtllm_fmha_v2_{input_layout.lower()}_{input_dtype_str}_{output_dtype_str}"
1953+
19431954
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
19441955
gen_directory.mkdir(parents=True, exist_ok=True)
19451956

19461957
# Source directories
19471958
csrc_dir = jit_env.FLASHINFER_CSRC_DIR
19481959
fmha_v2_src_dir = csrc_dir / "fmha_v2"
1949-
source_paths = generate_jit_sources(input_layout)
1960+
source_paths = generate_jit_sources(uri, input_layout, input_dtype_str, output_dtype_str)
19501961

19511962
# copy static fmha_v2_run.cu
19521963
static_run_path = csrc_dir / "fmha_v2_run.cu"

flashinfer/prefill.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3873,8 +3873,8 @@ def fmha_v2_prefill_deepseek(
38733873

38743874

38753875
@functools.cache
3876-
def get_trtllm_fmha_v2_module(input_layout: str):
3877-
return gen_fmha_v2_module(input_layout).build_and_load()
3876+
def get_trtllm_fmha_v2_module(input_layout: str, input_dtype: torch.dtype, output_dtype: torch.dtype = None):
3877+
return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load()
38783878

38793879

38803880
@flashinfer_api
@@ -4018,6 +4018,16 @@ def trtllm_fmha_v2_prefill(
40184018
elif len(qkv) == 3:
40194019
input_layout = "SEPARATE_Q_K_V"
40204020
query, k_cache, v_cache = qkv
4021+
if hasattr(torch, "float8_e4m3fn") and query.dtype == torch.float8_e4m3fn:
4022+
raise ValueError(
4023+
"FP8 (e4m3) is not supported for the SEPARATE_Q_K_V input layout. "
4024+
"Use PACKED_QKV, CONTIGUOUS_Q_KV, or Q_PAGED_KV layout instead."
4025+
)
4026+
if logits_soft_cap_scale is not None and logits_soft_cap_scale > 0:
4027+
raise ValueError(
4028+
"Logits soft capping is not supported for the SEPARATE_Q_K_V input layout. "
4029+
"Use PACKED_QKV, CONTIGUOUS_Q_KV, or Q_PAGED_KV layout instead."
4030+
)
40214031

40224032
else:
40234033
raise ValueError(
@@ -4102,7 +4112,7 @@ def trtllm_fmha_v2_prefill(
41024112
logits_soft_cap_scale if logits_soft_cap_scale is not None else 0.0
41034113
)
41044114

4105-
module = get_trtllm_fmha_v2_module(input_layout)
4115+
module = get_trtllm_fmha_v2_module(input_layout, query.dtype, o_dtype if query.dtype == torch.float8_e4m3fn else None)
41064116
total_q_tokens = int(cum_seq_lens_q[-1].item())
41074117
total_kv_tokens = int(cum_seq_lens_kv[-1].item())
41084118

0 commit comments

Comments
 (0)