Skip to content

refactor: Port upstream CUTLASS fixes and refactor grouped_gemm_nt_masked GEMM module location#2503

Merged
bkryu merged 5 commits intoflashinfer-ai:mainfrom
bkryu:grouped_gemm_masked_refactor
Feb 6, 2026
Merged

refactor: Port upstream CUTLASS fixes and refactor grouped_gemm_nt_masked GEMM module location#2503
bkryu merged 5 commits intoflashinfer-ai:mainfrom
bkryu:grouped_gemm_masked_refactor

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Feb 5, 2026

📌 Description

CUTLASS Upstream Updates
Ported the following commits from cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py:

  • 1cfbb53a: Fix SM100 block-scale gemm overlapping accumulator and threads_per_warp
    • Added self.threads_per_warp = 32 constant to avoid hardcoded magic numbers
    • Fixed num_acc_consumer_threads calculation (was missing threads_per_warp * multiplier)
    • Removed unnecessary elect_one context around acc_pipeline.consumer_release() calls
    • Updated all thread count calculations to use self.threads_per_warp consistently
  • acb45938: Update nvvm API call from nvvm enum to str
    • Changed cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) to cute.arch.fence_proxy("async.shared", space="cta")

Code Reorganization

  • Moved flashinfer/cute_dsl/blockscaled_gemm.pyflashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
  • Created flashinfer/gemm/kernels/__init__.py for the new module exports
  • Added backwards compatibility shim at flashinfer/cute_dsl/blockscaled_gemm.py that re-exports from the new location
  • Updated flashinfer/gemm/__init__.py to export CuTe-DSL kernels when available

All existing import paths continue to work:

# Old imports (still work for backwards compatibility)
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from flashinfer.cute_dsl import grouped_gemm_nt_masked

# New imports (recommended)
from flashinfer.gemm import grouped_gemm_nt_masked

Benchmarking results via bench_cute_dsl_blockscaled_gemm.py show no perf difference:

