Skip to content

Commit c0cf5a3

Browse files
authored
[None][feat] Optimize 6KD fp8 blockscale gemm (#11502)
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
1 parent 0507609 commit c0cf5a3

File tree

13 files changed

+1680
-339
lines changed

13 files changed

+1680
-339
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,23 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
136136
}
137137
}
138138

139+
int arch = tensorrt_llm::common::getSMVersion();
140+
if (arch == 120)
141+
{
142+
if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
143+
{
144+
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
145+
nullptr, fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets,
146+
num_problems, expected_m, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
147+
internal_quantize_a, internal_quantize_b);
148+
}
149+
else
150+
{
151+
TLLM_THROW("sm120 fp8 blockscale moe gemm only supports ElementA=bfloat16, ElementB=fp8_e4m3.");
152+
}
153+
return;
154+
}
155+
139156
#ifdef COMPILE_HOPPER_TMA_GEMMS
140157
if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_bfloat16>)
141158
{

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh

Lines changed: 158 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "fp8_blockscale_mma_utils.cuh"
3232
#include "fp8_blockscale_tma_utils.cuh"
3333
#include "sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh"
34+
#include "sm120_blockwise_gemm/sm120_fp8_moe_gemm_1d1d.cuh"
3435
#include "tensorrt_llm/common/config.h"
3536
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
3637
#include "tensorrt_llm/common/cudaUtils.h"
@@ -713,8 +714,11 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
713714
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel runtime error: %s", cudaGetErrorString(result));
714715
}
715716

716-
void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a, float* scales_b, uint32_t shape_m,
717-
uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
717+
template <int TileM, int TileN, int NumStages>
718+
void launch_sm120_gemm_kernel(__nv_fp8_e4m3* mat_a, int64_t ld_a, int64_t stride_a, __nv_fp8_e4m3* mat_b, int64_t ld_b,
719+
int64_t stride_b, __nv_bfloat16* mat_d, int64_t ld_d, int64_t stride_d, float* scales_a, int64_t stride_scales_a,
720+
float* scales_b, int64_t stride_scales_b, uint32_t num_problems, uint32_t shape_m, uint32_t shape_n,
721+
uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
718722
{
719723
if (num_device_sms < 0)
720724
{
@@ -724,26 +728,19 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
724728
using ElementOutput = cute::bfloat16_t;
725729
using ElementAccum = float;
726730
using ElementBlockScale = int32_t;
727-
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128>;
731+
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<TileM, TileN, NumStages>;
728732
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
729733
using Params = typename GemmKernel::Params;
730734
using Arguments = typename GemmKernel::Arguments;
731735
using ProblemShape = typename GemmKernel::ProblemShape;
732-
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, 1);
736+
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, (int) num_problems);
733737

734738
auto ptr_A = reinterpret_cast<ElementInput*>(mat_a);
735739
auto ptr_B = reinterpret_cast<ElementInput*>(mat_b);
736740
auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a);
737741
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
738742
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);
739743

740-
int32_t ld_a = shape_k;
741-
int32_t stride_a = shape_m * shape_k;
742-
int32_t ld_b = shape_k;
743-
int32_t stride_b = shape_n * shape_k;
744-
int32_t ld_d = shape_n;
745-
int32_t stride_d = shape_m * shape_n;
746-
747744
typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
748745
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
749746
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
@@ -764,7 +761,7 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
764761
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
765762
attrs[0].val.programmaticStreamSerializationAllowed = 1;
766763

767-
launch_config.gridDim = GemmKernel::get_grid_shape(kernel_params);
764+
launch_config.gridDim = dim3(num_device_sms, 1, 1);
768765
launch_config.blockDim = GemmKernel::get_block_shape();
769766
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
770767
launch_config.stream = stream;
@@ -777,6 +774,35 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
777774
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel runtime error: %s", cudaGetErrorString(result));
778775
}
779776

