From 8a854cfa544f92ac133dae3f6070abdff8350d2a Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 5 Feb 2026 18:09:09 +0000 Subject: [PATCH 1/5] Port upstream CUTLASS fixes and refactor grouped_gemm_nt_masked GEMM module location --- benchmarks/bench_cute_dsl_blockscaled_gemm.py | 14 +++++-- flashinfer/cute_dsl/__init__.py | 3 ++ flashinfer/gemm/__init__.py | 26 ++++++++++++ .../kernels/grouped_gemm_masked_blackwell.py} | 40 +++++++++++-------- 4 files changed, 63 insertions(+), 20 deletions(-) rename flashinfer/{cute_dsl/blockscaled_gemm.py => gemm/kernels/grouped_gemm_masked_blackwell.py} (99%) diff --git a/benchmarks/bench_cute_dsl_blockscaled_gemm.py b/benchmarks/bench_cute_dsl_blockscaled_gemm.py index fb444b019d..0852ac5fee 100644 --- a/benchmarks/bench_cute_dsl_blockscaled_gemm.py +++ b/benchmarks/bench_cute_dsl_blockscaled_gemm.py @@ -1,5 +1,7 @@ import json import random + +import numpy as np import cutlass from flashinfer.cute_dsl.blockscaled_gemm import ( create_scale_factor_tensor, @@ -8,7 +10,7 @@ import torch import cutlass.torch as cutlass_torch from flashinfer.cute_dsl.utils import get_cutlass_dtype -from flashinfer.testing.utils import bench_kineto, count_bytes +from flashinfer.testing.utils import bench_gpu_time, count_bytes ab_dtype = "float4_e2m1fn" @@ -44,11 +46,15 @@ def test_func(): alpha_dtype="float32", ) - t = bench_kineto( + times = bench_gpu_time( test_func, - "Sm100BlockScaledPersistentDenseGemmKernel", - suppress_kineto_output=True, + dry_run_iters=10, + repeat_iters=30, + enable_cupti=True, + use_cuda_graph=False, + cold_l2_cache=True, ) + t = np.median(times) valid_m = data["masked_m"].sum().item() t_calibrated = t / valid_m * (expected_m_per_group * num_groups) diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 31c5120d6c..8886324473 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -22,9 +22,11 @@ # Conditionally import CuTe-DSL kernels if is_cute_dsl_available(): + # Re-export from new location for backwards compatibility from .blockscaled_gemm import ( grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel, + create_scale_factor_tensor, ) from .rmsnorm_fp4quant import ( rmsnorm_fp4quant, @@ -49,6 +51,7 @@ # Blockscaled GEMM "grouped_gemm_nt_masked", "Sm100BlockScaledPersistentDenseGemmKernel", + "create_scale_factor_tensor", # RMSNorm + FP4 Quantization "rmsnorm_fp4quant", "RMSNormFP4QuantKernel", diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index bd30c178dc..d8132c167a 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -23,6 +23,19 @@ mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, ) +# Import CuTe-DSL kernels if available +try: + from flashinfer.cute_dsl.utils import is_cute_dsl_available + + if is_cute_dsl_available(): + from .kernels import ( + grouped_gemm_nt_masked as grouped_gemm_nt_masked, + Sm100BlockScaledPersistentDenseGemmKernel as Sm100BlockScaledPersistentDenseGemmKernel, + create_scale_factor_tensor as create_scale_factor_tensor, + ) +except ImportError: + pass + __all__ = [ "SegmentGEMMWrapper", "bmm_bf16", @@ -42,3 +55,16 @@ "mm_M1_16_K7168_N128", "mm_M1_16_K7168_N256", ] + +# Add CuTe-DSL kernels to __all__ if available +try: + from flashinfer.cute_dsl.utils import is_cute_dsl_available + + if is_cute_dsl_available(): + __all__ += [ + "grouped_gemm_nt_masked", + "Sm100BlockScaledPersistentDenseGemmKernel", + "create_scale_factor_tensor", + ] +except ImportError: + pass diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py similarity index 99% rename from flashinfer/cute_dsl/blockscaled_gemm.py rename to flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py index d1843a33aa..fb6cb6360c 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py @@ -57,7 +57,12 @@ from flashinfer.utils import get_compute_capability from flashinfer.api_logging import flashinfer_api from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo -from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr +from flashinfer.cute_dsl.utils import ( + get_cutlass_dtype, + cutlass_to_torch_dtype, + get_num_sm, + make_ptr, +) from typing import Callable, List @@ -556,7 +561,8 @@ def __init__( ) self.mma_warp_id = 4 self.tma_warp_id = 5 - self.threads_per_cta = 32 * len( + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) ) # Set barrier id for cta sync, epilogue sync and tmem ptr sync @@ -1062,8 +1068,10 @@ def kernel( # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - num_acc_consumer_threads = len(self.epilog_warp_id) * ( - 2 if use_2cta_instrs else 1 + num_acc_consumer_threads = ( + self.threads_per_warp + * len(self.epilog_warp_id) + * (2 if use_2cta_instrs else 1) ) acc_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_acc_consumer_threads @@ -1374,7 +1382,9 @@ def kernel( # # Bar sync for retrieve tensor memory ptr from shared mem # - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + tmem_ptr_read_threads = self.threads_per_warp * len( + (self.mma_warp_id, *self.epilog_warp_id) + ) cute.arch.barrier( barrier_id=self.tmem_ptr_sync_bar_id, number_of_threads=tmem_ptr_read_threads, @@ -1587,7 +1597,9 @@ def kernel( # # Bar sync for retrieve tensor memory ptr from shared memory # - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + tmem_ptr_read_threads = self.threads_per_warp * len( + (self.mma_warp_id, *self.epilog_warp_id) + ) cute.arch.barrier( barrier_id=self.tmem_ptr_sync_bar_id, number_of_threads=tmem_ptr_read_threads, @@ -1639,8 +1651,8 @@ def kernel( # Threads/warps participating in tma store pipeline c_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, - 32 * len(self.epilog_warp_id), - 32 * len(self.epilog_warp_id), + self.threads_per_warp * len(self.epilog_warp_id), + self.threads_per_warp * len(self.epilog_warp_id), ) c_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage, @@ -1723,11 +1735,8 @@ def kernel( tRS_sC[(None, None, None, c_buffer)], ) # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) - epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.fence_proxy("async.shared", space="cta") + epilog_threads = self.threads_per_warp * len(self.epilog_warp_id) cute.arch.barrier( barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads, @@ -1793,8 +1802,7 @@ def kernel( # # Async arrive accumulator buffer empty # - with cute.arch.elect_one(): - acc_pipeline.consumer_release(acc_consumer_state) + acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() # @@ -1812,7 +1820,7 @@ def kernel( # if warp_idx == self.epilog_warp_id[0]: cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) - epilog_threads = 32 * len(self.epilog_warp_id) + epilog_threads = self.threads_per_warp * len(self.epilog_warp_id) cute.arch.barrier( barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads ) From d58d2ce89d45356d7bc38c0d284a047570d110ab Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 5 Feb 2026 19:50:59 +0000 Subject: [PATCH 2/5] Respond to comments. --- flashinfer/cute_dsl/blockscaled_gemm.py | 38 +++++++++++++++++++++++++ flashinfer/gemm/__init__.py | 22 ++++++-------- flashinfer/gemm/kernels/__init__.py | 38 +++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 14 deletions(-) create mode 100644 flashinfer/cute_dsl/blockscaled_gemm.py create mode 100644 flashinfer/gemm/kernels/__init__.py diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py new file mode 100644 index 0000000000..b2cfb486ea --- /dev/null +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Backwards compatibility module. + +This module has been moved to flashinfer.gemm.kernels.grouped_gemm_masked_blackwell. +All imports are re-exported here for backwards compatibility. +""" + +# Re-export everything from the new location +from flashinfer.gemm.kernels.grouped_gemm_masked_blackwell import ( + grouped_gemm_nt_masked, + Sm100BlockScaledPersistentDenseGemmKernel, + create_scale_factor_tensor, + get_cute_dsl_compiled_masked_gemm_kernel, + MaskedSchedulerParams, + MaskedScheduler, +) + +__all__ = [ + "grouped_gemm_nt_masked", + "Sm100BlockScaledPersistentDenseGemmKernel", + "create_scale_factor_tensor", + "get_cute_dsl_compiled_masked_gemm_kernel", + "MaskedSchedulerParams", + "MaskedScheduler", +] diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index d8132c167a..e3eea1eaf9 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -24,6 +24,7 @@ ) # Import CuTe-DSL kernels if available +_cute_dsl_kernels = [] try: from flashinfer.cute_dsl.utils import is_cute_dsl_available @@ -33,6 +34,12 @@ Sm100BlockScaledPersistentDenseGemmKernel as Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor as create_scale_factor_tensor, ) + + _cute_dsl_kernels = [ + "grouped_gemm_nt_masked", + "Sm100BlockScaledPersistentDenseGemmKernel", + "create_scale_factor_tensor", + ] except ImportError: pass @@ -54,17 +61,4 @@ "fp8_blockscale_gemm_sm90", "mm_M1_16_K7168_N128", "mm_M1_16_K7168_N256", -] - -# Add CuTe-DSL kernels to __all__ if available -try: - from flashinfer.cute_dsl.utils import is_cute_dsl_available - - if is_cute_dsl_available(): - __all__ += [ - "grouped_gemm_nt_masked", - "Sm100BlockScaledPersistentDenseGemmKernel", - "create_scale_factor_tensor", - ] -except ImportError: - pass +] + _cute_dsl_kernels diff --git a/flashinfer/gemm/kernels/__init__.py b/flashinfer/gemm/kernels/__init__.py new file mode 100644 index 0000000000..b2afdcac52 --- /dev/null +++ b/flashinfer/gemm/kernels/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FlashInfer GEMM Kernels +======================= + +This module provides high-performance GPU GEMM kernels implemented using NVIDIA CuTe-DSL. +""" + +from flashinfer.cute_dsl.utils import is_cute_dsl_available + +# Conditionally import CuTe-DSL kernels +if is_cute_dsl_available(): + from .grouped_gemm_masked_blackwell import ( + grouped_gemm_nt_masked, + Sm100BlockScaledPersistentDenseGemmKernel, + create_scale_factor_tensor, + ) + +__all__ = [] + +if is_cute_dsl_available(): + __all__ += [ + "grouped_gemm_nt_masked", + "Sm100BlockScaledPersistentDenseGemmKernel", + "create_scale_factor_tensor", + ] From b75a4f8de15da6daa1f08d4ba7234c2e7cefad53 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 5 Feb 2026 20:07:20 +0000 Subject: [PATCH 3/5] Fix benchmark off by 1000x issue --- benchmarks/bench_cute_dsl_blockscaled_gemm.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmarks/bench_cute_dsl_blockscaled_gemm.py b/benchmarks/bench_cute_dsl_blockscaled_gemm.py index 0852ac5fee..e24b43b39b 100644 --- a/benchmarks/bench_cute_dsl_blockscaled_gemm.py +++ b/benchmarks/bench_cute_dsl_blockscaled_gemm.py @@ -54,24 +54,25 @@ def test_func(): use_cuda_graph=False, cold_l2_cache=True, ) - t = np.median(times) + t_ms = np.median(times) # bench_gpu_time returns milliseconds + t_s = t_ms / 1e3 # convert to seconds for downstream calculations valid_m = data["masked_m"].sum().item() - t_calibrated = t / valid_m * (expected_m_per_group * num_groups) + t_calibrated_s = t_s / valid_m * (expected_m_per_group * num_groups) - tflops = 2 * valid_m * n * k / t / 1e12 + tflops = 2 * valid_m * n * k / t_s / 1e12 gb_per_s = ( ( count_bytes(data["a"], data["c"]) * valid_m / (max_m * num_groups) + count_bytes(data["b"]) ) / 1e9 - / t + / t_s ) print( f" > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): " - f"{t * 1e6:4.0f} us | {tflops:4.0f} TFLOPS | {gb_per_s:4.0f} GB/s" + f"{t_s * 1e6:4.0f} us | {tflops:4.0f} TFLOPS | {gb_per_s:4.0f} GB/s" ) metrics = dict( @@ -80,8 +81,8 @@ def test_func(): valid_m=valid_m, n=n, k=k, - t_us_raw=t * 1e6, - t_us_calibrated=t_calibrated * 1e6, + t_us_raw=t_s * 1e6, + t_us_calibrated=t_calibrated_s * 1e6, tflops=tflops, gb_per_s=gb_per_s, ) From cac838a7cb3981abd2119eda0d0a71375e1b4717 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 6 Feb 2026 00:00:43 +0000 Subject: [PATCH 4/5] Add deprecation notice for non-recommended import path --- benchmarks/bench_cute_dsl_blockscaled_gemm.py | 2 +- flashinfer/cute_dsl/__init__.py | 18 +++++++++++++ flashinfer/cute_dsl/blockscaled_gemm.py | 15 +++++++++++ flashinfer/gemm/__init__.py | 2 +- flashinfer/gemm/kernels/__init__.py | 26 +++---------------- tests/gemm/test_cute_dsl_blockscaled_gemm.py | 2 +- 6 files changed, 40 insertions(+), 25 deletions(-) diff --git a/benchmarks/bench_cute_dsl_blockscaled_gemm.py b/benchmarks/bench_cute_dsl_blockscaled_gemm.py index e24b43b39b..162a5f07b9 100644 --- a/benchmarks/bench_cute_dsl_blockscaled_gemm.py +++ b/benchmarks/bench_cute_dsl_blockscaled_gemm.py @@ -3,7 +3,7 @@ import numpy as np import cutlass -from flashinfer.cute_dsl.blockscaled_gemm import ( +from flashinfer.gemm import ( create_scale_factor_tensor, grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration ) diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 8886324473..c97a1f1001 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -16,8 +16,15 @@ =========================== This module provides high-performance GPU kernels implemented using NVIDIA CuTe-DSL. + +.. deprecated:: + Importing GEMM kernels from ``flashinfer.cute_dsl`` is deprecated. + Use ``flashinfer.gemm`` instead. The old import paths will be + removed in a future release. """ +import warnings + from .utils import is_cute_dsl_available, make_ptr, get_cutlass_dtype, get_num_sm # Conditionally import CuTe-DSL kernels @@ -28,6 +35,17 @@ Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor, ) + + warnings.warn( + "Importing GEMM kernels (grouped_gemm_nt_masked, " + "Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor) " + "from flashinfer.cute_dsl is deprecated. " + "Use flashinfer.gemm instead. " + "The old import paths will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + from .rmsnorm_fp4quant import ( rmsnorm_fp4quant, RMSNormFP4QuantKernel, diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index b2cfb486ea..51a4d6f74d 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -15,9 +15,24 @@ Backwards compatibility module. This module has been moved to flashinfer.gemm.kernels.grouped_gemm_masked_blackwell. +Import from ``flashinfer.gemm`` for the public API. All imports are re-exported here for backwards compatibility. + +.. deprecated:: + ``flashinfer.cute_dsl.blockscaled_gemm`` is deprecated. + Use ``flashinfer.gemm`` instead. This module will be removed in a future release. """ +import warnings + +warnings.warn( + "flashinfer.cute_dsl.blockscaled_gemm is deprecated. " + "Use flashinfer.gemm instead. " + "This module will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) + # Re-export everything from the new location from flashinfer.gemm.kernels.grouped_gemm_masked_blackwell import ( grouped_gemm_nt_masked, diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index e3eea1eaf9..3bb70486d4 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -29,7 +29,7 @@ from flashinfer.cute_dsl.utils import is_cute_dsl_available if is_cute_dsl_available(): - from .kernels import ( + from .kernels.grouped_gemm_masked_blackwell import ( grouped_gemm_nt_masked as grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel as Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor as create_scale_factor_tensor, diff --git a/flashinfer/gemm/kernels/__init__.py b/flashinfer/gemm/kernels/__init__.py index b2afdcac52..28492745b4 100644 --- a/flashinfer/gemm/kernels/__init__.py +++ b/flashinfer/gemm/kernels/__init__.py @@ -12,27 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -FlashInfer GEMM Kernels -======================= +FlashInfer GEMM Kernels (internal) +=================================== -This module provides high-performance GPU GEMM kernels implemented using NVIDIA CuTe-DSL. +Internal module containing GPU GEMM kernel implementations. +Import from ``flashinfer.gemm`` for the public API. """ - -from flashinfer.cute_dsl.utils import is_cute_dsl_available - -# Conditionally import CuTe-DSL kernels -if is_cute_dsl_available(): - from .grouped_gemm_masked_blackwell import ( - grouped_gemm_nt_masked, - Sm100BlockScaledPersistentDenseGemmKernel, - create_scale_factor_tensor, - ) - -__all__ = [] - -if is_cute_dsl_available(): - __all__ += [ - "grouped_gemm_nt_masked", - "Sm100BlockScaledPersistentDenseGemmKernel", - "create_scale_factor_tensor", - ] diff --git a/tests/gemm/test_cute_dsl_blockscaled_gemm.py b/tests/gemm/test_cute_dsl_blockscaled_gemm.py index 30a59260d2..b8492d81d6 100644 --- a/tests/gemm/test_cute_dsl_blockscaled_gemm.py +++ b/tests/gemm/test_cute_dsl_blockscaled_gemm.py @@ -12,7 +12,7 @@ import torch from cutlass.cute.runtime import from_dlpack -from flashinfer.cute_dsl.blockscaled_gemm import ( +from flashinfer.gemm import ( Sm100BlockScaledPersistentDenseGemmKernel, # not used in python interface grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration create_scale_factor_tensor, From a8f6dda0805c4b1827ff42c19ed95ea4c5e4645b Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 6 Feb 2026 00:17:29 +0000 Subject: [PATCH 5/5] Remove warnings.warn --- flashinfer/cute_dsl/__init__.py | 22 ++++++---------------- flashinfer/cute_dsl/blockscaled_gemm.py | 10 ---------- 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index c97a1f1001..1cd587b5aa 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -18,34 +18,24 @@ This module provides high-performance GPU kernels implemented using NVIDIA CuTe-DSL. .. deprecated:: - Importing GEMM kernels from ``flashinfer.cute_dsl`` is deprecated. + Importing GEMM kernels (``grouped_gemm_nt_masked``, + ``Sm100BlockScaledPersistentDenseGemmKernel``, ``create_scale_factor_tensor``) + from ``flashinfer.cute_dsl`` is deprecated. Use ``flashinfer.gemm`` instead. The old import paths will be removed in a future release. """ -import warnings - from .utils import is_cute_dsl_available, make_ptr, get_cutlass_dtype, get_num_sm # Conditionally import CuTe-DSL kernels if is_cute_dsl_available(): - # Re-export from new location for backwards compatibility + # Deprecated GEMM symbols: re-exported for backwards compatibility. + # Use flashinfer.gemm instead. from .blockscaled_gemm import ( grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor, ) - - warnings.warn( - "Importing GEMM kernels (grouped_gemm_nt_masked, " - "Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor) " - "from flashinfer.cute_dsl is deprecated. " - "Use flashinfer.gemm instead. " - "The old import paths will be removed in a future release.", - DeprecationWarning, - stacklevel=2, - ) - from .rmsnorm_fp4quant import ( rmsnorm_fp4quant, RMSNormFP4QuantKernel, @@ -66,7 +56,7 @@ if is_cute_dsl_available(): __all__ += [ - # Blockscaled GEMM + # Blockscaled GEMM (deprecated, use flashinfer.gemm instead) "grouped_gemm_nt_masked", "Sm100BlockScaledPersistentDenseGemmKernel", "create_scale_factor_tensor", diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index 51a4d6f74d..e7a37ab9d7 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -23,16 +23,6 @@ Use ``flashinfer.gemm`` instead. This module will be removed in a future release. """ -import warnings - -warnings.warn( - "flashinfer.cute_dsl.blockscaled_gemm is deprecated. " - "Use flashinfer.gemm instead. " - "This module will be removed in a future release.", - DeprecationWarning, - stacklevel=2, -) - # Re-export everything from the new location from flashinfer.gemm.kernels.grouped_gemm_masked_blackwell import ( grouped_gemm_nt_masked,