Before this PR
$ python3 bench_cute_dsl_blockscaled_gemm.py | grep Perf
 > Perf (num_groups=6, expected_m_per_group=1024, n=4096, k=7168):   98 us | 3531 TFLOPS | 1748 GB/s
 > Perf (num_groups=6, expected_m_per_group=1024, n=7168, k=2048):   59 us | 2787 TFLOPS | 2307 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=4096, k=7168):   62 us | 2700 TFLOPS | 2170 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=7168, k=2048):   37 us | 2371 TFLOPS | 2579 GB/s
 > Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168):   29 us | 2641 TFLOPS | 1119 GB/s
 > Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048):   16 us | 1718 TFLOPS | 1430 GB/s
 > Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168):   24 us | 1862 TFLOPS | 1740 GB/s
 > Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048):   18 us | 1931 TFLOPS | 1955 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168):   36 us | 1913 TFLOPS | 2217 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048):   22 us | 1457 TFLOPS | 2287 GB/s
 > Perf (num_groups=72, expected_m_per_group=   7, n=4096, k=7168):  219 us |  124 TFLOPS | 5455 GB/s
 > Perf (num_groups=72, expected_m_per_group=   7, n=7168, k=2048):  125 us |  108 TFLOPS | 4801 GB/s
 > Perf (num_groups=72, expected_m_per_group=  14, n=4096, k=7168):  219 us |  263 TFLOPS | 5474 GB/s
 > Perf (num_groups=72, expected_m_per_group=  14, n=7168, k=2048):  126 us |  234 TFLOPS | 4858 GB/s
 > Perf (num_groups=72, expected_m_per_group=  28, n=4096, k=7168):  220 us |  534 TFLOPS | 5519 GB/s
 > Perf (num_groups=72, expected_m_per_group=  28, n=7168, k=2048):  126 us |  461 TFLOPS | 4977 GB/s
 > Perf (num_groups=72, expected_m_per_group=  42, n=4096, k=7168):  219 us |  785 TFLOPS | 5581 GB/s
 > Perf (num_groups=72, expected_m_per_group=  42, n=7168, k=2048):  126 us |  710 TFLOPS | 5111 GB/s
 > Perf (num_groups=72, expected_m_per_group=  56, n=4096, k=7168):  220 us | 1036 TFLOPS | 5621 GB/s
 > Perf (num_groups=72, expected_m_per_group=  56, n=7168, k=2048):  125 us |  960 TFLOPS | 5245 GB/s
 > Perf (num_groups=72, expected_m_per_group=  85, n=4096, k=7168):  220 us | 1678 TFLOPS | 5765 GB/s
 > Perf (num_groups=72, expected_m_per_group=  85, n=7168, k=2048):  126 us | 1442 TFLOPS | 5487 GB/s
 > Perf (num_groups=72, expected_m_per_group= 113, n=4096, k=7168):  249 us | 2007 TFLOPS | 5189 GB/s
 > Perf (num_groups=72, expected_m_per_group= 113, n=7168, k=2048):  135 us | 1704 TFLOPS | 5296 GB/s
 > Perf (num_groups=36, expected_m_per_group=  14, n=4096, k=7168):  118 us |  243 TFLOPS | 5108 GB/s
 > Perf (num_groups=36, expected_m_per_group=  14, n=7168, k=2048):   68 us |  209 TFLOPS | 4510 GB/s
 > Perf (num_groups=36, expected_m_per_group=  28, n=4096, k=7168):  117 us |  493 TFLOPS | 5171 GB/s
 > Perf (num_groups=36, expected_m_per_group=  28, n=7168, k=2048):   68 us |  413 TFLOPS | 4614 GB/s
 > Perf (num_groups=36, expected_m_per_group=  56, n=4096, k=7168):  117 us |  989 TFLOPS | 5267 GB/s
 > Perf (num_groups=36, expected_m_per_group=  56, n=7168, k=2048):   68 us |  873 TFLOPS | 4860 GB/s
 > Perf (num_groups=36, expected_m_per_group=  85, n=4096, k=7168):  117 us | 1510 TFLOPS | 5378 GB/s
 > Perf (num_groups=36, expected_m_per_group=  85, n=7168, k=2048):   68 us | 1325 TFLOPS | 5103 GB/s
 > Perf (num_groups=36, expected_m_per_group= 113, n=4096, k=7168):  132 us | 1868 TFLOPS | 4882 GB/s
 > Perf (num_groups=36, expected_m_per_group= 113, n=7168, k=2048):   73 us | 1554 TFLOPS | 4896 GB/s
 > Perf (num_groups=36, expected_m_per_group= 170, n=4096, k=7168):  158 us | 2266 TFLOPS | 4239 GB/s
 > Perf (num_groups=36, expected_m_per_group= 170, n=7168, k=2048):   93 us | 1964 TFLOPS | 4220 GB/s
 > Perf (num_groups=36, expected_m_per_group= 227, n=4096, k=7168):  179 us | 2732 TFLOPS | 3887 GB/s
 > Perf (num_groups=36, expected_m_per_group= 227, n=7168, k=2048):  106 us | 2301 TFLOPS | 4016 GB/s
 > Perf (num_groups=18, expected_m_per_group=  28, n=4096, k=7168):   67 us |  431 TFLOPS | 4530 GB/s
 > Perf (num_groups=18, expected_m_per_group=  28, n=7168, k=2048):   38 us |  381 TFLOPS | 4069 GB/s
 > Perf (num_groups=18, expected_m_per_group=  56, n=4096, k=7168):   67 us |  896 TFLOPS | 4632 GB/s
 > Perf (num_groups=18, expected_m_per_group=  56, n=7168, k=2048):   38 us |  798 TFLOPS | 4300 GB/s
 > Perf (num_groups=18, expected_m_per_group= 113, n=4096, k=7168):   72 us | 1624 TFLOPS | 4495 GB/s
 > Perf (num_groups=18, expected_m_per_group= 113, n=7168, k=2048):   41 us | 1364 TFLOPS | 4378 GB/s
 > Perf (num_groups=18, expected_m_per_group= 170, n=4096, k=7168):   89 us | 2119 TFLOPS | 3792 GB/s
 > Perf (num_groups=18, expected_m_per_group= 170, n=7168, k=2048):   53 us | 1712 TFLOPS | 3686 GB/s
 > Perf (num_groups=18, expected_m_per_group= 227, n=4096, k=7168):   96 us | 2464 TFLOPS | 3613 GB/s
 > Perf (num_groups=18, expected_m_per_group= 227, n=7168, k=2048):   62 us | 2069 TFLOPS | 3499 GB/s
 > Perf (num_groups=18, expected_m_per_group= 341, n=4096, k=7168):  123 us | 2932 TFLOPS | 3036 GB/s
 > Perf (num_groups=18, expected_m_per_group= 341, n=7168, k=2048):   74 us | 2388 TFLOPS | 3272 GB/s
 > Perf (num_groups=18, expected_m_per_group= 455, n=4096, k=7168):  153 us | 3249 TFLOPS | 2615 GB/s
 > Perf (num_groups=18, expected_m_per_group= 455, n=7168, k=2048):   88 us | 2705 TFLOPS | 3123 GB/s
 > Perf (num_groups=9, expected_m_per_group=  56, n=4096, k=7168):   38 us |  742 TFLOPS | 4059 GB/s
 > Perf (num_groups=9, expected_m_per_group=  56, n=7168, k=2048):   24 us |  650 TFLOPS | 3454 GB/s
 > Perf (num_groups=9, expected_m_per_group= 113, n=4096, k=7168):   44 us | 1469 TFLOPS | 3674 GB/s
 > Perf (num_groups=9, expected_m_per_group= 113, n=7168, k=2048):   29 us | 1149 TFLOPS | 3191 GB/s
 > Perf (num_groups=9, expected_m_per_group= 227, n=4096, k=7168):   59 us | 2184 TFLOPS | 2973 GB/s
 > Perf (num_groups=9, expected_m_per_group= 227, n=7168, k=2048):   34 us | 1660 TFLOPS | 3089 GB/s
 > Perf (num_groups=9, expected_m_per_group= 341, n=4096, k=7168):   72 us | 2588 TFLOPS | 2611 GB/s
 > Perf (num_groups=9, expected_m_per_group= 341, n=7168, k=2048):   43 us | 2028 TFLOPS | 2811 GB/s
 > Perf (num_groups=9, expected_m_per_group= 455, n=4096, k=7168):   78 us | 2903 TFLOPS | 2510 GB/s
 > Perf (num_groups=9, expected_m_per_group= 455, n=7168, k=2048):   47 us | 2368 TFLOPS | 2823 GB/s
 > Perf (num_groups=9, expected_m_per_group= 682, n=4096, k=7168):  108 us | 3331 TFLOPS | 2072 GB/s
 > Perf (num_groups=9, expected_m_per_group= 682, n=7168, k=2048):   64 us | 2755 TFLOPS | 2620 GB/s
 > Perf (num_groups=9, expected_m_per_group= 910, n=4096, k=7168):  124 us | 3504 TFLOPS | 1927 GB/s
 > Perf (num_groups=9, expected_m_per_group= 910, n=7168, k=2048):   84 us | 2903 TFLOPS | 2412 GB/s
 > Perf (num_groups=8, expected_m_per_group=  64, n=4096, k=7168):   36 us |  827 TFLOPS | 3829 GB/s
 > Perf (num_groups=8, expected_m_per_group=  64, n=7168, k=2048):   23 us |  704 TFLOPS | 3289 GB/s
 > Perf (num_groups=8, expected_m_per_group= 128, n=4096, k=7168):   43 us | 1530 TFLOPS | 3397 GB/s
 > Perf (num_groups=8, expected_m_per_group= 128, n=7168, k=2048):   27 us | 1109 TFLOPS | 2998 GB/s
 > Perf (num_groups=8, expected_m_per_group= 256, n=4096, k=7168):   56 us | 2005 TFLOPS | 2782 GB/s
 > Perf (num_groups=8, expected_m_per_group= 256, n=7168, k=2048):   32 us | 1887 TFLOPS | 3092 GB/s
 > Perf (num_groups=8, expected_m_per_group= 384, n=4096, k=7168):   68 us | 2510 TFLOPS | 2464 GB/s
 > Perf (num_groups=8, expected_m_per_group= 384, n=7168, k=2048):   39 us | 2148 TFLOPS | 2827 GB/s
 > Perf (num_groups=8, expected_m_per_group= 512, n=4096, k=7168):   80 us | 3034 TFLOPS | 2293 GB/s
 > Perf (num_groups=8, expected_m_per_group= 512, n=7168, k=2048):   48 us | 2513 TFLOPS | 2694 GB/s
 > Perf (num_groups=8, expected_m_per_group= 768, n=4096, k=7168):  110 us | 3595 TFLOPS | 1954 GB/s
 > Perf (num_groups=8, expected_m_per_group= 768, n=7168, k=2048):   64 us | 2843 TFLOPS | 2540 GB/s
 > Perf (num_groups=8, expected_m_per_group=1024, n=4096, k=7168):  135 us | 3666 TFLOPS | 1742 GB/s
 > Perf (num_groups=8, expected_m_per_group=1024, n=7168, k=2048):   83 us | 2969 TFLOPS | 2360 GB/s
 > Perf (num_groups=6, expected_m_per_group=  85, n=4096, k=7168):   32 us |  986 TFLOPS | 3258 GB/s
 > Perf (num_groups=6, expected_m_per_group=  85, n=7168, k=2048):   18 us |  855 TFLOPS | 3139 GB/s
 > Perf (num_groups=6, expected_m_per_group= 170, n=4096, k=7168):   39 us | 1572 TFLOPS | 2837 GB/s
 > Perf (num_groups=6, expected_m_per_group= 170, n=7168, k=2048):   24 us | 1309 TFLOPS | 2786 GB/s
 > Perf (num_groups=6, expected_m_per_group= 341, n=4096, k=7168):   53 us | 2213 TFLOPS | 2343 GB/s
 > Perf (num_groups=6, expected_m_per_group= 341, n=7168, k=2048):   32 us | 1973 TFLOPS | 2583 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=4096, k=7168):   61 us | 2838 TFLOPS | 2223 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=7168, k=2048):   37 us | 2388 TFLOPS | 2593 GB/s
 > Perf (num_groups=6, expected_m_per_group= 682, n=4096, k=7168):   78 us | 3330 TFLOPS | 1959 GB/s
 > Perf (num_groups=6, expected_m_per_group= 682, n=7168, k=2048):   49 us | 2533 TFLOPS | 2355 GB/s
 > Perf (num_groups=6, expected_m_per_group=1024, n=4096, k=7168):  101 us | 3585 TFLOPS | 1725 GB/s
 > Perf (num_groups=6, expected_m_per_group=1024, n=7168, k=2048):   65 us | 2929 TFLOPS | 2303 GB/s
 > Perf (num_groups=6, expected_m_per_group=1365, n=4096, k=7168):  124 us | 3793 TFLOPS | 1588 GB/s
 > Perf (num_groups=6, expected_m_per_group=1365, n=7168, k=2048):   73 us | 2978 TFLOPS | 2247 GB/s
 > Perf (num_groups=4, expected_m_per_group= 128, n=4096, k=7168):   28 us | 1085 TFLOPS | 2592 GB/s
 > Perf (num_groups=4, expected_m_per_group= 128, n=7168, k=2048):   16 us |  999 TFLOPS | 2577 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168):   30 us | 1710 TFLOPS | 2585 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048):   19 us | 1567 TFLOPS | 2590 GB/s
 > Perf (num_groups=4, expected_m_per_group= 512, n=4096, k=7168):   50 us | 2484 TFLOPS | 1846 GB/s
 > Perf (num_groups=4, expected_m_per_group= 512, n=7168, k=2048):   30 us | 2345 TFLOPS | 2352 GB/s
 > Perf (num_groups=4, expected_m_per_group= 768, n=4096, k=7168):   57 us | 3060 TFLOPS | 1796 GB/s
 > Perf (num_groups=4, expected_m_per_group= 768, n=7168, k=2048):   30 us | 2417 TFLOPS | 2369 GB/s
 > Perf (num_groups=4, expected_m_per_group=1024, n=4096, k=7168):   75 us | 3526 TFLOPS | 1618 GB/s
 > Perf (num_groups=4, expected_m_per_group=1024, n=7168, k=2048):   39 us | 2569 TFLOPS | 2205 GB/s
 > Perf (num_groups=4, expected_m_per_group=1536, n=4096, k=7168):   98 us | 3851 TFLOPS | 1479 GB/s
 > Perf (num_groups=4, expected_m_per_group=1536, n=7168, k=2048):   63 us | 3045 TFLOPS | 2128 GB/s
 > Perf (num_groups=4, expected_m_per_group=2048, n=4096, k=7168):  130 us | 3846 TFLOPS | 1311 GB/s
 > Perf (num_groups=4, expected_m_per_group=2048, n=7168, k=2048):   82 us | 3143 TFLOPS | 2062 GB/s
After this PR
flashinfer/benchmarks$ python3 bench_cute_dsl_blockscaled_gemm.py  | grep Perf
 > Perf (num_groups=6, expected_m_per_group=1024, n=4096, k=7168):   98 us | 3521 TFLOPS | 1744 GB/s
 > Perf (num_groups=6, expected_m_per_group=1024, n=7168, k=2048):   59 us | 2793 TFLOPS | 2311 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=4096, k=7168):   62 us | 2701 TFLOPS | 2171 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=7168, k=2048):   37 us | 2370 TFLOPS | 2578 GB/s
 > Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168):   29 us | 2641 TFLOPS | 1119 GB/s
 > Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048):   16 us | 1720 TFLOPS | 1431 GB/s
 > Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168):   24 us | 1865 TFLOPS | 1744 GB/s
 > Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048):   18 us | 1933 TFLOPS | 1956 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168):   36 us | 1913 TFLOPS | 2217 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048):   22 us | 1458 TFLOPS | 2288 GB/s
 > Perf (num_groups=72, expected_m_per_group=   7, n=4096, k=7168):  219 us |  124 TFLOPS | 5459 GB/s
 > Perf (num_groups=72, expected_m_per_group=   7, n=7168, k=2048):  126 us |  108 TFLOPS | 4787 GB/s
 > Perf (num_groups=72, expected_m_per_group=  14, n=4096, k=7168):  220 us |  263 TFLOPS | 5469 GB/s
 > Perf (num_groups=72, expected_m_per_group=  14, n=7168, k=2048):  126 us |  234 TFLOPS | 4860 GB/s
 > Perf (num_groups=72, expected_m_per_group=  28, n=4096, k=7168):  220 us |  534 TFLOPS | 5522 GB/s
 > Perf (num_groups=72, expected_m_per_group=  28, n=7168, k=2048):  125 us |  461 TFLOPS | 4983 GB/s
 > Perf (num_groups=72, expected_m_per_group=  42, n=4096, k=7168):  220 us |  784 TFLOPS | 5577 GB/s
 > Perf (num_groups=72, expected_m_per_group=  42, n=7168, k=2048):  126 us |  709 TFLOPS | 5105 GB/s
 > Perf (num_groups=72, expected_m_per_group=  56, n=4096, k=7168):  220 us | 1038 TFLOPS | 5633 GB/s
 > Perf (num_groups=72, expected_m_per_group=  56, n=7168, k=2048):  126 us |  959 TFLOPS | 5239 GB/s
 > Perf (num_groups=72, expected_m_per_group=  85, n=4096, k=7168):  219 us | 1680 TFLOPS | 5770 GB/s
 > Perf (num_groups=72, expected_m_per_group=  85, n=7168, k=2048):  126 us | 1440 TFLOPS | 5479 GB/s
 > Perf (num_groups=72, expected_m_per_group= 113, n=4096, k=7168):  249 us | 2010 TFLOPS | 5196 GB/s
 > Perf (num_groups=72, expected_m_per_group= 113, n=7168, k=2048):  135 us | 1702 TFLOPS | 5290 GB/s
 > Perf (num_groups=36, expected_m_per_group=  14, n=4096, k=7168):  118 us |  243 TFLOPS | 5108 GB/s
 > Perf (num_groups=36, expected_m_per_group=  14, n=7168, k=2048):   68 us |  209 TFLOPS | 4502 GB/s
 > Perf (num_groups=36, expected_m_per_group=  28, n=4096, k=7168):  117 us |  493 TFLOPS | 5173 GB/s
 > Perf (num_groups=36, expected_m_per_group=  28, n=7168, k=2048):   67 us |  414 TFLOPS | 4623 GB/s
 > Perf (num_groups=36, expected_m_per_group=  56, n=4096, k=7168):  118 us |  988 TFLOPS | 5263 GB/s
 > Perf (num_groups=36, expected_m_per_group=  56, n=7168, k=2048):   68 us |  873 TFLOPS | 4862 GB/s
 > Perf (num_groups=36, expected_m_per_group=  85, n=4096, k=7168):  118 us | 1508 TFLOPS | 5372 GB/s
 > Perf (num_groups=36, expected_m_per_group=  85, n=7168, k=2048):   68 us | 1322 TFLOPS | 5092 GB/s
 > Perf (num_groups=36, expected_m_per_group= 113, n=4096, k=7168):  132 us | 1867 TFLOPS | 4880 GB/s
 > Perf (num_groups=36, expected_m_per_group= 113, n=7168, k=2048):   73 us | 1553 TFLOPS | 4892 GB/s
 > Perf (num_groups=36, expected_m_per_group= 170, n=4096, k=7168):  158 us | 2270 TFLOPS | 4247 GB/s
 > Perf (num_groups=36, expected_m_per_group= 170, n=7168, k=2048):   93 us | 1962 TFLOPS | 4217 GB/s
 > Perf (num_groups=36, expected_m_per_group= 227, n=4096, k=7168):  179 us | 2728 TFLOPS | 3881 GB/s
 > Perf (num_groups=36, expected_m_per_group= 227, n=7168, k=2048):  106 us | 2302 TFLOPS | 4019 GB/s
 > Perf (num_groups=18, expected_m_per_group=  28, n=4096, k=7168):   67 us |  431 TFLOPS | 4535 GB/s
 > Perf (num_groups=18, expected_m_per_group=  28, n=7168, k=2048):   39 us |  380 TFLOPS | 4052 GB/s
 > Perf (num_groups=18, expected_m_per_group=  56, n=4096, k=7168):   67 us |  896 TFLOPS | 4635 GB/s
 > Perf (num_groups=18, expected_m_per_group=  56, n=7168, k=2048):   39 us |  793 TFLOPS | 4275 GB/s
 > Perf (num_groups=18, expected_m_per_group= 113, n=4096, k=7168):   72 us | 1618 TFLOPS | 4480 GB/s
 > Perf (num_groups=18, expected_m_per_group= 113, n=7168, k=2048):   41 us | 1357 TFLOPS | 4358 GB/s
 > Perf (num_groups=18, expected_m_per_group= 170, n=4096, k=7168):   89 us | 2112 TFLOPS | 3779 GB/s
 > Perf (num_groups=18, expected_m_per_group= 170, n=7168, k=2048):   54 us | 1705 TFLOPS | 3670 GB/s
 > Perf (num_groups=18, expected_m_per_group= 227, n=4096, k=7168):   96 us | 2463 TFLOPS | 3610 GB/s
 > Perf (num_groups=18, expected_m_per_group= 227, n=7168, k=2048):   62 us | 2069 TFLOPS | 3499 GB/s
 > Perf (num_groups=18, expected_m_per_group= 341, n=4096, k=7168):  123 us | 2929 TFLOPS | 3032 GB/s
 > Perf (num_groups=18, expected_m_per_group= 341, n=7168, k=2048):   74 us | 2394 TFLOPS | 3281 GB/s
 > Perf (num_groups=18, expected_m_per_group= 455, n=4096, k=7168):  153 us | 3248 TFLOPS | 2615 GB/s
 > Perf (num_groups=18, expected_m_per_group= 455, n=7168, k=2048):   88 us | 2699 TFLOPS | 3116 GB/s
 > Perf (num_groups=9, expected_m_per_group=  56, n=4096, k=7168):   38 us |  742 TFLOPS | 4059 GB/s
 > Perf (num_groups=9, expected_m_per_group=  56, n=7168, k=2048):   24 us |  651 TFLOPS | 3456 GB/s
 > Perf (num_groups=9, expected_m_per_group= 113, n=4096, k=7168):   44 us | 1467 TFLOPS | 3670 GB/s
 > Perf (num_groups=9, expected_m_per_group= 113, n=7168, k=2048):   29 us | 1142 TFLOPS | 3171 GB/s
 > Perf (num_groups=9, expected_m_per_group= 227, n=4096, k=7168):   59 us | 2187 TFLOPS | 2977 GB/s
 > Perf (num_groups=9, expected_m_per_group= 227, n=7168, k=2048):   34 us | 1659 TFLOPS | 3087 GB/s
 > Perf (num_groups=9, expected_m_per_group= 341, n=4096, k=7168):   72 us | 2577 TFLOPS | 2601 GB/s
 > Perf (num_groups=9, expected_m_per_group= 341, n=7168, k=2048):   43 us | 2025 TFLOPS | 2807 GB/s
 > Perf (num_groups=9, expected_m_per_group= 455, n=4096, k=7168):   78 us | 2897 TFLOPS | 2504 GB/s
 > Perf (num_groups=9, expected_m_per_group= 455, n=7168, k=2048):   47 us | 2374 TFLOPS | 2830 GB/s
 > Perf (num_groups=9, expected_m_per_group= 682, n=4096, k=7168):  108 us | 3336 TFLOPS | 2075 GB/s
 > Perf (num_groups=9, expected_m_per_group= 682, n=7168, k=2048):   64 us | 2755 TFLOPS | 2620 GB/s
 > Perf (num_groups=9, expected_m_per_group= 910, n=4096, k=7168):  124 us | 3493 TFLOPS | 1921 GB/s
 > Perf (num_groups=9, expected_m_per_group= 910, n=7168, k=2048):   84 us | 2904 TFLOPS | 2412 GB/s
 > Perf (num_groups=8, expected_m_per_group=  64, n=4096, k=7168):   36 us |  830 TFLOPS | 3843 GB/s
 > Perf (num_groups=8, expected_m_per_group=  64, n=7168, k=2048):   23 us |  704 TFLOPS | 3289 GB/s
 > Perf (num_groups=8, expected_m_per_group= 128, n=4096, k=7168):   43 us | 1542 TFLOPS | 3424 GB/s
 > Perf (num_groups=8, expected_m_per_group= 128, n=7168, k=2048):   27 us | 1113 TFLOPS | 3010 GB/s
 > Perf (num_groups=8, expected_m_per_group= 256, n=4096, k=7168):   56 us | 2005 TFLOPS | 2781 GB/s
 > Perf (num_groups=8, expected_m_per_group= 256, n=7168, k=2048):   32 us | 1879 TFLOPS | 3078 GB/s
 > Perf (num_groups=8, expected_m_per_group= 384, n=4096, k=7168):   68 us | 2497 TFLOPS | 2452 GB/s
 > Perf (num_groups=8, expected_m_per_group= 384, n=7168, k=2048):   39 us | 2142 TFLOPS | 2820 GB/s
 > Perf (num_groups=8, expected_m_per_group= 512, n=4096, k=7168):   79 us | 3036 TFLOPS | 2294 GB/s
 > Perf (num_groups=8, expected_m_per_group= 512, n=7168, k=2048):   48 us | 2513 TFLOPS | 2695 GB/s
 > Perf (num_groups=8, expected_m_per_group= 768, n=4096, k=7168):  110 us | 3592 TFLOPS | 1952 GB/s
 > Perf (num_groups=8, expected_m_per_group= 768, n=7168, k=2048):   63 us | 2847 TFLOPS | 2543 GB/s
 > Perf (num_groups=8, expected_m_per_group=1024, n=4096, k=7168):  135 us | 3664 TFLOPS | 1741 GB/s
 > Perf (num_groups=8, expected_m_per_group=1024, n=7168, k=2048):   83 us | 2974 TFLOPS | 2364 GB/s
 > Perf (num_groups=6, expected_m_per_group=  85, n=4096, k=7168):   33 us |  984 TFLOPS | 3251 GB/s
 > Perf (num_groups=6, expected_m_per_group=  85, n=7168, k=2048):   18 us |  855 TFLOPS | 3137 GB/s
 > Perf (num_groups=6, expected_m_per_group= 170, n=4096, k=7168):   39 us | 1572 TFLOPS | 2837 GB/s
 > Perf (num_groups=6, expected_m_per_group= 170, n=7168, k=2048):   24 us | 1308 TFLOPS | 2784 GB/s
 > Perf (num_groups=6, expected_m_per_group= 341, n=4096, k=7168):   52 us | 2226 TFLOPS | 2357 GB/s
 > Perf (num_groups=6, expected_m_per_group= 341, n=7168, k=2048):   32 us | 1977 TFLOPS | 2588 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=4096, k=7168):   61 us | 2839 TFLOPS | 2224 GB/s
 > Perf (num_groups=6, expected_m_per_group= 512, n=7168, k=2048):   37 us | 2391 TFLOPS | 2596 GB/s
 > Perf (num_groups=6, expected_m_per_group= 682, n=4096, k=7168):   78 us | 3334 TFLOPS | 1961 GB/s
 > Perf (num_groups=6, expected_m_per_group= 682, n=7168, k=2048):   49 us | 2534 TFLOPS | 2356 GB/s
 > Perf (num_groups=6, expected_m_per_group=1024, n=4096, k=7168):  101 us | 3585 TFLOPS | 1725 GB/s
 > Perf (num_groups=6, expected_m_per_group=1024, n=7168, k=2048):   65 us | 2931 TFLOPS | 2304 GB/s
 > Perf (num_groups=6, expected_m_per_group=1365, n=4096, k=7168):  124 us | 3801 TFLOPS | 1592 GB/s
 > Perf (num_groups=6, expected_m_per_group=1365, n=7168, k=2048):   73 us | 2971 TFLOPS | 2242 GB/s
 > Perf (num_groups=4, expected_m_per_group= 128, n=4096, k=7168):   28 us | 1087 TFLOPS | 2596 GB/s
 > Perf (num_groups=4, expected_m_per_group= 128, n=7168, k=2048):   16 us |  994 TFLOPS | 2564 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168):   30 us | 1709 TFLOPS | 2584 GB/s
 > Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048):   19 us | 1564 TFLOPS | 2585 GB/s
 > Perf (num_groups=4, expected_m_per_group= 512, n=4096, k=7168):   50 us | 2488 TFLOPS | 1849 GB/s
 > Perf (num_groups=4, expected_m_per_group= 512, n=7168, k=2048):   30 us | 2350 TFLOPS | 2357 GB/s
 > Perf (num_groups=4, expected_m_per_group= 768, n=4096, k=7168):   57 us | 3055 TFLOPS | 1793 GB/s
 > Perf (num_groups=4, expected_m_per_group= 768, n=7168, k=2048):   30 us | 2420 TFLOPS | 2371 GB/s
 > Perf (num_groups=4, expected_m_per_group=1024, n=4096, k=7168):   75 us | 3531 TFLOPS | 1620 GB/s
 > Perf (num_groups=4, expected_m_per_group=1024, n=7168, k=2048):   39 us | 2566 TFLOPS | 2202 GB/s
 > Perf (num_groups=4, expected_m_per_group=1536, n=4096, k=7168):   97 us | 3861 TFLOPS | 1483 GB/s
 > Perf (num_groups=4, expected_m_per_group=1536, n=7168, k=2048):   63 us | 3049 TFLOPS | 2130 GB/s
 > Perf (num_groups=4, expected_m_per_group=2048, n=4096, k=7168):  130 us | 3839 TFLOPS | 1308 GB/s
 > Perf (num_groups=4, expected_m_per_group=2048, n=7168, k=2048):   82 us | 3147 TFLOPS | 2065 GB/s

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • CuTe-DSL GEMM kernels and a scale-factor utility are now available from the public gemm surface when CuTe-DSL is present.
  • Deprecation

    • Importing GEMM kernels from the old cute_dsl path now emits a deprecation notice; use flashinfer.gemm going forward.
  • Chores

    • Benchmarks now use GPU-timing, report median calibrated times in seconds (and microseconds in outputs).
  • Improvements

    • Kernel internals updated for improved thread handling and compatibility.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