777+
void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a, float* scales_b, uint32_t shape_m,
778+
uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
779+
{
780+
if (num_device_sms < 0)
781+
{
782+
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
783+
}
784+
785+
auto* a = reinterpret_cast<__nv_fp8_e4m3*>(mat_a);
786+
auto* b = reinterpret_cast<__nv_fp8_e4m3*>(mat_b);
787+
auto* d = reinterpret_cast<__nv_bfloat16*>(mat_d);
788+
int64_t ld_a = shape_k;
789+
int64_t ld_b = shape_k;
790+
int64_t ld_d = shape_n;
791+
constexpr int64_t stride = 0;
792+
constexpr uint32_t num_problems = 1;
793+
794+
if (shape_m <= 64)
795+
{
796+
launch_sm120_gemm_kernel<32, 128, 4>(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
797+
scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
798+
}
799+
else
800+
{
801+
launch_sm120_gemm_kernel<64, 128, 4>(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
802+
scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
803+
}
804+
}
805+
780806
void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
781807
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, float* scales_a, float* scales_b, cudaStream_t stream)
782808
{
@@ -866,6 +892,81 @@ void grouped_gemm_dispatch(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bflo
866892
}
867893
}
868894

895+
void grouped_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bfloat16* mat_d,
896+
uint32_t num_problems, int64_t const* problem_m_offsets, uint32_t expected_m, uint32_t max_shape_m,
897+
uint32_t max_shape_m_padded, uint32_t shape_n, uint32_t shape_k, float* scales_a, float* scales_b,
898+
cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
899+
{
900+
if (num_device_sms < 0)
901+
{
902+
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
903+
}
904+
905+
int64_t total_tokens = static_cast<int64_t>(max_shape_m);
906+
// max_shape_m_padded = (max_shape_m + num_problems * 31) / 32 * 32
907+
// m_padded = (total_tokens + num_problems * 3) / 4 * 4;
908+
// so we can promise m_padded < max_shape_m_padded
909+
int64_t m_padded = sm120_blockscaled_gemm::compute_padded_offset(max_shape_m, num_problems);
910+
911+
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128, 4>;
912+
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledMoeKernel<KT>;
913+
using Params = typename GemmKernel::Params;
914+
using Arguments = typename GemmKernel::Arguments;
915+
using ProblemShape = typename GemmKernel::ProblemShape;
916+
917+
ProblemShape problem_shape = make_shape(static_cast<int>(total_tokens), static_cast<int>(shape_n),
918+
static_cast<int>(shape_k), static_cast<int>(num_problems));
919+
920+
auto ptr_A = reinterpret_cast<typename KT::ElementA*>(mat_a);
921+
auto ptr_B = reinterpret_cast<typename KT::ElementB*>(mat_b);
922+
auto ptr_SFA = reinterpret_cast<typename KT::ElementSFLoad*>(scales_a);
923+
auto ptr_SFB = reinterpret_cast<typename KT::ElementSFLoad*>(scales_b);
924+
auto ptr_D = reinterpret_cast<typename KT::ElementD*>(mat_d);
925+
926+
int64_t ld_a = static_cast<int64_t>(shape_k);
927+
int64_t ld_b = static_cast<int64_t>(shape_k);
928+
int64_t ld_d = static_cast<int64_t>(shape_n);
929+
int64_t stride_a = total_tokens * ld_a;
930+
int64_t stride_b = static_cast<int64_t>(shape_n) * ld_b;
931+
int64_t stride_d = total_tokens * ld_d;
932+
933+
typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
934+
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
935+
auto sfa_shape = make_shape(static_cast<int>(m_padded), static_cast<int>(shape_n), static_cast<int>(shape_k), 1);
936+
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(sfa_shape).stride();
937+
auto sfb_shape = make_shape(static_cast<int>(m_padded), static_cast<int>(shape_n), static_cast<int>(shape_k),
938+
static_cast<int>(num_problems));
939+
typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(sfb_shape).stride();
940+
typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d);
941+
942+
Arguments args
943+
= {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD, const_cast<int64_t*>(problem_m_offsets)};
944+
945+
Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args);
946+
auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
947+
948+
cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize);
949+
auto result = cudaGetLastError();
950+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 moe gemm kernel cannot launch: %s", cudaGetErrorString(result));
951+
952+
cudaLaunchConfig_t launch_config;
953+
cudaLaunchAttribute attrs[1];
954+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
955+
attrs[0].val.programmaticStreamSerializationAllowed = 1;
956+
957+
launch_config.gridDim = dim3(num_device_sms, 1, 1);
958+
launch_config.blockDim = GemmKernel::get_block_shape();
959+
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
960+
launch_config.stream = stream;
961+
launch_config.attrs = attrs;
962+
launch_config.numAttrs = 1;
963+
964+
cudaLaunchKernelEx(&launch_config, kernel_ptr, kernel_params);
965+
966+
result = cudaGetLastError();
967+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 moe gemm kernel runtime error: %s", cudaGetErrorString(result));
968+
}
969+
869970
void fp8_grouped_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a,
870971
__nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, float* scales_b, __nv_bfloat16* mat_d,
871972
int64_t const* problem_m_offsets, int num_problems, int64_t expected_m, int64_t max_shape_m,
@@ -877,6 +978,41 @@ void fp8_grouped_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a,
877978
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
878979
}
879980

