Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions benchmarks/bench_cute_dsl_blockscaled_gemm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import json
import random

import numpy as np
import cutlass
from flashinfer.cute_dsl.blockscaled_gemm import (
from flashinfer.gemm import (
create_scale_factor_tensor,
grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration
)
import torch
import cutlass.torch as cutlass_torch
from flashinfer.cute_dsl.utils import get_cutlass_dtype
from flashinfer.testing.utils import bench_kineto, count_bytes
from flashinfer.testing.utils import bench_gpu_time, count_bytes


ab_dtype = "float4_e2m1fn"
Expand Down Expand Up @@ -44,28 +46,33 @@ def test_func():
alpha_dtype="float32",
)

t = bench_kineto(
times = bench_gpu_time(
test_func,
"Sm100BlockScaledPersistentDenseGemmKernel",
suppress_kineto_output=True,
dry_run_iters=10,
repeat_iters=30,
enable_cupti=True,
use_cuda_graph=False,
cold_l2_cache=True,
)
t_ms = np.median(times) # bench_gpu_time returns milliseconds
t_s = t_ms / 1e3 # convert to seconds for downstream calculations

valid_m = data["masked_m"].sum().item()
t_calibrated = t / valid_m * (expected_m_per_group * num_groups)
t_calibrated_s = t_s / valid_m * (expected_m_per_group * num_groups)

tflops = 2 * valid_m * n * k / t / 1e12
tflops = 2 * valid_m * n * k / t_s / 1e12
gb_per_s = (
(
count_bytes(data["a"], data["c"]) * valid_m / (max_m * num_groups)
+ count_bytes(data["b"])
)
/ 1e9
/ t
/ t_s
)

print(
f" > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): "
f"{t * 1e6:4.0f} us | {tflops:4.0f} TFLOPS | {gb_per_s:4.0f} GB/s"
f"{t_s * 1e6:4.0f} us | {tflops:4.0f} TFLOPS | {gb_per_s:4.0f} GB/s"
)

metrics = dict(
Expand All @@ -74,8 +81,8 @@ def test_func():
valid_m=valid_m,
n=n,
k=k,
t_us_raw=t * 1e6,
t_us_calibrated=t_calibrated * 1e6,
t_us_raw=t_s * 1e6,
t_us_calibrated=t_calibrated_s * 1e6,
tflops=tflops,
gb_per_s=gb_per_s,
)
Expand Down
13 changes: 12 additions & 1 deletion flashinfer/cute_dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,25 @@
===========================

This module provides high-performance GPU kernels implemented using NVIDIA CuTe-DSL.

.. deprecated::
Importing GEMM kernels (``grouped_gemm_nt_masked``,
``Sm100BlockScaledPersistentDenseGemmKernel``, ``create_scale_factor_tensor``)
from ``flashinfer.cute_dsl`` is deprecated.
Use ``flashinfer.gemm`` instead. The old import paths will be
removed in a future release.
"""

from .utils import is_cute_dsl_available, make_ptr, get_cutlass_dtype, get_num_sm

# Conditionally import CuTe-DSL kernels
if is_cute_dsl_available():
# Deprecated GEMM symbols: re-exported for backwards compatibility.
# Use flashinfer.gemm instead.
from .blockscaled_gemm import (
grouped_gemm_nt_masked,
Sm100BlockScaledPersistentDenseGemmKernel,
create_scale_factor_tensor,
)
from .rmsnorm_fp4quant import (
rmsnorm_fp4quant,
Expand All @@ -46,9 +56,10 @@

if is_cute_dsl_available():
__all__ += [
# Blockscaled GEMM
# Blockscaled GEMM (deprecated, use flashinfer.gemm instead)
"grouped_gemm_nt_masked",
"Sm100BlockScaledPersistentDenseGemmKernel",
"create_scale_factor_tensor",
# RMSNorm + FP4 Quantization
"rmsnorm_fp4quant",
"RMSNormFP4QuantKernel",
Expand Down
Loading
Loading