Skip to content

feat: fmha fwd cutlass supports fp8 output#3177

Open
carlyou wants to merge 3 commits intoflashinfer-ai:mainfrom
carlyou:feat--fmha-fwd-cutlass-fp8-out
Open

feat: fmha fwd cutlass supports fp8 output#3177
carlyou wants to merge 3 commits intoflashinfer-ai:mainfrom
carlyou:feat--fmha-fwd-cutlass-fp8-out

Conversation

@carlyou
Copy link
Copy Markdown

@carlyou carlyou commented Apr 25, 2026

📌 Description

  • Extends SM100 CUTLASS FMHA dispatch to cover bf16/fp16 input with fp8_e4m3fn / fp8_e5m2 output.
  • Callers of BatchPrefillWithRaggedKVCacheWrapper(backend="cutlass") (notably vLLM's MLA prefill path) can now request fused-quantized output and skip the post-kernel BF16→FP8 quant op.

🔍 Related Issues

#3178

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • pytest tests/attention/test_blackwell_fmha.py on B200

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Corrected FP8 output scale handling so kernels receive the expected quant multiplier.
    • Improved mixed input/output dtype dispatch to handle FP8 output cases correctly.
  • New Features

    • Enforced dtype compatibility checks to block unsupported q/kv/out combinations and surface clear errors.
    • Output allocation now respects planned/cached output dtype for variable-length paths.
  • Tests

    • Added coverage for BF16/FP16 inputs with FP8 outputs, including ragged scenarios.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

📝 Walkthrough

Walkthrough

The PR expands SM100 FMHA FP8 dispatch logic, enforces dtype compatibility in Cutlass paged prefill planning, changes kernel argument to use the inverted output quant scale, and adds parametrized ragged prefill tests validating BF16/FP16→FP8 output behavior.

Changes

Cohort / File(s) Summary
Kernel dispatch & scale
csrc/fmha_cutlass_sm100.cu, include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh
Expanded DISPATCH_DTYPE_IN_OUT to handle FP8 o_data_type (e4m3fn/e5m2) distinctly; kernel runner now computes and passes inv_o_scale = 1/o_scale (default 1.0 when non-positive) into CUTLASS arguments.
Prefill planning & execution
flashinfer/prefill.py
Prefill Cutlass paged planner now validates q_data_type == kv_data_type and restricts allowed (q_data_type, o_data_type) pairs, raising ValueError for unsupported combos; output allocation uses the planned/cached o_data_type and removes duplicate/late dtype inference.
Tests
tests/attention/test_blackwell_fmha.py
Adds parametrized test test_blackwell_cutlass_fmha_bf16_in_fp8_out(...) for ragged prefill with BF16/FP16 inputs and FP8 outputs, including o_scale derivation, kernel run, dequantization and relaxed-tolerance validation; includes conditional skips for unsupported shapes/devices.

Sequence Diagram(s)

sequenceDiagram
  participant Test
  participant PrefillPlanner
  participant Dispatcher
  participant CUTLASSKernel
  participant Dequantizer
  participant Validator

  Test->>PrefillPlanner: plan(q_dtype, kv_dtype, o_dtype, shapes)
  PrefillPlanner-->>Test: cached o_data_type / plan
  Test->>Dispatcher: run(plan, o_scale, inputs)
  Dispatcher->>CUTLASSKernel: args(..., inv_o_scale = 1/o_scale, dtypes)
  CUTLASSKernel-->>Dequantizer: fp8_output
  Dequantizer-->>Validator: dequantized_output
  Validator-->>Test: compare with reference
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • cyx-6
  • yongwww
  • yzh119
  • qsang-nv
  • nv-yunzheq
  • jimmyzho
  • bkryu
  • saltyminty

Poem

🐰 I hopped through bytes and kernel code,

FP8 flowers along the road,
Scales inverted, dtypes aligned,
Prefill plans and tests combined,
A tiny rabbit cheers: "Good build!" 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: fmha fwd cutlass supports fp8 output' clearly and concisely summarizes the main change: adding FP8 output support to the CUTLASS FMHA forward implementation.
Description check ✅ Passed The description includes a detailed explanation of what the PR does, links a related issue, and completes the pre-commit checklist. Tests are partially addressed with one test marked as needed.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables FP8 output support for the Blackwell CUTLASS FMHA backend when using BF16 or FP16 inputs. Key changes include updating the dispatch logic in the CUDA source, adding validation in the Python prefill module to ensure compatible data types, and inverting the output scale to align with the kernel's requirements. A new test case has been added to verify these combinations. I have no feedback to provide.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/prefill.py`:
- Around line 3083-3108: The planned FP8 output dtype passed into planning is
not being used at execution: modify BatchPrefillWithRaggedKVCacheWrapper.run()
to accept and propagate the planned/output dtype (o_data_type) instead of
deriving out.dtype from q.dtype, pass that dtype through to fmha_varlen(), and
update fmha_varlen() to use the provided dtype_o when selecting/building the
CUTLASS module and when allocating the output tensor (instead of v.dtype or
q.dtype). Ensure all internal calls that allocate or infer the output tensor
(e.g., any out = torch.empty_like(...) or module.build(...) sites) take the
threaded dtype_o parameter so the FP8 output plan is honored end-to-end.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 640bb2a1-b197-469e-aa44-5bf56b77a155

📥 Commits

Reviewing files that changed from the base of the PR and between 5e1318c and 8c3297f.

📒 Files selected for processing (4)
  • csrc/fmha_cutlass_sm100.cu
  • flashinfer/prefill.py
  • include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh
  • tests/attention/test_blackwell_fmha.py

Comment thread flashinfer/prefill.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

3602-3617: ⚠️ Potential issue | 🟡 Minor

Direct fmha_varlen callers can't request FP8 output without pre-allocating out.

When invoked through BatchPrefillWithRaggedKVCacheWrapper.run() this is fine — out is always pre-allocated from self._cached_o_data_type (lines 3312-3325) before reaching here, so the FP8 plan flows through. But fmha_varlen is also a public-looking entry point (with @overloads), and for direct callers the only way to get FP8 output is to allocate out themselves; passing q as bf16/fp16 with out=None will silently fall back to q.dtype and build an FP16 module.

If direct FP8-output use isn't a supported path, please document that on fmha_varlen (or assert FP16-out for out is None); otherwise consider plumbing an explicit out_dtype parameter through to get_fmha_module/output allocation so it parallels the wrapper API.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3602 - 3617, fmha_varlen currently infers
out_dtype from out or q, which prevents callers from requesting FP8 output
unless they pre-allocate out; either make this explicit by adding an out_dtype
parameter to fmha_varlen (and propagate it through
BatchPrefillWithRaggedKVCacheWrapper.run, get_fmha_module and the output
allocation paths, using self._cached_o_data_type where appropriate) so callers
can request FP8 without preallocating out, or make the behavior explicit by
asserting out is not None when an FP8 out is requested (e.g., assert not (out is
None and desired FP8) or document/raise if out is None and q.dtype implies FP8),
and update get_fmha_module calls to use the new out_dtype variable instead of
deriving it from q when out is None.
🧹 Nitpick comments (1)
flashinfer/prefill.py (1)

3086-3095: Optional: hoist _SUPPORTED_CUTLASS_DTYPES to module scope.

This frozen set is rebuilt on every plan() invocation. Lifting it (and the equivalent kernel-macro comment) to a module-level constant — e.g. _SM100_FMHA_SUPPORTED_DTYPES = frozenset({...}) near get_fmha_module — makes the kernel↔Python contract easier to find/audit and trims a bit of per-plan allocation. Functionally equivalent.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3086 - 3095, The set
_SUPPORTED_CUTLASS_DTYPES is being rebuilt inside plan() each call; hoist it to
module scope as a frozenset constant (e.g. _SM100_FMHA_SUPPORTED_DTYPES) near
get_fmha_module so the kernel↔Python dtype contract is centralized and avoids
per-plan allocations, then replace references to _SUPPORTED_CUTLASS_DTYPES in
plan() with the new module-level constant.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3602-3617: fmha_varlen currently infers out_dtype from out or q,
which prevents callers from requesting FP8 output unless they pre-allocate out;
either make this explicit by adding an out_dtype parameter to fmha_varlen (and
propagate it through BatchPrefillWithRaggedKVCacheWrapper.run, get_fmha_module
and the output allocation paths, using self._cached_o_data_type where
appropriate) so callers can request FP8 without preallocating out, or make the
behavior explicit by asserting out is not None when an FP8 out is requested
(e.g., assert not (out is None and desired FP8) or document/raise if out is None
and q.dtype implies FP8), and update get_fmha_module calls to use the new
out_dtype variable instead of deriving it from q when out is None.

---

Nitpick comments:
In `@flashinfer/prefill.py`:
- Around line 3086-3095: The set _SUPPORTED_CUTLASS_DTYPES is being rebuilt
inside plan() each call; hoist it to module scope as a frozenset constant (e.g.
_SM100_FMHA_SUPPORTED_DTYPES) near get_fmha_module so the kernel↔Python dtype
contract is centralized and avoids per-plan allocations, then replace references
to _SUPPORTED_CUTLASS_DTYPES in plan() with the new module-level constant.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5695f570-8579-4a8e-8e9b-5e7081789f4b

📥 Commits

Reviewing files that changed from the base of the PR and between 8c3297f and 39ba39a.

📒 Files selected for processing (1)
  • flashinfer/prefill.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants