diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index e8b56af0bb..2988d459ad 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -780,8 +780,8 @@ __device__ inline Vec convertKCacheWordToF16<__nv_bfloat16, __nv_fp Vec ret; // This needs CUDA Toolkit version >= 13.2 #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -#if (defined __CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 13) && \ - (defined __CUDACC_VER_MINOR__) && (__CUDACC_VER_MINOR__ >= 2) +#if (defined __CUDACC_VER_MAJOR__) && (defined __CUDACC_VER_MINOR__) && \ + ((__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ == 13) && (__CUDACC_VER_MINOR__ >= 2))) uint32_t src = i8data | (i8data >> 4); uint32_t(&dst)[2] = reinterpret_cast(ret); asm("{\n" diff --git a/flashinfer/attention.py b/flashinfer/attention.py index c4bc4f27dc..0f08330959 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -146,6 +146,9 @@ def run( v_scale: Optional[torch.Tensor] = None, logits_soft_cap: float = 0.0, profiler_buffer: Optional[torch.Tensor] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if profiler_buffer is None: if self._use_profiler: @@ -176,6 +179,13 @@ def run( # profiler_buffer is optional profiler_args = (profiler_buffer,) if self._use_profiler else () + # Unpack kv_cache_sf for NVFP4 (maybe_k_cache_sf, maybe_v_cache_sf) + k_cache_sf, v_cache_sf = ( + _unpack_paged_kv_cache(kv_cache_sf, self._kv_layout) + if kv_cache_sf is not None + else (None, None) + ) + self.module.run( self.float_workspace_buffer, self.int_workspace_buffer, @@ -194,7 +204,9 @@ def run( v_scale, sm_scale, logits_soft_cap, - # ADDITIONAL_FUNC_PARAMS + # ADDITIONAL_FUNC_PARAMS (maybe_k_cache_sf, maybe_v_cache_sf) + k_cache_sf, + v_cache_sf, # PROFILER_FUNC_PARAMS *profiler_args, ) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 822aca407c..5e34af8d61 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -444,9 +444,9 @@ def single_decode_with_kv_cache( q_scale : Optional[float] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. + The calibration scale of key for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + The calibration scale of value for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. window_left : int The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. @@ -1192,7 +1192,9 @@ def run( sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, skip_softmax_threshold_scale_factor: Optional[float] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> torch.Tensor: ... @overload @@ -1212,7 +1214,9 @@ def run( sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, skip_softmax_threshold_scale_factor: Optional[float] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @flashinfer_api @@ -1232,7 +1236,9 @@ def run( sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, skip_softmax_threshold_scale_factor: Optional[float] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch decode attention between query and paged kv cache. @@ -1258,9 +1264,9 @@ def run( q_scale : Optional[float] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. + The calibration scale of key for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + The calibration scale of value for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] @@ -1278,6 +1284,19 @@ def run( If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length. + kv_cache_sf : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] + Per-block scale factors for NVFP4 KV cache. Accepts the same formats as + ``paged_kv_cache``: + + * a tuple ``(k_scales, v_scales)`` of 4-D tensors, each with shape: + ``[num_pages, page_size, num_kv_heads, head_dim // 16]`` if :attr:`kv_layout` is ``NHD``, + and ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` if :attr:`kv_layout` is ``HND``. + * a single 5-D tensor with shape: + ``[num_pages, 2, page_size, num_kv_heads, head_dim // 16]`` if :attr:`kv_layout` is ``NHD``, + and ``[num_pages, 2, num_kv_heads, page_size, head_dim // 16]`` if :attr:`kv_layout` is ``HND``, + where dim 1 holds k (index 0) and v (index 1) scales. + + Both tensors have dtype ``torch.float8_e4m3fn``. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -1295,18 +1314,11 @@ def run( k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8 ) and kv_cache_sf is None: raise ValueError("kv_cache_sf must be provided for NVFP4 KV cache.") - key_block_scales = None - value_block_scales = None - if kv_cache_sf is not None: - if ( - not isinstance(kv_cache_sf, (tuple, list)) - or len(kv_cache_sf) != 2 - or not all(torch.is_tensor(x) for x in kv_cache_sf) - ): - raise TypeError( - "kv_cache_sf must be a tuple/list of two tensors: (k_scales, v_scales)." - ) - key_block_scales, value_block_scales = kv_cache_sf + key_block_scales, value_block_scales = ( + _unpack_paged_kv_cache(kv_cache_sf, self._kv_layout) + if kv_cache_sf is not None + else (None, None) + ) if self._kv_layout == "NHD": page_size = k_cache.shape[1] @@ -1448,20 +1460,44 @@ def run( rope_theta, 0, # token_pos_in_items_len self._workspace_size, - paged_kv_cache, - self._num_qo_heads, - self._num_kv_heads, - self._block_tables, - self._kv_lens_buffer, - page_size, - self._max_kv_len, - sinks, - key_block_scales, - value_block_scales, - skip_softmax_threshold_scale_factor, - True, # uses_shared_paged_kv_idx ] + if self._backend == "trtllm-gen": + # decode.py's trtllm-gen paged_run (get_trtllm_gen_decode_module) + # has a different optional-param layout than prefill.py's paged_run + run_args += [ + paged_kv_cache, + self._num_qo_heads, + self._num_kv_heads, + self._block_tables, + self._kv_lens_buffer, + page_size, + self._max_kv_len, + sinks, + key_block_scales, + value_block_scales, + skip_softmax_threshold_scale_factor, + True, # uses_shared_paged_kv_idx + ] + else: + run_args += [ + self._num_qo_heads, + self._num_kv_heads, + self._block_tables, + self._kv_lens_buffer, + page_size, + None, # max_q_len (not applicable for decode) + self._max_kv_len, + None, # batch_size (not applicable for decode) + None, # cum_seq_lens_q (not applicable for decode) + None, # cum_seq_lens_kv (not applicable for decode) + sinks, + key_block_scales, + value_block_scales, + skip_softmax_threshold_scale_factor, + True, # uses_shared_paged_kv_idx + ] + self._cached_module.paged_run(*run_args) else: # trtllm-gen does not need plan info @@ -2257,7 +2293,9 @@ def trtllm_batch_decode_with_kv_cache( max_q_len: Optional[int] = None, cum_seq_lens_q: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, uses_shared_paged_kv_idx: bool = True, ) -> Union[torch.Tensor, FP4Tensor]: """ @@ -2267,11 +2305,15 @@ def trtllm_batch_decode_with_kv_cache( query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch. kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``, - or [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``. - If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``, - or [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``. - The first tensor is the key cache, and the second tensor is the value cache. + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape + ``[num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``, + or ``[num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``. + * a single 5-D tensor with shape + ``[num_pages, 2, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``, + or ``[num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + where dim 1 holds k (index 0) and v (index 1). **Contiguity requirements (trtllm-gen backend):** @@ -2361,10 +2403,17 @@ def trtllm_batch_decode_with_kv_cache( Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length. - kv_cache_sf : Optional[Tuple[torch.Tensor, torch.Tensor]] = None - Per-block scale factors for NVFP4 KV cache, as a tuple of ``(k_scales, v_scales)``. - Each scale tensor has shape ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` - in HND layout, with dtype ``torch.float8_e4m3fn``. + kv_cache_sf : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None + Per-block scale factors for NVFP4 KV cache. Accepts the same formats as + ``kv_cache``: + + * a tuple ``(k_scales, v_scales)`` of 4-D tensors, each with shape + ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` in HND layout. + * a single 5-D tensor with shape + ``[num_pages, 2, num_kv_heads, page_size, head_dim // 16]`` in HND layout, + where dim 1 holds k (index 0) and v (index 1) scales. + + Both tensors have dtype ``torch.float8_e4m3fn``. **Contiguity requirements (trtllm-gen backend):** @@ -2386,18 +2435,7 @@ def trtllm_batch_decode_with_kv_cache( """ enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - if isinstance(kv_cache, tuple): - k_cache, v_cache = kv_cache - else: - if kv_cache.shape[1] == 1: - k_cache, v_cache = kv_cache, kv_cache - else: - assert kv_cache.shape[1] == 2, ( - "When kv_cache is a single tensor, the second dimension must be 1 or 2" - ) - # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) - # it doesn't change underlying storage - k_cache, v_cache = kv_cache.unbind(dim=1) + k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, kv_layout) if ( k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8 @@ -2409,18 +2447,12 @@ def trtllm_batch_decode_with_kv_cache( and kv_cache_sf is not None ) - k_block_scales = None - v_block_scales = None + k_block_scales, v_block_scales = ( + _unpack_paged_kv_cache(kv_cache_sf, kv_layout) + if kv_cache_sf is not None + else (None, None) + ) if is_nvfp4_kvcache: - if ( - not isinstance(kv_cache_sf, (tuple, list)) - or len(kv_cache_sf) != 2 - or not all(torch.is_tensor(x) for x in kv_cache_sf) - ): - raise TypeError( - "kv_cache_sf must be a tuple/list of two tensors: (k_scales, v_scales)." - ) - k_block_scales, v_block_scales = kv_cache_sf assert ( k_block_scales.dtype == torch.float8_e4m3fn and v_block_scales.dtype == torch.float8_e4m3fn @@ -2636,8 +2668,8 @@ def xqa_batch_decode_with_kv_cache( q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, mask: Optional[torch.Tensor] = None, - kv_cache_sf: Union[ - torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]] + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ] = None, ) -> torch.Tensor: """ @@ -2647,10 +2679,15 @@ def xqa_batch_decode_with_kv_cache( query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``, - or [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``. - If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``, - or [num_pages, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``. + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape + ``[num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + * a single 5-D tensor with shape + ``[num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, 2, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``, + where dim 1 holds k (index 0) and v (index 1). workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace @@ -2693,8 +2730,19 @@ def xqa_batch_decode_with_kv_cache( mask : Optional[torch.Tensor] = None causal attention mask for xqa speculative decoding. - kv_cache_sf : Optional[torch.Tensor] = None - KV cache scaling factors. Must provide when NVFP4 KV cache is used. + kv_cache_sf : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None + Per-block scale factors for NVFP4 KV cache. Accepts the same formats as + ``kv_cache``: + + * a tuple ``(k_scales, v_scales)`` of 4-D tensors, each with shape + ``[num_pages, page_size, num_kv_heads, head_dim // 16]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` if :attr:`kv_layout` is ``HND``. + * a single 5-D tensor with shape + ``[num_pages, 2, page_size, num_kv_heads, head_dim // 16]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, 2, num_kv_heads, page_size, head_dim // 16]`` if :attr:`kv_layout` is ``HND``, + where dim 1 holds k (index 0) and v (index 1) scales. + + Both tensors have dtype ``torch.float8_e4m3fn``. Returns ------- @@ -2703,31 +2751,13 @@ def xqa_batch_decode_with_kv_cache( """ enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - if isinstance(kv_cache, tuple): - k_cache, v_cache = kv_cache - else: - if kv_cache.shape[1] == 1: - k_cache, v_cache = kv_cache, kv_cache - else: - assert kv_cache.shape[1] == 2, ( - "When kv_cache is a single tensor, the second dimension must be 1 or 2" - ) - # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) - # it doesn't change underlying storage - k_cache, v_cache = kv_cache.unbind(dim=1) - - k_cache_sf = None - v_cache_sf = None - if kv_cache_sf is not None: - if isinstance(kv_cache_sf, tuple): - k_cache_sf, v_cache_sf = kv_cache_sf - else: - assert kv_cache_sf.shape[1] == 2, ( - "When kv_cache is a single tensor, the second dimension must be 1 or 2" - ) - # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) - # it doesn't change underlying storage - k_cache_sf, v_cache_sf = kv_cache_sf.unbind(dim=1) + k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, kv_layout) + + k_cache_sf, v_cache_sf = ( + _unpack_paged_kv_cache(kv_cache_sf, kv_layout) + if kv_cache_sf is not None + else (None, None) + ) sm_count = get_device_sm_count(query.device) diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index 465f50ed0b..00acfd21a6 100755 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -31,6 +31,7 @@ from ...jit.cubin_loader import get_artifact, get_meta_hash from ..utils import ( dtype_map, + dtype_map_kv, filename_safe_dtype_map, mask_mode_literal, pos_encoding_mode_literal, @@ -141,7 +142,7 @@ def gen_batch_mla_module( generated_config_path, config_templ.render( dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], + dtype_kv=dtype_map_kv[dtype_kv], dtype_o=dtype_map[dtype_o], dtype_idx=dtype_map[dtype_idx], head_dim_ckv=head_dim_ckv, @@ -169,7 +170,7 @@ def gen_batch_mla_module( generated_config_path, config_templ.render( dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], + dtype_kv=dtype_map_kv[dtype_kv], dtype_o=dtype_map[dtype_o], dtype_idx=dtype_map[dtype_idx], head_dim_ckv=head_dim_ckv, @@ -278,7 +279,7 @@ def gen_batch_decode_mla_module( generated_config_path, config_templ.render( dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], + dtype_kv=dtype_map_kv[dtype_kv], dtype_o=dtype_map[dtype_o], dtype_idx=dtype_map[dtype_idx], head_dim_ckv=head_dim, @@ -518,8 +519,13 @@ def gen_single_prefill_module( if backend == "fa2": assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend" - additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] - additional_tensor_dtypes = ["uint8_t", "float"] + additional_tensor_names = [ + "maybe_custom_mask", + "maybe_alibi_slopes", + "maybe_k_cache_sf", + "maybe_v_cache_sf", + ] + additional_tensor_dtypes = ["uint8_t", "float", "uint8_t", "uint8_t"] additional_scalar_names = [ "logits_soft_cap", "sm_scale", @@ -755,7 +761,7 @@ def gen_customize_pod_module( "variant_name_p": variant_name_p, "variant_name_d": variant_name_d, "dtype_q": dtype_map[dtype_q], - "dtype_kv": dtype_map[dtype_kv], + "dtype_kv": dtype_map_kv[dtype_kv], "dtype_o": dtype_map[dtype_o], "idtype": dtype_map[dtype_idx], "head_dim_qk": head_dim, @@ -855,7 +861,7 @@ def gen_customize_batch_pod_module( "variant_name_p": variant_name_p, "variant_name_d": variant_name_d, "dtype_q": dtype_map[dtype_q], - "dtype_kv": dtype_map[dtype_kv], + "dtype_kv": dtype_map_kv[dtype_kv], "dtype_o": dtype_map[dtype_o], "idtype": dtype_map[dtype_idx], "head_dim_qk": head_dim, @@ -1001,6 +1007,8 @@ def gen_batch_prefill_module( "maybe_prefix_len_ptr", "maybe_token_pos_in_items_ptr", "maybe_max_item_len_ptr", + "maybe_k_cache_sf", + "maybe_v_cache_sf", ] additional_tensor_dtypes = [ "uint8_t", @@ -1009,6 +1017,8 @@ def gen_batch_prefill_module( "uint32_t", "uint16_t", "uint16_t", + "uint8_t", + "uint8_t", ] # NOTE(Zihao): int32_t should follow dtype_idx additional_scalar_names = [ "logits_soft_cap", @@ -1149,8 +1159,8 @@ def gen_batch_attention_module( use_profiler, ) - additional_tensor_names: List[str] = [] - additional_tensor_dtypes: List[str] = [] + additional_tensor_names: List[str] = ["maybe_k_cache_sf", "maybe_v_cache_sf"] + additional_tensor_dtypes: List[str] = ["uint8_t", "uint8_t"] additional_scalar_names: List[str] = [] additional_scalar_dtypes: List[str] = [] variant_name = f"StandardAttention<{str(use_logits_soft_cap).lower()}>" @@ -1221,7 +1231,7 @@ def gen_customize_single_decode_module( "variant_decl": variant_decl, "variant_name": variant_name, "dtype_q": dtype_map[dtype_q], - "dtype_kv": dtype_map[dtype_kv], + "dtype_kv": dtype_map_kv[dtype_kv], "dtype_o": dtype_map[dtype_o], "head_dim_qk": head_dim_qk, "head_dim_vo": head_dim_vo, @@ -1286,7 +1296,7 @@ def gen_customize_single_prefill_module( "variant_decl": variant_decl, "variant_name": variant_name, "dtype_q": dtype_map[dtype_q], - "dtype_kv": dtype_map[dtype_kv], + "dtype_kv": dtype_map_kv[dtype_kv], "dtype_o": dtype_map[dtype_o], "head_dim_qk": head_dim_qk, "head_dim_vo": head_dim_vo, @@ -1461,7 +1471,7 @@ def gen_customize_batch_decode_module( "variant_decl": variant_decl, "variant_name": variant_name, "dtype_q": dtype_map[dtype_q], - "dtype_kv": dtype_map[dtype_kv], + "dtype_kv": dtype_map_kv[dtype_kv], "dtype_o": dtype_map[dtype_o], "idtype": dtype_map[idtype], "head_dim_qk": head_dim_qk, @@ -1531,7 +1541,7 @@ def gen_customize_batch_prefill_module( "variant_decl": variant_decl, "variant_name": variant_name, "dtype_q": dtype_map[dtype_q], - "dtype_kv": dtype_map[dtype_kv], + "dtype_kv": dtype_map_kv[dtype_kv], "dtype_o": dtype_map[dtype_o], "idtype": dtype_map[idtype], "head_dim_qk": head_dim_qk, @@ -1819,7 +1829,7 @@ def gen_customize_batch_attention_module( "variant_decl": variant_decl, "variant_name": variant_name, "dtype_q": dtype_map[dtype_q], - "dtype_kv": dtype_map[dtype_kv], + "dtype_kv": dtype_map_kv[dtype_kv], "dtype_o": dtype_map[dtype_o], "idtype": dtype_map[idtype], "head_dim_qk": head_dim_qk, @@ -1828,13 +1838,26 @@ def gen_customize_batch_attention_module( "use_logits_soft_cap": str(use_logits_soft_cap).lower(), } gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) + (additional_params_decl, additional_func_params, _) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + ) + # batch_attention.cu loops over params[i], so generate a setter using params[i] syntax + # instead of the params.X syntax from generate_additional_params. + batch_additional_params_setter = " \\\n".join( + [ + ( + f"params[i].{var} = {var} ? static_cast<{dtype}*>({var}.value().data_ptr()): nullptr;" + if var.startswith("maybe") + else f"params[i].{var} = static_cast<{dtype}*>({var}.data_ptr());" + ) + for dtype, var in zip( + additional_tensor_dtypes, additional_tensor_names, strict=True + ) + ] + + [f"params[i].{var} = {var};" for var in additional_scalar_names] ) with open( jit_env.FLASHINFER_CSRC_DIR / "batch_attention_customize_config.jinja" @@ -1849,7 +1872,7 @@ def gen_customize_batch_attention_module( kwargs |= { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, - "additional_params_setter": additional_params_setter, + "additional_params_setter": batch_additional_params_setter, } generated_inc_str = config_templ.render( diff --git a/flashinfer/jit/utils.py b/flashinfer/jit/utils.py index 4e19212e14..befd37d867 100644 --- a/flashinfer/jit/utils.py +++ b/flashinfer/jit/utils.py @@ -43,6 +43,16 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.uint64: "uint64_t", } +dtype_map_kv = { + torch.float16: "half", + torch.bfloat16: "nv_bfloat16", + torch.float8_e4m3fn: "__nv_fp8_e4m3", + torch.float8_e5m2: "__nv_fp8_e5m2", + torch.uint8: "__nv_fp4x2_e2m1", +} +if hasattr(torch, "float4_e2m1fn_x2"): + dtype_map_kv[torch.float4_e2m1fn_x2] = "__nv_fp4x2_e2m1" + dtype_cutlass_map = { torch.float16: "cutlass::half_t", torch.bfloat16: "cutlass::bfloat16_t", @@ -68,6 +78,8 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.int64: "i64", torch.uint64: "u64", } +if hasattr(torch, "float4_e2m1fn_x2"): + filename_safe_dtype_map[torch.float4_e2m1fn_x2] = "fp4_e2m1" pos_encoding_mode_literal = { 0: "PosEncodingMode::kNone", diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 4ec6a29e7d..52da91f9a8 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -356,6 +356,8 @@ def run_single_prefill( scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, + maybe_k_cache_sf: Optional[torch.Tensor] = None, + maybe_v_cache_sf: Optional[torch.Tensor] = None, ) -> None: if backend == "fa3": scale_v_tensor, scale_v_scalar = _split_scale_param(scale_v) @@ -410,6 +412,8 @@ def run_single_prefill( window_left, maybe_packed_custom_mask, maybe_alibi_slopes, + maybe_k_cache_sf, + maybe_v_cache_sf, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale @@ -434,6 +438,8 @@ def _fake_run_single_prefill( sm_scale: float, rope_scale: float, rope_theta: float, + maybe_k_cache_sf: Optional[torch.Tensor] = None, + maybe_v_cache_sf: Optional[torch.Tensor] = None, ) -> None: pass @@ -493,6 +499,8 @@ def ragged_run( rope_scale: float, rope_theta: float, token_pos_in_items_len: int, + maybe_k_cache_sf: Optional[torch.Tensor] = None, + maybe_v_cache_sf: Optional[torch.Tensor] = None, scale_q: Optional[torch.Tensor] = None, scale_k: Optional[torch.Tensor] = None, scale_v: Optional[torch.Tensor] = None, @@ -522,6 +530,8 @@ def ragged_run( maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, maybe_max_item_len_ptr, + maybe_k_cache_sf, + maybe_v_cache_sf, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale @@ -614,6 +624,8 @@ def _fake_ragged_run( rope_scale: float, rope_theta: float, token_pos_in_items_len: int, + maybe_k_cache_sf: Optional[torch.Tensor] = None, + maybe_v_cache_sf: Optional[torch.Tensor] = None, ) -> None: pass @@ -740,6 +752,8 @@ def paged_run( maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, maybe_max_item_len_ptr, + key_block_scales, + value_block_scales, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale @@ -1071,6 +1085,9 @@ def single_prefill_with_kv_cache( rope_theta: Optional[float] = None, backend: str = "auto", return_lse: Literal[False] = False, + kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, ) -> torch.Tensor: ... @@ -1096,6 +1113,9 @@ def single_prefill_with_kv_cache( rope_theta: Optional[float] = None, backend: str = "auto", return_lse: Literal[True] = True, + kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @@ -1121,6 +1141,9 @@ def single_prefill_with_kv_cache( rope_theta: Optional[float] = None, backend: str = "auto", return_lse: bool = False, + kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Prefill/Append attention with KV cache for single request, return the attention output. @@ -1194,6 +1217,15 @@ def single_prefill_with_kv_cache( device architecture and kernel availability. return_lse : bool Whether to return the log sum exp value of the attention logits. + kv_cache_sf : Optional[Tuple[torch.Tensor, torch.Tensor]] + Per-block scale factors for NVFP4 KV cache, as a tuple of ``(k_scales, v_scales)``. + When provided, ``k`` and ``v`` are expected to be packed uint8 FP4 tensors with last + dimension ``head_dim // 2``, and the scale factors dequantize them before attention. + Both ``k_scales`` and ``v_scales`` use a linear (row-major) layout, and both have dtype ``torch.float8_e4m3fn``. + k_scale : Optional[Union[float, torch.Tensor]] + The calibration scale of key for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. + v_scale : Optional[Union[float, torch.Tensor]] + The calibration scale of value for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. Returns ------- @@ -1296,10 +1328,20 @@ def single_prefill_with_kv_cache( k.dtype, ) + # Unpack NVFP4 scale factors + k_sf, v_sf = None, None + if kv_cache_sf is not None: + k_sf, v_sf = kv_cache_sf + + if k_scale is not None: + sm_scale *= k_scale + # o_dtype should be provided for FP8 attention if o_dtype is None: o_dtype = q.dtype - out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.device) + # For NVFP4 KV (uint8 packed), last dim is head_dim//2; output uses q head_dim + out_head_dim = q.shape[-1] if kv_cache_sf is not None else v.shape[-1] + out = torch.empty(q.shape[:-1] + (out_head_dim,), dtype=o_dtype, device=q.device) module = get_single_prefill_module( backend, @@ -1307,7 +1349,7 @@ def single_prefill_with_kv_cache( k.dtype, out.dtype, q.shape[-1], # head_dim_qk - v.shape[-1], # head_dim_vo + out_head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -1335,8 +1377,17 @@ def single_prefill_with_kv_cache( scale_v, rope_scale, rope_theta, + k_sf, + v_sf, ) + is_float_one = isinstance(v_scale, float) and v_scale == 1.0 + if v_scale is not None and not is_float_one: + if is_float8(out): + out = (out.to(torch.float32) * v_scale).to(out.dtype) + else: + out *= v_scale + return (out, lse) if return_lse else out @@ -2110,7 +2161,9 @@ def run( enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, ) -> torch.Tensor: ... @@ -2128,7 +2181,9 @@ def run( enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @@ -2147,7 +2202,9 @@ def run( enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and paged kv-cache. @@ -2175,9 +2232,9 @@ def run( q_scale : Optional[Union[float, torch.Tensor]] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[Union[float, torch.Tensor]] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. + The calibration scale of key for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. v_scale : Optional[Union[float, torch.Tensor]] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + The calibration scale of value for fp8 or nvfp4 input, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] @@ -2187,16 +2244,21 @@ def run( enable_pdl : bool Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode. - kv_cache_sf : Optional[Tuple[torch.Tensor, torch.Tensor]] - Per-block scale factors for NVFP4 KV cache, as a tuple of ``(k_scales, v_scales)``. - Scale tensors must follow the same :attr:`kv_layout` as the KV cache: + kv_cache_sf : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] + Per-block scale factors for NVFP4 KV cache. Accepts the same formats as + ``paged_kv_cache``: - * **HND**: ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` - * **NHD**: ``[num_pages, page_size, num_kv_heads, head_dim // 16]`` + * a tuple ``(k_scales, v_scales)`` of 4-D tensors, each with shape: + ``[num_pages, page_size, num_kv_heads, head_dim // 16]`` if :attr:`kv_layout` is ``NHD``, + and ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` if :attr:`kv_layout` is ``HND``. + * a single 5-D tensor with shape: + ``[num_pages, 2, page_size, num_kv_heads, head_dim // 16]`` if :attr:`kv_layout` is ``NHD``, + and ``[num_pages, 2, num_kv_heads, page_size, head_dim // 16]`` if :attr:`kv_layout` is ``HND``, + where dim 1 holds k (index 0) and v (index 1) scales. Both tensors have dtype ``torch.float8_e4m3fn``. ``k_scales`` uses a linear (row-major) layout, while ``v_scales`` must use TRT-LLM's 4-token interleaved - layout within each ``[page_size, head_dim // 16]`` tile. Use + layout within each ``[page_size, head_dim // 16]`` tile if backend is `trtllm-gen`. Use :func:`flashinfer.fp4_quantization.nvfp4_quantize_paged_kv_cache` to produce correctly formatted scale factors. @@ -2237,18 +2299,11 @@ def run( k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8 ) and kv_cache_sf is None: raise ValueError("kv_cache_sf must be provided for NVFP4 KV cache.") - key_block_scales = None - value_block_scales = None - if kv_cache_sf is not None: - if ( - not isinstance(kv_cache_sf, (tuple, list)) - or len(kv_cache_sf) != 2 - or not all(torch.is_tensor(x) for x in kv_cache_sf) - ): - raise TypeError( - "kv_cache_sf must be a tuple/list of two tensors: (k_scales, v_scales)." - ) - key_block_scales, value_block_scales = kv_cache_sf + key_block_scales, value_block_scales = ( + _unpack_paged_kv_cache(kv_cache_sf, self._kv_layout) + if kv_cache_sf is not None + else (None, None) + ) o_dtype = self._cached_o_data_type if out is not None and out.dtype != o_dtype: @@ -3171,6 +3226,9 @@ def run( lse: Optional[torch.Tensor] = None, return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> torch.Tensor: ... @overload @@ -3184,6 +3242,9 @@ def run( lse: Optional[torch.Tensor] = None, return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @flashinfer_api @@ -3201,6 +3262,9 @@ def run( lse: Optional[torch.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and kv-cache stored as ragged tensor. @@ -3218,9 +3282,9 @@ def run( q_scale: Optional[float] The calibration scale of fp8 query, if not provided, will be set to ``1.0``. k_scale: Optional[float] - The calibration scale of fp8 key, if not provided, will be set to ``1.0``. + The calibration scale of fp8 or nvfp4 key, if not provided, will be set to ``1.0``. v_scale: Optional[float] - The calibration scale of fp8 value, if not provided, will be set to ``1.0``. + The calibration scale of fp8 or nvfp4 value, if not provided, will be set to ``1.0``. o_scale: Optional[float] The calibration scale of output, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] @@ -3271,6 +3335,12 @@ def run( logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) + # For NVFP4 KV, fuse q_scale and k_scale into sm_scale + if kv_cache_sf is not None: + if q_scale is not None: + sm_scale *= q_scale + if k_scale is not None: + sm_scale *= k_scale if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -3284,18 +3354,28 @@ def run( check_shape_dtype_device( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) + # Unpack kv_cache_sf for NVFP4 ragged KV + k_sf, v_sf = None, None + if kv_cache_sf is not None: + if isinstance(kv_cache_sf, tuple): + k_sf, v_sf = kv_cache_sf + else: + k_sf, v_sf = kv_cache_sf.unbind(dim=1) + + # For NVFP4 KV (uint8 packed), v last dim is head_dim//2; use q head_dim for output + out_head_dim = q.shape[-1] if kv_cache_sf is not None else v.shape[-1] if out is None: # when input dtype is fp8, we need to use bf16 output out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( - q.shape[:-1] + v.shape[-1:], + q.shape[:-1] + (out_head_dim,), dtype=out_dtype, device=q.device, ) else: check_shape_dtype_device( out, - q.shape[:-1] + v.shape[-1:], + q.shape[:-1] + (out_head_dim,), self._cached_o_data_type, q.device, "out", @@ -3430,6 +3510,8 @@ def run( rope_scale, rope_theta, self._token_pos_in_items_len, + k_sf, + v_sf, ] # For FP8, append scale tensors if is_float8(q): @@ -3437,6 +3519,12 @@ def run( assert self._cached_module is not None, "cached module is not initialized" self._cached_module.ragged_run(*run_args) + + # Apply V scaling for NVFP4 ragged KV if v_scale is provided and not equal to 1.0 + is_float_one = isinstance(v_scale, float) and v_scale == 1.0 + if kv_cache_sf is not None and v_scale is not None and not is_float_one: + out *= v_scale + return (out, lse) if return_lse else out run_return_lse = functools.partialmethod(run, return_lse=True) @@ -3861,7 +3949,9 @@ def trtllm_batch_context_with_kv_cache( kv_layout: str = "HND", enable_pdl: Optional[bool] = None, sinks: Optional[List[torch.Tensor]] = None, - kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache_sf: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, ) -> Union[torch.Tensor, FP4Tensor]: @@ -3928,12 +4018,15 @@ def trtllm_batch_context_with_kv_cache( data copy overhead. Use ``HND`` for better performance. sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. - kv_cache_sf : Optional[Tuple[torch.Tensor, torch.Tensor]] = None - Per-block scale factors for NVFP4 KV cache, as a tuple of ``(k_scales, v_scales)``. - Scale tensors must follow the same :attr:`kv_layout` as the KV cache: + kv_cache_sf : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None + Per-block scale factors for NVFP4 KV cache. Accepts the same formats as + ``kv_cache``: - * **HND**: ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` - * **NHD**: ``[num_pages, page_size, num_kv_heads, head_dim // 16]`` + * a tuple ``(k_scales, v_scales)`` of 4-D tensors, each following + :attr:`kv_layout`: ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` for + HND, or ``[num_pages, page_size, num_kv_heads, head_dim // 16]`` for NHD. + * a single 5-D tensor with shape ``[num_pages, 2, ...]`` matching the layout of + ``kv_cache``, split on dim 1 to yield k (index 0) and v (index 1) scales. Both tensors have dtype ``torch.float8_e4m3fn``. ``k_scales`` uses a linear (row-major) layout, while ``v_scales`` must use TRT-LLM's 4-token interleaved @@ -3987,18 +4080,11 @@ def trtllm_batch_context_with_kv_cache( k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8 ) and kv_cache_sf is None: raise ValueError("kv_cache_sf must be provided for NVFP4 KV cache.") - key_block_scales = None - value_block_scales = None - if kv_cache_sf is not None: - if ( - not isinstance(kv_cache_sf, (tuple, list)) - or len(kv_cache_sf) != 2 - or not all(torch.is_tensor(x) for x in kv_cache_sf) - ): - raise TypeError( - "kv_cache_sf must be a tuple/list of two tensors: (k_scales, v_scales)." - ) - key_block_scales, value_block_scales = kv_cache_sf + key_block_scales, value_block_scales = ( + _unpack_paged_kv_cache(kv_cache_sf, kv_layout) + if kv_cache_sf is not None + else (None, None) + ) # Convert NHD layout to HND if necessary if kv_layout == "NHD": diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py index 4cd5cd34f3..d0068535d7 100644 --- a/flashinfer/quantization/fp4_quantization.py +++ b/flashinfer/quantization/fp4_quantization.py @@ -83,6 +83,68 @@ def _pad_scale_factors( ).contiguous() +# E2M1 lookup table: 16 possible 4-bit values (index = 4-bit code, value = float) +# Format: bit3=sign, bits2-0=magnitude (exponent+mantissa) +_E2M1_VALUES = torch.tensor( + [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6], + dtype=torch.float32, +) +_e2m1_values_cache: dict = {} + + +def _get_e2m1_values(device: torch.device) -> torch.Tensor: + if device not in _e2m1_values_cache: + _e2m1_values_cache[device] = _E2M1_VALUES.to(device) + return _e2m1_values_cache[device] + + +def _e2m1_and_ufp8sf_scale_to_float_cpu( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor], + sf_vec_size: int, + ufp8_type: int, + is_sf_swizzled_layout: bool, +) -> torch.Tensor: + """Pure-PyTorch CPU-compatible dequantization fallback for arch < SM90. + + Only supports is_sf_swizzled_layout=False (linear SF layout). + """ + if is_sf_swizzled_layout: + raise NotImplementedError( + "CPU fallback for e2m1_and_ufp8sf_scale_to_float does not support " + "swizzled SF layout. Use a GPU with SM90+ for swizzled layout support." + ) + + device = e2m1_tensor.device + m, k_half = e2m1_tensor.shape + k = k_half * 2 + + # Unpack two E2M1 nibbles per byte: low nibble = even indices, high nibble = odd + fp4_vals = torch.empty(m, k, dtype=torch.uint8, device=device) + fp4_vals[:, 0::2] = e2m1_tensor & 0x0F + fp4_vals[:, 1::2] = (e2m1_tensor >> 4) & 0x0F + + # Map 4-bit codes to float via LUT + float_vals = _get_e2m1_values(device)[fp4_vals.long()] # [M, K] + + # Decode UFP8 scale factors + if ufp8_type == 1: + # E4M3: interpret raw bytes as float8_e4m3fn + sf_float = ufp8_scale_tensor.view(torch.float8_e4m3fn).float() + else: + # UE8M0: 2^(byte - 127) + sf_float = torch.pow(2.0, ufp8_scale_tensor.float() - 127.0) + + # Broadcast each SF over its sf_vec_size consecutive FP4 elements + sf_expanded = sf_float.repeat_interleave(sf_vec_size, dim=-1) # [M, K] + + # Apply global scale + gs = global_scale_tensor.float().item() if global_scale_tensor is not None else 1.0 + + return (float_vals * sf_expanded * gs).float() + + def gen_fp4_quantization_sm100_module() -> JitSpec: return gen_fp4_quantization_module(sm100a_nvcc_flags, "100") @@ -875,6 +937,16 @@ def e2m1_and_ufp8sf_scale_to_float( major, minor = get_compute_capability( torch.device("cuda:0") ) # select any cuda device to get a compute capability + if major * 10 + minor < 90: + # No kernel available; use pure-PyTorch fallback + return _e2m1_and_ufp8sf_scale_to_float_cpu( + e2m1_tensor, + ufp8_scale_tensor, + global_scale_tensor, + sf_vec_size, + ufp8_type, + is_sf_swizzled_layout, + ) device_arch = f"{major * 10 + minor}" return get_fp4_quantization_module( device_arch diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 986702fa3b..09c0412b66 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -426,6 +426,9 @@ def is_fa3_backend_supported( torch.float8_e5m2, }: return False + # FA3 does not support NVFP4 KV cache (uint8 packed FP4). + if dtype_kv == torch.uint8: + return False return True diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index 86c65e1843..b489b0bc5f 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -61,9 +61,11 @@ __device__ __forceinline__ void prefetch_offest( lane_idx / KV_THR_LAYOUT_COL + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i; block_size.divmod(packed_block_iter, page_iter, entry_idx); + // FP4: GMEM is packed (2 FP4/byte), so the column byte offset is halved relative to fp8 + constexpr uint32_t fp4_pack_factor = is_fp4_type_v ? 2 : 1; kv_offset[i] = (packed_block_iter < packed_kv_bound ? indices[page_iter] : 0) * kv_stride_page + entry_idx * kv_stride_n + kv_head_idx * kv_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size() / fp4_pack_factor; } } @@ -210,6 +212,14 @@ struct BlockBatchPagedAttentionPersistent { DTypeKV* k = params.k; DTypeKV* v = params.v; IdType* kv_indices = params.kv_indices; + uint8_t* maybe_k_cache_sf = nullptr; + if constexpr (has_maybe_k_cache_sf_v) { + maybe_k_cache_sf = params.maybe_k_cache_sf; + } + uint8_t* maybe_v_cache_sf = nullptr; + if constexpr (has_maybe_v_cache_sf_v) { + maybe_v_cache_sf = params.maybe_v_cache_sf; + } float* partial_lse = params.partial_lse; IdType* work_indptr = params.work_indptr; @@ -316,10 +326,18 @@ struct BlockBatchPagedAttentionPersistent { page_produce_kv(smem_storage, &k_smem_offset_w, k, kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); + page_produce_kv_sf( + smem_storage, maybe_k_cache_sf, block_iter_base + kv_tile_idx * CTA_TILE_KV, + packed_kv_bound, kv_head_idx, k_stride_page, k_stride_h, k_stride_n, block_size, + kv_indices, kv_start + kv_tile_idx * CTA_TILE_KV, kv_end, warp_idx, lane_idx); cp_async::commit_group(); page_produce_kv(smem_storage, &v_smem_offset_w, v, kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); + page_produce_kv_sf( + smem_storage, maybe_v_cache_sf, block_iter_base + kv_tile_idx * CTA_TILE_KV, + packed_kv_bound, kv_head_idx, v_stride_page, v_stride_h, v_stride_n, block_size, + kv_indices, kv_start + kv_tile_idx * CTA_TILE_KV, kv_end, warp_idx, lane_idx); cp_async::commit_group(); // loop with mask @@ -332,7 +350,11 @@ struct BlockBatchPagedAttentionPersistent { cp_async::wait_group<1>(); __syncthreads(); - compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, + smem_storage->k_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_QK, + lane_idx, s_frag); if constexpr (AttentionVariant::use_logits_soft_cap) { logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, @@ -353,16 +375,28 @@ struct BlockBatchPagedAttentionPersistent { page_produce_kv(smem_storage, &k_smem_offset_w, k, kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); + page_produce_kv_sf( + smem_storage, maybe_k_cache_sf, block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV, + packed_kv_bound, kv_head_idx, k_stride_page, k_stride_h, k_stride_n, block_size, + kv_indices, kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, kv_end, warp_idx, lane_idx); cp_async::commit_group(); cp_async::wait_group<1>(); __syncthreads(); - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + compute_sfm_v(&v_smem, &v_smem_offset_r, + smem_storage->v_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_VO, + lane_idx, s_frag, o_frag, d); __syncthreads(); page_produce_kv(smem_storage, &v_smem_offset_w, v, kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); + page_produce_kv_sf( + smem_storage, maybe_v_cache_sf, block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV, + packed_kv_bound, kv_head_idx, v_stride_page, v_stride_h, v_stride_n, block_size, + kv_indices, kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, kv_end, warp_idx, lane_idx); cp_async::commit_group(); }); cp_async::wait_group<0>(); @@ -370,7 +404,11 @@ struct BlockBatchPagedAttentionPersistent { #pragma unroll for (; kv_tile_idx >= 0; --kv_tile_idx) { - compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, + smem_storage->k_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_QK, + lane_idx, s_frag); if constexpr (AttentionVariant::use_logits_soft_cap) { logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, @@ -384,7 +422,11 @@ struct BlockBatchPagedAttentionPersistent { (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx); update_mdo_states(variant, s_frag, o_frag, m, d); - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + compute_sfm_v(&v_smem, &v_smem_offset_r, + smem_storage->v_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_VO, + lane_idx, s_frag, o_frag, d); } __syncthreads(); diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 5db013bb03..3e8f0e9b6f 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -20,6 +20,9 @@ #include #include #include +#if CUDA_VERSION >= 12080 +#include +#endif #include #include "../cp_async.cuh" @@ -45,12 +48,26 @@ DEFINE_HAS_MEMBER(maybe_prefix_len_ptr) DEFINE_HAS_MEMBER(maybe_token_pos_in_items_ptr) DEFINE_HAS_MEMBER(token_pos_in_items_len) DEFINE_HAS_MEMBER(maybe_max_item_len_ptr) +DEFINE_HAS_MEMBER(maybe_k_cache_sf) +DEFINE_HAS_MEMBER(maybe_v_cache_sf) + +// Type trait to detect packed NVFP4 KV cache types (__nv_fp4x2_e2m1 stores 2 FP4 per byte). +template +struct is_fp4_type : std::false_type {}; +#if CUDA_VERSION >= 12080 +template <> +struct is_fp4_type<__nv_fp4x2_e2m1> : std::true_type {}; +#endif +template +inline constexpr bool is_fp4_type_v = is_fp4_type::value; namespace cg = cooperative_groups; using cp_async::SharedMemFillMode; using mma::MMAMode; constexpr uint32_t WARP_SIZE = 32; +// Number of NVFP4 elements sharing one scale factor (UE4M3 byte). +constexpr uint32_t NVFP4_SF_VEC_SIZE = 16; constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) { if (cta_tile_q > 16) { @@ -90,6 +107,14 @@ struct SharedStorageQKVO { }; alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; }; + // Scale factors for NVFP4 KV cache: one UE4M3 byte per NVFP4_SF_VEC_SIZE elements. + // Sized to 1 when DTypeKV is not FP4 to avoid wasting shared memory. + alignas(16) std::conditional_t, + uint8_t[CTA_TILE_KV * HEAD_DIM_QK / NVFP4_SF_VEC_SIZE], + uint8_t[1]> k_sf_smem; + alignas(16) std::conditional_t, + uint8_t[CTA_TILE_KV * HEAD_DIM_VO / NVFP4_SF_VEC_SIZE], + uint8_t[1]> v_sf_smem; }; template smem const uint32_t kv_len, const dim3 tid = threadIdx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment using DTypeKV = typename KTraits::DTypeKV; + constexpr bool IS_FP4 = is_fp4_type_v; constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; @@ -295,15 +321,21 @@ __device__ __forceinline__ void produce_kv(smem_t smem for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + // FP4 GMEM rows are packed 2x denser; load 64b (upper 64b of smem slot zeroed). + if constexpr (IS_FP4) { + smem.template load_64b_async(*smem_offset, *gptr, kv_idx < kv_len); + } else { + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + } *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * upcast_size(); + *gptr += (IS_FP4 ? 4 : 8) * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = smem.template advance_offset_by_row(*smem_offset) - sizeof(DTypeKV) * NUM_MMA_D; - *gptr += NUM_WARPS * 4 * stride_n - sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); + *gptr += NUM_WARPS * 4 * stride_n - + (IS_FP4 ? 4 : 8) * upcast_size() * (NUM_MMA_D / (8 / sizeof(DTypeKV))); } *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; } else { @@ -312,7 +344,12 @@ __device__ __forceinline__ void produce_kv(smem_t smem static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + // FP4 GMEM rows are packed 2x denser; load 64b (upper 64b of smem slot zeroed). + if constexpr (IS_FP4) { + smem.template load_64b_async(*smem_offset, *gptr, kv_idx < kv_len); + } else { + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + } *smem_offset = smem.template advance_offset_by_row(*smem_offset); kv_idx += NUM_WARPS * 8; @@ -342,6 +379,10 @@ __device__ __forceinline__ void page_produce_kv(typename KTraits::SharedStorage* constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + // FP4 stores 2 elements per byte in GMEM (packed); SMEM uses 64b data + 64b zero per 128b slot. + // Use a 64b async load (cp.async with src-size=8) and advance GMEM pointer by half the normal + // amount, while SMEM addressing remains unchanged. + constexpr bool IS_FP4 = is_fp4_type_v; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps @@ -351,9 +392,15 @@ __device__ __forceinline__ void page_produce_kv(typename KTraits::SharedStorage* DType* gptr = kv_ptr + thr_local_kv_offset[i]; #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { - smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + if constexpr (IS_FP4) { + // Load 64b from packed GMEM into lower 64b of 128b SMEM slot (upper 64b zeroed) + smem.load_64b_async(*smem_offset, gptr, kv_idx < kv_len); + } else { + smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + } *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - gptr += 8 * upcast_size(); + // FP4: GMEM row is HEAD_DIM/2 bytes wide (packed), so advance by half + gptr += (IS_FP4 ? 4 : 8) * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = @@ -368,7 +415,11 @@ __device__ __forceinline__ void page_produce_kv(typename KTraits::SharedStorage* #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { DType* gptr = kv_ptr + thr_local_kv_offset[i]; - smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + if constexpr (IS_FP4) { + smem.load_64b_async(*smem_offset, gptr, kv_idx < kv_len); + } else { + smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + } kv_idx += NUM_WARPS * 8; *smem_offset = smem.template advance_offset_by_row(*smem_offset); @@ -377,6 +428,164 @@ __device__ __forceinline__ void page_produce_kv(typename KTraits::SharedStorage* } } +/*! + * \brief Load NVFP4 KV scale-factors for one CTA tile into shared memory. + * + * Uses a fixed thread mapping independent of KV swizzle mode: each thread + * (thread_id = warp_idx * 32 + lane_idx) issues a 32-bit LDGSTS to load 4 consecutive + * SF bytes per iteration, advancing by NUM_WARPS * 128 bytes across iterations. + * The SF smem layout is a plain flat byte array — no swizzle. + * + * SF strides are KV byte strides divided by SF_CONTAINERS (= NVFP4_SF_VEC_SIZE/2 = 8), + * which is exact because all NVFP4-compatible head_dims are divisible by 16. + * No-op when KTraits::DTypeKV is not FP4. + * + * \tparam produce_v true → fill v_sf_smem, false → fill k_sf_smem. + * \tparam KTraits Kernel traits type. + * \tparam IdType Page index type (deduced from indices). + * \param smem_storage Shared storage holding k_sf_smem / v_sf_smem. + * \param sf_ptr Base pointer to the flat uint8_t SF array (K or V). + * \param packed_page_iter_base Packed page-iter for the start of this CTA tile. + * \param packed_kv_bound Upper bound for valid packed page-iters (last_indptr * page_size). + * \param kv_head_idx KV head index. + * \param kv_stride_page Byte stride per page in the KV tensor. + * \param kv_stride_h Byte stride per head in the KV tensor. + * \param kv_stride_n Byte stride per token in the KV tensor. + * \param page_size Page size (fast divisor). + * \param indices Page index array. + * \param kv_idx_base First KV row index for this tile within the chunk. + * \param kv_len Chunk size; rows at or beyond this are not loaded. + * \param warp_idx Global warp index within the CTA. + * \param lane_idx Lane index within the warp. + */ +template +__device__ __forceinline__ void page_produce_kv_sf( + typename KTraits::SharedStorage* smem_storage, uint8_t* sf_ptr, + const uint32_t packed_page_iter_base, const uint32_t packed_kv_bound, + const uint32_t kv_head_idx, const uint32_t kv_stride_page, const uint32_t kv_stride_h, + const uint32_t kv_stride_n, const uint_fastdiv& page_size, const IdType* indices, + const uint32_t kv_idx_base, const uint32_t kv_len, const uint32_t warp_idx, + const uint32_t lane_idx) { + if constexpr (!is_fp4_type_v) return; + + constexpr uint32_t HEAD_DIM = produce_v ? KTraits::HEAD_DIM_VO : KTraits::HEAD_DIM_QK; + constexpr uint32_t SF_COLS = HEAD_DIM / NVFP4_SF_VEC_SIZE; // SF bytes per KV row + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + // DTypeKV containers per SF byte: NVFP4_SF_VEC_SIZE FP4 / 2 FP4-per-container. + constexpr uint32_t SF_CONTAINERS = NVFP4_SF_VEC_SIZE / 2; // = 8 + constexpr uint32_t SF_TOTAL_BYTES = CTA_TILE_KV * SF_COLS; + static_assert(SF_TOTAL_BYTES % 4 == 0, "SF smem size must be 4-byte aligned for 32-bit LDGSTS"); + // Each thread loads 4 SF bytes (32 bits) per iteration via LDGSTS.32. + constexpr uint32_t THREADS_PER_CTA = NUM_WARPS * 32; + constexpr uint32_t NUM_SF_ITERS = (SF_TOTAL_BYTES / 4 + THREADS_PER_CTA - 1) / THREADS_PER_CTA; + + uint8_t* sf_smem = produce_v ? smem_storage->v_sf_smem : smem_storage->k_sf_smem; + const uint32_t thread_id = warp_idx * 32 + lane_idx; + +#pragma unroll + for (uint32_t k = 0; k < NUM_SF_ITERS; ++k) { + const uint32_t flat_uint32_idx = thread_id + k * THREADS_PER_CTA; + const uint32_t flat_byte = flat_uint32_idx * 4; + // sf_smem_col is 4-byte aligned: flat_byte is a multiple of 4, and SF_COLS is a power of 2 + // (HEAD_DIM / 16), so flat_byte % SF_COLS is always a multiple of 4 (or 0 when SF_COLS < 4). + const uint32_t sf_smem_row = flat_byte / SF_COLS; + const uint32_t sf_smem_col = flat_byte % SF_COLS; + // For k < NUM_SF_ITERS-1, (flat_byte < SF_TOTAL_BYTES) is always true (optimized away). + const bool in_bounds = (flat_byte < SF_TOTAL_BYTES) && (kv_idx_base + sf_smem_row < kv_len); + + // SF strides are KV byte strides / SF_CONTAINERS (1 SF byte per SF_CONTAINERS KV containers). + // packed_kv_bound guards indices[] access; returns offset 0 for out-of-range rows. + uint32_t page_iter, entry_idx; + const uint32_t packed_block_iter = packed_page_iter_base + sf_smem_row; + page_size.divmod(packed_block_iter, page_iter, entry_idx); + const size_t sf_gmem_offset = + static_cast(packed_block_iter < packed_kv_bound ? indices[page_iter] : 0) * + (kv_stride_page / SF_CONTAINERS) + + kv_head_idx * (kv_stride_h / SF_CONTAINERS) + entry_idx * (kv_stride_n / SF_CONTAINERS) + + sf_smem_col; + + // V SF must zero-fill out-of-bounds entries: compute_sfm_v reads SF for all CTA_TILE_KV rows + // including padding, and 0 (softmax weight) * NaN (uninitialized SF) = NaN (IEEE 754). + // K SF can use kNoFill since NaN K scores are replaced by -inf via logits_mask before + // update_mdo_states, so they never reach the accumulator. + constexpr auto fill_mode = + produce_v ? cp_async::SharedMemFillMode::kFillZero : cp_async::SharedMemFillMode::kNoFill; + cp_async::pred_load_32b(reinterpret_cast(sf_smem + flat_byte), + reinterpret_cast(sf_ptr + sf_gmem_offset), + in_bounds); + } +} + +/*! + * \brief Load NVFP4 KV scale-factors for one CTA tile (contiguous/ragged layout). + * + * Contiguous analog of page_produce_kv_sf — no page indirection. + * kv_abs_base is the absolute first token index for this CTA tile + * (kv_indptr[request_idx] + chunk_start for ragged, chunk_start for single prefill). + * SF strides are KV byte strides / SF_CONTAINERS (exact for all valid head_dims). + * No-op when DTypeKV is not FP4. + * + * \tparam produce_v true → fill v_sf_smem, false → fill k_sf_smem. + * \tparam KTraits Kernel traits type. + * \param smem_storage Shared storage holding k_sf_smem / v_sf_smem. + * \param sf_ptr Base pointer to the flat uint8_t SF array (K or V). + * \param kv_abs_base Absolute first token index for this CTA tile. + * \param kv_head_idx KV head index. + * \param kv_stride_n Byte stride per token in the KV tensor. + * \param kv_stride_h Byte stride per head in the KV tensor. + * \param kv_idx_base First KV row index for this tile within the chunk. + * \param kv_len Chunk size; rows at or beyond this are not loaded. + * \param warp_idx Global warp index within the CTA. + * \param lane_idx Lane index within the warp. + */ +template +__device__ __forceinline__ void produce_kv_sf(typename KTraits::SharedStorage* smem_storage, + uint8_t* sf_ptr, const uint32_t kv_abs_base, + const uint32_t kv_head_idx, + const uint32_t kv_stride_n, + const uint32_t kv_stride_h, + const uint32_t kv_idx_base, const uint32_t kv_len, + const uint32_t warp_idx, const uint32_t lane_idx) { + if constexpr (!is_fp4_type_v) return; + + constexpr uint32_t HEAD_DIM = produce_v ? KTraits::HEAD_DIM_VO : KTraits::HEAD_DIM_QK; + constexpr uint32_t SF_COLS = HEAD_DIM / NVFP4_SF_VEC_SIZE; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + // DTypeKV containers per SF byte: NVFP4_SF_VEC_SIZE FP4 / 2 FP4-per-container. + constexpr uint32_t SF_CONTAINERS = NVFP4_SF_VEC_SIZE / 2; // = 8 + constexpr uint32_t SF_TOTAL_BYTES = CTA_TILE_KV * SF_COLS; + static_assert(SF_TOTAL_BYTES % 4 == 0, "SF smem size must be 4-byte aligned for 32-bit LDGSTS"); + // Each thread loads 4 SF bytes (32 bits) per iteration via LDGSTS.32. + constexpr uint32_t THREADS_PER_CTA = NUM_WARPS * 32; + constexpr uint32_t NUM_SF_ITERS = (SF_TOTAL_BYTES / 4 + THREADS_PER_CTA - 1) / THREADS_PER_CTA; + + uint8_t* sf_smem = produce_v ? smem_storage->v_sf_smem : smem_storage->k_sf_smem; + const uint32_t thread_id = warp_idx * 32 + lane_idx; + const uint32_t sf_stride_n = kv_stride_n / SF_CONTAINERS; + const uint32_t sf_stride_h = kv_stride_h / SF_CONTAINERS; + +#pragma unroll + for (uint32_t i = 0; i < NUM_SF_ITERS; ++i) { + const uint32_t flat_byte = (thread_id + i * THREADS_PER_CTA) * 4; + const uint32_t sf_smem_row = flat_byte / SF_COLS; + const uint32_t sf_smem_col = flat_byte % SF_COLS; + const uint32_t abs_kv_row = kv_idx_base + sf_smem_row; + const bool in_bounds = (flat_byte < SF_TOTAL_BYTES) && (abs_kv_row < kv_len); + const size_t sf_gmem_offset = + in_bounds ? (static_cast(kv_abs_base + abs_kv_row) * sf_stride_n + + kv_head_idx * sf_stride_h + sf_smem_col) + : 0; + // Same rationale as page_produce_kv_sf: zero-fill V SF to prevent 0*NaN=NaN in compute_sfm_v. + constexpr auto fill_mode = + produce_v ? cp_async::SharedMemFillMode::kFillZero : cp_async::SharedMemFillMode::kNoFill; + cp_async::pred_load_32b(reinterpret_cast(sf_smem + flat_byte), + reinterpret_cast(sf_ptr + sf_gmem_offset), + in_bounds); + } +} + template __device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const float rope_rcp_scale, const float rope_rcp_theta, @@ -614,8 +823,8 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( template __device__ __forceinline__ void compute_qk( smem_t* q_smem, uint32_t* q_smem_offset_r, - smem_t* k_smem, uint32_t* k_smem_offset_r, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8]) { + smem_t* k_smem, uint32_t* k_smem_offset_r, uint8_t* k_sf_smem, + uint32_t lane_idx, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8]) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; uint32_t a_frag[KTraits::NUM_MMA_Q][4], b_frag[4]; @@ -635,16 +844,40 @@ __device__ __forceinline__ void compute_qk( #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { - uint32_t b_frag_f8[2]; + uint32_t b_frag_quant[2]; if (mma_d % 2 == 0) { - k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_f8); + k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_quant); + } else { + k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, b_frag_quant); + } + if constexpr (is_fp4_type_v) { + b_frag_quant[0] = frag_layout_swizzle_16b_to_4b(b_frag_quant[0]); + b_frag_quant[1] = frag_layout_swizzle_16b_to_4b(b_frag_quant[1]); } else { - k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, b_frag_f8); + b_frag_quant[0] = frag_layout_swizzle_16b_to_8b(b_frag_quant[0]); + b_frag_quant[1] = frag_layout_swizzle_16b_to_8b(b_frag_quant[1]); } - b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); - b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); vec_cast::cast<8>( - (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); + (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_quant); + if constexpr (is_fp4_type_v) { + // Apply scaling factors for K. + // SF smem is linear: sf[kv_row * SF_COLS + hd_group], SF_COLS = HEAD_DIM_QK/16. + // For m16n8k16 B layout, thread t's KV rows are t/4 and t/4+8 in the mma_kv tile. + // b_frag[0,1] share KV row (t/4), b_frag[2,3] share KV row (t/4+8). + using DTypeQ_ = typename KTraits::DTypeQ; + using packed2_ = std::conditional_t, half2, __nv_bfloat162>; + constexpr uint32_t SF_COLS_K = KTraits::NUM_MMA_D_QK; // HEAD_DIM_QK / 16 + uint32_t sf_base = (mma_kv * 16 + lane_idx / 4) * SF_COLS_K + mma_d; + __nv_fp8_e4m3 sf_a_fp8, sf_b_fp8; + sf_a_fp8.__x = k_sf_smem[sf_base]; + sf_b_fp8.__x = k_sf_smem[sf_base + 8 * SF_COLS_K]; + packed2_ scale_a{static_cast(sf_a_fp8), static_cast(sf_a_fp8)}; + packed2_ scale_b{static_cast(sf_b_fp8), static_cast(sf_b_fp8)}; + *(packed2_*)&b_frag[0] = __hmul2(*(packed2_*)&b_frag[0], scale_a); + *(packed2_*)&b_frag[1] = __hmul2(*(packed2_*)&b_frag[1], scale_a); + *(packed2_*)&b_frag[2] = __hmul2(*(packed2_*)&b_frag[2], scale_b); + *(packed2_*)&b_frag[3] = __hmul2(*(packed2_*)&b_frag[3], scale_b); + } } else { k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); } @@ -954,8 +1187,8 @@ __device__ __forceinline__ void update_mdo_states( template __device__ __forceinline__ void compute_sfm_v( - smem_t* v_smem, uint32_t* v_smem_offset_r, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], + smem_t* v_smem, uint32_t* v_smem_offset_r, uint8_t* v_sf_smem, + uint32_t lane_idx, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], float (*o_frag)[KTraits::NUM_MMA_D_VO][8], float (*d)[2]) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; @@ -991,17 +1224,45 @@ __device__ __forceinline__ void compute_sfm_v( for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { uint32_t b_frag[4]; if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { - uint32_t b_frag_f8[2]; + uint32_t b_frag_quant[2]; if (mma_d % 2 == 0) { - v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); + v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_quant); } else { - v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); + v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_quant); + } + + if constexpr (is_fp4_type_v) { + b_frag_quant[0] = frag_layout_swizzle_16b_to_4b_trans(b_frag_quant[0]); + b_frag_quant[1] = frag_layout_swizzle_16b_to_4b_trans(b_frag_quant[1]); + } else { + b_frag_quant[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_quant[0]); + b_frag_quant[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_quant[1]); } - b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); - b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); vec_cast::cast<8>( - (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); + (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_quant); swap(b_frag[1], b_frag[2]); + if constexpr (is_fp4_type_v) { + // Apply scaling factors for V. + // SF smem is linear: sf[kv_row * SF_COLS + hd_group], SF_COLS = HEAD_DIM_VO/16. + // For transposed B (V), thread t's KV rows are 2*(t%4)+{0,1} and 2*(t%4)+{8,9} + // in the mma_kv tile. After swap, b_frag[0,2] cover rows {r0, r0+1} and + // b_frag[1,3] cover rows {r0+8, r0+9}. Each half2 needs two distinct SFs. + using DTypeQ_ = typename KTraits::DTypeQ; + using packed2_ = std::conditional_t, half2, __nv_bfloat162>; + constexpr uint32_t SF_COLS_V = KTraits::NUM_MMA_D_VO; // HEAD_DIM_VO / 16 + uint32_t sf_base = (mma_kv * 16 + 2 * (lane_idx % 4)) * SF_COLS_V + mma_d; + __nv_fp8_e4m3 sf0_fp8, sf1_fp8, sf2_fp8, sf3_fp8; + sf0_fp8.__x = v_sf_smem[sf_base]; + sf1_fp8.__x = v_sf_smem[sf_base + SF_COLS_V]; + sf2_fp8.__x = v_sf_smem[sf_base + 8 * SF_COLS_V]; + sf3_fp8.__x = v_sf_smem[sf_base + 9 * SF_COLS_V]; + packed2_ scale_lo{static_cast(sf0_fp8), static_cast(sf1_fp8)}; + packed2_ scale_hi{static_cast(sf2_fp8), static_cast(sf3_fp8)}; + *(packed2_*)&b_frag[0] = __hmul2(*(packed2_*)&b_frag[0], scale_lo); + *(packed2_*)&b_frag[1] = __hmul2(*(packed2_*)&b_frag[1], scale_hi); + *(packed2_*)&b_frag[2] = __hmul2(*(packed2_*)&b_frag[2], scale_lo); + *(packed2_*)&b_frag[3] = __hmul2(*(packed2_*)&b_frag[3], scale_hi); + } } else { v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); } @@ -1381,6 +1642,15 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( const uint32_t v_stride_h = params.v_stride_h; const uint_fastdiv& group_size = params.group_size; + uint8_t* maybe_k_cache_sf = nullptr; + if constexpr (has_maybe_k_cache_sf_v) { + maybe_k_cache_sf = params.maybe_k_cache_sf; + } + uint8_t* maybe_v_cache_sf = nullptr; + if constexpr (has_maybe_v_cache_sf_v) { + maybe_v_cache_sf = params.maybe_v_cache_sf; + } + static_assert(sizeof(DTypeQ) == 2); const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); const uint32_t num_qo_heads = num_kv_heads * group_size; @@ -1455,14 +1725,17 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( : chunk_size) / CTA_TILE_KV; + constexpr uint32_t fp4_pack = is_fp4_type_v ? 2 : 1; DTypeKV* k_ptr = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + - kv_head_idx * k_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size() / fp4_pack; DTypeKV* v_ptr = v + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + - kv_head_idx * v_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size() / fp4_pack; uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + @@ -1476,11 +1749,17 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( v_smem_offset_w = v_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL); + // For single prefill, the absolute KV base is just chunk_start (no kv_indptr offset). + const uint32_t kv_abs_base = chunk_start; produce_kv(k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_k_cache_sf, kv_abs_base, kv_head_idx, + k_stride_n, k_stride_h, 0, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_v_cache_sf, kv_abs_base, kv_head_idx, + v_stride_n, v_stride_h, 0, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); #pragma unroll 1 @@ -1495,7 +1774,11 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, + smem_storage.k_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_QK, + lane_idx, s_frag); uint32_t kv_idx_base = chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16; logits_transform(params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_idx_base, @@ -1513,16 +1796,26 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_k_cache_sf, kv_abs_base, kv_head_idx, + k_stride_n, k_stride_h, (iter + 1) * CTA_TILE_KV, chunk_size, + warp_idx, lane_idx); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + compute_sfm_v(&v_smem, &v_smem_offset_r, + smem_storage.v_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_VO, + lane_idx, s_frag, o_frag, d); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_v_cache_sf, kv_abs_base, kv_head_idx, + v_stride_n, v_stride_h, (iter + 1) * CTA_TILE_KV, chunk_size, + warp_idx, lane_idx); cp_async::commit_group(); } cp_async::wait_group<0>(); @@ -1769,6 +2062,15 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV const uint32_t v_stride_h = params.v_stride_h; const uint_fastdiv& group_size = params.group_size; + uint8_t* maybe_k_cache_sf = nullptr; + if constexpr (has_maybe_k_cache_sf_v) { + maybe_k_cache_sf = params.maybe_k_cache_sf; + } + uint8_t* maybe_v_cache_sf = nullptr; + if constexpr (has_maybe_v_cache_sf_v) { + maybe_v_cache_sf = params.maybe_v_cache_sf; + } + static_assert(sizeof(DTypeQ) == 2); const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); const dim3& tid = threadIdx; @@ -1800,6 +2102,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * max_chunk_size + kv_start_idx, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; + // Absolute first token index for this CTA tile (used by produce_kv_sf). + const uint32_t kv_abs_base = kv_indptr[request_idx] + chunk_start; DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; @@ -1896,24 +2200,29 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL); + constexpr uint32_t fp4_pack = is_fp4_type_v ? 2 : 1; DTypeKV* k_ptr = k + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size() / fp4_pack; DTypeKV* v_ptr = v + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size() / fp4_pack; produce_kv(k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_k_cache_sf, kv_abs_base, kv_head_idx, + k_stride_n, k_stride_h, 0, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_v_cache_sf, kv_abs_base, kv_head_idx, + v_stride_n, v_stride_h, 0, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); #pragma unroll 1 @@ -1934,7 +2243,11 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, + smem_storage.k_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_QK, + lane_idx, s_frag); uint32_t kv_idx_base = chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16; logits_transform(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, @@ -1953,16 +2266,26 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_k_cache_sf, kv_abs_base, kv_head_idx, + k_stride_n, k_stride_h, (iter + 1) * CTA_TILE_KV, chunk_size, + warp_idx, lane_idx); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + compute_sfm_v(&v_smem, &v_smem_offset_r, + smem_storage.v_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_VO, + lane_idx, s_frag, o_frag, d); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + produce_kv_sf(&smem_storage, maybe_v_cache_sf, kv_abs_base, kv_head_idx, + v_stride_n, v_stride_h, (iter + 1) * CTA_TILE_KV, chunk_size, + warp_idx, lane_idx); cp_async::commit_group(); } cp_async::wait_group<0>(); @@ -2086,6 +2409,14 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( if constexpr (has_maybe_max_item_len_ptr_v) { maybe_max_item_len_ptr = params.maybe_max_item_len_ptr; } + uint8_t* maybe_k_cache_sf = nullptr; + if constexpr (has_maybe_k_cache_sf_v) { + maybe_k_cache_sf = params.maybe_k_cache_sf; + } + uint8_t* maybe_v_cache_sf = nullptr; + if constexpr (has_maybe_v_cache_sf_v) { + maybe_v_cache_sf = params.maybe_v_cache_sf; + } static_assert(sizeof(DTypeQ) == 2); auto block = cg::this_thread_block(); @@ -2197,15 +2528,25 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( lane_idx / KV_THR_LAYOUT_COL + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, page_iter, entry_idx); + // FP4: GMEM is packed (2 FP4/byte), so the column byte offset is halved relative to fp8 + constexpr uint32_t fp4_pack_factor = is_fp4_type_v ? 2 : 1; thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size() / fp4_pack_factor, last_indptr); } page_produce_kv(&smem_storage, &k_smem_offset_w, paged_kv.k_data, 0, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); + page_produce_kv_sf( + &smem_storage, maybe_k_cache_sf, packed_page_iter_base, last_indptr * paged_kv.page_size.d, + kv_head_idx, paged_kv.stride_page, paged_kv.stride_h, paged_kv.stride_n, paged_kv.page_size, + paged_kv.indices, 0, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); page_produce_kv(&smem_storage, &v_smem_offset_w, paged_kv.v_data, 0, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); + page_produce_kv_sf( + &smem_storage, maybe_v_cache_sf, packed_page_iter_base, last_indptr * paged_kv.page_size.d, + kv_head_idx, paged_kv.stride_page, paged_kv.stride_h, paged_kv.stride_n, paged_kv.page_size, + paged_kv.indices, 0, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); uint32_t num_iterations_prefix; @@ -2281,9 +2622,11 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( lane_idx / KV_THR_LAYOUT_COL + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, page_iter, entry_idx); + // FP4: GMEM is packed (2 FP4/byte), so the column byte offset is halved relative to fp8 + constexpr uint32_t fp4_pack_factor = is_fp4_type_v ? 2 : 1; thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size() / fp4_pack_factor, last_indptr); } cp_async::wait_group<1>(); block.sync(); @@ -2297,7 +2640,11 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, + smem_storage.k_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_QK, + lane_idx, s_frag); uint32_t kv_idx_base = chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16; logits_transform(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, @@ -2340,17 +2687,31 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( page_produce_kv(&smem_storage, &k_smem_offset_w, paged_kv.k_data, (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); + page_produce_kv_sf(&smem_storage, maybe_k_cache_sf, packed_page_iter_base, + last_indptr * paged_kv.page_size.d, kv_head_idx, + paged_kv.stride_page, paged_kv.stride_h, paged_kv.stride_n, + paged_kv.page_size, paged_kv.indices, + (iter + 1) * CTA_TILE_KV, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + compute_sfm_v(&v_smem, &v_smem_offset_r, + smem_storage.v_sf_smem + get_warp_idx_kv(tid.z) * + KTraits::NUM_MMA_KV * 16 * + KTraits::NUM_MMA_D_VO, + lane_idx, s_frag, o_frag, d); block.sync(); page_produce_kv(&smem_storage, &v_smem_offset_w, paged_kv.v_data, (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); + page_produce_kv_sf(&smem_storage, maybe_v_cache_sf, packed_page_iter_base, + last_indptr * paged_kv.page_size.d, kv_head_idx, + paged_kv.stride_page, paged_kv.stride_h, paged_kv.stride_n, + paged_kv.page_size, paged_kv.indices, + (iter + 1) * CTA_TILE_KV, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); } cp_async::wait_group<0>(); diff --git a/include/flashinfer/cp_async.cuh b/include/flashinfer/cp_async.cuh index bd59cc58e3..4ea1de2f4a 100644 --- a/include/flashinfer/cp_async.cuh +++ b/include/flashinfer/cp_async.cuh @@ -182,6 +182,82 @@ __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool p } } +/*! + * \brief Like pred_load_128b but reads only 64 bits from global memory into the lower 64 bits of + * the 128-bit shared memory destination, zero-padding the upper 64 bits when predicate is true. + * Used for NVFP4 KV loading: GMEM stores 2 FP4 elements per byte (packed), so each SMEM + * 128-bit slot is loaded from 64 GMEM bits and padded with 64 bits of zeros. + */ +template +__device__ __forceinline__ void pred_load_128b_from_64b(T* smem_ptr, const T* gmem_ptr, + bool predicate) { +#ifdef FLASHINFER_CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(8), "r"(src_in_bytes)); + + } else { + // kNoFill: only issue the copy if predicate is true; cp.async always zeros the upper 8 bytes + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3, %4;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(8), "n"(8)); + } +#else + if (predicate) { + uint64_t* smem_u64 = reinterpret_cast(smem_ptr); + smem_u64[0] = *reinterpret_cast(gmem_ptr); + smem_u64[1] = 0; + } else { + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + *((uint4*)smem_ptr) = make_uint4(0, 0, 0, 0); + } + } +#endif +} + +/*! + * \brief Asynchronously copy 4 bytes from global memory to shared memory (LDGSTS.32). + * Uses cp.async.ca.shared.global with a 4-byte transfer size. + * When fill_mode==kFillZero and predicate is false, writes 0 to shared memory. + * When fill_mode==kNoFill and predicate is false, no operation is issued. + * \tparam fill_mode Whether to fill zero to shared memory when predicate is false. + * \param smem_ptr 4-byte aligned shared memory destination. + * \param gmem_ptr Global memory source. + * \param predicate Predicate value. + */ +template +__device__ __forceinline__ void pred_load_32b(uint32_t* smem_ptr, const uint32_t* gmem_ptr, + bool predicate) { +#ifdef FLASHINFER_CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(4), "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(4)); + } +#else + if (predicate) { + *smem_ptr = *gmem_ptr; + } else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + *smem_ptr = 0; + } +#endif +} + } // namespace cp_async } // namespace flashinfer diff --git a/include/flashinfer/frag_layout_swizzle.cuh b/include/flashinfer/frag_layout_swizzle.cuh index 39cf92bcd9..68396025d1 100644 --- a/include/flashinfer/frag_layout_swizzle.cuh +++ b/include/flashinfer/frag_layout_swizzle.cuh @@ -38,4 +38,38 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t return x; } +// Convert 16b fragment layout to 4b fragment layout. +__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_4b(uint32_t x) { + // Broadcast from the thread 0 of each group of 4 (thread t gets value from thread t & ~3) + uint32_t tmp0 = __shfl_sync(0xffffffff, x, threadIdx.x & ~0x3u); + // Similarly, broadcast from the thread 1 of each group of 4 + uint32_t tmp1 = __shfl_sync(0xffffffff, x, (threadIdx.x & ~0x3u) + 1); + // Select byte i = (threadIdx.x % 4) of each register and assemble them together. + uint32_t byte_idx = threadIdx.x & 0x3u; + x = __byte_perm(tmp0, tmp1, byte_idx * 0x0101u + 0x0400u); + return x; +} + +// Convert transposed 16b fragment layout to 4b (NVfp4) fragment layout. +// Counterpart to frag_layout_swizzle_16b_to_4b for the column-major (transposed) case. +__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_4b_trans(uint32_t x) { + // Shuffle the data across threads. We group threads in a stride of 4: {i, i+4, i+8, i+12, ..., + // i+28} (i in {0,1,2,3}). Thread {i, i+4, i+8, i+12} receives data from thread i and i+8. Thread + // {i+16, i+20, i+24, i+28} receives data from thread i+4 and i+12. + unsigned src_thrd = (threadIdx.x & ~0x1cu) + ((threadIdx.x & 0x10u) >> 2); + uint32_t tmp0 = __shfl_sync(0xffffffff, x, src_thrd); + uint32_t tmp1 = __shfl_sync(0xffffffff, x, src_thrd + 8u); + // Select byte. Thread ((i / 8) % 2 == 0) selects [6,4,2,0] + // Thread ((i / 8) % 2 == 1) selects [7,5,3,1]. + uint32_t select_code = (threadIdx.x & 0x8u) ? 0x7531u : 0x6420u; + uint32_t tmp = __byte_perm(tmp0, tmp1, select_code); + // Right-shift by 4 bits to align 4b nibbles to the correct place. + tmp = tmp >> (threadIdx.x & 0x4u); + // At this point the 4b data are distributed in individual bytes. + // Pack them into byte 0 and byte 2 for efficient data conversion. + tmp = tmp & 0x0F0F0F0F; + tmp = tmp | (tmp >> 4); + return tmp; +} + #endif // FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index a63283ebdc..d0b4f71e18 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -173,6 +173,13 @@ struct smem_t { reinterpret_cast(gptr)); } + template + __device__ __forceinline__ void load_64b_async(uint32_t offset, const T* gptr, bool predicate) { + b128_t* smem_ptr = base + offset; + cp_async::pred_load_128b_from_64b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + template __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) { *reinterpret_cast(gptr) = *(base + offset); diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index 0e8fd468fe..25c3b6fc60 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -414,6 +414,137 @@ struct vec_cast { } }; +#if defined(FLASHINFER_ENABLE_FP4_E2M1) && CUDA_VERSION >= 12080 +// Convert __nv_fp4x2_e2m1 (2 fp4 values per byte) to fp16. +// vec_size counts fp16 output elements; src has stride-2 layout: +// src[0] holds x0,x1 src[1] is padding +// src[2] holds x2,x3 src[3] is padding ... etc. +// Each valid byte encodes 2 fp4 values -> 2 fp16 via cvt.rn.f16x2.e2m1x2. +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const __nv_fp4x2_e2m1* src) { + static_assert(vec_size % 2 == 0, "vec_size must be even for fp4x2 dequantization"); +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + // Valid fp4x2 bytes are at even positions (stride 2); odd positions are padding. + uint32_t b = reinterpret_cast(src)[i * 2]; + asm volatile( + "{\n" + ".reg .b8 fp4_byte;\n" + "mov.b32 {fp4_byte, _, _, _}, %1;\n" + "cvt.rn.f16x2.e2m1x2 %0, fp4_byte;\n" + "}" + : "=r"(y) + : "r"(b)); + reinterpret_cast(dst)[i] = y; + } +#else + // Software LUT fallback for arch < SM100. + // e2m1 encoding: bit[3]=sign, bit[2:0]=magnitude index in {0,0.5,1,1.5,2,3,4,6}. + // Each packed byte holds two fp4 values: bits[3:0]=first, bits[7:4]=second. + constexpr uint16_t lut[16] = { + 0x0000, // +0.0 + 0x3800, // +0.5 + 0x3C00, // +1.0 + 0x3E00, // +1.5 + 0x4000, // +2.0 + 0x4200, // +3.0 + 0x4400, // +4.0 + 0x4600, // +6.0 + 0x8000, // -0.0 + 0xB800, // -0.5 + 0xBC00, // -1.0 + 0xBE00, // -1.5 + 0xC000, // -2.0 + 0xC200, // -3.0 + 0xC400, // -4.0 + 0xC600, // -6.0 + }; +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint8_t b = reinterpret_cast(src)[i * 2]; + reinterpret_cast(dst)[i * 2 + 0] = lut[b & 0x0F]; + reinterpret_cast(dst)[i * 2 + 1] = lut[(b >> 4) & 0x0F]; + } +#endif + } +}; +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(nv_bfloat16* dst, const __nv_fp4x2_e2m1* src) { + static_assert(vec_size % 2 == 0, "vec_size must be even for fp4x2 dequantization"); +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + // Valid fp4x2 bytes are at even positions (stride 2); odd positions are padding. + uint32_t b = reinterpret_cast(src)[i * 2]; +#if (defined __CUDACC_VER_MAJOR__) && (defined __CUDACC_VER_MINOR__) && \ + ((__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ == 13) && (__CUDACC_VER_MINOR__ >= 2))) + // cvt.rn.bf16x2.e2m1x2 requires CUDA Toolkit >= 13.2 + asm volatile( + "{\n" + ".reg .b8 fp4_byte;\n" + "mov.b32 {fp4_byte, _, _, _}, %1;\n" + "cvt.rn.bf16x2.e2m1x2 %0, fp4_byte;\n" + "}" + : "=r"(y) + : "r"(b)); +#else + // Fallback: convert e2m1 -> fp16 -> bf16 when cvt.rn.bf16x2.e2m1x2 is unavailable + uint32_t fp16x2; + asm volatile( + "{\n" + ".reg .b8 fp4_byte;\n" + "mov.b32 {fp4_byte, _, _, _}, %1;\n" + "cvt.rn.f16x2.e2m1x2 %0, fp4_byte;\n" + "}" + : "=r"(fp16x2) + : "r"(b)); + __half2 h2 = reinterpret_cast<__half2&>(fp16x2); + __nv_bfloat162 bf16x2 = __float22bfloat162_rn(__half22float2(h2)); + y = reinterpret_cast(bf16x2); +#endif + reinterpret_cast(dst)[i] = y; + } +#else + // Software LUT fallback for arch < SM100. + // e2m1 encoding: bit[3]=sign, bit[2:0]=magnitude index in {0,0.5,1,1.5,2,3,4,6}. + // Each packed byte holds two fp4 values: bits[3:0]=first, bits[7:4]=second. + constexpr uint16_t lut[16] = { + 0x0000, // +0.0 + 0x3F00, // +0.5 + 0x3F80, // +1.0 + 0x3FC0, // +1.5 + 0x4000, // +2.0 + 0x4040, // +3.0 + 0x4080, // +4.0 + 0x40C0, // +6.0 + 0x8000, // -0.0 + 0xBF00, // -0.5 + 0xBF80, // -1.0 + 0xBFC0, // -1.5 + 0xC000, // -2.0 + 0xC040, // -3.0 + 0xC080, // -4.0 + 0xC0C0, // -6.0 + }; +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint8_t b = reinterpret_cast(src)[i * 2]; + reinterpret_cast(dst)[i * 2 + 0] = lut[b & 0x0F]; + reinterpret_cast(dst)[i * 2 + 1] = lut[(b >> 4) & 0x0F]; + } +#endif + } +}; + +#endif // FLASHINFER_ENABLE_FP4_E2M1 && CUDA_VERSION >= 12080 + template <> struct vec_cast { template diff --git a/tests/attention/test_batch_attention.py b/tests/attention/test_batch_attention.py index 1a0532b479..55d9fba440 100644 --- a/tests/attention/test_batch_attention.py +++ b/tests/attention/test_batch_attention.py @@ -23,6 +23,7 @@ gen_persistent_batch_attention_modules, gen_prefill_attention_modules, ) +from tests.test_helpers.utils_fp4 import create_nvfp4_kv, nvfp4_to_float from flashinfer.utils import get_compute_capability, has_flashinfer_jit_cache @@ -289,6 +290,139 @@ def test_batch_attention_correctness( ) +@pytest.mark.xfail( + get_compute_capability(torch.device(device="cuda"))[0] == 12, + reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.", +) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [128, 256]) +@pytest.mark.parametrize("qo_len", [64, 128]) +@pytest.mark.parametrize("page_size", [16, 64]) +@pytest.mark.parametrize("num_kv_heads", [1]) +@pytest.mark.parametrize("num_qo_heads", [1]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("q_dtype", [torch.float16, torch.bfloat16]) +def test_batch_attention_nvfp4( + batch_size, + kv_len, + qo_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + q_dtype, +): + """Test BatchAttention with NVFP4 KV cache. + + KV cache layout (NHD): + kv_cache: [num_pages, 2, page_size, num_kv_heads, head_dim//2] uint8 (packed FP4x2) + kv_cache_sf: [num_pages, 2, page_size, num_kv_heads, head_dim//16] uint8 (FP8 SFs) + + Reference is computed by dequantizing the packed KV back to q_dtype and running + single_prefill_with_kv_cache per batch item. + """ + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + kv_layout = "NHD" + torch.manual_seed(42) + + # --- query --- + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype + ) + q_indptr_cpu = torch.arange(0, batch_size + 1, dtype=torch.int32) * qo_len + + # --- paged KV metadata --- + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_indptr_cpu = ( + torch.arange(0, batch_size + 1, dtype=torch.int32) * num_pages_per_seq + ) + kv_indices_cpu = torch.arange(0, total_num_pages, dtype=torch.int32) + kv_last_page_len_cpu = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ) + kv_len_arr_cpu = torch.full((batch_size,), kv_len, dtype=torch.int32) + + # --- create NVFP4 KV pages directly (NHD: [num_pages, page_size, num_kv_heads, head_dim//2]) --- + kv_shape = (total_num_pages, page_size, num_kv_heads, head_dim // 2) + k_packed, k_sf, k_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + v_packed, v_sf, v_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + + # Dequantize for reference attention + k_dq = nvfp4_to_float(k_packed, k_sf, k_global_scale).to(q_dtype) + v_dq = nvfp4_to_float(v_packed, v_sf, v_global_scale).to(q_dtype) + + # Pack into combined tensors: + # kv_cache: [num_pages, 2, page_size, num_kv_heads, head_dim//2] + # kv_cache_sf: [num_pages, 2, page_size, num_kv_heads, head_dim//16] + kv_cache = torch.stack([k_packed, v_packed], dim=1) + kv_cache_sf = torch.stack([k_sf, v_sf], dim=1) + + # --- run BatchAttention --- + q_indptr_gpu = q_indptr_cpu.to("cuda:0") + kv_indptr_gpu = kv_indptr_cpu.to("cuda:0") + kv_indices_gpu = kv_indices_cpu.to("cuda:0") + kv_len_arr_gpu = kv_len_arr_cpu.to("cuda:0") + + wrapper = flashinfer.BatchAttention(kv_layout=kv_layout) + wrapper.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_len_arr_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_size, + causal=causal, + q_data_type=q_dtype, + kv_data_type=torch.uint8, + ) + o, _ = wrapper.run( + q, + kv_cache, + k_scale=k_global_scale.item(), + v_scale=v_global_scale.item(), + kv_cache_sf=kv_cache_sf, + ) + + # --- reference: single_prefill_with_kv_cache per batch item using dequantized KV --- + for i in range(batch_size): + qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + + full_pages_k = k_dq[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] + last_page_k = k_dq[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i]] + ki = torch.cat( + [ + full_pages_k.reshape(-1, num_kv_heads, head_dim), + last_page_k.reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + + full_pages_v = v_dq[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] + last_page_v = v_dq[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i]] + vi = torch.cat( + [ + full_pages_v.reshape(-1, num_kv_heads, head_dim), + last_page_v.reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + + o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( + qi, ki, vi, causal=causal, pos_encoding_mode="NONE", logits_soft_cap=0.0 + ) + o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-1, atol=1e-1) + + if __name__ == "__main__": test_batch_attention_correctness( seq_len_pairs=[(1000, 1000)], @@ -296,8 +430,20 @@ def test_batch_attention_correctness( num_kv_heads=4, gqa_group_size=7, head_dim=128, + v_scale=2.0, causal=True, layout="NHD", test_dtype=torch.bfloat16, logits_soft_cap=0.0, ) + test_batch_attention_nvfp4( + batch_size=4, + kv_len=128, + qo_len=64, + page_size=16, + num_kv_heads=1, + num_qo_heads=1, + head_dim=128, + causal=False, + q_dtype=torch.float16, + ) diff --git a/tests/attention/test_batch_decode_kernels.py b/tests/attention/test_batch_decode_kernels.py index aed6d25246..5ab8c46205 100644 --- a/tests/attention/test_batch_decode_kernels.py +++ b/tests/attention/test_batch_decode_kernels.py @@ -20,6 +20,7 @@ gen_decode_attention_modules, gen_prefill_attention_modules, ) +from tests.test_helpers.utils_fp4 import create_nvfp4_kv, nvfp4_to_float from functools import partial import flashinfer from flashinfer.utils import has_flashinfer_jit_cache @@ -680,6 +681,135 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17, 128]) +@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("q_dtype", [torch.float16, torch.bfloat16]) +def test_batch_decode_with_paged_kv_cache_nvfp4( + batch_size, + kv_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + q_dtype, +): + """Test BatchDecodeWithPagedKVCacheWrapper with NVFP4 KV cache. + + KV cache layout (NHD): + kv_cache: [num_pages, 2, page_size, num_kv_heads, head_dim//2] uint8 (packed FP4x2) + kv_cache_sf: [num_pages, 2, page_size, num_kv_heads, head_dim//16] uint8 (FP8 SFs) + + Reference is computed by dequantizing the packed KV back to q_dtype and running + single_decode_with_kv_cache per batch item. + """ + kv_layout = "NHD" + torch.manual_seed(42) + + # --- query: one token per request --- + q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) + + # --- paged KV metadata --- + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + ) + + # --- create NVFP4 KV pages directly (NHD: [num_pages, page_size, num_kv_heads, head_dim//2]) --- + kv_shape = (total_num_pages, page_size, num_kv_heads, head_dim // 2) + k_packed, k_sf, k_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + v_packed, v_sf, v_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + + # Dequantize for reference attention + k_dq = nvfp4_to_float(k_packed, k_sf, k_global_scale).to(q_dtype) + v_dq = nvfp4_to_float(v_packed, v_sf, v_global_scale).to(q_dtype) + + # Pack into combined tensors: + # kv_cache: [num_pages, 2, page_size, num_kv_heads, head_dim//2] + # kv_cache_sf: [num_pages, 2, page_size, num_kv_heads, head_dim//16] + kv_cache = torch.stack([k_packed, v_packed], dim=1) # [P, 2, S, H, D//2] + kv_cache_sf = torch.stack([k_sf, v_sf], dim=1) # [P, 2, S, H, D//16] + + # --- run BatchDecodeWithPagedKVCacheWrapper --- + # NVFP4 KV is only supported via the tensor-cores path (FA2 prefill kernel), + # not the legacy FA2 decode kernel. + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + + # This actually initialize the smem buffers to 0x7F (NaN as FP8), to trigger the SF addressing issue. + workspace_buffer.fill_(0x7F) + + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + kv_data_type=torch.uint8, + q_data_type=q_dtype, + ) + o = wrapper.run( + q, + kv_cache, + k_scale=k_global_scale.item(), + v_scale=v_global_scale.item(), + kv_cache_sf=kv_cache_sf, + ) + + # --- reference: single_decode_with_kv_cache per batch item using dequantized KV --- + kv_indptr_cpu = kv_indptr.cpu() + kv_last_page_len_cpu = kv_last_page_len.cpu() + for i in range(batch_size): + qi = q[i] + + # Gather full (non-padded) KV for sequence i from pages + full_pages_k = k_dq[ + kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1 + ] # [p-1, S, H, D] + last_page_k = k_dq[ + kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i] + ] # [l, H, D] + ki = torch.cat( + [ + full_pages_k.reshape(-1, num_kv_heads, head_dim), + last_page_k.reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + + full_pages_v = v_dq[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] + last_page_v = v_dq[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i]] + vi = torch.cat( + [ + full_pages_v.reshape(-1, num_kv_heads, head_dim), + last_page_v.reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + + o_ref_i = flashinfer.decode.single_decode_with_kv_cache( + qi, ki, vi, pos_encoding_mode="NONE", logits_soft_cap=0.0 + ) + + # NVFP4 is 4-bit; use relaxed tolerance + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-1, atol=1e-1) + + if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( 256, @@ -765,6 +895,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( test_cuda_graph_batch_decode_with_paged_kv_cache( 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float8_e5m2, True ) + test_batch_decode_with_paged_kv_cache_nvfp4(4, 128, 64, 1, 1, 128, torch.float16) def test_single_decode_torch_compile_cuda_graph(): diff --git a/tests/attention/test_batch_prefill_kernels.py b/tests/attention/test_batch_prefill_kernels.py index 15eb3310ee..1a52230f1e 100644 --- a/tests/attention/test_batch_prefill_kernels.py +++ b/tests/attention/test_batch_prefill_kernels.py @@ -20,6 +20,7 @@ from tests.test_helpers.jit_utils import gen_prefill_attention_modules import flashinfer +from tests.test_helpers.utils_fp4 import create_nvfp4_kv, nvfp4_to_float from flashinfer.utils import has_flashinfer_jit_cache @@ -1026,6 +1027,239 @@ def create_2D_multi_item_mask_dense( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [128, 256]) +@pytest.mark.parametrize("qo_len", [64, 128]) +@pytest.mark.parametrize("page_size", [16, 64]) +@pytest.mark.parametrize("num_kv_heads", [1]) +@pytest.mark.parametrize("num_qo_heads", [1]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("q_dtype", [torch.float16, torch.bfloat16]) +def test_batch_prefill_with_paged_kv_cache_nvfp4( + batch_size, + kv_len, + qo_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + q_dtype, +): + """Test BatchPrefillWithPagedKVCacheWrapper with NVFP4 KV cache. + + KV cache layout (NHD): + kv_cache: [num_pages, 2, page_size, num_kv_heads, head_dim//2] uint8 (packed FP4x2) + kv_cache_sf: [num_pages, 2, page_size, num_kv_heads, head_dim//16] uint8 (FP8 SFs) + + Reference is computed by dequantizing the packed KV back to q_dtype and running + single_prefill_with_kv_cache per batch item. + """ + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + kv_layout = "NHD" + torch.manual_seed(42) + + # --- query --- + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype + ) + q_indptr_cpu = torch.arange(0, batch_size + 1, dtype=torch.int32) * qo_len + + # --- paged KV metadata --- + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_indptr_cpu = ( + torch.arange(0, batch_size + 1, dtype=torch.int32) * num_pages_per_seq + ) + kv_indices_cpu = torch.arange(0, total_num_pages, dtype=torch.int32) + kv_last_page_len_cpu = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ) + + # --- create NVFP4 KV pages directly (NHD: [num_pages, page_size, num_kv_heads, head_dim//2]) --- + kv_shape = (total_num_pages, page_size, num_kv_heads, head_dim // 2) + k_packed, k_sf, k_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + v_packed, v_sf, v_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + + # Dequantize for reference attention + k_dq = nvfp4_to_float(k_packed, k_sf, k_global_scale).to(q_dtype) + v_dq = nvfp4_to_float(v_packed, v_sf, v_global_scale).to(q_dtype) + + # Pack into combined tensors: + # kv_cache: [num_pages, 2, page_size, num_kv_heads, head_dim//2] + # kv_cache_sf: [num_pages, 2, page_size, num_kv_heads, head_dim//16] + kv_cache = torch.stack([k_packed, v_packed], dim=1) # [P, 2, S, H, D//2] + kv_cache_sf = torch.stack([k_sf, v_sf], dim=1) # [P, 2, S, H, D//16] + + # --- run BatchPrefillWithPagedKVCacheWrapper --- + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + q_indptr_gpu = q_indptr_cpu.to("cuda:0") + kv_indptr_gpu = kv_indptr_cpu.to("cuda:0") + kv_indices_gpu = kv_indices_cpu.to("cuda:0") + kv_last_page_len_gpu = kv_last_page_len_cpu.to("cuda:0") + + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + kv_data_type=torch.uint8, + q_data_type=q_dtype, + ) + o = wrapper.run( + q, + kv_cache, + k_scale=k_global_scale.item(), + v_scale=v_global_scale.item(), + kv_cache_sf=kv_cache_sf, + ) + + # --- reference: single_prefill_with_kv_cache per batch item using dequantized KV --- + for i in range(batch_size): + qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + + # Gather full (non-padded) KV for sequence i from pages + full_pages_k = k_dq[ + kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1 + ] # [p-1, S, H, D] + last_page_k = k_dq[ + kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i] + ] # [l, H, D] + ki = torch.cat( + [ + full_pages_k.reshape(-1, num_kv_heads, head_dim), + last_page_k.reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + + full_pages_v = v_dq[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] + last_page_v = v_dq[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i]] + vi = torch.cat( + [ + full_pages_v.reshape(-1, num_kv_heads, head_dim), + last_page_v.reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + + o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( + qi, ki, vi, causal=causal, pos_encoding_mode="NONE", logits_soft_cap=0.0 + ) + o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + + # NVFP4 is 4-bit; use relaxed tolerance + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-1, atol=1e-1) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [128, 256]) +@pytest.mark.parametrize("qo_len", [64, 128]) +@pytest.mark.parametrize("num_kv_heads", [1]) +@pytest.mark.parametrize("num_qo_heads", [1]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("q_dtype", [torch.float16, torch.bfloat16]) +def test_batch_prefill_with_ragged_kv_cache_nvfp4( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + q_dtype, +): + """Test BatchPrefillWithRaggedKVCacheWrapper with NVFP4 KV cache. + + KV cache layout (NHD): + k/v: [total_kv_tokens, num_kv_heads, head_dim//2] uint8 (packed FP4x2) + k/v_sf: [total_kv_tokens, num_kv_heads, head_dim//16] uint8 (FP8 SFs) + + Reference is computed by dequantizing the packed KV back to q_dtype and running + single_prefill_with_kv_cache per batch item. + """ + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + kv_layout = "NHD" + torch.manual_seed(42) + + # --- query --- + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype + ) + q_indptr_cpu = torch.arange(0, batch_size + 1, dtype=torch.int32) * qo_len + q_indptr_gpu = q_indptr_cpu.to("cuda:0") + + # --- ragged KV metadata --- + kv_indptr_cpu = torch.arange(0, batch_size + 1, dtype=torch.int32) * kv_len + kv_indptr_gpu = kv_indptr_cpu.to("cuda:0") + total_kv_tokens = batch_size * kv_len + + # --- create NVFP4 ragged KV (NHD: [total_kv_tokens, num_kv_heads, head_dim//2]) --- + kv_shape = (total_kv_tokens, num_kv_heads, head_dim // 2) + k_packed, k_sf, k_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + v_packed, v_sf, v_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + + # Dequantize for reference attention + k_dq = nvfp4_to_float(k_packed, k_sf, k_global_scale).to(q_dtype) + v_dq = nvfp4_to_float(v_packed, v_sf, v_global_scale).to(q_dtype) + + # --- run BatchPrefillWithRaggedKVCacheWrapper --- + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + q_indptr_gpu, + kv_indptr_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + causal=causal, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + kv_data_type=torch.uint8, + q_data_type=q_dtype, + ) + o = wrapper.run( + q, + k_packed, + v_packed, + k_scale=k_global_scale.item(), + v_scale=v_global_scale.item(), + kv_cache_sf=(k_sf, v_sf), + ) + + # --- reference: single_prefill_with_kv_cache per batch item using dequantized KV --- + for i in range(batch_size): + qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + ki = k_dq[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1]] + vi = v_dq[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1]] + + o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( + qi, ki, vi, causal=causal, pos_encoding_mode="NONE", logits_soft_cap=0.0 + ) + o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + + # NVFP4 is 4-bit; use relaxed tolerance + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-1, atol=1e-1) + + if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 1, 4, 4, 128, False, "NHD", "NONE", False, 0.0, True, True @@ -1049,6 +1283,12 @@ def create_2D_multi_item_mask_dense( test_batch_prefill_with_ragged_kv_cache_custom_mask( 1, 137, 137, 8, 8, 128, "NONE", 0.0, False ) + test_batch_prefill_with_paged_kv_cache_nvfp4( + 4, 128, 64, 64, 1, 1, 128, False, torch.float16 + ) + test_batch_prefill_with_ragged_kv_cache_nvfp4( + 4, 128, 64, 1, 1, 128, False, torch.float16 + ) def test_single_prefill_torch_compile_cuda_graph(): diff --git a/tests/attention/test_single_prefill.py b/tests/attention/test_single_prefill.py index d08a63d3c1..d1b920632d 100644 --- a/tests/attention/test_single_prefill.py +++ b/tests/attention/test_single_prefill.py @@ -4,6 +4,7 @@ import torch import flashinfer +from tests.test_helpers.utils_fp4 import create_nvfp4_kv, nvfp4_to_float def build_causal_mask(qo_len, kv_len): @@ -101,3 +102,61 @@ def test_sinqle_prefill_with_paged_kv_cache( o_ref = single_prefill_with_kv_cache_ref(q, k, v, causal=causal) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("kv_len", [128, 256]) +@pytest.mark.parametrize("qo_len", [64, 128]) +@pytest.mark.parametrize("num_kv_heads", [1]) +@pytest.mark.parametrize("num_qo_heads", [1]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("q_dtype", [torch.float16, torch.bfloat16]) +def test_single_prefill_with_kv_cache_nvfp4( + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + q_dtype, +): + """Test single_prefill_with_kv_cache with NVFP4 KV cache (contiguous layout). + + KV layout (NHD): + k/v: [kv_len, num_kv_heads, head_dim//2] uint8 (packed FP4x2) + k/v_sf: [kv_len, num_kv_heads, head_dim//16] uint8 (FP8 scale factors) + + Reference uses dequantized KV with the standard fp16 kernel. + """ + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + torch.manual_seed(42) + + q = torch.randn(qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) + + kv_shape = (kv_len, num_kv_heads, head_dim // 2) + k_packed, k_sf, k_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + v_packed, v_sf, v_global_scale = create_nvfp4_kv(kv_shape, "cuda:0") + + k_dq = nvfp4_to_float(k_packed, k_sf, k_global_scale).to(q_dtype) + v_dq = nvfp4_to_float(v_packed, v_sf, v_global_scale).to(q_dtype) + + o = flashinfer.prefill.single_prefill_with_kv_cache( + q, + k_packed, + v_packed, + causal=causal, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + k_scale=k_global_scale.item(), + v_scale=v_global_scale.item(), + kv_cache_sf=(k_sf, v_sf), + ) + + o_ref = flashinfer.prefill.single_prefill_with_kv_cache( + q, k_dq, v_dq, causal=causal, pos_encoding_mode="NONE", logits_soft_cap=0.0 + ) + + # NVFP4 is 4-bit; use relaxed tolerance + torch.testing.assert_close(o, o_ref, rtol=1e-1, atol=1e-1) diff --git a/tests/test_helpers/utils_fp4.py b/tests/test_helpers/utils_fp4.py index f5aaa8e7ec..780030b827 100644 --- a/tests/test_helpers/utils_fp4.py +++ b/tests/test_helpers/utils_fp4.py @@ -98,3 +98,47 @@ def recover_swizzled_scales(scale, m, n, block_size, sf_start_index=0): tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) result = torch.reshape(tmp, (full_m, rounded_n)).to(torch.float32) return result[sf_start_index : sf_start_index + m, :scale_n] + + +def create_nvfp4_kv(shape, device): + """Create random NVFP4 KV data directly. + + Args: + shape: (..., head_dim//2) for packed data, where leading dims are e.g. + (total_num_pages, page_size, num_kv_heads, head_dim//2). + device: torch device. + + Returns: + packed: uint8 tensor of given shape, random with bits 3 and 7 cleared. + sf: uint8 tensor of shape (*shape[:-1], shape[-1]//8), random from [32, 40, 48, 56] + (FP8 e4m3 encoding of 0.125, 0.25, 0.5, 1.0). + global_scale: scalar tensor, 1.0. + """ + packed = torch.randint(0, 256, shape, dtype=torch.uint8, device=device) + packed &= 0x77 # clear bit 3 (0x08) and bit 7 (0x80) + + # head_dim//2 packed bytes → head_dim FP4 values; one SF per 16 FP4 values → head_dim//16 SFs + sf_shape = (*shape[:-1], shape[-1] // 8) + sf_choices = torch.tensor( + [56, 48, 40, 32], dtype=torch.uint8, device=device + ) # 1.0, 0.5, 0.25, 0.125 in FP8 e4m3 + sf_idx = torch.randint(0, 4, sf_shape, device=device) + sf = sf_choices[sf_idx] + + return packed, sf, torch.tensor(1.0, device=device) + + +def nvfp4_to_float(x, sf, global_sf): + """Dequantize NVFP4 (packed uint8 + FP8 SF) back to float32. + + x: (..., head_dim//2) uint8 packed FP4 + sf: (..., head_dim//16) uint8 FP8 scale factors, one per 16 FP4 elements + """ + from flashinfer.fp4_quantization import e2m1_and_ufp8sf_scale_to_float + + x_flat = x.reshape(-1, x.shape[-1]) + sf_flat = sf.reshape(-1, sf.shape[-1]) + x_dq = e2m1_and_ufp8sf_scale_to_float( + x_flat, sf_flat, global_sf, sf_vec_size=16, is_sf_swizzled_layout=False + ) + return x_dq.reshape(*x.shape[:-1], -1).to(x.device)