Skip to content

Commit e16e85b

Browse files
committed
Commit
1 parent dbeb431 commit e16e85b

6 files changed

Lines changed: 17 additions & 18 deletions

File tree

flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@
8989
st_global_u64,
9090
scatter_add_bf16x2,
9191
)
92-
from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
93-
Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
92+
from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
93+
Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
9494
)
9595

9696

flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@
122122
st_global_u64,
123123
scatter_add_bf16x2,
124124
)
125-
from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
126-
Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
125+
from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
126+
Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
127127
)
128128

129129

flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@
120120
st_global_u64,
121121
scatter_add_bf16x2,
122122
)
123-
from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
124-
Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
123+
from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
124+
Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
125125
)
126126

127127

flashinfer/gemm/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@
6161
from flashinfer.cute_dsl.utils import is_cute_dsl_available
6262

6363
if is_cute_dsl_available():
64-
from .kernels.dense_blockscaled_gemm_sm120 import (
65-
Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
64+
from .kernels.dense_blockscaled_gemm_sm120_b12x import (
65+
Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
6666
)
6767

68-
_cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
68+
_cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
6969
except ImportError:
7070
pass
7171

flashinfer/gemm/gemm_base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4858,8 +4858,8 @@ def _b12x_gemm_fp4_runner(
48584858
"""
48594859
import cutlass
48604860

4861-
from .kernels.dense_blockscaled_gemm_sm120 import (
4862-
Sm120BlockScaledDenseGemmKernel,
4861+
from .kernels.dense_blockscaled_gemm_sm120_b12x import (
4862+
Sm120B12xBlockScaledDenseGemmKernel,
48634863
)
48644864

48654865
cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype)
@@ -4905,7 +4905,7 @@ def get_valid_tactics(
49054905
]
49064906
swap_ab = False
49074907
for mma_tiler_mn in sm120_mma_tiler_candidates:
4908-
if not Sm120BlockScaledDenseGemmKernel.can_implement(
4908+
if not Sm120B12xBlockScaledDenseGemmKernel.can_implement(
49094909
ab_dtype,
49104910
sf_dtype,
49114911
sf_vec_size,
@@ -4945,11 +4945,10 @@ def forward(
49454945
batch_size = 1
49464946

49474947
if tactic is None or tactic == -1:
4948-
_sm_count = torch.cuda.get_device_properties(
4949-
a.device
4950-
).multi_processor_count
49514948
tactic = (
4952-
_select_default_sm120_mma_tiler(m, n, _sm_count),
4949+
_select_default_sm120_mma_tiler(
4950+
m, n, get_device_sm_count(a.device)
4951+
),
49534952
(1, 1),
49544953
False,
49554954
False,
@@ -4987,7 +4986,7 @@ def forward(
49874986
out_dtype,
49884987
)
49894988

4990-
make_kernel = lambda: Sm120BlockScaledDenseGemmKernel(
4989+
make_kernel = lambda: Sm120B12xBlockScaledDenseGemmKernel(
49914990
sf_vec_size,
49924991
mma_tiler_mn,
49934992
cluster_shape_mn,

flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py renamed to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ def wrapper(
15501550

15511551

15521552
# Alias for FlashInfer integration
1553-
Sm120BlockScaledDenseGemmKernel = DenseGemmKernel
1553+
Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel
15541554

15551555

15561556
class _DenseGemmLaunch:

0 commit comments

Comments
 (0)