Skip to content

[FI] SM103 cuDNN BF16 GEMM does not support FP16 as an output datatype #203

@raayandhar

Description

@raayandhar

Describe the bug

We recently landed cuDNN GEMM support for BF16 GEMMs in FlashInfer: flashinfer-ai/flashinfer#2376

I noticed that when setting the output datatype to FP16 on SM103 (B300) cuDNN would fail with an error (scroll down to reproduce to see the error). With BF16 output datatype things are fine, and also FP16 on SM100 (B200) cuDNN would pass fine.

The implementation is here, it's possible I'm doing something wrong: https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gemm/gemm_base.py#L2134
But I find it odd that it works on B200 and not B300.

Expected behavior
I imagine that it should not error out and correctly have the right datatype and expected result.

System Environment (please complete the following information):

  • cudnn_frontend version: v1.18.0
  • cudnn_backend version: n/a?
  • GPU arch: B300
  • cuda runtime version: 13.0
  • cuda driver version: 580.82.07
  • host compiler: n/a?
  • OS: ubuntu 24.04.3

API logs

With log level 2.5 (github gists is down...):
https://github.com/raayandhar/TRTLLM-scripts/blob/main/cudnn.log

With log level 2:
https://gist.github.com/raayandhar/8e5495df8fcc392219bac29646be4eff

To Reproduce
Steps to reproduce the behavior:

Here's how I reproduce the issue in FlashInfer:

(flashinfer) root@a0f94960d09c:/sgl-workspace/sglang/flashinfer# export CUDNN_LOGLEVEL_DBG=2.5 CUDNN_LOGDEST_DBG=cudnn.log                                                                                        
(flashinfer) root@a0f94960d09c:/sgl-workspace/sglang/flashinfer# python
Python 3.12.3 (main, Jan  8 2026, 11:30:50) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import flashinfer
>>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16)
>>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> out = flashinfer.mm_bf16(a, b, out_dtype=torch.float16, backend="cudnn")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/sgl-workspace/sglang/flashinfer/flashinfer/utils.py", line 1176, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 419, in mm_bf16
    bf16_gemm_sm100(a, b, bias, pdl, out, workspace_buffer, backends)
  File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 908, in bf16_gemm_sm100
    runner(inputs=inputs, tactic=tactic)
  File "/sgl-workspace/sglang/flashinfer/flashinfer/autotuner.py", line 217, in __call__
    return self.forward(inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 2247, in forward
    _cudnn_gemm_bf16(workspace_buffer, a, b, out, tactic=tactic)
  File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 2202, in _cudnn_gemm_bf16
    graph = build_cudnn_gemm_bf16_graph(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/flashinfer/flashinfer/gemm/gemm_base.py", line 2161, in build_cudnn_gemm_bf16_graph
    graph.build_plans()
cudnn._compiled_module.cudnnGraphNotSupportedError: [cudnn_frontend] Error: No valid execution plans built.

Additional context
Let me know if you need any other details / testing from me.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions