@@ -4858,8 +4858,8 @@ def _b12x_gemm_fp4_runner(
48584858 """
48594859 import cutlass
48604860
4861- from .kernels .dense_blockscaled_gemm_sm120 import (
4862- Sm120BlockScaledDenseGemmKernel ,
4861+ from .kernels .dense_blockscaled_gemm_sm120_b12x import (
4862+ Sm120B12xBlockScaledDenseGemmKernel ,
48634863 )
48644864
48654865 cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR .get (out_dtype )
@@ -4905,7 +4905,7 @@ def get_valid_tactics(
49054905 ]
49064906 swap_ab = False
49074907 for mma_tiler_mn in sm120_mma_tiler_candidates :
4908- if not Sm120BlockScaledDenseGemmKernel .can_implement (
4908+ if not Sm120B12xBlockScaledDenseGemmKernel .can_implement (
49094909 ab_dtype ,
49104910 sf_dtype ,
49114911 sf_vec_size ,
@@ -4945,11 +4945,10 @@ def forward(
49454945 batch_size = 1
49464946
49474947 if tactic is None or tactic == - 1 :
4948- _sm_count = torch .cuda .get_device_properties (
4949- a .device
4950- ).multi_processor_count
49514948 tactic = (
4952- _select_default_sm120_mma_tiler (m , n , _sm_count ),
4949+ _select_default_sm120_mma_tiler (
4950+ m , n , get_device_sm_count (a .device )
4951+ ),
49534952 (1 , 1 ),
49544953 False ,
49554954 False ,
@@ -4987,7 +4986,7 @@ def forward(
49874986 out_dtype ,
49884987 )
49894988
4990- make_kernel = lambda : Sm120BlockScaledDenseGemmKernel (
4989+ make_kernel = lambda : Sm120B12xBlockScaledDenseGemmKernel (
49914990 sf_vec_size ,
49924991 mma_tiler_mn ,
49934992 cluster_shape_mn ,
0 commit comments