Skip to content

Commit 54cc308

Browse files
committed
[feat] add tileshape selection
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
1 parent 1cc45b1 commit 54cc308

File tree

2 files changed

+48
-60
lines changed

2 files changed

+48
-60
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh

Lines changed: 47 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -714,8 +714,11 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
714714
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel runtime error: %s", cudaGetErrorString(result));
715715
}
716716

717-
void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a, float* scales_b, uint32_t shape_m,
718-
uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
717+
template <int TileM, int TileN, int NumStages>
718+
void launch_sm120_gemm_kernel(__nv_fp8_e4m3* mat_a, int64_t ld_a, int64_t stride_a, __nv_fp8_e4m3* mat_b, int64_t ld_b,
719+
int64_t stride_b, __nv_bfloat16* mat_d, int64_t ld_d, int64_t stride_d, float* scales_a, int64_t stride_scales_a,
720+
float* scales_b, int64_t stride_scales_b, uint32_t num_problems, uint32_t shape_m, uint32_t shape_n,
721+
uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
719722
{
720723
if (num_device_sms < 0)
721724
{
@@ -725,25 +728,19 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
725728
using ElementOutput = cute::bfloat16_t;
726729
using ElementAccum = float;
727730
using ElementBlockScale = int32_t;
728-
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64, 128, 4>;
731+
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<TileM, TileN, NumStages>;
729732
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
730733
using Params = typename GemmKernel::Params;
731734
using Arguments = typename GemmKernel::Arguments;
732735
using ProblemShape = typename GemmKernel::ProblemShape;
733-
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, 1);
736+
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, (int) num_problems);
734737

735738
auto ptr_A = reinterpret_cast<ElementInput*>(mat_a);
736739
auto ptr_B = reinterpret_cast<ElementInput*>(mat_b);
737740
auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a);
738741
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
739742
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);
740743

741-
int64_t ld_a = static_cast<int64_t>(shape_k);
742-
int64_t ld_b = static_cast<int64_t>(shape_k);
743-
int64_t ld_d = static_cast<int64_t>(shape_n);
744-
int64_t stride_a = static_cast<int64_t>(shape_m) * ld_a;
745-
int64_t stride_b = static_cast<int64_t>(shape_n) * ld_b;
746-
int64_t stride_d = static_cast<int64_t>(shape_m) * ld_d;
747744
typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
748745
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
749746
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
@@ -777,6 +774,35 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
777774
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel runtime error: %s", cudaGetErrorString(result));
778775
}
779776

777+
void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a, float* scales_b, uint32_t shape_m,
778+
uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
779+
{
780+
if (num_device_sms < 0)
781+
{
782+
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
783+
}
784+
785+
auto* a = reinterpret_cast<__nv_fp8_e4m3*>(mat_a);
786+
auto* b = reinterpret_cast<__nv_fp8_e4m3*>(mat_b);
787+
auto* d = reinterpret_cast<__nv_bfloat16*>(mat_d);
788+
int64_t ld_a = shape_k;
789+
int64_t ld_b = shape_k;
790+
int64_t ld_d = shape_n;
791+
constexpr int64_t stride = 0;
792+
constexpr uint32_t num_problems = 1;
793+
794+
if (shape_m <= 64)
795+
{
796+
launch_sm120_gemm_kernel<32, 128, 4>(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
797+
scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
798+
}
799+
else
800+
{
801+
launch_sm120_gemm_kernel<64, 128, 4>(a, ld_a, stride, b, ld_b, stride, d, ld_d, stride, scales_a, stride,
802+
scales_b, stride, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
803+
}
804+
}
805+
780806
void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
781807
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, float* scales_a, float* scales_b, cudaStream_t stream)
782808
{
@@ -882,7 +908,7 @@ void grouped_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __n
882908
// so we can promise m_padded < max_shape_m_padded
883909
int64_t m_padded = sm120_blockscaled_gemm::compute_padded_offset(max_shape_m, num_problems);
884910

885-
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64, 128, 4>;
911+
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128, 4>;
886912
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledMoeKernel<KT>;
887913
using Params = typename GemmKernel::Params;
888914
using Arguments = typename GemmKernel::Arguments;
@@ -1104,54 +1130,17 @@ void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int strid
11041130
{
11051131
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
11061132
}
1107-
using ElementInput = cute::float_e4m3_t;
1108-
using ElementOutput = cute::bfloat16_t;
1109-
using ElementAccum = float;
1110-
using ElementBlockScale = int32_t;
1111-
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64, 128, 4>;
1112-
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
1113-
using Params = typename GemmKernel::Params;
1114-
using Arguments = typename GemmKernel::Arguments;
1115-
using ProblemShape = typename GemmKernel::ProblemShape;
1116-
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, (int) num_problems);
1117-
1118-
auto ptr_A = reinterpret_cast<ElementInput*>(mat_a);
1119-
auto ptr_B = reinterpret_cast<ElementInput*>(mat_b);
1120-
auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a);
1121-
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
1122-
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);
1123-
1124-
typename KT::StrideA dA = make_stride(static_cast<int64_t>(ld_a), Int<1>{}, static_cast<int64_t>(stride_a));
1125-
typename KT::StrideB dB = make_stride(static_cast<int64_t>(ld_b), Int<1>{}, static_cast<int64_t>(stride_b));
1126-
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
1127-
typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(problem_shape).stride();
1128-
typename KT::StrideD dD = make_stride(static_cast<int64_t>(ld_d), Int<1>{}, static_cast<int64_t>(stride_d));
1129-
1130-
Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD};
1131-
1132-
Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args);
1133-
auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
1134-
1135-
cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize);
1136-
auto result = cudaGetLastError();
1137-
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel cannot launch: %s", cudaGetErrorString(result));
1138-
1139-
cudaLaunchConfig_t launch_config;
1140-
cudaLaunchAttribute attrs[1];
1141-
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
1142-
attrs[0].val.programmaticStreamSerializationAllowed = 1;
1143-
1144-
launch_config.gridDim = dim3(num_device_sms, 1, 1);
1145-
launch_config.blockDim = GemmKernel::get_block_shape();
1146-
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
1147-
launch_config.stream = stream;
1148-
launch_config.attrs = attrs;
1149-
launch_config.numAttrs = 1;
11501133

1151-
cudaLaunchKernelEx(&launch_config, kernel_ptr, kernel_params);
1152-
1153-
result = cudaGetLastError();
1154-
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel runtime error: %s", cudaGetErrorString(result));
1134+
if (shape_m <= 64)
1135+
{
1136+
launch_sm120_gemm_kernel<32, 128, 4>(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1137+
scales_a, stride_scales_a, scales_b, 0, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1138+
}
1139+
else
1140+
{
1141+
launch_sm120_gemm_kernel<64, 128, 4>(mat_a, ld_a, stride_a, mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
1142+
scales_a, stride_scales_a, scales_b, 0, num_problems, shape_m, shape_n, shape_k, stream, num_device_sms);
1143+
}
11551144
}
11561145

11571146
void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a, int ld_a,

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from mpi4py import MPI
1919
from mpi4py.futures import MPIPoolExecutor
2020
from transformers.configuration_utils import PretrainedConfig
21-
from utils.util import (check_accuracy, getSMVersion, skip_blackwell,
22-
skip_blackwell_geforce,
21+
from utils.util import (check_accuracy, skip_blackwell, skip_blackwell_geforce,
2322
skip_neither_ada_nor_hopper_unittest, skip_no_hopper,
2423
skip_pre_blackwell, skip_pre_hopper)
2524

0 commit comments

Comments
 (0)