3131#include " fp8_blockscale_mma_utils.cuh"
3232#include " fp8_blockscale_tma_utils.cuh"
3333#include " sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh"
34+ #include " sm120_blockwise_gemm/sm120_fp8_moe_gemm_1d1d.cuh"
3435#include " tensorrt_llm/common/config.h"
3536#include " tensorrt_llm/common/cudaTypeUtils.cuh"
3637#include " tensorrt_llm/common/cudaUtils.h"
@@ -713,8 +714,11 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
713714 TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm89 gemm kernel runtime error: %s" , cudaGetErrorString (result));
714715}
715716
716- void gemm_dispatch_sm120 (void * mat_a, void * mat_b, void * mat_d, float * scales_a, float * scales_b, uint32_t shape_m,
717- uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs )
717+ template <int TileM, int TileN, int NumStages>
718+ void launch_sm120_gemm_kernel (__nv_fp8_e4m3* mat_a, int64_t ld_a, int64_t stride_a, __nv_fp8_e4m3* mat_b, int64_t ld_b,
719+ int64_t stride_b, __nv_bfloat16* mat_d, int64_t ld_d, int64_t stride_d, float * scales_a, int64_t stride_scales_a,
720+ float * scales_b, int64_t stride_scales_b, uint32_t num_problems, uint32_t shape_m, uint32_t shape_n,
721+ uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs )
718722{
719723 if (num_device_sms < 0 )
720724 {
@@ -724,26 +728,19 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
724728 using ElementOutput = cute::bfloat16_t ;
725729 using ElementAccum = float ;
726730 using ElementBlockScale = int32_t ;
727- using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32 , 128 >;
731+ using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<TileM, TileN, NumStages >;
728732 using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
729733 using Params = typename GemmKernel::Params;
730734 using Arguments = typename GemmKernel::Arguments;
731735 using ProblemShape = typename GemmKernel::ProblemShape;
732- ProblemShape problem_shape = make_shape ((int ) shape_m, (int ) shape_n, (int ) shape_k, 1 );
736+ ProblemShape problem_shape = make_shape ((int ) shape_m, (int ) shape_n, (int ) shape_k, ( int ) num_problems );
733737
734738 auto ptr_A = reinterpret_cast <ElementInput*>(mat_a);
735739 auto ptr_B = reinterpret_cast <ElementInput*>(mat_b);
736740 auto ptr_SFA = reinterpret_cast <ElementBlockScale*>(scales_a);
737741 auto ptr_SFB = reinterpret_cast <ElementBlockScale*>(scales_b);
738742 auto ptr_D = reinterpret_cast <ElementOutput*>(mat_d);
739743
740- int32_t ld_a = shape_k;
741- int32_t stride_a = shape_m * shape_k;
742- int32_t ld_b = shape_k;
743- int32_t stride_b = shape_n * shape_k;
744- int32_t ld_d = shape_n;
745- int32_t stride_d = shape_m * shape_n;
746-
747744 typename KT::StrideA dA = make_stride (ld_a, Int<1 >{}, stride_a);
748745 typename KT::StrideB dB = make_stride (ld_b, Int<1 >{}, stride_b);
749746 typename KT::StrideSFA dSFA = KT::deduce_sfa_layout (problem_shape).stride ();
@@ -764,7 +761,7 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
764761 attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
765762 attrs[0 ].val .programmaticStreamSerializationAllowed = 1 ;
766763
767- launch_config.gridDim = GemmKernel::get_grid_shape (kernel_params );
764+ launch_config.gridDim = dim3 (num_device_sms, 1 , 1 );
768765 launch_config.blockDim = GemmKernel::get_block_shape ();
769766 launch_config.dynamicSmemBytes = GemmKernel::kSmemSize ;
770767 launch_config.stream = stream;
@@ -777,6 +774,35 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
777774 TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 gemm kernel runtime error: %s" , cudaGetErrorString (result));
778775}
779776
777+ void gemm_dispatch_sm120 (void * mat_a, void * mat_b, void * mat_d, float * scales_a, float * scales_b, uint32_t shape_m,
778+ uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs )
779+ {
780+ if (num_device_sms < 0 )
781+ {
782+ num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount ();
783+ }
784+
785+ auto * a = reinterpret_cast <__nv_fp8_e4m3*>(mat_a);
786+ auto * b = reinterpret_cast <__nv_fp8_e4m3*>(mat_b);
787+ auto * d = reinterpret_cast <__nv_bfloat16*>(mat_d);
788+ int64_t ld_a = shape_k;
789+ int64_t ld_b = shape_k;
790+ int64_t ld_d = shape_n;
791+ constexpr int64_t stride = 0 ;
792+ constexpr uint32_t num_problems = 1 ;
793+
794+ if (shape_m <= 64 )
795+ {
796+ launch_sm120_gemm_kernel<32 , 128 , 4 >(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
797+ scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
798+ }
799+ else
800+ {
801+ launch_sm120_gemm_kernel<64 , 128 , 4 >(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
802+ scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
803+ }
804+ }
805+
780806void fp8_gemm_run (__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
781807 uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, float * scales_a, float * scales_b, cudaStream_t stream)
782808{
@@ -866,6 +892,81 @@ void grouped_gemm_dispatch(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bflo
866892 }
867893}
868894
895+ void grouped_gemm_dispatch_sm120 (__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bfloat16* mat_d,
896+ uint32_t num_problems, int64_t const * problem_m_offsets, uint32_t expected_m, uint32_t max_shape_m,
897+ uint32_t max_shape_m_padded, uint32_t shape_n, uint32_t shape_k, float * scales_a, float * scales_b,
898+ cudaStream_t stream, int num_device_sms = kNumDeviceSMs )
899+ {
900+ if (num_device_sms < 0 )
901+ {
902+ num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount ();
903+ }
904+
905+ int64_t total_tokens = static_cast <int64_t >(max_shape_m);
906+ // max_shape_m_padded = (max_shape_m + num_problems * 31) / 32 * 32
907+ // m_padded = (total_tokens + num_problems * 3) / 4 * 4;
908+ // so we can promise m_padded < max_shape_m_padded
909+ int64_t m_padded = sm120_blockscaled_gemm::compute_padded_offset (max_shape_m, num_problems);
910+
911+ using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32 , 128 , 4 >;
912+ using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledMoeKernel<KT>;
913+ using Params = typename GemmKernel::Params;
914+ using Arguments = typename GemmKernel::Arguments;
915+ using ProblemShape = typename GemmKernel::ProblemShape;
916+
917+ ProblemShape problem_shape = make_shape (static_cast <int >(total_tokens), static_cast <int >(shape_n),
918+ static_cast <int >(shape_k), static_cast <int >(num_problems));
919+
920+ auto ptr_A = reinterpret_cast <typename KT::ElementA*>(mat_a);
921+ auto ptr_B = reinterpret_cast <typename KT::ElementB*>(mat_b);
922+ auto ptr_SFA = reinterpret_cast <typename KT::ElementSFLoad*>(scales_a);
923+ auto ptr_SFB = reinterpret_cast <typename KT::ElementSFLoad*>(scales_b);
924+ auto ptr_D = reinterpret_cast <typename KT::ElementD*>(mat_d);
925+
926+ int64_t ld_a = static_cast <int64_t >(shape_k);
927+ int64_t ld_b = static_cast <int64_t >(shape_k);
928+ int64_t ld_d = static_cast <int64_t >(shape_n);
929+ int64_t stride_a = total_tokens * ld_a;
930+ int64_t stride_b = static_cast <int64_t >(shape_n) * ld_b;
931+ int64_t stride_d = total_tokens * ld_d;
932+
933+ typename KT::StrideA dA = make_stride (ld_a, Int<1 >{}, stride_a);
934+ typename KT::StrideB dB = make_stride (ld_b, Int<1 >{}, stride_b);
935+ auto sfa_shape = make_shape (static_cast <int >(m_padded), static_cast <int >(shape_n), static_cast <int >(shape_k), 1 );
936+ typename KT::StrideSFA dSFA = KT::deduce_sfa_layout (sfa_shape).stride ();
937+ auto sfb_shape = make_shape (static_cast <int >(m_padded), static_cast <int >(shape_n), static_cast <int >(shape_k),
938+ static_cast <int >(num_problems));
939+ typename KT::StrideSFB dSFB = KT::deduce_sfb_layout (sfb_shape).stride ();
940+ typename KT::StrideD dD = make_stride (ld_d, Int<1 >{}, stride_d);
941+
942+ Arguments args
943+ = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD, const_cast <int64_t *>(problem_m_offsets)};
944+
945+ Params kernel_params = GemmKernel::to_underlying_arguments (problem_shape, args);
946+ auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
947+
948+ cudaFuncSetAttribute (kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize );
949+ auto result = cudaGetLastError ();
950+ TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 moe gemm kernel cannot launch: %s" , cudaGetErrorString (result));
951+
952+ cudaLaunchConfig_t launch_config;
953+ cudaLaunchAttribute attrs[1 ];
954+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
955+ attrs[0 ].val .programmaticStreamSerializationAllowed = 1 ;
956+
957+ launch_config.gridDim = dim3 (num_device_sms, 1 , 1 );
958+ launch_config.blockDim = GemmKernel::get_block_shape ();
959+ launch_config.dynamicSmemBytes = GemmKernel::kSmemSize ;
960+ launch_config.stream = stream;
961+ launch_config.attrs = attrs;
962+ launch_config.numAttrs = 1 ;
963+
964+ cudaLaunchKernelEx (&launch_config, kernel_ptr, kernel_params);
965+
966+ result = cudaGetLastError ();
967+ TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 moe gemm kernel runtime error: %s" , cudaGetErrorString (result));
968+ }
969+
869970void fp8_grouped_gemm_run (__nv_bfloat16 const * mat_a, __nv_fp8_e4m3* fp8_mat_a, float * scales_a,
870971 __nv_bfloat16 const * mat_b, __nv_fp8_e4m3* fp8_mat_b, float * scales_b, __nv_bfloat16* mat_d,
871972 int64_t const * problem_m_offsets, int num_problems, int64_t expected_m, int64_t max_shape_m,
@@ -877,6 +978,41 @@ void fp8_grouped_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a,
877978 kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount ();
878979 }
879980
981+ int arch = tensorrt_llm::common::getSMVersion ();
982+
983+ if (arch == 120 )
984+ {
985+ if (internal_quantize_a)
986+ {
987+ constexpr int WarpsPerBlock = 4 ;
988+ int num_k_blocks = div_up (shape_k, 512 );
989+ int64_t num_token_blocks = div_up (max_shape_m, static_cast <int64_t >(WarpsPerBlock));
990+ int64_t scale_leading_dim = sm120_blockscaled_gemm::compute_padded_offset (max_shape_m, num_problems);
991+
992+ constexpr int kBlocksPerSM = 8 ;
993+ int64_t max_blocks = static_cast <int64_t >(kNumDeviceSMs ) * kBlocksPerSM ;
994+ int num_blocks_y = static_cast <int >(std::min (num_token_blocks, max_blocks));
995+
996+ dim3 grid (num_k_blocks, num_blocks_y);
997+ dim3 block (WarpsPerBlock * 32 );
998+ int smem_size = (num_problems + 1 ) * sizeof (int64_t );
999+ auto scale_kernel
1000+ = sm120_blockscaled_gemm::scale_1x128_kernel_sm120<__nv_bfloat16, __nv_fp8_e4m3, WarpsPerBlock>;
1001+ cudaFuncSetAttribute (scale_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
1002+ scale_kernel<<<grid, block, smem_size, stream>>> (fp8_mat_a, reinterpret_cast <int32_t *>(scales_a), mat_a,
1003+ problem_m_offsets, num_problems, shape_k, scale_leading_dim);
1004+ }
1005+ if (internal_quantize_b)
1006+ {
1007+ TLLM_CHECK_WITH_INFO (false , " sm120 moe gemm kernel does not support internal_quantize_b" );
1008+ return ;
1009+ }
1010+
1011+ grouped_gemm_dispatch_sm120 (fp8_mat_a, fp8_mat_b, mat_d, num_problems, problem_m_offsets, expected_m,
1012+ max_shape_m, max_shape_m_padded, shape_n, shape_k, scales_a, scales_b, stream);
1013+ return ;
1014+ }
1015+
8801016 if (internal_quantize_a)
8811017 {
8821018 constexpr int NumThreads = 256 ;
@@ -994,54 +1130,17 @@ void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int strid
9941130 {
9951131 num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount ();
9961132 }
997- using ElementInput = cute::float_e4m3_t ;
998- using ElementOutput = cute::bfloat16_t ;
999- using ElementAccum = float ;
1000- using ElementBlockScale = int32_t ;
1001- using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32 , 128 >;
1002- using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
1003- using Params = typename GemmKernel::Params;
1004- using Arguments = typename GemmKernel::Arguments;
1005- using ProblemShape = typename GemmKernel::ProblemShape;
1006- ProblemShape problem_shape = make_shape ((int ) shape_m, (int ) shape_n, (int ) shape_k, (int ) num_problems);
1007-
1008- auto ptr_A = reinterpret_cast <ElementInput*>(mat_a);
1009- auto ptr_B = reinterpret_cast <ElementInput*>(mat_b);
1010- auto ptr_SFA = reinterpret_cast <ElementBlockScale*>(scales_a);
1011- auto ptr_SFB = reinterpret_cast <ElementBlockScale*>(scales_b);
1012- auto ptr_D = reinterpret_cast <ElementOutput*>(mat_d);
1013-
1014- typename KT::StrideA dA = make_stride (ld_a, Int<1 >{}, stride_a);
1015- typename KT::StrideB dB = make_stride (ld_b, Int<1 >{}, stride_b);
1016- typename KT::StrideSFA dSFA = KT::deduce_sfa_layout (problem_shape).stride ();
1017- typename KT::StrideSFB dSFB = KT::deduce_sfb_layout (problem_shape).stride ();
1018- typename KT::StrideD dD = make_stride (ld_d, Int<1 >{}, stride_d);
1019-
1020- Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD};
1021-
1022- Params kernel_params = GemmKernel::to_underlying_arguments (problem_shape, args);
1023- auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
1024-
1025- cudaFuncSetAttribute (kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize );
1026- auto result = cudaGetLastError ();
1027- TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 gemm kernel cannot launch: %s" , cudaGetErrorString (result));
1028-
1029- cudaLaunchConfig_t launch_config;
1030- cudaLaunchAttribute attrs[1 ];
1031- attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
1032- attrs[0 ].val .programmaticStreamSerializationAllowed = 1 ;
10331133
1034- launch_config.gridDim = GemmKernel::get_grid_shape (kernel_params);
1035- launch_config.blockDim = GemmKernel::get_block_shape ();
1036- launch_config.dynamicSmemBytes = GemmKernel::kSmemSize ;
1037- launch_config.stream = stream;
1038- launch_config.attrs = attrs;
1039- launch_config.numAttrs = 1 ;
1040-
1041- cudaLaunchKernelEx (&launch_config, kernel_ptr, kernel_params);
1042-
1043- result = cudaGetLastError ();
1044- TLLM_CHECK_WITH_INFO (result == cudaSuccess, " sm120 gemm kernel runtime error: %s" , cudaGetErrorString (result));
1134+ if (shape_m <= 64 )
1135+ {
1136+ launch_sm120_gemm_kernel<32 , 128 , 4 >(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1137+ scales_a, stride_scales_a, scales_b, 0 , num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1138+ }
1139+ else
1140+ {
1141+ launch_sm120_gemm_kernel<64 , 128 , 4 >(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1142+ scales_a, stride_scales_a, scales_b, 0 , num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1143+ }
10451144}
10461145
10471146void fp8_stride_batch_gemm_run (__nv_bfloat16 const * mat_a, __nv_fp8_e4m3* fp8_mat_a, float * scales_a, int ld_a,
0 commit comments