981+
int arch = tensorrt_llm::common::getSMVersion();
982+
983+
if (arch == 120)
984+
{
985+
if (internal_quantize_a)
986+
{
987+
constexpr int WarpsPerBlock = 4;
988+
int num_k_blocks = div_up(shape_k, 512);
989+
int64_t num_token_blocks = div_up(max_shape_m, static_cast<int64_t>(WarpsPerBlock));
990+
int64_t scale_leading_dim = sm120_blockscaled_gemm::compute_padded_offset(max_shape_m, num_problems);
991+
992+
constexpr int kBlocksPerSM = 8;
993+
int64_t max_blocks = static_cast<int64_t>(kNumDeviceSMs) * kBlocksPerSM;
994+
int num_blocks_y = static_cast<int>(std::min(num_token_blocks, max_blocks));
995+
996+
dim3 grid(num_k_blocks, num_blocks_y);
997+
dim3 block(WarpsPerBlock * 32);
998+
int smem_size = (num_problems + 1) * sizeof(int64_t);
999+
auto scale_kernel
1000+
= sm120_blockscaled_gemm::scale_1x128_kernel_sm120<__nv_bfloat16, __nv_fp8_e4m3, WarpsPerBlock>;
1001+
cudaFuncSetAttribute(scale_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
1002+
scale_kernel<<<grid, block, smem_size, stream>>>(fp8_mat_a, reinterpret_cast<int32_t*>(scales_a), mat_a,
1003+
problem_m_offsets, num_problems, shape_k, scale_leading_dim);
1004+
}
1005+
if (internal_quantize_b)
1006+
{
1007+
TLLM_CHECK_WITH_INFO(false, "sm120 moe gemm kernel does not support internal_quantize_b");
1008+
return;
1009+
}
1010+
1011+
grouped_gemm_dispatch_sm120(fp8_mat_a, fp8_mat_b, mat_d, num_problems, problem_m_offsets, expected_m,
1012+
max_shape_m, max_shape_m_padded, shape_n, shape_k, scales_a, scales_b, stream);
1013+
return;
1014+
}
1015+
8801016
if (internal_quantize_a)
8811017
{
8821018
constexpr int NumThreads = 256;
@@ -994,54 +1130,17 @@ void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int strid
9941130
{
9951131
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
9961132
}
997-
using ElementInput = cute::float_e4m3_t;
998-
using ElementOutput = cute::bfloat16_t;
999-
using ElementAccum = float;
1000-
using ElementBlockScale = int32_t;
1001-
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128>;
1002-
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
1003-
using Params = typename GemmKernel::Params;
1004-
using Arguments = typename GemmKernel::Arguments;
1005-
using ProblemShape = typename GemmKernel::ProblemShape;
1006-
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, (int) num_problems);
1007-
1008-
auto ptr_A = reinterpret_cast<ElementInput*>(mat_a);
1009-
auto ptr_B = reinterpret_cast<ElementInput*>(mat_b);
1010-
auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a);
1011-
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
1012-
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);
1013-
1014-
typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
1015-
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
1016-
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
1017-
typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(problem_shape).stride();
1018-
typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d);
1019-
1020-
Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD};
1021-
1022-
Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args);
1023-
auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
1024-
1025-
cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize);
1026-
auto result = cudaGetLastError();
1027-
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel cannot launch: %s", cudaGetErrorString(result));
1028-
1029-
cudaLaunchConfig_t launch_config;
1030-
cudaLaunchAttribute attrs[1];
1031-
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
1032-
attrs[0].val.programmaticStreamSerializationAllowed = 1;
10331133

1034-
launch_config.gridDim = GemmKernel::get_grid_shape(kernel_params);
1035-
launch_config.blockDim = GemmKernel::get_block_shape();
1036-
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
1037-
launch_config.stream = stream;
1038-
launch_config.attrs = attrs;
1039-
launch_config.numAttrs = 1;
1040-
1041-
cudaLaunchKernelEx(&launch_config, kernel_ptr, kernel_params);
1042-
1043-
result = cudaGetLastError();
1044-
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel runtime error: %s", cudaGetErrorString(result));
1134+
if (shape_m <= 64)
1135+
{
1136+
launch_sm120_gemm_kernel<32, 128, 4>(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1137+
scales_a, stride_scales_a, scales_b, 0, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1138+
}
1139+
else
1140+
{
1141+
launch_sm120_gemm_kernel<64, 128, 4>(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1142+
scales_a, stride_scales_a, scales_b, 0, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1143+
}
10451144
}
10461145

10471146
void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a, int ld_a,

0 commit comments

Comments
 (0)