Integrates CuTe‑DSL kernels into GEMM exports, parameterizes Blackwell masked GEMM threading, and updates a CuTe‑DSL blockscaled GEMM benchmark to use bench_gpu_time (median of samples in seconds) for calibrated performance and bandwidth metrics.

Changes

Cohort / File(s) Summary
Benchmarking update
benchmarks/bench_cute_dsl_blockscaled_gemm.py
Replace Kineto timing with bench_gpu_time (returns multiple ms samples); compute median → convert to seconds (t_s); rename calibrated timing fields to _s; update calibration, TFLOPS, and GB/s calculations and output to use seconds.
Public API & CuTe‑DSL exports
flashinfer/cute_dsl/__init__.py, flashinfer/gemm/__init__.py
Deprecation note for old import path; re-export create_scale_factor_tensor for compatibility; conditionally include CuTe‑DSL symbols (grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor) in __all__ when available.
Kernel threading & implementation
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
Add self.threads_per_warp attribute; replace hardcoded per‑warp thread counts (32) with threads_per_warp across accumulator, TMEM store, epilog/barrier calculations; change utils import to flashinfer.cute_dsl.utils; replace fence_proxy usage with string form; remove elect_one wrapper.
Module placeholder
flashinfer/gemm/kernels/__init__.py
Add module file with license header and docstring describing internal GEMM kernels (no functional exports).
Test import update
tests/gemm/test_cute_dsl_blockscaled_gemm.py
Update imports to use flashinfer.gemm instead of flashinfer.cute_dsl.blockscaled_gemm; API usage unchanged.

