Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,23 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
}
}

int arch = tensorrt_llm::common::getSMVersion();
if (arch == 120)
{
if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
{
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
nullptr, fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets,
num_problems, expected_m, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
internal_quantize_a, internal_quantize_b);
}
else
{
TLLM_THROW("sm120 fp8 blockscale moe gemm only supports ElementA=bfloat16, ElementB=fp8_e4m3.");
}
return;
}

#ifdef COMPILE_HOPPER_TMA_GEMMS
if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_bfloat16>)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "fp8_blockscale_mma_utils.cuh"
#include "fp8_blockscale_tma_utils.cuh"
#include "sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh"
#include "sm120_blockwise_gemm/sm120_fp8_moe_gemm_1d1d.cuh"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
Expand Down Expand Up @@ -724,7 +725,7 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
using ElementOutput = cute::bfloat16_t;
using ElementAccum = float;
using ElementBlockScale = int32_t;
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128>;
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64, 128, 4>;
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
using Params = typename GemmKernel::Params;
using Arguments = typename GemmKernel::Arguments;
Expand All @@ -737,13 +738,12 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);

int32_t ld_a = shape_k;
int32_t stride_a = shape_m * shape_k;
int32_t ld_b = shape_k;
int32_t stride_b = shape_n * shape_k;
int32_t ld_d = shape_n;
int32_t stride_d = shape_m * shape_n;

int64_t ld_a = static_cast<int64_t>(shape_k);
int64_t ld_b = static_cast<int64_t>(shape_k);
int64_t ld_d = static_cast<int64_t>(shape_n);
int64_t stride_a = static_cast<int64_t>(shape_m) * ld_a;
int64_t stride_b = static_cast<int64_t>(shape_n) * ld_b;
int64_t stride_d = static_cast<int64_t>(shape_m) * ld_d;
typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
Expand All @@ -764,7 +764,7 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;

launch_config.gridDim = GemmKernel::get_grid_shape(kernel_params);
launch_config.gridDim = dim3(num_device_sms, 1, 1);
launch_config.blockDim = GemmKernel::get_block_shape();
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
launch_config.stream = stream;
Expand Down Expand Up @@ -866,6 +866,81 @@ void grouped_gemm_dispatch(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bflo
}
}

void grouped_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bfloat16* mat_d,
uint32_t num_problems, int64_t const* problem_m_offsets, uint32_t expected_m, uint32_t max_shape_m,
uint32_t max_shape_m_padded, uint32_t shape_n, uint32_t shape_k, float* scales_a, float* scales_b,
cudaStream_t stream, int num_device_sms = kNumDeviceSMs)
{
if (num_device_sms < 0)
{
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
}

int64_t total_tokens = static_cast<int64_t>(max_shape_m);
// max_shape_m_padded = (max_shape_m + num_problems * 31) / 32 * 32
// m_padded = (total_tokens + num_problems * 3) / 4 * 4;
// so we can promise m_padded < max_shape_m_padded
int64_t m_padded = sm120_blockscaled_gemm::compute_padded_offset(max_shape_m, num_problems);

using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64, 128, 4>;
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledMoeKernel<KT>;
using Params = typename GemmKernel::Params;
using Arguments = typename GemmKernel::Arguments;
using ProblemShape = typename GemmKernel::ProblemShape;

ProblemShape problem_shape = make_shape(static_cast<int>(total_tokens), static_cast<int>(shape_n),
static_cast<int>(shape_k), static_cast<int>(num_problems));

auto ptr_A = reinterpret_cast<typename KT::ElementA*>(mat_a);
auto ptr_B = reinterpret_cast<typename KT::ElementB*>(mat_b);
auto ptr_SFA = reinterpret_cast<typename KT::ElementSFLoad*>(scales_a);
auto ptr_SFB = reinterpret_cast<typename KT::ElementSFLoad*>(scales_b);
auto ptr_D = reinterpret_cast<typename KT::ElementD*>(mat_d);

int64_t ld_a = static_cast<int64_t>(shape_k);
int64_t ld_b = static_cast<int64_t>(shape_k);
int64_t ld_d = static_cast<int64_t>(shape_n);
int64_t stride_a = total_tokens * ld_a;
int64_t stride_b = static_cast<int64_t>(shape_n) * ld_b;
int64_t stride_d = total_tokens * ld_d;

typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
auto sfa_shape = make_shape(static_cast<int>(m_padded), static_cast<int>(shape_n), static_cast<int>(shape_k), 1);
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(sfa_shape).stride();
auto sfb_shape = make_shape(static_cast<int>(m_padded), static_cast<int>(shape_n), static_cast<int>(shape_k),
static_cast<int>(num_problems));
typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(sfb_shape).stride();
typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d);

Arguments args
= {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD, const_cast<int64_t*>(problem_m_offsets)};

Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args);
auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;

cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 moe gemm kernel cannot launch: %s", cudaGetErrorString(result));

cudaLaunchConfig_t launch_config;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;

launch_config.gridDim = dim3(num_device_sms, 1, 1);
launch_config.blockDim = GemmKernel::get_block_shape();
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
launch_config.stream = stream;
launch_config.attrs = attrs;
launch_config.numAttrs = 1;

cudaLaunchKernelEx(&launch_config, kernel_ptr, kernel_params);

result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 moe gemm kernel runtime error: %s", cudaGetErrorString(result));
}

void fp8_grouped_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a,
__nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, float* scales_b, __nv_bfloat16* mat_d,
int64_t const* problem_m_offsets, int num_problems, int64_t expected_m, int64_t max_shape_m,
Expand All @@ -877,6 +952,36 @@ void fp8_grouped_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a,
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
}

int arch = tensorrt_llm::common::getSMVersion();

if (arch == 120)
{
if (internal_quantize_a)
{
constexpr int WarpsPerBlock = 4;
int num_k_blocks = div_up(shape_k, 512);
int num_token_blocks = div_up(max_shape_m, static_cast<int64_t>(WarpsPerBlock));
int64_t scale_leading_dim = sm120_blockscaled_gemm::compute_padded_offset(max_shape_m, num_problems);
dim3 grid(num_k_blocks, num_token_blocks);
dim3 block(WarpsPerBlock * 32);
int smem_size = (num_problems + 1) * sizeof(int64_t);
auto scale_kernel
= sm120_blockscaled_gemm::scale_1x128_kernel_sm120<__nv_bfloat16, __nv_fp8_e4m3, WarpsPerBlock>;
cudaFuncSetAttribute(scale_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
scale_kernel<<<grid, block, smem_size, stream>>>(fp8_mat_a, reinterpret_cast<int32_t*>(scales_a), mat_a,
problem_m_offsets, num_problems, shape_k, scale_leading_dim);
}
if (internal_quantize_b)
{
TLLM_CHECK_WITH_INFO(false, "sm120 moe gemm kernel does not support internal_quantize_b");
return;
}

grouped_gemm_dispatch_sm120(fp8_mat_a, fp8_mat_b, mat_d, num_problems, problem_m_offsets, expected_m,
max_shape_m, max_shape_m_padded, shape_n, shape_k, scales_a, scales_b, stream);
return;
}

if (internal_quantize_a)
{
constexpr int NumThreads = 256;
Expand Down Expand Up @@ -998,7 +1103,7 @@ void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int strid
using ElementOutput = cute::bfloat16_t;
using ElementAccum = float;
using ElementBlockScale = int32_t;
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128>;
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<64, 128, 4>;
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
using Params = typename GemmKernel::Params;
using Arguments = typename GemmKernel::Arguments;
Expand All @@ -1011,11 +1116,11 @@ void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int strid
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);

typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
typename KT::StrideA dA = make_stride(static_cast<int64_t>(ld_a), Int<1>{}, static_cast<int64_t>(stride_a));
typename KT::StrideB dB = make_stride(static_cast<int64_t>(ld_b), Int<1>{}, static_cast<int64_t>(stride_b));
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(problem_shape).stride();
typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d);
typename KT::StrideD dD = make_stride(static_cast<int64_t>(ld_d), Int<1>{}, static_cast<int64_t>(stride_d));

Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD};

Expand All @@ -1031,7 +1136,7 @@ void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int strid
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;

launch_config.gridDim = GemmKernel::get_grid_shape(kernel_params);
launch_config.gridDim = dim3(num_device_sms, 1, 1);
launch_config.blockDim = GemmKernel::get_block_shape();
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
launch_config.stream = stream;
Expand Down
Loading