Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion flashinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def run(
v_scale: Optional[torch.Tensor] = None,
logits_soft_cap: float = 0.0,
profiler_buffer: Optional[torch.Tensor] = None,
kv_cache_sf: Optional[
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if profiler_buffer is None:
if self._use_profiler:
Expand Down Expand Up @@ -176,6 +179,13 @@ def run(
# profiler_buffer is optional
profiler_args = (profiler_buffer,) if self._use_profiler else ()

# Unpack kv_cache_sf for NVFP4 (maybe_k_cache_sf, maybe_v_cache_sf)
k_cache_sf, v_cache_sf = (
_unpack_paged_kv_cache(kv_cache_sf, self._kv_layout)
if kv_cache_sf is not None
else (None, None)
)

self.module.run(
self.float_workspace_buffer,
self.int_workspace_buffer,
Expand All @@ -194,7 +204,9 @@ def run(
v_scale,
sm_scale,
logits_soft_cap,
# ADDITIONAL_FUNC_PARAMS
# ADDITIONAL_FUNC_PARAMS (maybe_k_cache_sf, maybe_v_cache_sf)
k_cache_sf,
v_cache_sf,
# PROFILER_FUNC_PARAMS
*profiler_args,
)
Expand Down
28 changes: 20 additions & 8 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,15 +1298,16 @@ def run(
key_block_scales = None
value_block_scales = None
if kv_cache_sf is not None:
if (
not isinstance(kv_cache_sf, (tuple, list))
or len(kv_cache_sf) != 2
or not all(torch.is_tensor(x) for x in kv_cache_sf)
):
if isinstance(kv_cache_sf, (tuple, list)):
key_block_scales, value_block_scales = kv_cache_sf
elif torch.is_tensor(kv_cache_sf):
# stacked tensor [num_pages, 2, ...] β€” unbind along dim 1
key_block_scales, value_block_scales = kv_cache_sf.unbind(dim=1)
else:
raise TypeError(
"kv_cache_sf must be a tuple/list of two tensors: (k_scales, v_scales)."
"kv_cache_sf must be a tuple/list of two tensors or a stacked tensor "
"of shape [num_pages, 2, ...]: (k_scales, v_scales)."
Comment thread
Tom-Zheng marked this conversation as resolved.
Outdated
)
key_block_scales, value_block_scales = kv_cache_sf

if self._kv_layout == "NHD":
page_size = k_cache.shape[1]
Expand Down Expand Up @@ -1448,13 +1449,24 @@ def run(
rope_theta,
0, # token_pos_in_items_len
self._workspace_size,
paged_kv_cache,
]

if self._backend == "trtllm-gen":
# decode.py's trtllm-gen paged_run (get_trtllm_gen_decode_module)
# has a different optional-param layout than prefill.py's paged_run
run_args += [paged_kv_cache]
Comment thread
Tom-Zheng marked this conversation as resolved.
Outdated

run_args += [
self._num_qo_heads,
self._num_kv_heads,
self._block_tables,
self._kv_lens_buffer,
page_size,
None, # max_q_len (not applicable for decode)
Comment thread
qsang-nv marked this conversation as resolved.
Outdated
self._max_kv_len,
None, # batch_size (not applicable for decode)
None, # cum_seq_lens_q (not applicable for decode)
None, # cum_seq_lens_kv (not applicable for decode)
sinks,
key_block_scales,
value_block_scales,
Expand Down
67 changes: 45 additions & 22 deletions flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ...jit.cubin_loader import get_artifact, get_meta_hash
from ..utils import (
dtype_map,
dtype_map_kv,
filename_safe_dtype_map,
mask_mode_literal,
pos_encoding_mode_literal,
Expand Down Expand Up @@ -141,7 +142,7 @@ def gen_batch_mla_module(
generated_config_path,
config_templ.render(
dtype_q=dtype_map[dtype_q],
dtype_kv=dtype_map[dtype_kv],
dtype_kv=dtype_map_kv[dtype_kv],
dtype_o=dtype_map[dtype_o],
dtype_idx=dtype_map[dtype_idx],
head_dim_ckv=head_dim_ckv,
Expand Down Expand Up @@ -169,7 +170,7 @@ def gen_batch_mla_module(
generated_config_path,
config_templ.render(
dtype_q=dtype_map[dtype_q],
dtype_kv=dtype_map[dtype_kv],
dtype_kv=dtype_map_kv[dtype_kv],
dtype_o=dtype_map[dtype_o],
dtype_idx=dtype_map[dtype_idx],
head_dim_ckv=head_dim_ckv,
Expand Down Expand Up @@ -278,7 +279,7 @@ def gen_batch_decode_mla_module(
generated_config_path,
config_templ.render(
dtype_q=dtype_map[dtype_q],
dtype_kv=dtype_map[dtype_kv],
dtype_kv=dtype_map_kv[dtype_kv],
dtype_o=dtype_map[dtype_o],
dtype_idx=dtype_map[dtype_idx],
head_dim_ckv=head_dim,
Expand Down Expand Up @@ -518,8 +519,13 @@ def gen_single_prefill_module(

if backend == "fa2":
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"]
additional_tensor_dtypes = ["uint8_t", "float"]
additional_tensor_names = [
"maybe_custom_mask",
"maybe_alibi_slopes",
"maybe_k_cache_sf",
"maybe_v_cache_sf",
]
additional_tensor_dtypes = ["uint8_t", "float", "uint8_t", "uint8_t"]
additional_scalar_names = [
"logits_soft_cap",
"sm_scale",
Expand Down Expand Up @@ -755,7 +761,7 @@ def gen_customize_pod_module(
"variant_name_p": variant_name_p,
"variant_name_d": variant_name_d,
"dtype_q": dtype_map[dtype_q],
"dtype_kv": dtype_map[dtype_kv],
"dtype_kv": dtype_map_kv[dtype_kv],
"dtype_o": dtype_map[dtype_o],
"idtype": dtype_map[dtype_idx],
"head_dim_qk": head_dim,
Expand Down Expand Up @@ -855,7 +861,7 @@ def gen_customize_batch_pod_module(
"variant_name_p": variant_name_p,
"variant_name_d": variant_name_d,
"dtype_q": dtype_map[dtype_q],
"dtype_kv": dtype_map[dtype_kv],
"dtype_kv": dtype_map_kv[dtype_kv],
"dtype_o": dtype_map[dtype_o],
"idtype": dtype_map[dtype_idx],
"head_dim_qk": head_dim,
Expand Down Expand Up @@ -1001,6 +1007,8 @@ def gen_batch_prefill_module(
"maybe_prefix_len_ptr",
"maybe_token_pos_in_items_ptr",
"maybe_max_item_len_ptr",
"maybe_k_cache_sf",
"maybe_v_cache_sf",
]
additional_tensor_dtypes = [
"uint8_t",
Expand All @@ -1009,6 +1017,8 @@ def gen_batch_prefill_module(
"uint32_t",
"uint16_t",
"uint16_t",
"uint8_t",
"uint8_t",
] # NOTE(Zihao): int32_t should follow dtype_idx
additional_scalar_names = [
"logits_soft_cap",
Expand Down Expand Up @@ -1149,8 +1159,8 @@ def gen_batch_attention_module(
use_profiler,
)

additional_tensor_names: List[str] = []
additional_tensor_dtypes: List[str] = []
additional_tensor_names: List[str] = ["maybe_k_cache_sf", "maybe_v_cache_sf"]
additional_tensor_dtypes: List[str] = ["uint8_t", "uint8_t"]
additional_scalar_names: List[str] = []
additional_scalar_dtypes: List[str] = []
variant_name = f"StandardAttention<{str(use_logits_soft_cap).lower()}>"
Expand Down Expand Up @@ -1221,7 +1231,7 @@ def gen_customize_single_decode_module(
"variant_decl": variant_decl,
"variant_name": variant_name,
"dtype_q": dtype_map[dtype_q],
"dtype_kv": dtype_map[dtype_kv],
"dtype_kv": dtype_map_kv[dtype_kv],
"dtype_o": dtype_map[dtype_o],
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
Expand Down Expand Up @@ -1286,7 +1296,7 @@ def gen_customize_single_prefill_module(
"variant_decl": variant_decl,
"variant_name": variant_name,
"dtype_q": dtype_map[dtype_q],
"dtype_kv": dtype_map[dtype_kv],
"dtype_kv": dtype_map_kv[dtype_kv],
"dtype_o": dtype_map[dtype_o],
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
Expand Down Expand Up @@ -1461,7 +1471,7 @@ def gen_customize_batch_decode_module(
"variant_decl": variant_decl,
"variant_name": variant_name,
"dtype_q": dtype_map[dtype_q],
"dtype_kv": dtype_map[dtype_kv],
"dtype_kv": dtype_map_kv[dtype_kv],
"dtype_o": dtype_map[dtype_o],
"idtype": dtype_map[idtype],
"head_dim_qk": head_dim_qk,
Expand Down Expand Up @@ -1531,7 +1541,7 @@ def gen_customize_batch_prefill_module(
"variant_decl": variant_decl,
"variant_name": variant_name,
"dtype_q": dtype_map[dtype_q],
"dtype_kv": dtype_map[dtype_kv],
"dtype_kv": dtype_map_kv[dtype_kv],
"dtype_o": dtype_map[dtype_o],
"idtype": dtype_map[idtype],
"head_dim_qk": head_dim_qk,
Expand Down Expand Up @@ -1819,7 +1829,7 @@ def gen_customize_batch_attention_module(
"variant_decl": variant_decl,
"variant_name": variant_name,
"dtype_q": dtype_map[dtype_q],
"dtype_kv": dtype_map[dtype_kv],
"dtype_kv": dtype_map_kv[dtype_kv],
"dtype_o": dtype_map[dtype_o],
"idtype": dtype_map[idtype],
"head_dim_qk": head_dim_qk,
Expand All @@ -1828,13 +1838,26 @@ def gen_customize_batch_attention_module(
"use_logits_soft_cap": str(use_logits_soft_cap).lower(),
}
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
(additional_params_decl, additional_func_params, additional_params_setter) = (
generate_additional_params(
additional_tensor_names,
additional_tensor_dtypes,
additional_scalar_names,
additional_scalar_dtypes,
)
(additional_params_decl, additional_func_params, _) = generate_additional_params(
additional_tensor_names,
additional_tensor_dtypes,
additional_scalar_names,
additional_scalar_dtypes,
)
# batch_attention.cu loops over params[i], so generate a setter using params[i] syntax
# instead of the params.X syntax from generate_additional_params.
batch_additional_params_setter = " \\\n".join(
[
(
f"params[i].{var} = {var} ? static_cast<{dtype}*>({var}.value().data_ptr()): nullptr;"
if var.startswith("maybe")
else f"params[i].{var} = static_cast<{dtype}*>({var}.data_ptr());"
)
for dtype, var in zip(
additional_tensor_dtypes, additional_tensor_names, strict=True
)
]
+ [f"params[i].{var} = {var};" for var in additional_scalar_names]
)
with open(
jit_env.FLASHINFER_CSRC_DIR / "batch_attention_customize_config.jinja"
Expand All @@ -1849,7 +1872,7 @@ def gen_customize_batch_attention_module(
kwargs |= {
"additional_params_decl": additional_params_decl,
"additional_func_params": additional_func_params,
"additional_params_setter": additional_params_setter,
"additional_params_setter": batch_additional_params_setter,
}

generated_inc_str = config_templ.render(
Expand Down
12 changes: 12 additions & 0 deletions flashinfer/jit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
torch.uint64: "uint64_t",
}

dtype_map_kv = {
torch.float16: "half",
torch.bfloat16: "nv_bfloat16",
torch.float8_e4m3fn: "__nv_fp8_e4m3",
torch.float8_e5m2: "__nv_fp8_e5m2",
torch.uint8: "__nv_fp4x2_e2m1",
}
if hasattr(torch, "float4_e2m1fn_x2"):
dtype_map_kv[torch.float4_e2m1fn_x2] = "__nv_fp4x2_e2m1"

dtype_cutlass_map = {
torch.float16: "cutlass::half_t",
torch.bfloat16: "cutlass::bfloat16_t",
Expand All @@ -68,6 +78,8 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
torch.int64: "i64",
torch.uint64: "u64",
}
if hasattr(torch, "float4_e2m1fn_x2"):
filename_safe_dtype_map[torch.float4_e2m1fn_x2] = "fp4_e2m1"

pos_encoding_mode_literal = {
0: "PosEncodingMode::kNone",
Expand Down
Loading
Loading