Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions cpp/tensorrt_llm/kernels/fusedCatFp8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "fusedCatFp8.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"

#include <cuda_bf16.h>
#include <cuda_fp8.h>

#include <cfloat>
#include <cmath>
#include <cstdint>

TRTLLM_NAMESPACE_BEGIN

namespace kernels
{

namespace
{

// Constants
constexpr int HEAD_DIM = 128; // Fixed for DSV3.2 indexer
constexpr int WARP_SIZE = 32; // One warp per row
constexpr int ELEMS_PER_THREAD = 4; // 128 / 32 = 4 elements per thread
constexpr int ROWS_PER_BLOCK = 8; // Process 8 rows per block for occupancy
constexpr float INV_FP8_E4M3_MAX = 1.0f / 448.0f;
constexpr float MIN_AMAX = 1.0e-12f;

/// Warp-wide max reduction
__device__ __forceinline__ float warpReduceMax(float val)
{
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1)
{
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset));
}
return val;
}

/// Helper union for vectorized BF16 loads (4 BF16 values = 8 bytes).
union BF16x4
{
int2 vec;
__nv_bfloat162 bf16x2[2];
};

/// Helper union for vectorized FP8 stores (4 FP8 values = 4 bytes).
union FP8x4
{
uint32_t u32;
__nv_fp8_e4m3 fp8[4];
};

/// Fused kernel: cat + FP8 quantization.
///
/// Grid: (ceil(M / ROWS_PER_BLOCK),)
/// Block: (WARP_SIZE * ROWS_PER_BLOCK,) i.e., (256,)
///
/// Each warp handles one row. Within a warp:
/// - Thread t handles elements [4t, 4t+1, 4t+2, 4t+3] of the 128-dim row.
/// - Loads from pe or nope based on element index (vectorized 8-byte loads).
/// - FP8 quantizes with per-row scale (vectorized 4-byte stores).
///
/// Templated on UseUe8m0 to eliminate branch divergence.
template <bool UseUe8m0>
__global__ __launch_bounds__(WARP_SIZE* ROWS_PER_BLOCK) void fusedCatFp8Kernel(__nv_fp8_e4m3* __restrict__ fp8_out,
float* __restrict__ scale_out, __nv_bfloat16 const* __restrict__ pe, __nv_bfloat16 const* __restrict__ nope,
int32_t M, int32_t pe_dim, int32_t nope_dim, int32_t pe_row_stride, int32_t nope_row_stride)
{
int warp_in_block = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
int row = blockIdx.x * ROWS_PER_BLOCK + warp_in_block;

if (row >= M)
{
return;
}

// ---- Stage 1: Load + Concat (vectorized 8-byte loads) ----
// pe_dim is guaranteed to be a multiple of ELEMS_PER_THREAD by the host check,
// so each thread's 4 elements come entirely from pe or entirely from nope.
// Use branchless pointer selection (compiles to SELP) to avoid warp divergence.
float v0, v1, v2, v3;
{
int base = lane * ELEMS_PER_THREAD;
__nv_bfloat16 const* pe_row = pe + static_cast<int64_t>(row) * pe_row_stride;
__nv_bfloat16 const* nope_row = nope + static_cast<int64_t>(row) * nope_row_stride;

bool from_pe = (base < pe_dim);
__nv_bfloat16 const* src = from_pe ? pe_row : nope_row;
int col = from_pe ? base : (base - pe_dim);

BF16x4 loaded;
loaded.vec = *reinterpret_cast<int2 const*>(src + col);

float2 f0 = __bfloat1622float2(loaded.bf16x2[0]);
float2 f1 = __bfloat1622float2(loaded.bf16x2[1]);
v0 = f0.x;
v1 = f0.y;
v2 = f1.x;
v3 = f1.y;
}

// ---- Stage 2: FP8 Quantization (1x128 block = entire row) ----
float local_max = fmaxf(fmaxf(fabsf(v0), fabsf(v1)), fmaxf(fabsf(v2), fabsf(v3)));
float amax = warpReduceMax(local_max);
amax = fmaxf(amax, MIN_AMAX);

float scale;
if constexpr (UseUe8m0)
{
// UE8M0: scale = 2^ceil(log2(amax / FP8_MAX)) via IEEE 754 bit manipulation.
// This replaces ceilf(log2f(...)) + exp2f(...) with integer ops.
float ratio = amax * INV_FP8_E4M3_MAX;
uint32_t bits = __float_as_uint(ratio);
uint32_t mantissa = bits & 0x007FFFFFu;
uint32_t exp_bits = bits & 0x7F800000u;
// If mantissa is non-zero, round exponent up to next power of 2
if (mantissa != 0u)
{
exp_bits += 0x00800000u;
}
scale = __uint_as_float(exp_bits);
}
else
{
scale = amax * INV_FP8_E4M3_MAX;
}

// Use hardware approximate reciprocal (MUFU.RCP, ~2^-23 relative error).
// This is more than sufficient for FP8 E4M3 quantization (3 mantissa bits).
// Avoids the expensive Newton-Raphson refinement of __frcp_rn.
float inv_scale;
asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(inv_scale) : "f"(scale));