Sequence Diagram(s)

(Skipped — changes do not introduce a new multi-component sequential flow that requires visualization.)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Possibly related PRs

Suggested labels

op: gemm

Suggested reviewers

  • Anerudhan
  • aleozlx
  • yongwww
  • nvmbreughe
  • kahyunnam
  • jimmyzho
  • jiahanc

Poem

🐰
I hopped through kernels, threads in queue,
Timers swapped for medians true,
Exports found a brighter cue,
Warps now count each careful view—
A tiny hop, the build feels new.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main changes: porting CUTLASS fixes and refactoring the grouped_gemm_nt_masked module location from cute_dsl to gemm.
Description check ✅ Passed PR description comprehensively covers CUTLASS upstream ports, code reorganization, backward compatibility, and includes benchmark results demonstrating performance parity.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the FlashInfer library by refactoring the location of key GEMM kernels and integrating upstream fixes from CUTLASS. These changes aim to improve the maintainability and correctness of the GPU-accelerated operations, alongside updating the benchmarking infrastructure for more accurate performance evaluation.

Highlights

  • Module Refactoring: The grouped_gemm_nt_masked GEMM module and related kernels (Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor) have been moved from flashinfer/cute_dsl/blockscaled_gemm.py to flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py for better organization and clarity within the codebase.
  • CUTLASS Fixes: Upstream CUTLASS fixes have been applied to the grouped_gemm_nt_masked kernel. These include standardizing the usage of threads_per_warp in thread count calculations, updating cute.arch.fence_proxy calls to a simplified syntax, and refining the pipeline consumer release logic by removing an unnecessary elect_one wrapper.
  • Benchmarking Improvements: The benchmarking script for block-scaled GEMM has been updated to utilize bench_gpu_time for more precise GPU time measurements, replacing the older bench_kineto utility. This change also incorporates numpy.median for robust performance aggregation.
  • Backward Compatibility: Imports have been adjusted in flashinfer/cute_dsl/__init__.py and flashinfer/gemm/__init__.py to ensure that the moved GEMM-related symbols remain accessible from their previous locations, maintaining backward compatibility for existing code.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_cute_dsl_blockscaled_gemm.py
    • Added numpy import for numerical operations.
    • Replaced bench_kineto with bench_gpu_time for more accurate GPU performance measurement.
    • Updated benchmarking logic to use np.median for aggregating performance results.
  • flashinfer/cute_dsl/init.py
    • Added create_scale_factor_tensor to the re-exported symbols and __all__ list to maintain backward compatibility after module refactoring.
  • flashinfer/cute_dsl/blockscaled_gemm.py -> flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
    • File renamed from blockscaled_gemm.py and moved to flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py.
    • Updated relative imports to absolute imports (e.g., from .utils to from flashinfer.cute_dsl.utils) due to the file's new location.
    • Introduced self.threads_per_warp constant and used it consistently in calculations for thread counts, improving code clarity and maintainability.
    • Updated cute.arch.fence_proxy syntax from cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta to async.shared", space="cta".
    • Removed the with cute.arch.elect_one(): wrapper from acc_pipeline.consumer_release call, simplifying pipeline management.
  • flashinfer/gemm/init.py
    • Conditionally imports grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel, and create_scale_factor_tensor from their new location in flashinfer/gemm/kernels.
    • Conditionally added these symbols to the __all__ list to make them discoverable.
Activity
  • No specific activity (comments, reviews, progress) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu bkryu changed the title Port upstream CUTLASS fixes and refactor grouped_gemm_nt_masked GEMM … refactor: Port upstream CUTLASS fixes and refactor grouped_gemm_nt_masked GEMM module location Feb 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports upstream CUTLASS fixes and refactors the location of the grouped_gemm_nt_masked module. The changes are solid, including updating benchmarks to use more robust timing functions and applying several important fixes to the kernel implementation, such as correcting a thread count calculation and removing magic numbers. My only suggestion is to refactor the conditional imports in flashinfer/gemm/__init__.py to reduce code duplication.

Comment thread flashinfer/gemm/__init__.py Outdated
@bkryu bkryu self-assigned this Feb 5, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/bench_cute_dsl_blockscaled_gemm.py`:
- Around line 49-57: bench_gpu_time returns milliseconds but the code treats t
as seconds; fix by converting t to seconds immediately (e.g., t_s = t / 1e3) and
use t_s for all downstream computations (replace uses of t in TFLOPS and GB/s
formulas and the microsecond print). Specifically, update places referencing
bench_gpu_time result (variable t) and change the microsecond display from t *
1e6 to t * 1e3 (or use t_s * 1e6), and divide the TFLOPS and GB/s calculations
by 1e3 (i.e., use t_s instead of t) so TFLOPS and GB/s are computed with
seconds. Ensure all references to t (printing and performance calculations) use
the consistent t_s value.

In `@flashinfer/gemm/__init__.py`:
- Around line 63-68: The __all__ extension list is unsorted and triggers Ruff
RUF022; reorder the added symbols alphabetically so the list is sorted. In the
block that appends to __all__ (the one with grouped_gemm_nt_masked,
Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor),
rearrange the strings into alphabetical order (create_scale_factor_tensor,
grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel) so the
module-level __all__ remains lexicographically sorted.

Comment thread benchmarks/bench_cute_dsl_blockscaled_gemm.py Outdated
Comment thread flashinfer/gemm/__init__.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@flashinfer/gemm/__init__.py`:
- Around line 26-44: The current broad try/except around importing
is_cute_dsl_available and the CuTe-DSL kernels swallows any ImportError from the
kernel module; change the logic so you only suppress the absence of the CuTe-DSL
utils but let kernel import errors surface: import is_cute_dsl_available inside
a narrow try/except (or set a fallback that returns False) and then, if
is_cute_dsl_available() is True, import grouped_gemm_nt_masked,
Sm100BlockScaledPersistentDenseGemmKernel, and create_scale_factor_tensor
normally (no broad try/except) and set _cute_dsl_kernels accordingly so genuine
import failures in those symbols are raised instead of being silenced.

