@@ -714,8 +714,11 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
714714 TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm89 gemm kernel runtime error: %s" , cudaGetErrorString (result));
715715}
716716
717- void gemm_dispatch_sm120 (void * mat_a, void * mat_b, void * mat_d, float * scales_a, float * scales_b, uint32_t shape_m,
718- uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs )
717+ template <int TileM, int TileN, int NumStages>
718+ void launch_sm120_gemm_kernel (__nv_fp8_e4m3* mat_a, int64_t ld_a, int64_t stride_a, __nv_fp8_e4m3* mat_b, int64_t ld_b,
719+ int64_t stride_b, __nv_bfloat16* mat_d, int64_t ld_d, int64_t stride_d, float * scales_a, int64_t stride_scales_a,
720+ float * scales_b, int64_t stride_scales_b, uint32_t num_problems, uint32_t shape_m, uint32_t shape_n,
721+ uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs )
719722{
720723 if (num_device_sms < 0 )
721724 {
@@ -725,25 +728,19 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
725728 using ElementOutput = cute::bfloat16_t ;
726729 using ElementAccum = float ;
727730 using ElementBlockScale = int32_t ;
728- using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64 , 128 , 4 >;
731+ using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<TileM, TileN, NumStages >;
729732 using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
730733 using Params = typename GemmKernel::Params;
731734 using Arguments = typename GemmKernel::Arguments;
732735 using ProblemShape = typename GemmKernel::ProblemShape;
733- ProblemShape problem_shape = make_shape ((int ) shape_m, (int ) shape_n, (int ) shape_k, 1 );
736+ ProblemShape problem_shape = make_shape ((int ) shape_m, (int ) shape_n, (int ) shape_k, ( int ) num_problems );
734737
735738 auto ptr_A = reinterpret_cast <ElementInput*>(mat_a);
736739 auto ptr_B = reinterpret_cast <ElementInput*>(mat_b);
737740 auto ptr_SFA = reinterpret_cast <ElementBlockScale*>(scales_a);
738741 auto ptr_SFB = reinterpret_cast <ElementBlockScale*>(scales_b);
739742 auto ptr_D = reinterpret_cast <ElementOutput*>(mat_d);
740743
741- int64_t ld_a = static_cast <int64_t >(shape_k);
742- int64_t ld_b = static_cast <int64_t >(shape_k);
743- int64_t ld_d = static_cast <int64_t >(shape_n);
744- int64_t stride_a = static_cast <int64_t >(shape_m) * ld_a;
745- int64_t stride_b = static_cast <int64_t >(shape_n) * ld_b;
746- int64_t stride_d = static_cast <int64_t >(shape_m) * ld_d;
747744 typename KT::StrideA dA = make_stride (ld_a, Int<1 >{}, stride_a);
748745 typename KT::StrideB dB = make_stride (ld_b, Int<1 >{}, stride_b);
749746 typename KT::StrideSFA dSFA = KT::deduce_sfa_layout (problem_shape).stride ();
@@ -777,6 +774,35 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
777774 TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 gemm kernel runtime error: %s" , cudaGetErrorString (result));
778775}
779776
777+ void gemm_dispatch_sm120 (void * mat_a, void * mat_b, void * mat_d, float * scales_a, float * scales_b, uint32_t shape_m,
778+ uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs )
779+ {
780+ if (num_device_sms < 0 )
781+ {
782+ num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount ();
783+ }
784+
785+ auto * a = reinterpret_cast <__nv_fp8_e4m3*>(mat_a);
786+ auto * b = reinterpret_cast <__nv_fp8_e4m3*>(mat_b);
787+ auto * d = reinterpret_cast <__nv_bfloat16*>(mat_d);
788+ int64_t ld_a = shape_k;
789+ int64_t ld_b = shape_k;
790+ int64_t ld_d = shape_n;
791+ constexpr int64_t stride = 0 ;
792+ constexpr uint32_t num_problems = 1 ;
793+
794+ if (shape_m <= 64 )
795+ {
796+ launch_sm120_gemm_kernel<32 , 128 , 4 >(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
797+ scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
798+ }
799+ else
800+ {
801+ launch_sm120_gemm_kernel<64 , 128 , 4 >(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
802+ scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
803+ }
804+ }
805+
780806void fp8_gemm_run (__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
781807 uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, float * scales_a, float * scales_b, cudaStream_t stream)
782808{
@@ -882,7 +908,7 @@ void grouped_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __n
882908 // so we can promise m_padded < max_shape_m_padded
883909 int64_t m_padded = sm120_blockscaled_gemm::compute_padded_offset (max_shape_m, num_problems);
884910
885- using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64 , 128 , 4 >;
911+ using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32 , 128 , 4 >;
886912 using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledMoeKernel<KT>;
887913 using Params = typename GemmKernel::Params;
888914 using Arguments = typename GemmKernel::Arguments;
@@ -1104,54 +1130,17 @@ void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int strid
11041130 {
11051131 num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount ();
11061132 }
1107- using ElementInput = cute::float_e4m3_t ;
1108- using ElementOutput = cute::bfloat16_t ;
1109- using ElementAccum = float ;
1110- using ElementBlockScale = int32_t ;
1111- using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64 , 128 , 4 >;
1112- using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
1113- using Params = typename GemmKernel::Params;
1114- using Arguments = typename GemmKernel::Arguments;
1115- using ProblemShape = typename GemmKernel::ProblemShape;
1116- ProblemShape problem_shape = make_shape ((int ) shape_m, (int ) shape_n, (int ) shape_k, (int ) num_problems);
1117-
1118- auto ptr_A = reinterpret_cast <ElementInput*>(mat_a);
1119- auto ptr_B = reinterpret_cast <ElementInput*>(mat_b);
1120- auto ptr_SFA = reinterpret_cast <ElementBlockScale*>(scales_a);
1121- auto ptr_SFB = reinterpret_cast <ElementBlockScale*>(scales_b);
1122- auto ptr_D = reinterpret_cast <ElementOutput*>(mat_d);
1123-
1124- typename KT::StrideA dA = make_stride (static_cast <int64_t >(ld_a), Int<1 >{}, static_cast <int64_t >(stride_a));
1125- typename KT::StrideB dB = make_stride (static_cast <int64_t >(ld_b), Int<1 >{}, static_cast <int64_t >(stride_b));
1126- typename KT::StrideSFA dSFA = KT::deduce_sfa_layout (problem_shape).stride ();
1127- typename KT::StrideSFB dSFB = KT::deduce_sfb_layout (problem_shape).stride ();
1128- typename KT::StrideD dD = make_stride (static_cast <int64_t >(ld_d), Int<1 >{}, static_cast <int64_t >(stride_d));
1129-
1130- Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD};
1131-
1132- Params kernel_params = GemmKernel::to_underlying_arguments (problem_shape, args);
1133- auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
1134-
1135- cudaFuncSetAttribute (kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize );
1136- auto result = cudaGetLastError ();
1137- TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 gemm kernel cannot launch: %s" , cudaGetErrorString (result));
1138-
1139- cudaLaunchConfig_t launch_config;
1140- cudaLaunchAttribute attrs[1 ];
1141- attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
1142- attrs[0 ].val .programmaticStreamSerializationAllowed = 1 ;
1143-
1144- launch_config.gridDim = dim3 (num_device_sms, 1 , 1 );
1145- launch_config.blockDim = GemmKernel::get_block_shape ();
1146- launch_config.dynamicSmemBytes = GemmKernel::kSmemSize ;
1147- launch_config.stream = stream;
1148- launch_config.attrs = attrs;
1149- launch_config.numAttrs = 1 ;
11501133
1151- cudaLaunchKernelEx (&launch_config, kernel_ptr, kernel_params);
1152-
1153- result = cudaGetLastError ();
1154- TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 gemm kernel runtime error: %s" , cudaGetErrorString (result));
1134+ if (shape_m <= 64 )
1135+ {
1136+ launch_sm120_gemm_kernel<32 , 128 , 4 >(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1137+ scales_a, stride_scales_a, scales_b, 0 , num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1138+ }
1139+ else
1140+ {
1141+ launch_sm120_gemm_kernel<64 , 128 , 4 >(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1142+ scales_a, stride_scales_a, scales_b, 0 , num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1143+ }
11551144}
11561145
11571146void fp8_stride_batch_gemm_run (__nv_bfloat16 const * mat_a, __nv_fp8_e4m3* fp8_mat_a, float * scales_a, int ld_a,
0 commit comments