// Quantize to FP8 — clamp is mathematically redundant since
// |val/scale| <= amax/scale <= FP8_MAX by construction, but kept
// for safety against floating-point rounding edge cases.
auto quantize = [&](float val) -> __nv_fp8_e4m3
{
float scaled = val * inv_scale;
return __nv_fp8_e4m3(scaled);
};

// ---- Stage 3: Store (vectorized 4-byte FP8 store) ----
FP8x4 packed;
packed.fp8[0] = quantize(v0);
packed.fp8[1] = quantize(v1);
packed.fp8[2] = quantize(v2);
packed.fp8[3] = quantize(v3);

int base_out = row * HEAD_DIM + lane * ELEMS_PER_THREAD;
*reinterpret_cast<uint32_t*>(fp8_out + base_out) = packed.u32;

if (lane == 0)
{
scale_out[row] = scale;
}
}

} // anonymous namespace

void invokeFusedCatFp8(__nv_fp8_e4m3* fp8_out, float* scale_out, __nv_bfloat16 const* pe, __nv_bfloat16 const* nope,
int32_t M, int32_t pe_dim, int32_t nope_dim, int32_t head_dim, int32_t pe_row_stride, int32_t nope_row_stride,
bool use_ue8m0, cudaStream_t stream)
{
if (M == 0)
{
return;
}

TLLM_CHECK_WITH_INFO(head_dim == HEAD_DIM, "fusedCatFp8: head_dim must be 128, got %d", head_dim);
TLLM_CHECK_WITH_INFO(pe_dim + nope_dim == head_dim, "fusedCatFp8: pe_dim (%d) + nope_dim (%d) != head_dim (%d)",
pe_dim, nope_dim, head_dim);
TLLM_CHECK_WITH_INFO((head_dim & (head_dim - 1)) == 0, "fusedCatFp8: head_dim must be power of 2");
TLLM_CHECK_WITH_INFO(pe_dim % ELEMS_PER_THREAD == 0,
"fusedCatFp8: pe_dim (%d) must be a multiple of %d for vectorized access", pe_dim, ELEMS_PER_THREAD);
TLLM_CHECK_WITH_INFO(
pe_row_stride >= pe_dim, "fusedCatFp8: pe_row_stride (%d) must be >= pe_dim (%d)", pe_row_stride, pe_dim);
TLLM_CHECK_WITH_INFO(nope_row_stride >= nope_dim, "fusedCatFp8: nope_row_stride (%d) must be >= nope_dim (%d)",
nope_row_stride, nope_dim);
TLLM_CHECK_WITH_INFO(pe_row_stride % ELEMS_PER_THREAD == 0,
"fusedCatFp8: pe_row_stride (%d) must be a multiple of %d for aligned vectorized access", pe_row_stride,
ELEMS_PER_THREAD);
TLLM_CHECK_WITH_INFO(nope_row_stride % ELEMS_PER_THREAD == 0,
"fusedCatFp8: nope_row_stride (%d) must be a multiple of %d for aligned vectorized access", nope_row_stride,
ELEMS_PER_THREAD);

int num_blocks = (M + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK;
dim3 grid(num_blocks);
dim3 block(WARP_SIZE * ROWS_PER_BLOCK); // 256 threads per block

if (use_ue8m0)
{
fusedCatFp8Kernel<true><<<grid, block, 0, stream>>>(
fp8_out, scale_out, pe, nope, M, pe_dim, nope_dim, pe_row_stride, nope_row_stride);
}
else
{
fusedCatFp8Kernel<false><<<grid, block, 0, stream>>>(
fp8_out, scale_out, pe, nope, M, pe_dim, nope_dim, pe_row_stride, nope_row_stride);
}

TLLM_CUDA_CHECK(cudaGetLastError());
}

} // namespace kernels

TRTLLM_NAMESPACE_END
63 changes: 63 additions & 0 deletions cpp/tensorrt_llm/kernels/fusedCatFp8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"

#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>

TRTLLM_NAMESPACE_BEGIN

namespace kernels
{

/// Fused concat + FP8 1x128 quantization.
///
/// Given two BF16 input matrices `pe` [M, pe_dim] and `nope` [M, nope_dim],
/// this kernel concatenates them along the last dimension (pe first, nope second),
/// then quantizes each row to FP8 E4M3 with one scale factor per row.
///
/// Inputs need not be fully contiguous — only the innermost dimension must be
/// contiguous (stride 1). The row stride for each input is provided explicitly
/// via pe_row_stride / nope_row_stride, which allows processing non-contiguous
/// views (e.g. from torch.split()) without a prior contiguous copy.
///
/// @param fp8_out Output FP8 data [M, head_dim], row-major.
/// @param scale_out Output scales [M, 1], float32. When use_ue8m0 is true,
/// the scale is stored as UE8M0 (power-of-two) in float bits.
/// @param pe Input PE part, BF16. Each row has pe_dim contiguous elements.
/// @param nope Input non-PE part, BF16. Each row has nope_dim contiguous elements.
/// @param M Number of rows (product of all dims except the last).
/// @param pe_dim Dimension of PE input (must satisfy pe_dim + nope_dim == head_dim).
/// @param nope_dim Dimension of non-PE input.
/// @param head_dim Total head dimension (must be 128, power of 2).
/// @param pe_row_stride Stride (in elements) between consecutive rows of pe.
/// For contiguous layout this equals pe_dim; for non-contiguous
/// views (e.g. from torch.split) it may be larger.
/// @param nope_row_stride Stride (in elements) between consecutive rows of nope.
/// @param use_ue8m0 If true, use UE8M0 (power-of-two) scale format.
/// @param stream CUDA stream.
void invokeFusedCatFp8(__nv_fp8_e4m3* fp8_out, float* scale_out, __nv_bfloat16 const* pe, __nv_bfloat16 const* nope,
int32_t M, int32_t pe_dim, int32_t nope_dim, int32_t head_dim, int32_t pe_row_stride, int32_t nope_row_stride,
bool use_ue8m0, cudaStream_t stream = 0);

} // namespace kernels

TRTLLM_NAMESPACE_END
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
Expand Down Expand Up @@ -88,6 +88,7 @@ add_library(
fp8PerTensorScaleMoe.cpp
fp4BlockScaleMoe.cpp
noAuxTcOp.cpp
fusedCatFp8Op.cpp
IndexerKCacheScatterOp.cpp
IndexerTopKOp.cpp
ncclCommunicatorOp.cpp
Expand Down
87 changes: 87 additions & 0 deletions cpp/tensorrt_llm/thop/fusedCatFp8Op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/kernels/fusedCatFp8.h"
#include "tensorrt_llm/thop/thUtils.h"

#include <ATen/cuda/EmptyTensor.h>

TRTLLM_NAMESPACE_BEGIN

namespace torch_ext
{

std::tuple<at::Tensor, at::Tensor> fused_cat_fp8(at::Tensor const& pe, at::Tensor const& nope, bool use_ue8m0)
{
CHECK_TH_CUDA(pe);
CHECK_TH_CUDA(nope);

TORCH_CHECK(pe.scalar_type() == at::ScalarType::BFloat16, "pe must be BF16, got ", pe.scalar_type());
TORCH_CHECK(nope.scalar_type() == at::ScalarType::BFloat16, "nope must be BF16, got ", nope.scalar_type());
TORCH_CHECK(pe.dim() >= 2, "pe must be >= 2D, got ", pe.dim(), "D");
TORCH_CHECK(nope.dim() >= 2, "nope must be >= 2D, got ", nope.dim(), "D");

// Innermost dimension must be contiguous for vectorized loads.
TORCH_CHECK(pe.stride(-1) == 1, "pe must have contiguous innermost dim (stride(-1)==1), got ", pe.stride(-1));
TORCH_CHECK(nope.stride(-1) == 1, "nope must have contiguous innermost dim (stride(-1)==1), got ", nope.stride(-1));

auto const pe_dim = static_cast<int32_t>(pe.size(-1));
auto const nope_dim = static_cast<int32_t>(nope.size(-1));
auto const head_dim = pe_dim + nope_dim;

TORCH_CHECK(head_dim == 128, "head_dim (pe_dim + nope_dim) must be 128, got ", head_dim);

// M = product of all dimensions except the last (handles 2D, 3D, etc.)
auto const pe_M = pe.numel() / pe_dim;
auto const nope_M = nope.numel() / nope_dim;
TORCH_CHECK(pe_M == nope_M, "pe and nope must have same number of rows. pe: ", pe_M, ", nope: ", nope_M);
auto const M = static_cast<int32_t>(pe_M);

// Extract row strides — stride of the second-to-last dimension.
// For contiguous [M, pe_dim], stride(-2) == pe_dim (same as before).
// For non-contiguous views from split(), stride(-2) may be larger (e.g. head_dim).
auto const pe_row_stride = static_cast<int32_t>(pe.stride(-2));
auto const nope_row_stride = static_cast<int32_t>(nope.stride(-2));

// Allocate output tensors
at::Tensor fp8_out
= at::detail::empty_cuda({M, head_dim}, at::ScalarType::Float8_e4m3fn, pe.device(), /* stride */ std::nullopt);
at::Tensor scale_out
= at::detail::empty_cuda({M, 1}, at::ScalarType::Float, pe.device(), /* stride */ std::nullopt);

auto stream = at::cuda::getCurrentCUDAStream(pe.get_device());

tensorrt_llm::kernels::invokeFusedCatFp8(reinterpret_cast<__nv_fp8_e4m3*>(fp8_out.data_ptr()),
reinterpret_cast<float*>(scale_out.data_ptr()), reinterpret_cast<__nv_bfloat16 const*>(pe.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(nope.data_ptr()), M, pe_dim, nope_dim, head_dim, pe_row_stride,
nope_row_stride, use_ue8m0, stream);

return {fp8_out, scale_out};
}

} // namespace torch_ext

TRTLLM_NAMESPACE_END

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("fused_cat_fp8(Tensor pe, Tensor nope, bool use_ue8m0=False) -> (Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("fused_cat_fp8", &tensorrt_llm::torch_ext::fused_cat_fp8);
}
Loading
Loading