In `@flashinfer/gemm/kernels/__init__.py`:
- Around line 33-38: The __all__ list appended inside the
is_cute_dsl_available() block is not lexicographically sorted; update the list
assigned to __all__ (the entries "grouped_gemm_nt_masked",
"Sm100BlockScaledPersistentDenseGemmKernel", "create_scale_factor_tensor") so
they are in sorted order (alphabetical) to satisfy RUF022, keeping the append
inside the is_cute_dsl_available() conditional and preserving the exact symbol
names.

Comment thread flashinfer/gemm/__init__.py
Comment thread flashinfer/gemm/kernels/__init__.py Outdated
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Feb 5, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !300 has been created, and the CI pipeline #43379685 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems the performance before and after this PR looks similar?

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Feb 5, 2026

LGTM, seems the performance before and after this PR looks similar?

Yes you read it correctly. The performance seems identical 👍

Comment thread flashinfer/cute_dsl/__init__.py Outdated
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Feb 6, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #43379685 has been cancelled.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/gemm/test_cute_dsl_blockscaled_gemm.py (1)

262-281: ⚠️ Potential issue | 🟡 Minor

Pre-existing bug: sm_count=132 is not a parameter of test_blockscaled_gemm_python_interface.

The __main__ block passes sm_count=132, but the function signature only has enable_dst_signalssm_count is computed internally (line 95). Running this file directly will raise TypeError. This is pre-existing, but worth fixing while you're here.

