feat: fmha fwd cutlass supports fp8 output#3177
feat: fmha fwd cutlass supports fp8 output#3177carlyou wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThe PR expands SM100 FMHA FP8 dispatch logic, enforces dtype compatibility in Cutlass paged prefill planning, changes kernel argument to use the inverted output quant scale, and adds parametrized ragged prefill tests validating BF16/FP16→FP8 output behavior. Changes
Sequence Diagram(s)sequenceDiagram
participant Test
participant PrefillPlanner
participant Dispatcher
participant CUTLASSKernel
participant Dequantizer
participant Validator
Test->>PrefillPlanner: plan(q_dtype, kv_dtype, o_dtype, shapes)
PrefillPlanner-->>Test: cached o_data_type / plan
Test->>Dispatcher: run(plan, o_scale, inputs)
Dispatcher->>CUTLASSKernel: args(..., inv_o_scale = 1/o_scale, dtypes)
CUTLASSKernel-->>Dequantizer: fp8_output
Dequantizer-->>Validator: dequantized_output
Validator-->>Test: compare with reference
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request enables FP8 output support for the Blackwell CUTLASS FMHA backend when using BF16 or FP16 inputs. Key changes include updating the dispatch logic in the CUDA source, adding validation in the Python prefill module to ensure compatible data types, and inverting the output scale to align with the kernel's requirements. A new test case has been added to verify these combinations. I have no feedback to provide.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/prefill.py`:
- Around line 3083-3108: The planned FP8 output dtype passed into planning is
not being used at execution: modify BatchPrefillWithRaggedKVCacheWrapper.run()
to accept and propagate the planned/output dtype (o_data_type) instead of
deriving out.dtype from q.dtype, pass that dtype through to fmha_varlen(), and
update fmha_varlen() to use the provided dtype_o when selecting/building the
CUTLASS module and when allocating the output tensor (instead of v.dtype or
q.dtype). Ensure all internal calls that allocate or infer the output tensor
(e.g., any out = torch.empty_like(...) or module.build(...) sites) take the
threaded dtype_o parameter so the FP8 output plan is honored end-to-end.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 640bb2a1-b197-469e-aa44-5bf56b77a155
📒 Files selected for processing (4)
csrc/fmha_cutlass_sm100.cuflashinfer/prefill.pyinclude/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuhtests/attention/test_blackwell_fmha.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
3602-3617:⚠️ Potential issue | 🟡 MinorDirect
fmha_varlencallers can't request FP8 output without pre-allocatingout.When invoked through
BatchPrefillWithRaggedKVCacheWrapper.run()this is fine —outis always pre-allocated fromself._cached_o_data_type(lines 3312-3325) before reaching here, so the FP8 plan flows through. Butfmha_varlenis also a public-looking entry point (with@overloads), and for direct callers the only way to get FP8 output is to allocateoutthemselves; passingqas bf16/fp16 without=Nonewill silently fall back toq.dtypeand build an FP16 module.If direct FP8-output use isn't a supported path, please document that on
fmha_varlen(or assert FP16-out forout is None); otherwise consider plumbing an explicitout_dtypeparameter through toget_fmha_module/output allocation so it parallels the wrapper API.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3602 - 3617, fmha_varlen currently infers out_dtype from out or q, which prevents callers from requesting FP8 output unless they pre-allocate out; either make this explicit by adding an out_dtype parameter to fmha_varlen (and propagate it through BatchPrefillWithRaggedKVCacheWrapper.run, get_fmha_module and the output allocation paths, using self._cached_o_data_type where appropriate) so callers can request FP8 without preallocating out, or make the behavior explicit by asserting out is not None when an FP8 out is requested (e.g., assert not (out is None and desired FP8) or document/raise if out is None and q.dtype implies FP8), and update get_fmha_module calls to use the new out_dtype variable instead of deriving it from q when out is None.
🧹 Nitpick comments (1)
flashinfer/prefill.py (1)
3086-3095: Optional: hoist_SUPPORTED_CUTLASS_DTYPESto module scope.This frozen set is rebuilt on every
plan()invocation. Lifting it (and the equivalent kernel-macro comment) to a module-level constant — e.g._SM100_FMHA_SUPPORTED_DTYPES = frozenset({...})nearget_fmha_module— makes the kernel↔Python contract easier to find/audit and trims a bit of per-plan allocation. Functionally equivalent.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3086 - 3095, The set _SUPPORTED_CUTLASS_DTYPES is being rebuilt inside plan() each call; hoist it to module scope as a frozenset constant (e.g. _SM100_FMHA_SUPPORTED_DTYPES) near get_fmha_module so the kernel↔Python dtype contract is centralized and avoids per-plan allocations, then replace references to _SUPPORTED_CUTLASS_DTYPES in plan() with the new module-level constant.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3602-3617: fmha_varlen currently infers out_dtype from out or q,
which prevents callers from requesting FP8 output unless they pre-allocate out;
either make this explicit by adding an out_dtype parameter to fmha_varlen (and
propagate it through BatchPrefillWithRaggedKVCacheWrapper.run, get_fmha_module
and the output allocation paths, using self._cached_o_data_type where
appropriate) so callers can request FP8 without preallocating out, or make the
behavior explicit by asserting out is not None when an FP8 out is requested
(e.g., assert not (out is None and desired FP8) or document/raise if out is None
and q.dtype implies FP8), and update get_fmha_module calls to use the new
out_dtype variable instead of deriving it from q when out is None.
---
Nitpick comments:
In `@flashinfer/prefill.py`:
- Around line 3086-3095: The set _SUPPORTED_CUTLASS_DTYPES is being rebuilt
inside plan() each call; hoist it to module scope as a frozenset constant (e.g.
_SM100_FMHA_SUPPORTED_DTYPES) near get_fmha_module so the kernel↔Python dtype
contract is centralized and avoids per-plan allocations, then replace references
to _SUPPORTED_CUTLASS_DTYPES in plan() with the new module-level constant.
📌 Description
BatchPrefillWithRaggedKVCacheWrapper(backend="cutlass")(notably vLLM's MLA prefill path) can now request fused-quantized output and skip the post-kernel BF16→FP8 quant op.🔍 Related Issues
#3178
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
pytest tests/attention/test_blackwell_fmha.pyon B200Reviewer Notes
Summary by CodeRabbit
Bug Fixes
New Features
Tests