Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,9 +1524,9 @@ def testMmBf16(args):
res = []

out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if out_dtype not in [torch.bfloat16, torch.float16]:
if out_dtype not in [torch.bfloat16, torch.float16, torch.float32]:
raise ValueError(
f"Unsupported output dtype: {args.out_dtype}. Supported dtypes are bfloat16 and float16."
f"Unsupported output dtype: {args.out_dtype}. Supported dtypes are bfloat16, float16, and float32."
)

## Prepare input tensors
Expand Down Expand Up @@ -1744,9 +1744,9 @@ def testBmmBf16(args):
res = []

out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if out_dtype not in [torch.bfloat16, torch.float16]:
if out_dtype not in [torch.bfloat16, torch.float16, torch.float32]:
raise ValueError(
f"Unsupported output dtype: {args.out_dtype}. Supported dtypes are bfloat16 and float16."
f"Unsupported output dtype: {args.out_dtype}. Supported dtypes are bfloat16, float16, and float32."
)

## Prepare input tensors
Expand Down
6 changes: 5 additions & 1 deletion csrc/bf16_gemm_cutlass.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace flashinfer {
namespace gemm {
template class CutlassBf16GemmRunner<__nv_bfloat16>;
template class CutlassBf16GemmRunner<half>;
template class CutlassBf16GemmRunner<float>;
} // namespace gemm
} // namespace flashinfer

Expand Down Expand Up @@ -134,8 +135,11 @@ void bf16_bmm_impl(TensorView mat1, TensorView mat2, TensorView out, TensorView
case bfloat16_code:
runGemm<__nv_bfloat16>(out, mat1, mat2, m, n, k, b, config, workspace_buffer);
break;
case float32_code:
runGemm<float>(out, mat1, mat2, m, n, k, b, config, workspace_buffer);
break;
default:
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of fp16/bf16.";
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of fp16/bf16/fp32.";
}
}

Expand Down
24 changes: 14 additions & 10 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _tgv_gemm_requirement(
):
if out_dtype != torch.bfloat16:
raise ValueError(
"You cannot provide an output dtype to the TGV backend. Use the CUTLASS backend instead."
"You cannot provide an output dtype to the TGV backend. Use the CUTLASS or cuDNN backend instead."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seem to be a fix for an old incorrect information. Is it true?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have tests for it.
Exception is that for SM103 it doesn't work...
https://github.com/flashinfer-ai/flashinfer/blob/main/tests/gemm/test_mm_bf16.py#L51
Worth mentioning here you think?

)
return True

Expand Down Expand Up @@ -355,10 +355,12 @@ def mm_bf16(
Whether to use persistant data loader mode. Enabled for TGV backend. Defaults to ``False``.

out: Optional[torch.Tensor]
Out tensor, shape (m, n), bf16 or fp16. Enabled for CUTLASS backend. Defaults to ``None``.
Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS and cuDNN backends.
Defaults to ``None``.

out_dtype: torch.dtype
Output dtype, bf16 or fp16. Enabled for CUTLASS and cuDNN backends. Defaults to ``torch.bfloat16``.
Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS and cuDNN backends.
Defaults to ``torch.bfloat16``.

backend: Literal["cudnn", "cutlass", "tgv", "auto"]
The backend to use for the operation. Defaults to ``"cudnn"``.
Expand All @@ -370,7 +372,7 @@ def mm_bf16(
Returns
-------
torch.Tensor
Out tensor, shape (m, n), bf16 or fp16 in row-major layout.
Out tensor, shape (m, n), bf16, fp16, or fp32 in row-major layout.

Examples
--------
Expand Down Expand Up @@ -534,18 +536,18 @@ def bmm_bf16(
Weight tensor, shape (b, k, n), bf16 in column-major layout.

out: Optional[torch.Tensor]
Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``.
Out tensor, shape (b, m, n), bf16, fp16, or fp32, defaults to ``None``.

out_dtype: torch.dtype
Output dtype, bf16 (default) or fp16.
Output dtype, bf16 (default), fp16, or fp32.

backend: Literal["cudnn", "cutlass", "auto"]
Backend to use, defaults to "cudnn". ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled.

Returns
-------
torch.Tensor
Out tensor, shape (b, m, n), bf16 or fp16 in row-major layout.
Out tensor, shape (b, m, n), bf16, fp16, or fp32 in row-major layout.

Examples
--------
Expand Down Expand Up @@ -1744,11 +1746,11 @@ def _validate_fp8_output_dtype(dtype: torch.dtype):


def _validate_bf16_output_dtype(dtype: torch.dtype):
"""Validate that the output dtype is either bf16 or fp16."""
if dtype not in (torch.bfloat16, torch.float16):
"""Validate that the output dtype is bf16, fp16, or fp32."""
if dtype not in (torch.bfloat16, torch.float16, torch.float32):
raise ValueError(
f"Unsupported output dtype: {dtype}. "
f"Only torch.bfloat16 and torch.float16 are supported for BF16 GEMM operations."
f"Only torch.bfloat16, torch.float16, and torch.float32 are supported for BF16 GEMM operations."
)


Expand Down Expand Up @@ -2077,6 +2079,8 @@ def _torch_data_type_to_cudnn_data_type(dtype: torch.dtype):
return cudnn.data_type.BFLOAT16
elif dtype == torch.float16:
return cudnn.data_type.HALF
elif dtype == torch.float32:
return cudnn.data_type.FLOAT
elif dtype == torch.float8_e4m3fn:
return cudnn.data_type.FP8_E4M3
elif dtype == torch.float8_e5m2:
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec:

with open(jit_env.FLASHINFER_CSRC_DIR / "bf16_gemm_cutlass.jinja") as f:
kernel_inst_templ = jinja2.Template(f.read())
dtype_list = ["__nv_bfloat16", "half"]
dtype_list = ["__nv_bfloat16", "half", "float"]
cta_m_n_k_list = [
(64, 64, 128),
(64, 128, 128),
Expand Down
2 changes: 1 addition & 1 deletion tests/gemm/test_bmm_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@pytest.mark.parametrize("m", [48, 128])
@pytest.mark.parametrize("n", [80, 64])
@pytest.mark.parametrize("k", [64, 256])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("backend", ["cutlass", "cudnn"])
def test_bmm_bf16(b, m, n, k, res_dtype, backend):
compute_capability = get_compute_capability(torch.device(device="cuda"))
Expand Down
4 changes: 2 additions & 2 deletions tests/gemm/test_mm_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@pytest.mark.parametrize("m", [1, 8, 16, 32, 64])
@pytest.mark.parametrize("n", [1024, 2048, 4096])
@pytest.mark.parametrize("k", [1024, 2048, 3072])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("enable_bias", [True, False])
@pytest.mark.parametrize("pdl", [True, False])
@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv"])
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_mm_bf16(
pytest.skip(
"mm_bf16 with CUTLASS backend does not support bias or pdl arguments."
)
if res_dtype == torch.float16 and backend == "tgv":
if res_dtype != torch.bfloat16 and backend == "tgv":
pytest.skip(
"mm_bf16 with TGV backend does not support specifying non-bfloat16 result dtypes."
)
Expand Down
Loading