Suggested fix
     test_blockscaled_gemm_python_interface(
         lm=(1, 1024),
         kn=(7168, 4096),
         ab_dtype="float4_e2m1fn",
         sf_dtype="float8_e8m0fnu",
         sf_vec_size=16,
         c_dtype="float16",
         a_major="k",
         b_major="k",
         c_major="n",
         fuse_alpha=False,
         alpha_dtype="float32",
         mma_tiler_mn=(128, 128),
         cluster_shape_mn=(2, 1),
         tolerance=1e-01,
         iterations=3,
-        sm_count=132,
         enable_dst_signals=True,
     )
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/__init__.py`:
- Around line 39-47: The module-level deprecation warning in
flashinfer.cute_dsl/__init__.py fires on any import; change it so the warning
only appears when deprecated GEMM symbols are actually accessed: either move the
warnings.warn call into the shim module (blockscaled_gemm.py) where
grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel, and
create_scale_factor_tensor are defined, or implement module-level __getattr__ in
flashinfer.cute_dsl.__init__ that checks the attribute name (e.g.,
"grouped_gemm_nt_masked", "Sm100BlockScaledPersistentDenseGemmKernel",
"create_scale_factor_tensor"), emits the DeprecationWarning with stacklevel=2,
then lazily imports and returns the requested symbol; leave other imports (like
is_cute_dsl_available or rmsnorm_fp4quant) untouched so they won't trigger the
GEMM deprecation.
🧹 Nitpick comments (1)
tests/gemm/test_cute_dsl_blockscaled_gemm.py (1)

83-88: Consider using flashinfer.utils.get_compute_capability for the architecture skip.

The device capability check uses torch.cuda.get_device_capability() directly with hardcoded tuples. The coding guidelines require test files to use flashinfer.utils functions (e.g., get_compute_capability) to skip tests on unsupported GPU architectures. This is pre-existing code, so it can be addressed separately. As per coding guidelines, tests/**/*.py: "Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures."

Comment thread flashinfer/cute_dsl/__init__.py Outdated
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Feb 6, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !300 has been updated with latest changes, and the CI pipeline #43391828 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #43391828: 12/20 passed

@bkryu bkryu merged commit 0342262 into flashinfer-ai:main Feb 6, 2026
33 checks passed
@bkryu bkryu deleted the grouped_gemm_masked_refactor branch February 6, 2026 17:22
@coderabbitai coderabbitai Bot mentioned this pull request Feb 11, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants