[wip] feat: add bias support to TGV and CUTLASS BF16 GEMM#2329
[wip] feat: add bias support to TGV and CUTLASS BF16 GEMM#2329
Conversation
- Make bias parameter optional in tgv_gemm_sm100() with full validation - Thread bias parameter through CUTLASS BF16 GEMM API (C++ and Python) - Add TODO for CUTLASS epilogue bias fusion (currently not applied) - Maintain backward compatibility (bias defaults to None) TGV GEMM bias support is fully functional. CUTLASS GEMM accepts bias in API but kernel fusion is pending. Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds optional bias parameter to BF16 GEMM interfaces across CUTLASS and TGV backends, threading it from Python API through CUDA runtime, templates, and dispatcher layers with validation logic for shape and dtype compatibility. Changes
Sequence DiagramsequenceDiagram
participant PythonAPI as Python API<br/>(mm_bf16)
participant Validation as Runtime<br/>Validation<br/>(csrc)
participant Dispatch as Template<br/>Dispatcher
participant Launcher as Kernel<br/>Launcher
participant Kernel as CUTLASS<br/>Kernel
PythonAPI->>Validation: Call bf16_gemm(mat1, mat2, bias, ...)
activate Validation
Validation->>Validation: Validate bias shape,<br/>dtype, device
Validation->>Dispatch: Call dispatchToArch<br/>(A, B, D, bias, ...)
deactivate Validation
activate Dispatch
Dispatch->>Dispatch: Route by architecture
Dispatch->>Launcher: Call dispatchGemmClusterShapeSm100<br/>(A, B, D, bias, ...)
deactivate Dispatch
activate Launcher
Launcher->>Launcher: Route by cluster shape
Launcher->>Kernel: Launch<br/>genericBf16GemmKernelLauncherSm100<br/>(A, B, D, bias, ...)
deactivate Launcher
activate Kernel
Note over Kernel: bias parameter<br/>reserved (TODO:<br/>epilogue fusion)
Kernel-->>PythonAPI: Output computed
deactivate Kernel
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (5)
🧰 Additional context used📓 Path-based instructions (2)csrc/**/*.cu📄 CodeRabbit inference engine (CLAUDE.md)
Files:
flashinfer/**/*.py📄 CodeRabbit inference engine (CLAUDE.md)
Files:
🧠 Learnings (1)📚 Learning: 2025-11-12T03:35:17.583ZApplied to files:
🧬 Code graph analysis (3)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer/gemm/gemm_base.py (1)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
🪛 GitHub Actions: pre-commitcsrc/bf16_gemm_cutlass.cu[error] 1-1: clang-format formatting check failed. Files were modified by this hook. include/flashinfer/gemm/bf16_gemm_template_sm100.h[error] 1-1: ruff-format formatting check failed. File formatting updated by pre-commit. flashinfer/gemm/gemm_base.py[error] 1-1: ruff-format formatting check failed. Files were reformatted by this hook. 🪛 Ruff (0.14.10)flashinfer/gemm/gemm_base.py1065-1067: Avoid specifying long messages outside the exception class (TRY003) 1069-1071: Avoid specifying long messages outside the exception class (TRY003) 1073-1075: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
🔇 Additional comments (18)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces optional bias support for BFloat16 General Matrix Multiply (GEMM) operations across both the TGV and CUTLASS backends. The TGV implementation is complete, providing immediate functionality with robust input validation. For the CUTLASS backend, the necessary API changes have been made to thread the bias parameter through the system, including validation at the binding layer. However, the core kernel epilogue for applying the bias in CUTLASS is explicitly marked as a future enhancement, with a safeguard in place to inform users of this current limitation. The design ensures that existing codebases remain unaffected, as the bias parameter is optional. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
|
||
| virtual void gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, int n, int k, | ||
| int b, CutlassGemmConfig gemmConfig, char* workspacePtr, | ||
| virtual void gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, void* bias, int m, |
There was a problem hiding this comment.
@claude Why do we use void* for bias here, is it designed to be align with D's data type?
Code ReviewThank you for this contribution! This PR adds bias support to TGV and CUTLASS BF16 GEMM operations. I've reviewed the changes and have the following feedback: 🎯 Overall AssessmentThe implementation is well-structured with a clear separation between fully functional TGV bias support and API-only CUTLASS support (with TODOs for kernel fusion). The backward compatibility is maintained, and the code follows FlashInfer conventions. ✅ Strengths
🔍 Issues & Recommendations1. Critical: Missing Test CoverageThe PR has no test updates, but I found that
Recommendation: # In tests/gemm/test_mm_bf16.py:33-36
if backend == "cutlass" and enable_bias:
pytest.skip(
"mm_bf16 with CUTLASS backend does not yet support bias (epilogue fusion TODO)."
)
if backend == "cutlass" and pdl:
pytest.skip(
"mm_bf16 with CUTLASS backend does not support pdl arguments."
)Also verify that the existing test passes for TGV with bias enabled. 2. Validation InconsistencyIn TVM_FFI_ICHECK(bias.value().dtype() == dl_bfloat16 || bias.value().dtype() == dl_float16)But in if bias.dtype \!= a.dtype:
raise ValueError(...)Questions:
Recommendation: Make validation consistent unless there's a specific reason for different behaviors. If CUTLASS won't support mixed precision, remove the 3. Bias Shape Validation LogicIn // For 2D gemm: bias shape should be (n,) where n = mat2.size(0)
// For 3D bmm: bias shape should be (n,) where n = mat2.size(1)
int64_t expected_bias_size = mat2.ndim() == 2 ? mat2.size(0) : mat2.size(1);This is correct for column-major Recommendation: // mat2 is in column-major layout (transposed), so:
// For 2D: mat2.shape = (N, K), bias.shape should be (N,) = (mat2.size(0),)
// For 3D: mat2.shape = (B, N, K), bias.shape should be (N,) = (mat2.size(1),)4. TGV Bias Shape Validation MismatchIn if bias.shape[0] \!= b.shape[1]:
raise ValueError(...)But according to TGV GEMM documentation (include/flashinfer/gemm/tgv_gemm.cuh:985):
This suggests bias should be shape
The test at tests/gemm/test_tgv_gemm.py:20 creates But the comment in tgv_gemm.cuh:985 says "Bias shape (M)" which seems incorrect! Recommendation: Fix the comment in // Bias shape (N), where N is contiguous (broadcast across M dimension)5. Unused Parameter Warning SuppressionThe // TODO(flashinfer): Implement bias fusion via CUTLASS epilogue visitor tree
// Current implementation: bias parameter is validated but not applied in kernel
// For reference implementation, see TGV GEMM in tgv_gemm.cuh:738-835
(void)bias; // Suppress unused parameter warning until epilogue fusion is implemented🧪 Testing StrategySince this requires SM100 hardware (not in CI), manual testing checklist: TGV GEMM (fully functional):
CUTLASS GEMM (API ready):
📝 Minor Suggestions
🚀 Next Steps
🎬 ConclusionThis is a well-implemented partial feature with TGV fully functional and CUTLASS API-ready. The main gap is test coverage verification. Once the test updates are made and manual testing confirms correctness on SM100 hardware, this is ready to merge! The TODO markers are clear and prevent incorrect usage, which is the right approach for the incomplete CUTLASS implementation. Status: Approve with minor changes requested (test updates + comment fixes) |
There was a problem hiding this comment.
Code Review
This pull request adds support for an optional bias tensor in TGV and CUTLASS BF16 GEMM operations. The changes are well-structured, threading the bias parameter through the Python and C++ layers. For the TGV backend, the implementation is complete with proper validation. For the CUTLASS backend, the API is prepared for future implementation of bias in the epilogue, with TODOs and safeguards in place.
My review focuses on improving the correctness of the bias validation logic in the CUTLASS C++ code.
| if (bias.has_value()) { | ||
| TVM_FFI_ICHECK_EQ(bias.value().device().device_type, kDLCUDA) << "Bias tensor must be on CUDA"; | ||
| TVM_FFI_ICHECK_EQ(bias.value().ndim(), 1) << "Bias tensor must be 1D"; | ||
| // For 2D gemm: bias shape should be (n,) where n = mat2.size(0) | ||
| // For 3D bmm: bias shape should be (n,) where n = mat2.size(1) | ||
| int64_t expected_bias_size = mat2.ndim() == 2 ? mat2.size(0) : mat2.size(1); | ||
| TVM_FFI_ICHECK_EQ(bias.value().size(0), expected_bias_size) | ||
| << "Bias tensor size mismatch: expected " << expected_bias_size << ", got " | ||
| << bias.value().size(0); | ||
| TVM_FFI_ICHECK(bias.value().dtype() == dl_bfloat16 || bias.value().dtype() == dl_float16) | ||
| << "Bias tensor must be bfloat16 or float16, got " << bias.value().dtype(); | ||
| } |
There was a problem hiding this comment.
The validation for the bias tensor can be improved for correctness and clarity.
-
Dtype Correctness: The current check
bias.value().dtype() == dl_bfloat16 || bias.value().dtype() == dl_float16is too permissive. Thebiastensor's data type must match theouttensor's data type. This is because inrunGemm, thebias_ptris cast toT*, whereTis the data type of the output tensor. A mismatch would lead to incorrect memory interpretation and likely incorrect results or crashes. The check should be changed tobias.value().dtype() == out.dtype(). -
Code Clarity:
bias.value()is called multiple times. It's cleaner and slightly more efficient to store the result in aconstreference and reuse it.
Here is a suggested implementation that addresses both points:
if (bias.has_value()) {
const auto& bias_tensor = bias.value();
TVM_FFI_ICHECK_EQ(bias_tensor.device().device_type, kDLCUDA) << "Bias tensor must be on CUDA";
TVM_FFI_ICHECK_EQ(bias_tensor.ndim(), 1) << "Bias tensor must be 1D";
// For 2D gemm: bias shape should be (n,) where n = mat2.size(0)
// For 3D bmm: bias shape should be (n,) where n = mat2.size(1)
int64_t expected_bias_size = mat2.ndim() == 2 ? mat2.size(0) : mat2.size(1);
TVM_FFI_ICHECK_EQ(bias_tensor.size(0), expected_bias_size)
<< "Bias tensor size mismatch: expected " << expected_bias_size << ", got "
<< bias_tensor.size(0);
TVM_FFI_ICHECK(bias_tensor.dtype() == out.dtype())
<< "Bias tensor must have the same dtype as the output tensor, got "
<< bias_tensor.dtype() << " but output is " << out.dtype();
}
|
@yzh119 not sure what your plans are for this, happy to take it over / contribute to this branch if that helps? It looks like claude help start with the scaffolding but didn't actually implement the EVT bias add part. |
This PR implements bias support for both TGV and CUTLASS BF16 GEMM operations.
Changes
TGV GEMM (fully functional):
biasparameter optional intgv_gemm_sm100()CUTLASS GEMM (API ready, epilogue TODO):
Backward Compatibility
biasdefaults toNone(optional parameter)Testing
Note: Testing requires SM100 hardware which is not available in CI.
Fixes #2290
Generated with Claude Code
Summary by CodeRabbit
New Features
✏️ Tip: You can customize this high-level summary in your review settings.