-
Notifications
You must be signed in to change notification settings - Fork 146
[FI] SM103 cuDNN BF16 GEMM does not support FP16 as an output datatype #203
Description
Describe the bug
We recently landed cuDNN GEMM support for BF16 GEMMs in FlashInfer: flashinfer-ai/flashinfer#2376
I noticed that when setting the output datatype to FP16 on SM103 (B300) cuDNN would fail with an error (scroll down to reproduce to see the error). With BF16 output datatype things are fine, and also FP16 on SM100 (B200) cuDNN would pass fine.
The implementation is here, it's possible I'm doing something wrong: https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gemm/gemm_base.py#L2134
But I find it odd that it works on B200 and not B300.
Expected behavior
I imagine that it should not error out and correctly have the right datatype and expected result.
System Environment (please complete the following information):
- cudnn_frontend version: v1.18.0
- cudnn_backend version: n/a?
- GPU arch: B300
- cuda runtime version: 13.0
- cuda driver version: 580.82.07
- host compiler: n/a?
- OS: ubuntu 24.04.3
API logs
With log level 2.5 (github gists is down...):
https://github.com/raayandhar/TRTLLM-scripts/blob/main/cudnn.log
With log level 2:
https://gist.github.com/raayandhar/8e5495df8fcc392219bac29646be4eff
To Reproduce
Steps to reproduce the behavior:
Here's how I reproduce the issue in FlashInfer:
(flashinfer) root@a0f94960d09c:/sgl-workspace/sglang/flashinfer# export CUDNN_LOGLEVEL_DBG=2.5 CUDNN_LOGDEST_DBG=cudnn.log
(flashinfer) root@a0f94960d09c:/sgl-workspace/sglang/flashinfer# python
Python 3.12.3 (main, Jan 8 2026, 11:30:50) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import flashinfer
>>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16)
>>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> out = flashinfer.mm_bf16(a, b, out_dtype=torch.float16, backend="cudnn")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/sgl-workspace/sglang/flashinfer/flashinfer/utils.py", line 1176, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 419, in mm_bf16
bf16_gemm_sm100(a, b, bias, pdl, out, workspace_buffer, backends)
File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 908, in bf16_gemm_sm100
runner(inputs=inputs, tactic=tactic)
File "/sgl-workspace/sglang/flashinfer/flashinfer/autotuner.py", line 217, in __call__
return self.forward(inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 2247, in forward
_cudnn_gemm_bf16(workspace_buffer, a, b, out, tactic=tactic)
File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 2202, in _cudnn_gemm_bf16
graph = build_cudnn_gemm_bf16_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 2161, in build_cudnn_gemm_bf16_graph
graph.build_plans()
cudnn._compiled_module.cudnnGraphNotSupportedError: [cudnn_frontend] Error: No valid execution plans built.
Additional context
Let me know if you need any other details / testing from me.