Skip to content

feat: FP32 dtype output for BF16 matmuls (CUTLASS & cuDNN)#2644

Merged
bkryu merged 1 commit intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/bf16_mm_fp32_out
Mar 18, 2026
Merged

feat: FP32 dtype output for BF16 matmuls (CUTLASS & cuDNN)#2644
bkryu merged 1 commit intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/bf16_mm_fp32_out

Conversation

@raayandhar
Copy link
Copy Markdown
Contributor

@raayandhar raayandhar commented Feb 26, 2026

📌 Description

Adds support for FP32 dtype output for mm_bf16 and bmm_bf16 for the CUTLASS and cuDNN backends. I'm not familiar enough with the TGV kernel to know if / how to support it for that backend.

🔍 Related Issues

#2624

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • BF16-based matrix ops (mm_bf16, bmm_bf16) now allow float32 outputs in addition to bfloat16 and float16; supported across applicable backends.
  • Tests

    • Tests extended to cover float32 outputs for BF16/GEMM operations.
  • Documentation

    • User-facing docs and validation messages updated to list bf16, fp16, fp32 as valid output dtypes.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 26, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b005d881-628c-492e-a2c3-bbf68a10d676

📥 Commits

Reviewing files that changed from the base of the PR and between 00c9897 and 78061e4.

📒 Files selected for processing (6)
  • benchmarks/routines/gemm.py
  • csrc/bf16_gemm_cutlass.cu
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/core.py
  • tests/gemm/test_bmm_bf16.py
  • tests/gemm/test_mm_bf16.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/gemm/test_bmm_bf16.py
  • flashinfer/gemm/gemm_base.py

📝 Walkthrough

Walkthrough

Adds fp32 (torch.float32) as a supported output dtype for BF16 GEMM across CUTLASS runtime, Python validation/JIT, benchmarks, and tests; updates dispatch, dtype mappings, docs, and tests to include fp32.

Changes

Cohort / File(s) Summary
CUTLASS Backend Implementation
csrc/bf16_gemm_cutlass.cu
Adds float dispatch branch, instantiates CutlassBf16GemmRunner<float>, and updates error message to include fp32.
Python API and Validation
flashinfer/gemm/gemm_base.py
Accepts fp32 in _validate_bf16_output_dtype, maps torch.float32 to cuDNN FLOAT, and updates docstrings for mm_bf16/bmm_bf16.
JIT Kernel Generation
flashinfer/jit/gemm/core.py
Adds "float" to dtype_list for BF16 Cutlass kernel generation (generates fp32 kernel variant).
Tests & Benchmarks
benchmarks/routines/gemm.py, tests/gemm/test_mm_bf16.py, tests/gemm/test_bmm_bf16.py
Extends test/benchmark parameterization to include torch.float32, and updates skip logic and error assertions accordingly.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PythonAPI as Python API (gemm_base)
    participant JIT as JIT (gen_gemm_sm100_module_cutlass_bf16)
    participant Native as Native/C++ (bf16_gemm_cutlass.cu)
    participant GPU

    User->>PythonAPI: Call mm_bf16/bmm_bf16(out_dtype=float32)
    PythonAPI->>PythonAPI: validate out_dtype (accept fp32)
    PythonAPI->>JIT: request kernel variant (dtype_list includes "float")
    JIT->>Native: load/compile kernel variant (fp32 variant)
    Native->>Native: dispatch runGemm<float> branch
    Native->>GPU: execute CUTLASS kernel (fp32 output path)
    GPU-->>Native: results
    Native-->>PythonAPI: return tensor (fp32)
    PythonAPI-->>User: deliver output tensor
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related issues

Possibly related PRs

Suggested labels

op: gemm

Suggested reviewers

  • yongwww
  • nvmbreughe
  • aleozlx
  • djmmoss
  • jimmyzho
  • jiahanc
  • cyx-6
  • yzh119

Poem

🐰 A soft hop for dtype delight,
bfloat, half, and now float take flight,
Kernels compiled and tests all sing,
CUTLASS hums — three types in spring! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.45% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title accurately and concisely describes the main change: adding FP32 dtype output support for BF16 matmul operations across CUTLASS and cuDNN backends, which aligns with the comprehensive changes across multiple files.
Description check ✅ Passed The description adequately covers the main objective, provides a related issue link, and confirms that pre-commit checks and tests have been completed. All major required template sections are addressed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can generate a title for your PR based on the changes with custom instructions.

Set the reviews.auto_title_instructions setting to generate a title for your PR based on the changes in the PR with custom instructions.

@raayandhar
Copy link
Copy Markdown
Contributor Author

(flashinfer) root@36978a460ca6:/flashinfer# pytest tests/gemm/test_mm_bf16.py
========================================================= test session starts ==========================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 1620 items

tests/gemm/test_mm_bf16.py sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  6%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 14%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 22%]
ssssssssssssssssssssssssssssssssssssssssssssssss................................................................................ [ 29%]
.......................................................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 37%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 45%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 53%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.................................................... [ 61%]
................................................................................................................................ [ 69%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss...................................... [ 77%]
.......ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss............................... [ 85%]
..............ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss........................ [ 93%]
.....................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                  [100%]

============================================ 450 passed, 1170 skipped in 200.02s (0:03:20) =============================================
(flashinfer) root@36978a460ca6:/flashinfer# pytest tests/gemm/test_bmm_bf16.py
========================================================= test session starts ==========================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 96 items

tests/gemm/test_bmm_bf16.py ................................................................................................     [100%]

========================================================== 96 passed in 9.07s ==========================================================

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for FP32 as an output data type for BF16 matrix multiplication operations (mm_bf16 and bmm_bf16) when utilizing the CUTLASS and cuDNN backends. This enhancement allows users to perform BF16 computations and receive results in higher precision, addressing a common requirement for certain deep learning workloads. The changes span across core C++ CUDA implementations, Python API definitions, JIT compilation, and associated benchmarks and tests.

Highlights

  • FP32 Output Support: Enabled FP32 as an output data type for BF16 matrix multiplication operations (mm_bf16 and bmm_bf16) in both CUTLASS and cuDNN backends.
  • Backend Integration: Integrated FP32 output support into the C++ CUDA kernels, Python API definitions, and JIT compilation logic for BF16 GEMM operations.
  • Testing and Benchmarking: Updated benchmark scripts and extended unit tests to validate the new FP32 output functionality for BF16 GEMM operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/routines/gemm.py
    • Allowed float32 as a supported output dtype for testMmBf16 and testBmmBf16 functions.
  • csrc/bf16_gemm_cutlass.cu
    • Instantiated CutlassBf16GemmRunner for float type.
    • Added a case to bf16_bmm_impl to handle float32_code for GEMM execution.
    • Updated the error message for unsupported output dtypes to include fp32.
  • flashinfer/gemm/gemm_base.py
    • Updated the error message for the TGV backend to mention cuDNN as an alternative.
    • Modified docstrings for mm_bf16 and bmm_bf16 to reflect fp32 support for out tensor and out_dtype.
    • Updated the _validate_bf16_output_dtype function to include torch.float32 as a valid output dtype.
    • Added mapping for torch.float32 to cudnn.data_type.FLOAT in _torch_data_type_to_cudnn_data_type.
  • flashinfer/jit/gemm/core.py
    • Included float in the dtype_list for generating CUTLASS BF16 GEMM kernels.
  • tests/gemm/test_bmm_bf16.py
    • Added torch.float32 to the res_dtype parameter list for test_bmm_bf16.
  • tests/gemm/test_mm_bf16.py
    • Added torch.float32 to the res_dtype parameter list for test_mm_bf16.
    • Updated the skip condition for the TGV backend to correctly handle all non-bfloat16 result dtypes.
Activity
  • Pre-commit checks have been successfully run and passed.
  • Tests have been added/updated and are all passing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FP32 output data type for BF16 matrix multiplications (mm_bf16 and bmm_bf16) for both CUTLASS and cuDNN backends. The changes are comprehensive, touching the C++ implementation, Python bindings, JIT compilation, documentation, and tests. The implementation correctly adds template instantiations and dispatch logic for float output in the CUTLASS backend, and updates the type mapping for cuDNN. The tests are properly extended to cover the new FP32 output functionality. The changes are well-executed and appear correct.

@bkryu bkryu added the run-ci label Feb 27, 2026
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Feb 27, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !353 has been created, and the CI pipeline #44988476 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44988476: 10/20 passed

@raayandhar
Copy link
Copy Markdown
Contributor Author

@bkryu when you get the chance, could you rerun CI?

@raayandhar
Copy link
Copy Markdown
Contributor Author

@bkryu bumping again

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 17, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !353 has been created, and the CI pipeline #46375530 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #46375530: canceled

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 17, 2026

/bot run

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 17, 2026

@nv-yunzheq, can you help review this PR?

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !353 has been updated with latest changes, and the CI pipeline #46376023 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

if out_dtype != torch.bfloat16:
raise ValueError(
"You cannot provide an output dtype to the TGV backend. Use the CUTLASS backend instead."
"You cannot provide an output dtype to the TGV backend. Use the CUTLASS or cuDNN backend instead."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seem to be a fix for an old incorrect information. Is it true?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have tests for it.
Exception is that for SM103 it doesn't work...
https://github.com/flashinfer-ai/flashinfer/blob/main/tests/gemm/test_mm_bf16.py#L51
Worth mentioning here you think?

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46376023: 13/20 passed

@bkryu bkryu merged commit b92da57 into flashinfer-ai:main Mar 18, 2026
32 of 33 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…r-ai#2644)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Adds support for FP32 dtype output for `mm_bf16` and `bmm_bf16` for the
CUTLASS and cuDNN backends. I'm not familiar enough with the TGV kernel
to know if / how to support it for that backend.

## 🔍 Related Issues

flashinfer-ai#2624

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [X] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [X] I have installed the hooks with `pre-commit install`.
- [X] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [X] Tests have been added or updated as needed.
- [X] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* BF16-based matrix ops (mm_bf16, bmm_bf16) now allow float32 outputs in
addition to bfloat16 and float16; supported across applicable backends.

* **Tests**
  * Tests extended to cover float32 outputs for BF16/GEMM operations.

* **Documentation**
* User-facing docs and validation messages updated to list bf16, fp16,
fp32 as valid output dtypes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…r-ai#2644)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Adds support for FP32 dtype output for `mm_bf16` and `bmm_bf16` for the
CUTLASS and cuDNN backends. I'm not familiar enough with the TGV kernel
to know if / how to support it for that backend.

## 🔍 Related Issues

flashinfer-ai#2624

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [X] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [X] I have installed the hooks with `pre-commit install`.
- [X] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [X] Tests have been added or updated as needed.
- [X] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* BF16-based matrix ops (mm_bf16, bmm_bf16) now allow float32 outputs in
addition to bfloat16 and float16; supported across applicable backends.

* **Tests**
  * Tests extended to cover float32 outputs for BF16/GEMM operations.

* **Documentation**
* User-facing docs and validation messages updated to list bf16, fp16,
fp32 as valid output dtypes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants