Commit 1de1b97
authored
bump version to 0.6.7 & fix api breaking changes (#2832)
<!-- .github/pull_request_template.md -->
## π Description
fix api breaking changes for 0.6.7 release
## π Related Issues (Gated-by PRs)
https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7
<!-- Link any related issues here -->
## π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### β
Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## π§ͺ Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
**API changes review**
API changes since v0.6.6
PR #2520 + commit e35c19e (fixed to be compatible)
Function: xqa()
Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params
(after *). Backward-compatible.
PR #2618 (has PR #2730 to fix it)
Function: gated_delta_rule_mtp()
Change: disable_state_update: bool = True β Optional[bool] = None. Still
defaults to True at runtime but emits a deprecation
warning; will flip to False in 0.7.0.
PR #2775 (expected β cute DSL MoE cleanup)
Function: blockscaled_contiguous_grouped_gemm_nvfp4()
Change: Entire @flashinfer_api decorated function deleted.
Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4()
Change: Entire @flashinfer_api decorated function deleted.
Function:
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.
Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.
Function: CuteDslMoEWrapper.__init__()
Change: Added enable_pdl: bool = True param. Backward-compatible.
Function: cute_dsl_fused_moe_nvfp4()
Change: Added enable_pdl: bool = True param. Backward-compatible.
PR #2428
Function: rmsnorm_quant()
Change: scale: float β scale: Union[float, torch.Tensor]; return type
torch.Tensor β None.
Function: fused_add_rmsnorm_quant()
Change: scale: float β scale: Union[float, torch.Tensor].
Quantization functions (relocated, not removed)
All quantization APIs (fp4_quantize, block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a,
nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize,
mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host,
mxfp8_quantize, mxfp8_dequantize_host) were moved from
flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to
flashinfer/quantization/. Signatures, @flashinfer_api decorators, and
__init__.py exports are preserved. No breakage.
```diff
$ git diff v0.6.6 | grep -A20 "@flashinfer_api"
@flashinfer_api
@@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper:
sinks: Optional[torch.Tensor] = None,
q_len_per_req: Optional[int] = 1,
skip_softmax_threshold_scale_factor: Optional[float] = None,
+ kv_block_scales: Optional[
+ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+ ] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Compute batch decode attention between query and paged kv cache.
@@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper:
enable_pdl = device_support_pdl(q.device)
k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)
+ # Unpack kv_block_scales
+ key_block_scales = None
+ value_block_scales = None
+ if kv_block_scales is not None:
+ if isinstance(kv_block_scales, tuple):
+ key_block_scales, value_block_scales = kv_block_scales
--
-@flashinfer_api
-def fp4_quantize(
- input: torch.Tensor,
- global_scale: Optional[torch.Tensor] = None,
- sf_vec_size: int = 16,
- sf_use_ue8m0: bool = False,
- is_sf_swizzled_layout: bool = True,
- is_sf_8x4_layout: bool = False,
- enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """Quantize input tensor to FP4 format.
-
- This function implements FP4 quantization that converts input tensors to a compressed FP4 format
- with associated scale factors. It supports various input data types and scale factor layouts.
-
- Args:
- input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
- global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
- sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
- sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
- is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
-@flashinfer_api
-def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
- """Swizzle block scale tensor for FP4 format.
-
- This function swizzles the block scale tensor to optimize memory access patterns
- for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
-
- Args:
- unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
-
- Returns:
- torch.Tensor: Swizzled tensor with the same shape as input.
-
- Raises:
- AssertionError: If input dtype is not uint8 or bfloat16.
- """
- # TODO(shuw): check input dtype is uint8
- assert (
- unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
- ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
-
--
-@flashinfer_api
-def e2m1_and_ufp8sf_scale_to_float(
- e2m1_tensor: torch.Tensor,
- ufp8_scale_tensor: torch.Tensor,
- global_scale_tensor: Optional[torch.Tensor] = None,
- sf_vec_size: int = 16,
- ufp8_type: int = 1,
- is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
- """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
-
- This function performs dequantization by converting a packed FP4 tensor in E2M1 format
- back to float values using the associated UFP8 scale factors and global scale.
-
- Args:
- e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
- ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
- global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
- sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
- ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
- is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
-@flashinfer_api
-def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
- """
- PyTorch equivalent of trtllm-gen `shuffleMatrixA`
- """
- row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
-
- return input_tensor[row_indices.to(input_tensor.device)]
-
-
-@flashinfer_api
-def shuffle_matrix_sf_a(
- input_tensor: torch.Tensor,
- epilogue_tile_m: int,
- num_elts_per_sf: int = 16,
-):
- """
- Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
- `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
- apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
- layout.
- This function expects the input to be in linear layout. It's done this
- way because the scaling factors in the NVFP4 checkpoints are quantized
- and are in linear layout.
- This function doesn't add padding.
- """
-
- row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
-
- w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
-
--
-@flashinfer_api
-def nvfp4_quantize(
- a,
- a_global_sf,
- sfLayout=SfLayout.layout_128x4,
- do_shuffle=False,
- sf_vec_size=16,
- enable_pdl=None,
-):
- """
- Quantize input tensor to NVFP4 format.
-
- Parameters:
- a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
- a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
- sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
- do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
- sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
- enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
- If None, automatically detects based on device capability. Defaults to None.
-
--
-@flashinfer_api
-def mxfp4_quantize(a):
- """
- Quantize input tensor to MXFP4 format.
-
- Parameters:
- a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
- - Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
- """
- a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
- a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
- return a_fp4, a_sf
-
-
-@flashinfer_api
-def mxfp4_dequantize(a_fp4, a_sf):
- """
- Dequantize input tensor from MXFP4 format.
-
- Parameters:
- a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
- a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-
- Returns:
- torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
- """
- return e2m1_and_ufp8sf_scale_to_float(
- a_fp4.cpu().view(torch.uint8),
- a_sf.cpu().view(torch.uint8).reshape(-1),
- torch.tensor([1.0], device=a_fp4.device),
- 32,
- 0,
- True,
- )
-
--
-@flashinfer_api
-def mxfp4_dequantize_host(
- weight: torch.Tensor,
- scale: torch.Tensor,
- group_size: int = 32,
-) -> torch.Tensor:
- """
- Dequantize input tensor from MXFP4 format on host.
-
- Parameters:
- weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
- scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
- group_size (int, optional): Group size for dequantization. Defaults to 32.
-
- Returns:
- torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
- """
- # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
- major, minor = get_compute_capability(
- torch.device("cuda:0")
- ) # use any cuda device to get a compute capability
--
-@flashinfer_api
-def nvfp4_batched_quantize(
- a,
- a_global_sf,
- sf_vec_size=16,
-):
- """
- Quantize batched input tensor to NVFP4 format.
-
- Parameters:
- a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
- a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
- sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
- - Scale factors tensor with shape determined by layout and sf_vec_size
- """
- major, minor = get_compute_capability(a.device)
- device_arch = f"{major * 10 + minor}"
--
-@flashinfer_api
-def scaled_fp4_grouped_quantize(
- a,
- mask,
- a_global_sf,
-):
- """
- quantize batched input tensor to NVFP4 format with mask.
- Parameters:
- a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
- a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
- mask (torch.Tensor): Mask tensor to apply before quantization.
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
- - Scale factors tensor with shape determined by layout and sf_vec_size
- """
- major, minor = get_compute_capability(a.device)
- device_arch = f"{major * 10 + minor}"
- a_fp4, a_sf = get_fp4_quantization_module(
- device_arch
--
-@flashinfer_api
-def mxfp8_quantize(
- input: torch.Tensor,
- is_sf_swizzled_layout: bool = True,
- alignment: int = 32,
- enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """Quantize input tensor to MxFP8 format.
-
- This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
- with associated scale factors. It supports various input data types and scale factor layouts.
-
- Args:
- input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
- is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
- alignment (int, optional): sfVecSize. Defaults to 32.
- enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
- If None, automatically detects based on device capability. Defaults to None.
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
--
-@flashinfer_api
-def mxfp8_dequantize_host(
- input: torch.Tensor,
- scale_tensor: torch.Tensor,
- is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
- """Dequantize input tensor from MxFP8 format.
-
- This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
- back to float values using the associated scale factors.
-
- Args:
- input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
- scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
- is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
-
- Returns:
- torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
-
- """
-
--
-@flashinfer_api
def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
a: torch.Tensor,
b: torch.Tensor,
@@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
vectorized_f32: bool = True,
raster_along_m: bool = False,
sm_count: Optional[int] = None,
+ enable_pdl: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads.
@@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
major, minor = get_compute_capability(a.device)
if major != 10:
raise ValueError(
- f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). "
+ f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). "
f"Got SM{major}{minor}."
)
--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_nvfp4(
- a: torch.Tensor,
- b: torch.Tensor,
- a_scale: torch.Tensor,
- b_scale: torch.Tensor,
- alpha: torch.Tensor,
- tile_idx_to_group_idx: torch.Tensor,
- num_non_exiting_tiles: torch.Tensor,
- out: Optional[torch.Tensor] = None,
- *,
- ab_dtype: str = "float4_e2m1fn",
- sf_dtype: str = "float8_e4m3fn",
- c_dtype: str = "bfloat16",
- sf_vec_size: int = 16,
- mma_tiler_mn: Tuple[int, int] = (128, 128),
- cluster_shape_mn: Tuple[int, int] = (1, 1),
- sm_count: Optional[int] = None,
-) -> torch.Tensor:
- """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization.
-
--
-@flashinfer_api
def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
a: torch.Tensor,
b: torch.Tensor,
@@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
cluster_shape_mn: Tuple[int, int] = (2, 1),
raster_along_m: bool = False,
sm_count: Optional[int] = None,
+ enable_pdl: bool = True,
) -> torch.Tensor:
"""Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads.
@@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
out: Optional output tensor, shape (seq_len, n). Created if None.
- This tensor is used for atomic accumulation, so it should be zero-initialized.
+ This tensor is used for atomic accumulation. If `out` is
+ provided, it must already be zero-initialized by the caller.
+ If `out` is None, this function allocates a zero-initialized
+ output tensor. Passing a non-zeroed `out` buffer will silently
--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4(
- a: torch.Tensor,
- b: torch.Tensor,
- a_scale: torch.Tensor,
- b_scale: torch.Tensor,
- alpha: torch.Tensor,
- tile_idx_to_group_idx: torch.Tensor,
- num_non_exiting_tiles: torch.Tensor,
- out: Optional[torch.Tensor] = None,
- out_scale: Optional[torch.Tensor] = None,
- global_scale: Optional[torch.Tensor] = None,
- *,
- ab_dtype: str = "float4_e2m1fn",
- sf_dtype: str = "float8_e4m3fn",
- c_dtype: str = "bfloat16",
- sf_vec_size: int = 16,
- mma_tiler_mn: Tuple[int, int] = (256, 128),
- cluster_shape_mn: Tuple[int, int] = (2, 1),
- vectorized_f32: bool = True,
- sm_count: Optional[int] = None,
--
@flashinfer_api
def __init__(
self,
@@ -347,6 +355,7 @@ class CuteDslMoEWrapper:
sf_vec_size: int = 16,
output_dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
+ enable_pdl: bool = True,
):
"""Initialize the MoE wrapper.
@@ -363,6 +372,7 @@ class CuteDslMoEWrapper:
sf_vec_size: Scale factor vector size. Default: 16.
output_dtype: Output data type. Default: torch.bfloat16.
device: Device for buffer allocation. Default: "cuda".
+ enable_pdl: Enable Programmatic Dependent Launch. Default: True.
"""
self.num_experts = num_experts
self.top_k = top_k
@@ -376,6 +386,7 @@ class CuteDslMoEWrapper:
self.sf_vec_size = sf_vec_size
--
@flashinfer_api
@@ -550,9 +570,10 @@ class CuteDslMoEWrapper:
f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})"
)
- # Allocate output buffer if not using pre-allocated one
+ # Slice the pre-allocated buffer to the active batch so that
+ # _moe_core_impl only zeros num_tokens rows, not max_num_tokens.
if self.use_cuda_graph:
- moe_output = self._moe_output
+ moe_output = self._moe_output[:num_tokens]
else:
moe_output = torch.empty(
(num_tokens, self.hidden_size),
@@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
use_fused_finalize: bool = True,
moe_output: Optional[torch.Tensor] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
+ enable_pdl: bool = True,
) -> torch.Tensor:
"""Internal implementation called by auto-tuner for functional API."""
--
@flashinfer_api
def cute_dsl_fused_moe_nvfp4(
x: torch.Tensor,
@@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4(
use_fused_finalize: bool = True,
moe_output: Optional[torch.Tensor] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
+ enable_pdl: bool = True,
) -> torch.Tensor:
"""Run fused MoE computation using CuteDSL NVFP4 kernels.
+ Supported architectures: SM100, SM103.
+
This is the simple functional API. For CUDA graph support, use
`CuteDslMoEWrapper` instead.
@@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4(
local_expert_offset=local_expert_offset,
use_fused_finalize=use_fused_finalize,
output_dtype=output_dtype,
+ enable_pdl=enable_pdl,
--
@flashinfer_api
def gated_delta_rule_decode_pretranspose(
q: torch.Tensor,
@@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose(
- State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
(supports both the direct ``state`` path and the pool+indices path).
- - pool+indices (``initial_state``/``initial_state_indices``) only supported
- via the bf16 fast path; float32 state raises an error.
+ - pool+indices (``initial_state``/``initial_state_indices``) supported on
+ both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path
+ (T=1). The float32 path also supports negative indices for padding.
- Legacy path (float32 state, T=1): K and V must be multiples of 4.
"""
# Validate input shapes
@@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose(
return_state = initial_state if use_pool else state
return output, return_state
- # Legacy path: T=1 only, float32 state (no pool+indices support)
- assert not use_pool, (
--
@flashinfer_api
def gated_delta_rule_mtp(
q: torch.Tensor,
@@ -2427,7 +489,7 @@ def gated_delta_rule_mtp(
scale: Optional[float] = None,
output: Optional[torch.Tensor] = None,
intermediate_states_buffer: Optional[torch.Tensor] = None,
- disable_state_update: bool = True,
+ disable_state_update: Optional[bool] = None,
use_qk_l2norm: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -2463,8 +525,15 @@ def gated_delta_rule_mtp(
intermediate_states_buffer (Optional[torch.Tensor]):
Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``.
If None, intermediate states are not cached.
- disable_state_update (bool):
- If True, the initial state is not updated. Default: ``True``.
+ disable_state_update (Optional[bool]):
+ If True, the initial state is not updated. Currently defaults to ``True``.
+ Please pass this argument explicitly β the default will change to ``False``
--
@flashinfer_api
@@ -60,16 +120,14 @@ def rmsnorm(
output: torch.Tensor
Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
"""
- if enable_pdl is None:
- enable_pdl = device_support_pdl(input.device)
if out is None:
out = torch.empty_like(input)
- _rmsnorm(out, input, weight, eps, enable_pdl)
+ _rmsnorm_impl(out, input, weight, eps, enable_pdl)
return out
@register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
-def _rmsnorm(
+def _rmsnorm_impl(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
@@ -78,11 +136,21 @@ def _rmsnorm(
--
@flashinfer_api
def fmha_v2_prefill_deepseek(
query: torch.Tensor,
@@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek(
If return_lse is False, the output will be a single tensor.
"""
if not is_sm12x_supported(query.device):
- major, minor = get_compute_capability(query.device)
- if major == 12:
- min_cuda = "13.0" if minor >= 1 else "12.8"
- raise ValueError(
- f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} "
- f"for SM12{minor}x GPUs."
- )
raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.")
assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
"currently only support deepseek r1 192 query and 128 value"
)
- module = get_trtllm_fmha_v2_module()
+ module = get_trtllm_fmha_v2_sm120_module()
is_e4m3 = query.dtype == torch.float8_e4m3fn
--
+@flashinfer_api
+def trtllm_fmha_v2_prefill(
+ qkv: Union[
+ torch.Tensor,
+ Tuple[torch.Tensor, torch.Tensor],
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ ],
+ input_layout: str,
+ workspace_buffer: torch.Tensor,
+ seq_lens: torch.Tensor,
+ max_q_len: int,
+ max_kv_len: int,
+ bmm1_scale: float,
+ bmm2_scale: float,
+ batch_size: int,
+ cum_seq_lens_q: torch.Tensor,
+ cum_seq_lens_kv: torch.Tensor,
+ block_tables: Optional[torch.Tensor] = None,
+ out: Optional[torch.Tensor] = None,
+ out_dtype: Optional[Union[torch.dtype, str]] = None,
+ sinks: Optional[List[torch.Tensor]] = None,
--
+@flashinfer_api
+def fp4_quantize(
+ input: torch.Tensor,
+ global_scale: Optional[torch.Tensor] = None,
+ sf_vec_size: int = 16,
+ sf_use_ue8m0: bool = False,
+ is_sf_swizzled_layout: bool = True,
+ is_sf_8x4_layout: bool = False,
+ enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Quantize input tensor to FP4 format.
+
+ This function implements FP4 quantization that converts input tensors to a compressed FP4 format
+ with associated scale factors. It supports various input data types and scale factor layouts.
+
+ Args:
+ input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+ global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+ sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+ sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
+ is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
+@flashinfer_api
+def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
+ """Swizzle block scale tensor for FP4 format.
+
+ This function swizzles the block scale tensor to optimize memory access patterns
+ for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
+
+ Args:
+ unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
+
+ Returns:
+ torch.Tensor: Swizzled tensor with the same shape as input.
+
+ Raises:
+ AssertionError: If input dtype is not uint8 or bfloat16.
+ """
+ # TODO(shuw): check input dtype is uint8
+ assert (
+ unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
+ ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
+
--
+@flashinfer_api
+def e2m1_and_ufp8sf_scale_to_float(
+ e2m1_tensor: torch.Tensor,
+ ufp8_scale_tensor: torch.Tensor,
+ global_scale_tensor: Optional[torch.Tensor] = None,
+ sf_vec_size: int = 16,
+ ufp8_type: int = 1,
+ is_sf_swizzled_layout: bool = True,
+) -> torch.Tensor:
+ """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
+
+ This function performs dequantization by converting a packed FP4 tensor in E2M1 format
+ back to float values using the associated UFP8 scale factors and global scale.
+
+ Args:
+ e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
+ ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
+ global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+ sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+ ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
+ is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
+@flashinfer_api
+def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
+ """
+ PyTorch equivalent of trtllm-gen `shuffleMatrixA`
+ """
+ row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
+
+ return input_tensor[row_indices.to(input_tensor.device)]
+
+
+@flashinfer_api
+def shuffle_matrix_sf_a(
+ input_tensor: torch.Tensor,
+ epilogue_tile_m: int,
+ num_elts_per_sf: int = 16,
+):
+ """
+ Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
+ `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
+ apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
+ layout.
+ This function expects the input to be in linear layout. It's done this
+ way because the scaling factors in the NVFP4 checkpoints are quantized
+ and are in linear layout.
+ This function doesn't add padding.
+ """
+
+ row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
+
+ w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
+
--
+@flashinfer_api
+def nvfp4_quantize(
+ a,
+ a_global_sf,
+ sfLayout=SfLayout.layout_128x4,
+ do_shuffle=False,
+ sf_vec_size=16,
+ enable_pdl=None,
+):
+ """
+ Quantize input tensor to NVFP4 format.
+
+ Parameters:
+ a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+ a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+ sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
+ do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
+ sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+ enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+ If None, automatically detects based on device capability. Defaults to None.
+
--
+@flashinfer_api
+def mxfp4_quantize(
+ a: torch.Tensor,
+ backend: str = "cuda",
+ enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantize input tensor to MXFP4 format.
+
+ Parameters:
+ a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+ backend (str, optional): Backend to use for quantization.
+ - "cuda": Use CUDA kernel (default, stable)
+ - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**)
+ enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic
+ Dependent Launch). Only used when backend="cute-dsl".
+ If None, automatically detects based on device capability.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
--
+@flashinfer_api
+def mxfp4_dequantize(a_fp4, a_sf):
+ """
+ Dequantize input tensor from MXFP4 format.
+
+ Parameters:
+ a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+ a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+
+ Returns:
+ torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+ """
+ return e2m1_and_ufp8sf_scale_to_float(
+ a_fp4.cpu().view(torch.uint8),
+ a_sf.cpu().view(torch.uint8).reshape(-1),
+ torch.tensor([1.0], device=a_fp4.device),
+ 32,
+ 0,
+ True,
+ )
+
--
+@flashinfer_api
+def mxfp4_dequantize_host(
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ group_size: int = 32,
+) -> torch.Tensor:
+ """
+ Dequantize input tensor from MXFP4 format on host.
+
+ Parameters:
+ weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+ scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+ group_size (int, optional): Group size for dequantization. Defaults to 32.
+
+ Returns:
+ torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+ """
+ # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
+ major, minor = get_compute_capability(
+ torch.device("cuda:0")
+ ) # use any cuda device to get a compute capability
--
+@flashinfer_api
+def nvfp4_batched_quantize(
+ a,
+ a_global_sf,
+ sf_vec_size=16,
+):
+ """
+ Quantize batched input tensor to NVFP4 format.
+
+ Parameters:
+ a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+ a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+ sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+ - Scale factors tensor with shape determined by layout and sf_vec_size
+ """
+ major, minor = get_compute_capability(a.device)
+ device_arch = f"{major * 10 + minor}"
--
+@flashinfer_api
+def nvfp4_quantize_paged_kv_cache(
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ kv_layout: str = "HND",
+ k_global_sf: Optional[torch.Tensor] = None,
+ v_global_sf: Optional[torch.Tensor] = None,
+) -> Tuple[
+ Tuple[torch.Tensor, torch.Tensor],
+ Tuple[torch.Tensor, torch.Tensor],
+ float,
+ float,
+]:
+ """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA.
+
+ Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling
+ (global FP32 + per-block FP8), and swizzles scale factors
+ for the SM100 trtllm-gen MHA kernel layout.
+
+ Args:
+ k_cache: Key cache tensor.
--
+@flashinfer_api
+def scaled_fp4_grouped_quantize(
+ a,
+ mask,
+ a_global_sf,
+):
+ """
+ quantize batched input tensor to NVFP4 format with mask.
+ Parameters:
+ a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+ a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+ mask (torch.Tensor): Mask tensor to apply before quantization.
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+ - Scale factors tensor with shape determined by layout and sf_vec_size
+ """
+ major, minor = get_compute_capability(a.device)
+ device_arch = f"{major * 10 + minor}"
+ a_fp4, a_sf = get_fp4_quantization_module(
+ device_arch
--
+@flashinfer_api
+def nvfp4_kv_dequantize(
+ fp4_data: torch.Tensor,
+ block_scales: torch.Tensor,
+ global_scale: torch.Tensor,
+ output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+ """GPU dequantization of NVFP4 KV cache data with linear block scale layout.
+
+ Requires SM80+.
+
+ Args:
+ fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+ block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]``
+ with dtype uint8.
+ global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+ on the same CUDA device as fp4_data.
+ output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``.
+
+ Returns:
+ torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype.
--
+@flashinfer_api
+def nvfp4_kv_quantize(
+ input: torch.Tensor,
+ global_scale: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """GPU quantization to NVFP4 KV cache format with linear block scale layout.
+
+ Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction.
+
+ Args:
+ input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16.
+ K must be divisible by 16.
+ global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+ on the same CUDA device as input.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+ - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8.
+ """
+ M, K = input.shape
--
+@flashinfer_api
+def mxfp8_quantize(
+ input: torch.Tensor,
+ is_sf_swizzled_layout: bool = True,
+ alignment: int = 32,
+ enable_pdl: Optional[bool] = None,
+ backend: Literal["cuda", "cute-dsl"] = "cuda",
+ sf_swizzle_layout: Optional[SfLayout] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Quantize input tensor to MxFP8 format.
+
+ This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
+ with associated scale factors. It supports various input data types and scale factor layouts.
+
+ Args:
+ input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+ is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+ alignment (int, optional): sfVecSize. Defaults to 32.
+ enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+ If None, automatically detects based on device capability (SM >= 9.0). Defaults to None.
+ backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are:
--
+@flashinfer_api
+def mxfp8_dequantize_host(
+ input: torch.Tensor,
+ scale_tensor: torch.Tensor,
+ is_sf_swizzled_layout: bool = True,
+ sf_swizzle_layout: Optional[SfLayout] = None,
+) -> torch.Tensor:
+ """Dequantize input tensor from MxFP8 format.
+
+ This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
+ back to float values using the associated scale factors.
+
+ Args:
+ input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
+ scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
+ is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+ sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors.
+ If provided,it overrides is_sf_swizzled_layout. Defaults to None.
+ Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear.
+
+ Returns:
--
+@flashinfer_api
+def mxfp4_quantize_cute_dsl(
+ input: torch.Tensor,
+ enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantize input tensor to MXFP4 format using CuTe-DSL kernel.
+
+ This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior:
+ - Global scale computed as (448 * 6) / max(|input|)
+ - UE8M0 scale factors
+ - E2M1 output format (4-bit, 2 values per byte)
+ - Swizzled (128x4) scale factor layout
+
+ The kernel is compiled once per (K, dtype, pdl) combination and handles
+ varying M (batch size) at runtime without recompilation.
+
+ Args:
+ input: Input tensor of shape [M, K] with dtype fp16/bf16
+ enable_pdl: Whether to enable PDL (Programmatic Dependent Launch).
+ If None, automatically detects based on device capability (SM >= 9.0).
--
+@flashinfer_api
+def mxfp8_quantize_cute_dsl(
+ input: torch.Tensor,
+ is_sf_swizzled_layout: bool = True,
+ alignment: int = 32,
+ enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantize input tensor to MXFP8 format using CuTe-DSL kernel.
+
+ This is a GPU implementation with dual-path optimization:
+ - LINEAR layout: SF-block based iteration (fast)
+ - SWIZZLED layout: Row-based iteration with padding fast path (optimized)
+
+ The kernel is compiled once per (K, dtype, pdl) combination and handles
+ varying M (batch size) at runtime without recompilation.
+
+ Args:
+ input: Input tensor of shape [M, K] with dtype fp16/bf16
+ is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False)
+ alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE)
```
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Enhancements**
* Normalization now accepts scale as either a float or tensor; passing a
float emits a deprecation warning and is auto-converted for
compatibility.
* Attention/decoding API: cache-scale parameters are now optional
keyword-only arguments with sensible defaults, simplifying common call
patterns.
* **Tests**
* Tests updated to match the adjusted attention/decoding call signature.
* **Chores**
* Release version bumped to 0.6.7.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->1 parent c0dbc38 commit 1de1b97
5 files changed
Lines changed: 19 additions & 13 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2715 | 2715 | | |
2716 | 2716 | | |
2717 | 2717 | | |
2718 | | - | |
2719 | | - | |
2720 | 2718 | | |
2721 | 2719 | | |
2722 | 2720 | | |
2723 | 2721 | | |
2724 | 2722 | | |
2725 | 2723 | | |
2726 | 2724 | | |
| 2725 | + | |
| 2726 | + | |
2727 | 2727 | | |
2728 | 2728 | | |
2729 | 2729 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
| 29 | + | |
| 30 | + | |
30 | 31 | | |
31 | 32 | | |
32 | 33 | | |
| |||
62 | 63 | | |
63 | 64 | | |
64 | 65 | | |
65 | | - | |
| 66 | + | |
66 | 67 | | |
67 | | - | |
| 68 | + | |
68 | 69 | | |
69 | | - | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
70 | 77 | | |
71 | 78 | | |
72 | 79 | | |
| |||
159 | 166 | | |
160 | 167 | | |
161 | 168 | | |
162 | | - | |
| 169 | + | |
163 | 170 | | |
164 | 171 | | |
165 | 172 | | |
| |||
268 | 275 | | |
269 | 276 | | |
270 | 277 | | |
271 | | - | |
| 278 | + | |
272 | 279 | | |
273 | 280 | | |
274 | 281 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
155 | 155 | | |
156 | 156 | | |
157 | 157 | | |
158 | | - | |
159 | | - | |
160 | 158 | | |
161 | 159 | | |
162 | 160 | | |
| |||
174 | 172 | | |
175 | 173 | | |
176 | 174 | | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
177 | 178 | | |
178 | 179 | | |
179 | 180 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
346 | 346 | | |
347 | 347 | | |
348 | 348 | | |
349 | | - | |
350 | | - | |
351 | 349 | | |
352 | 350 | | |
353 | 351 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
| 1 | + | |
0 commit comments