Skip to content

Commit 1de1b97

Browse files
authored
bump version to 0.6.7 & fix api breaking changes (#2832)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description fix api breaking changes for 0.6.7 release ## πŸ” Related Issues (Gated-by PRs) https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7 <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes **API changes review** API changes since v0.6.6 PR #2520 + commit e35c19e (fixed to be compatible) Function: xqa() Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params (after *). Backward-compatible. PR #2618 (has PR #2730 to fix it) Function: gated_delta_rule_mtp() Change: disable_state_update: bool = True β†’ Optional[bool] = None. Still defaults to True at runtime but emits a deprecation warning; will flip to False in 0.7.0. PR #2775 (expected β€” cute DSL MoE cleanup) Function: blockscaled_contiguous_grouped_gemm_nvfp4() Change: Entire @flashinfer_api decorated function deleted. Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4() Change: Entire @flashinfer_api decorated function deleted. Function: blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4() Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param. Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4() Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param. Function: CuteDslMoEWrapper.__init__() Change: Added enable_pdl: bool = True param. Backward-compatible. Function: cute_dsl_fused_moe_nvfp4() Change: Added enable_pdl: bool = True param. Backward-compatible. PR #2428 Function: rmsnorm_quant() Change: scale: float β†’ scale: Union[float, torch.Tensor]; return type torch.Tensor β†’ None. Function: fused_add_rmsnorm_quant() Change: scale: float β†’ scale: Union[float, torch.Tensor]. Quantization functions (relocated, not removed) All quantization APIs (fp4_quantize, block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a, nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize, mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host, mxfp8_quantize, mxfp8_dequantize_host) were moved from flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to flashinfer/quantization/. Signatures, @flashinfer_api decorators, and __init__.py exports are preserved. No breakage. ```diff $ git diff v0.6.6 | grep -A20 "@flashinfer_api" @flashinfer_api @@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper: sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, skip_softmax_threshold_scale_factor: Optional[float] = None, + kv_block_scales: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch decode attention between query and paged kv cache. @@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper: enable_pdl = device_support_pdl(q.device) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) + # Unpack kv_block_scales + key_block_scales = None + value_block_scales = None + if kv_block_scales is not None: + if isinstance(kv_block_scales, tuple): + key_block_scales, value_block_scales = kv_block_scales -- -@flashinfer_api -def fp4_quantize( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - is_sf_swizzled_layout: bool = True, - is_sf_8x4_layout: bool = False, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to FP4 format. - - This function implements FP4 quantization that converts input tensors to a compressed FP4 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- -@flashinfer_api -def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: - """Swizzle block scale tensor for FP4 format. - - This function swizzles the block scale tensor to optimize memory access patterns - for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. - - Args: - unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. - - Returns: - torch.Tensor: Swizzled tensor with the same shape as input. - - Raises: - AssertionError: If input dtype is not uint8 or bfloat16. - """ - # TODO(shuw): check input dtype is uint8 - assert ( - unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 - ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" - -- -@flashinfer_api -def e2m1_and_ufp8sf_scale_to_float( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - ufp8_type: int = 1, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Convert E2M1 format tensor and UFP8 scale factors to float tensor. - - This function performs dequantization by converting a packed FP4 tensor in E2M1 format - back to float values using the associated UFP8 scale factors and global scale. - - Args: - e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. - ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. - global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- -@flashinfer_api -def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: - """ - PyTorch equivalent of trtllm-gen `shuffleMatrixA` - """ - row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - - return input_tensor[row_indices.to(input_tensor.device)] - - -@flashinfer_api -def shuffle_matrix_sf_a( - input_tensor: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: int = 16, -): - """ - Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. - `shuffleMatrixSfA` expects the input to be in 128x4 layout and then - apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 - layout. - This function expects the input to be in linear layout. It's done this - way because the scaling factors in the NVFP4 checkpoints are quantized - and are in linear layout. - This function doesn't add padding. - """ - - row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) - - w_shuffled = input_tensor[row_indices.to(input_tensor.device)] - -- -@flashinfer_api -def nvfp4_quantize( - a, - a_global_sf, - sfLayout=SfLayout.layout_128x4, - do_shuffle=False, - sf_vec_size=16, - enable_pdl=None, -): - """ - Quantize input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. - do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - -- -@flashinfer_api -def mxfp4_quantize(a): - """ - Quantize input tensor to MXFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - - Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - """ - a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() - a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) - return a_fp4, a_sf - - -@flashinfer_api -def mxfp4_dequantize(a_fp4, a_sf): - """ - Dequantize input tensor from MXFP4 format. - - Parameters: - a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - return e2m1_and_ufp8sf_scale_to_float( - a_fp4.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8).reshape(-1), - torch.tensor([1.0], device=a_fp4.device), - 32, - 0, - True, - ) - -- -@flashinfer_api -def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, -) -> torch.Tensor: - """ - Dequantize input tensor from MXFP4 format on host. - - Parameters: - weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - group_size (int, optional): Group size for dequantization. Defaults to 32. - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future - major, minor = get_compute_capability( - torch.device("cuda:0") - ) # use any cuda device to get a compute capability -- -@flashinfer_api -def nvfp4_batched_quantize( - a, - a_global_sf, - sf_vec_size=16, -): - """ - Quantize batched input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" -- -@flashinfer_api -def scaled_fp4_grouped_quantize( - a, - mask, - a_global_sf, -): - """ - quantize batched input tensor to NVFP4 format with mask. - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - mask (torch.Tensor): Mask tensor to apply before quantization. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" - a_fp4, a_sf = get_fp4_quantization_module( - device_arch -- -@flashinfer_api -def mxfp8_quantize( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to MxFP8 format. - - This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - alignment (int, optional): sfVecSize. Defaults to 32. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 -- -@flashinfer_api -def mxfp8_dequantize_host( - input: torch.Tensor, - scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Dequantize input tensor from MxFP8 format. - - This function performs dequantization by converting a packed FP8 tensor in MxFP8 format - back to float values using the associated scale factors. - - Args: - input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. - scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - - """ - -- -@flashinfer_api def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( vectorized_f32: bool = True, raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads. @@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( major, minor = get_compute_capability(a.device) if major != 10: raise ValueError( - f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). " + f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). " f"Got SM{major}{minor}." ) -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (128, 128), - cluster_shape_mn: Tuple[int, int] = (1, 1), - sm_count: Optional[int] = None, -) -> torch.Tensor: - """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization. - -- -@flashinfer_api def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( cluster_shape_mn: Tuple[int, int] = (2, 1), raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads. @@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1. token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16 out: Optional output tensor, shape (seq_len, n). Created if None. - This tensor is used for atomic accumulation, so it should be zero-initialized. + This tensor is used for atomic accumulation. If `out` is + provided, it must already be zero-initialized by the caller. + If `out` is None, this function allocates a zero-initialized + output tensor. Passing a non-zeroed `out` buffer will silently -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - out_scale: Optional[torch.Tensor] = None, - global_scale: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (256, 128), - cluster_shape_mn: Tuple[int, int] = (2, 1), - vectorized_f32: bool = True, - sm_count: Optional[int] = None, -- @flashinfer_api def __init__( self, @@ -347,6 +355,7 @@ class CuteDslMoEWrapper: sf_vec_size: int = 16, output_dtype: torch.dtype = torch.bfloat16, device: str = "cuda", + enable_pdl: bool = True, ): """Initialize the MoE wrapper. @@ -363,6 +372,7 @@ class CuteDslMoEWrapper: sf_vec_size: Scale factor vector size. Default: 16. output_dtype: Output data type. Default: torch.bfloat16. device: Device for buffer allocation. Default: "cuda". + enable_pdl: Enable Programmatic Dependent Launch. Default: True. """ self.num_experts = num_experts self.top_k = top_k @@ -376,6 +386,7 @@ class CuteDslMoEWrapper: self.sf_vec_size = sf_vec_size -- @flashinfer_api @@ -550,9 +570,10 @@ class CuteDslMoEWrapper: f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})" ) - # Allocate output buffer if not using pre-allocated one + # Slice the pre-allocated buffer to the active batch so that + # _moe_core_impl only zeros num_tokens rows, not max_num_tokens. if self.use_cuda_graph: - moe_output = self._moe_output + moe_output = self._moe_output[:num_tokens] else: moe_output = torch.empty( (num_tokens, self.hidden_size), @@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Internal implementation called by auto-tuner for functional API.""" -- @flashinfer_api def cute_dsl_fused_moe_nvfp4( x: torch.Tensor, @@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Run fused MoE computation using CuteDSL NVFP4 kernels. + Supported architectures: SM100, SM103. + This is the simple functional API. For CUDA graph support, use `CuteDslMoEWrapper` instead. @@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4( local_expert_offset=local_expert_offset, use_fused_finalize=use_fused_finalize, output_dtype=output_dtype, + enable_pdl=enable_pdl, -- @flashinfer_api def gated_delta_rule_decode_pretranspose( q: torch.Tensor, @@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose( - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used (supports both the direct ``state`` path and the pool+indices path). - - pool+indices (``initial_state``/``initial_state_indices``) only supported - via the bf16 fast path; float32 state raises an error. + - pool+indices (``initial_state``/``initial_state_indices``) supported on + both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path + (T=1). The float32 path also supports negative indices for padding. - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes @@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose( return_state = initial_state if use_pool else state return output, return_state - # Legacy path: T=1 only, float32 state (no pool+indices support) - assert not use_pool, ( -- @flashinfer_api def gated_delta_rule_mtp( q: torch.Tensor, @@ -2427,7 +489,7 @@ def gated_delta_rule_mtp( scale: Optional[float] = None, output: Optional[torch.Tensor] = None, intermediate_states_buffer: Optional[torch.Tensor] = None, - disable_state_update: bool = True, + disable_state_update: Optional[bool] = None, use_qk_l2norm: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -2463,8 +525,15 @@ def gated_delta_rule_mtp( intermediate_states_buffer (Optional[torch.Tensor]): Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``. If None, intermediate states are not cached. - disable_state_update (bool): - If True, the initial state is not updated. Default: ``True``. + disable_state_update (Optional[bool]): + If True, the initial state is not updated. Currently defaults to ``True``. + Please pass this argument explicitly β€” the default will change to ``False`` -- @flashinfer_api @@ -60,16 +120,14 @@ def rmsnorm( output: torch.Tensor Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, enable_pdl) + _rmsnorm_impl(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",)) -def _rmsnorm( +def _rmsnorm_impl( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -78,11 +136,21 @@ def _rmsnorm( -- @flashinfer_api def fmha_v2_prefill_deepseek( query: torch.Tensor, @@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek( If return_lse is False, the output will be a single tensor. """ if not is_sm12x_supported(query.device): - major, minor = get_compute_capability(query.device) - if major == 12: - min_cuda = "13.0" if minor >= 1 else "12.8" - raise ValueError( - f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} " - f"for SM12{minor}x GPUs." - ) raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.") assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, ( "currently only support deepseek r1 192 query and 128 value" ) - module = get_trtllm_fmha_v2_module() + module = get_trtllm_fmha_v2_sm120_module() is_e4m3 = query.dtype == torch.float8_e4m3fn -- +@flashinfer_api +def trtllm_fmha_v2_prefill( + qkv: Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ], + input_layout: str, + workspace_buffer: torch.Tensor, + seq_lens: torch.Tensor, + max_q_len: int, + max_kv_len: int, + bmm1_scale: float, + bmm2_scale: float, + batch_size: int, + cum_seq_lens_q: torch.Tensor, + cum_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[Union[torch.dtype, str]] = None, + sinks: Optional[List[torch.Tensor]] = None, -- +@flashinfer_api +def fp4_quantize( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to FP4 format. + + This function implements FP4 quantization that converts input tensors to a compressed FP4 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- +@flashinfer_api +def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: + """Swizzle block scale tensor for FP4 format. + + This function swizzles the block scale tensor to optimize memory access patterns + for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. + + Args: + unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. + + Returns: + torch.Tensor: Swizzled tensor with the same shape as input. + + Raises: + AssertionError: If input dtype is not uint8 or bfloat16. + """ + # TODO(shuw): check input dtype is uint8 + assert ( + unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 + ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" + -- +@flashinfer_api +def e2m1_and_ufp8sf_scale_to_float( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, +) -> torch.Tensor: + """Convert E2M1 format tensor and UFP8 scale factors to float tensor. + + This function performs dequantization by converting a packed FP4 tensor in E2M1 format + back to float values using the associated UFP8 scale factors and global scale. + + Args: + e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. + ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. + global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. + is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- +@flashinfer_api +def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: + """ + PyTorch equivalent of trtllm-gen `shuffleMatrixA` + """ + row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) + + return input_tensor[row_indices.to(input_tensor.device)] + + +@flashinfer_api +def shuffle_matrix_sf_a( + input_tensor: torch.Tensor, + epilogue_tile_m: int, + num_elts_per_sf: int = 16, +): + """ + Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. + `shuffleMatrixSfA` expects the input to be in 128x4 layout and then + apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 + layout. + This function expects the input to be in linear layout. It's done this + way because the scaling factors in the NVFP4 checkpoints are quantized + and are in linear layout. + This function doesn't add padding. + """ + + row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) + + w_shuffled = input_tensor[row_indices.to(input_tensor.device)] + -- +@flashinfer_api +def nvfp4_quantize( + a, + a_global_sf, + sfLayout=SfLayout.layout_128x4, + do_shuffle=False, + sf_vec_size=16, + enable_pdl=None, +): + """ + Quantize input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. + do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + -- +@flashinfer_api +def mxfp4_quantize( + a: torch.Tensor, + backend: str = "cuda", + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + backend (str, optional): Backend to use for quantization. + - "cuda": Use CUDA kernel (default, stable) + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**) + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic + Dependent Launch). Only used when backend="cute-dsl". + If None, automatically detects based on device capability. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) -- +@flashinfer_api +def mxfp4_dequantize(a_fp4, a_sf): + """ + Dequantize input tensor from MXFP4 format. + + Parameters: + a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + return e2m1_and_ufp8sf_scale_to_float( + a_fp4.cpu().view(torch.uint8), + a_sf.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0], device=a_fp4.device), + 32, + 0, + True, + ) + -- +@flashinfer_api +def mxfp4_dequantize_host( + weight: torch.Tensor, + scale: torch.Tensor, + group_size: int = 32, +) -> torch.Tensor: + """ + Dequantize input tensor from MXFP4 format on host. + + Parameters: + weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + group_size (int, optional): Group size for dequantization. Defaults to 32. + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future + major, minor = get_compute_capability( + torch.device("cuda:0") + ) # use any cuda device to get a compute capability -- +@flashinfer_api +def nvfp4_batched_quantize( + a, + a_global_sf, + sf_vec_size=16, +): + """ + Quantize batched input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" -- +@flashinfer_api +def nvfp4_quantize_paged_kv_cache( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_layout: str = "HND", + k_global_sf: Optional[torch.Tensor] = None, + v_global_sf: Optional[torch.Tensor] = None, +) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + float, + float, +]: + """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA. + + Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling + (global FP32 + per-block FP8), and swizzles scale factors + for the SM100 trtllm-gen MHA kernel layout. + + Args: + k_cache: Key cache tensor. -- +@flashinfer_api +def scaled_fp4_grouped_quantize( + a, + mask, + a_global_sf, +): + """ + quantize batched input tensor to NVFP4 format with mask. + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + mask (torch.Tensor): Mask tensor to apply before quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" + a_fp4, a_sf = get_fp4_quantization_module( + device_arch -- +@flashinfer_api +def nvfp4_kv_dequantize( + fp4_data: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """GPU dequantization of NVFP4 KV cache data with linear block scale layout. + + Requires SM80+. + + Args: + fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]`` + with dtype uint8. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as fp4_data. + output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``. + + Returns: + torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype. -- +@flashinfer_api +def nvfp4_kv_quantize( + input: torch.Tensor, + global_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """GPU quantization to NVFP4 KV cache format with linear block scale layout. + + Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16. + K must be divisible by 16. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as input. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8. + """ + M, K = input.shape -- +@flashinfer_api +def mxfp8_quantize( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: Optional[bool] = None, + backend: Literal["cuda", "cute-dsl"] = "cuda", + sf_swizzle_layout: Optional[SfLayout] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to MxFP8 format. + + This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + alignment (int, optional): sfVecSize. Defaults to 32. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). Defaults to None. + backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are: -- +@flashinfer_api +def mxfp8_dequantize_host( + input: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: Optional[SfLayout] = None, +) -> torch.Tensor: + """Dequantize input tensor from MxFP8 format. + + This function performs dequantization by converting a packed FP8 tensor in MxFP8 format + back to float values using the associated scale factors. + + Args: + input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. + scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors. + If provided,it overrides is_sf_swizzled_layout. Defaults to None. + Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear. + + Returns: -- +@flashinfer_api +def mxfp4_quantize_cute_dsl( + input: torch.Tensor, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format using CuTe-DSL kernel. + + This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior: + - Global scale computed as (448 * 6) / max(|input|) + - UE8M0 scale factors + - E2M1 output format (4-bit, 2 values per byte) + - Swizzled (128x4) scale factor layout + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). -- +@flashinfer_api +def mxfp8_quantize_cute_dsl( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP8 format using CuTe-DSL kernel. + + This is a GPU implementation with dual-path optimization: + - LINEAR layout: SF-block based iteration (fast) + - SWIZZLED layout: Row-based iteration with padding fast path (optimized) + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False) + alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE) ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Enhancements** * Normalization now accepts scale as either a float or tensor; passing a float emits a deprecation warning and is auto-converted for compatibility. * Attention/decoding API: cache-scale parameters are now optional keyword-only arguments with sensible defaults, simplifying common call patterns. * **Tests** * Tests updated to match the adjusted attention/decoding call signature. * **Chores** * Release version bumped to 0.6.7. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent c0dbc38 commit 1de1b97

5 files changed

Lines changed: 19 additions & 13 deletions

File tree

β€Žflashinfer/decode.pyβ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2715,15 +2715,15 @@ def xqa_batch_decode_with_kv_cache(
27152715
query_new,
27162716
k_cache,
27172717
v_cache,
2718-
k_cache_sf,
2719-
v_cache_sf,
27202718
block_tables,
27212719
seq_lens_new,
27222720
out_4d,
27232721
scratch,
27242722
semaphore,
27252723
num_kv_heads,
27262724
page_size,
2725+
k_sf_cache=k_cache_sf,
2726+
v_sf_cache=v_cache_sf,
27272727
sinks=sinks_new,
27282728
q_scale=q_scale_value,
27292729
kv_scale=kv_scale_value,

β€Žflashinfer/norm/__init__.pyβ€Ž

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727
import functools
2828
import os
29-
from typing import Optional
29+
import warnings
30+
from typing import Optional, Union
3031

3132
import torch
3233

@@ -62,11 +63,17 @@ def get_norm_module():
6263

6364

6465
def _normalize_scale_tensor(
65-
scale: torch.Tensor, ref_tensor: torch.Tensor
66+
scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor
6667
) -> torch.Tensor:
67-
"""Normalize quantization scale tensor to 1D shape (1,) on target device."""
68+
"""Normalize quantization scale to 1D tensor of shape (1,) on target device."""
6869
if not isinstance(scale, torch.Tensor):
69-
raise TypeError(f"scale must be torch.Tensor, got {type(scale)}")
70+
warnings.warn(
71+
"Passing scale as a float is deprecated and will be removed in a future "
72+
"release. Use a torch.Tensor of shape (1,) instead.",
73+
FutureWarning,
74+
stacklevel=3,
75+
)
76+
scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device)
7077
if scale.device != ref_tensor.device:
7178
scale = scale.to(ref_tensor.device)
7279
if scale.dtype != torch.float32:
@@ -159,7 +166,7 @@ def rmsnorm_quant(
159166
out: torch.Tensor,
160167
input: torch.Tensor,
161168
weight: torch.Tensor,
162-
scale: torch.Tensor,
169+
scale: Union[float, torch.Tensor],
163170
eps: float = 1e-6,
164171
enable_pdl: Optional[bool] = None,
165172
) -> None:
@@ -268,7 +275,7 @@ def fused_add_rmsnorm_quant(
268275
input: torch.Tensor,
269276
residual: torch.Tensor,
270277
weight: torch.Tensor,
271-
scale: torch.Tensor,
278+
scale: Union[float, torch.Tensor],
272279
eps: float = 1e-6,
273280
enable_pdl: Optional[bool] = None,
274281
) -> None:

β€Žflashinfer/xqa.pyβ€Ž

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def xqa(
155155
q: torch.Tensor,
156156
k_cache: torch.Tensor,
157157
v_cache: torch.Tensor,
158-
k_sf_cache: Optional[torch.Tensor],
159-
v_sf_cache: Optional[torch.Tensor],
160158
page_table: torch.Tensor,
161159
seq_lens: torch.Tensor,
162160
output: torch.Tensor,
@@ -174,6 +172,9 @@ def xqa(
174172
rcp_out_scale: float = 1.0,
175173
q_seq_len: int = 1,
176174
mask: Optional[torch.Tensor] = None,
175+
*,
176+
k_sf_cache: Optional[torch.Tensor] = None,
177+
v_sf_cache: Optional[torch.Tensor] = None,
177178
) -> None:
178179
r"""Apply attention with paged KV cache using XQA kernel.
179180
Parameters

β€Žtests/attention/test_xqa.pyβ€Ž

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,6 @@ def test_xqa(
346346
q_heads,
347347
cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads,
348348
cache_v_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_v_heads,
349-
None,
350-
None,
351349
page_list_arg,
352350
seq_len_list,
353351
output,

β€Žversion.txtβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.6.6
1+
0.6.7

0 commit comments

Comments
Β (0)