Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,17 @@ if(ENABLE_UCX)
if(NOT ${ucx_FOUND})
set(ENABLE_UCX 0)
else()
if(DEFINED ENV{GITHUB_MIRROR} AND NOT "$ENV{GITHUB_MIRROR}" STREQUAL "")
if(EXISTS "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake")
file(READ "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake" FILE_CONTENTS)
string(
REPLACE "https://raw.githubusercontent.com/rapidsai/rapids-cmake"
"$ENV{GITHUB_MIRROR}/rapidsai/rapids-cmake/raw/refs/heads"
FILE_CONTENTS "${FILE_CONTENTS}")
file(WRITE "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake" "${FILE_CONTENTS}")
message(WARNING "Replace UCXX fetch_rapids.cmake with internal mirror")
endif()
endif()
# installing ucxx via add_subdirectory results in strange cudart linking
# error, thus using their installation script to isolate the installation
# process until the issue is understood. And always trigger the build so
Expand Down
26 changes: 0 additions & 26 deletions cpp/include/tensorrt_llm/common/assert.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,8 @@

#pragma once

#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/tllmException.h"

#include <string>

namespace tensorrt_llm::common
{
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, char const* info)
{
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info).c_str());
}

[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
{
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()).c_str());
}

} // namespace tensorrt_llm::common

class DebugConfig
{
public:
Expand Down Expand Up @@ -86,12 +69,3 @@ class DebugConfig
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__).c_str()); \
} \
} while (0)

#define TLLM_THROW(...) \
do \
{ \
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
} while (0)

#define TLLM_WRAP(ex) \
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())
21 changes: 21 additions & 0 deletions cpp/include/tensorrt_llm/common/tllmException.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@

#pragma once

#include "tensorrt_llm/common/stringUtils.h"

#include <array>
#include <cstddef>
#include <stdexcept>
#include <string>

#define TLLM_THROW(...) \
do \
{ \
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
} while (0)

#define TLLM_WRAP(ex) \
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())

#define NEW_TLLM_EXCEPTION(...) \
tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__).c_str())

Expand All @@ -45,4 +56,14 @@ class TllmException : public std::runtime_error
int mNbFrames;
};

[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, char const* info)
{
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info).c_str());
}

[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
{
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()).c_str());
}

} // namespace tensorrt_llm::common
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/common/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
namespace tensorrt_llm::common
{

std::uintptr_t constexpr kCudaMemAlign = 128;
// CuBLAS >= 12.9.1 requires 256-byte alignment.
std::uintptr_t constexpr kCudaMemAlign = 256;

inline int8_t* alignPtr(int8_t* ptr, uintptr_t to)
{
Expand Down
38 changes: 27 additions & 11 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class TllmGenFmhaKernel
int headDimPerCtaV, int headDimQk, int headDimV, int tileSizeKv, int numTokensPerPage,
int maxNumHeadsQPerKvInCta, bool reuseSmemKForV, bool uses2CtaMma) const
{
TLLM_CHECK_WITH_INFO((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && (headDimPerCtaV <= 2048)
&& (headDimQk <= 2048) && (headDimV <= 2048) && (numTokensPerPage <= 128),
"Expect (32 <= headDim <= 2048) && (numTokensPerPage <= 128), got headDimPerCtaV=%d, headDimQk=%d, "
TLLM_CHECK_WITH_INFO((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && (headDimPerCtaV <= 1024)
&& (headDimQk <= 1024) && (headDimV <= 1024) && (numTokensPerPage <= 128),
"Expect (32 <= headDim <= 1024) && (numTokensPerPage <= 128), got headDimPerCtaV=%d, headDimQk=%d, "
"headDimV=%d, numTokensPerPage=%d",
headDimPerCtaV, headDimQk, headDimV, numTokensPerPage);
TLLM_CHECK_WITH_INFO(maxNumHeadsQPerKvInCta <= 128, "The maxNumHeadsQPerKvInCta <= 128 is required.");
Expand All @@ -115,19 +115,19 @@ class TllmGenFmhaKernel
// Bit 8 - 11: kernelType.
// Bit 12 - 15: tileScheduler.
// Bit 16 - 17: multiCtasKvMode.
// Bit 18 - 24: (headDimPerCtaV >> 5).
// Bit 25 - 31: (headDimQk >> 5).
// Bit 32 - 38: (headDimV >> 5).
// Bit 39 - 40: (tileSizeKv >> 6).
// Bit 41 - 48: numTokensPerPage.
// Bit 18 - 25: (headDimPerCtaV >> 3).
// Bit 26 - 33: (headDimQk >> 3).
// Bit 34 - 41: (headDimV >> 3).
// Bit 42 - 43: (tileSizeKv >> 6).
// Bit 44 - 48: (numTokensPerPage >> 3).
// Bit 49 - 56: maxNumHeadsQPerKvInCta.
// Bit 57 - 57: reuseSmemKForV.
// Bit 58 - 58: uses2CtaMma.
return (static_cast<uint64_t>(qkvLayout) << 0) | (static_cast<uint64_t>(maskType) << 4)
| (static_cast<uint64_t>(kernelType) << 8) | (static_cast<uint64_t>(scheduler) << 12)
| (static_cast<uint64_t>(multiCtasKvMode) << 16) | (static_cast<uint64_t>(headDimPerCtaV >> 5) << 18)
| (static_cast<uint64_t>(headDimQk >> 5) << 25) | (static_cast<uint64_t>(headDimV >> 5) << 32)
| (static_cast<uint64_t>(tileSizeKv >> 6) << 39) | (static_cast<uint64_t>(numTokensPerPage) << 41)
| (static_cast<uint64_t>(multiCtasKvMode) << 16) | (static_cast<uint64_t>(headDimPerCtaV >> 3) << 18)
| (static_cast<uint64_t>(headDimQk >> 3) << 26) | (static_cast<uint64_t>(headDimV >> 3) << 34)
| (static_cast<uint64_t>(tileSizeKv >> 6) << 42) | (static_cast<uint64_t>(numTokensPerPage >> 3) << 44)
| (static_cast<uint64_t>(maxNumHeadsQPerKvInCta) << 49) | (static_cast<uint64_t>(reuseSmemKForV) << 57)
| (static_cast<uint64_t>(uses2CtaMma) << 58);
}
Expand All @@ -142,6 +142,17 @@ class TllmGenFmhaKernel

std::pair<bool, std::string> checkIfKernelExist(RunnerParams const& params) const
{
// Some conditions to check if the kernel is supported.
// This is meant to avoid occupying unnecessary hashId bits.
if (params.mHeadDimQk % 8 != 0 || params.mHeadDimV % 8 != 0)
{
return std::make_pair(false, "HeadDimQk and HeadDimV must be divisible by 8");
}
if (params.mNumTokensPerPage % 8 != 0)
{
return std::make_pair(false, "NumTokensPerPage must be divisible by 8");
}

// The selectKernelParams that might be updated.
SelectKernelParams selectKernelParams{params};
auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams);
Expand Down Expand Up @@ -347,6 +358,11 @@ class TllmGenFmhaKernel
selectKernelParams.mTileScheduler = TileScheduler::Persistent;
// Need to select a different kernel.
selectKernelParams.mSelectNewKernel = true;
// FIXME(perkz): use static scheduler instead as WAR for https://nvbugspro.nvidia.com/bug/5394685.
if (selectKernelParams.mUses2CtaMma)
{
selectKernelParams.mTileScheduler = TileScheduler::Static;
}
}
else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params)
&& selectKernelParams.mTileSizeKv == 128 && tensorrt_llm::common::getEnvUseTileSizeKv64ForTrtllmGen())
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/unit_tests/runtime/decodingLayerWorkspaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "tensorrt_llm/runtime/decodingLayerWorkspace.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/workspace.h"
#include <gtest/gtest.h>
#include <random>

Expand Down Expand Up @@ -171,7 +172,7 @@ TEST_P(MirrorInWorkspaceTest, TestMirrorInWorkspaceFunctionality)
requiredWorkspaceSize)
<< "The calculated workspace size cannot possibly be enough to contain all the tensors.";

constexpr std::size_t addressAlignment = 128;
constexpr std::size_t addressAlignment = tensorrt_llm::common::kCudaMemAlign;
constexpr std::size_t numTensors = 3;
constexpr std::size_t maxAlignmentOverhead = numTensors * addressAlignment;
ASSERT_GE(hostTensor1->getSizeInBytes() + hostTensor2->getSizeInBytes() + hostTensor3->getSizeInBytes()
Expand Down
10 changes: 7 additions & 3 deletions docker/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Default base image for the docker build as defined in Dockerfile.multi
BASE_IMAGE ?= $(shell grep '^ARG BASE_IMAGE=' Dockerfile.multi | grep -o '=.*' | tr -d '="')
BASE_TAG ?= $(shell grep '^ARG BASE_TAG=' Dockerfile.multi | grep -o '=.*' | tr -d '="')
TRITON_IMAGE ?= $(shell grep '^ARG TRITON_IMAGE=' Dockerfile.multi | grep -o '=.*' | tr -d '="')
TRITON_BASE_TAG ?= $(shell grep '^ARG TRITON_BASE_TAG=' Dockerfile.multi | grep -o '=.*' | tr -d '="')
# Name of the new image
IMAGE_NAME ?= tensorrt_llm
IMAGE_TAG ?= latest
Expand Down Expand Up @@ -80,6 +82,8 @@ endef
--progress $(DOCKER_PROGRESS) \
$(if $(BASE_IMAGE), --build-arg BASE_IMAGE=$(BASE_IMAGE)) \
$(if $(BASE_TAG), --build-arg BASE_TAG=$(BASE_TAG)) \
$(if $(TRITON_IMAGE), --build-arg TRITON_IMAGE=$(TRITON_IMAGE)) \
$(if $(TRITON_BASE_TAG), --build-arg TRITON_BASE_TAG=$(TRITON_BASE_TAG)) \
$(if $(BUILD_WHEEL_ARGS), --build-arg BUILD_WHEEL_ARGS="$(BUILD_WHEEL_ARGS)") \
$(if $(BUILD_WHEEL_SCRIPT), --build-arg BUILD_WHEEL_SCRIPT="$(BUILD_WHEEL_SCRIPT)") \
$(if $(TORCH_INSTALL_TYPE), --build-arg TORCH_INSTALL_TYPE="$(TORCH_INSTALL_TYPE)") \
Expand Down Expand Up @@ -187,16 +191,16 @@ jenkins-aarch64_%: STAGE = tritondevel
jenkins-rockylinux8_%: PYTHON_VERSION_TAG_ID = $(if $(findstring 3.12,${PYTHON_VERSION}),PY312,$(if $(findstring 3.10,${PYTHON_VERSION}),PY310,$(error Unknown PYTHON_VERSION specified)))
jenkins-rockylinux8_%: IMAGE_WITH_TAG = $(shell . ../jenkins/current_image_tags.properties && echo $$LLM_ROCKYLINUX8_${PYTHON_VERSION_TAG_ID}_DOCKER_IMAGE)
jenkins-rockylinux8_%: STAGE = tritondevel
jenkins-rockylinux8_%: BASE_IMAGE = nvidia/cuda
jenkins-rockylinux8_%: BASE_IMAGE = nvcr.io/nvidia/cuda
jenkins-rockylinux8_%: BASE_TAG = 12.9.1-devel-rockylinux8

rockylinux8_%: STAGE = tritondevel
rockylinux8_%: BASE_IMAGE = nvidia/cuda
rockylinux8_%: BASE_IMAGE = nvcr.io/nvidia/cuda
rockylinux8_%: BASE_TAG = 12.9.1-devel-rockylinux8

# For x86_64 and aarch64
ubuntu22_%: STAGE = tritondevel
ubuntu22_%: BASE_IMAGE = nvidia/cuda
ubuntu22_%: BASE_IMAGE = nvcr.io/nvidia/cuda
ubuntu22_%: BASE_TAG = 12.9.1-devel-ubuntu22.04

trtllm_%: STAGE = release
Expand Down
4 changes: 1 addition & 3 deletions docker/common/install_tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ CUDNN_VER="9.10.2.21-1"
# NGC PyTorch 25.06 image uses NCCL 2.27.3, while NCCL 2.27.5 resolves a perf regression issue.
# Use NCCL version 2.27.5 instead.
NCCL_VER="2.27.5-1+cuda12.9"
# NGC PyTorch 25.06 image uses cuBLAS 12.9.1.4, but which leads to failures with MoE Lora (see https://nvbugs/5376270).
# Continue using cuBLAS 12.9.0.13 until this issue is resolved.
CUBLAS_VER="12.9.0.13-1"
CUBLAS_VER="12.9.1.4-1"
# Align with the pre-installed CUDA / NVCC / NVRTC versions from
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
NVRTC_VER="12.9.86-1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@ cuda_graph_config:
max_batch_size: 1024
kv_cache_config:
dtype: fp8
use_torch_sampler: true
EOF
```

> Here `use_torch_sampler: true` is added as a temporary WAR to solve illegal memory access issue when using trtllm native sampler.
>
> TODO: Remove this after the issue is resolved

### Launch the TRT-LLM Server

Below is an example command to launch the TRT-LLM server with the Llama-4-Scout-17B-16E-Instruct-FP8 model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section.
Expand Down
12 changes: 12 additions & 0 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def add_llm_args(parser):
default=False,
action='store_true',
help='Use piecewise CUDA graph to optimize the model')
parser.add_argument('--apply_chat_template',
default=False,
action='store_true')

# Sampling
parser.add_argument("--max_tokens", type=int, default=64)
Expand Down Expand Up @@ -273,6 +276,15 @@ def main():
prompts = args.prompt if args.prompt else example_prompts

llm, sampling_params = setup_llm(args)
new_prompts = []
if args.apply_chat_template:
for prompt in prompts:
messages = [{"role": "user", "content": f"{prompt}"}]
new_prompts.append(
llm.tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True))
prompts = new_prompts
outputs = llm.generate(prompts, sampling_params)

for i, output in enumerate(outputs):
Expand Down
16 changes: 10 additions & 6 deletions examples/models/core/gpt_oss/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ GPT-OSS is a reasoning model with MoE weights quantized with mxfp4. All the othe

## MoE Support Matrix

In MoE, the weights are pre-quantized to mxfp4. The activation can be in either bf16 (Hopper) or mxfp8 (Blackwell), with similar accuracy.

| device | Activation | Weight | Supported moe_backend |
|----------|----------|----------|----------|
| Hopper | bf16 | mxfp4 | **TRITON**, CUTLASS |
| Blackwell | mxfp8 | mxfp4 | CUTLASS, TRTLLM |
In MoE, the weights are pre-quantized to mxfp4. The activation can be in either bf16 (Hopper) or mxfp8 (Blackwell), with similar accuracy. FP8 activation with per-tensor scaling factor has limited support. Note that the per-tensor scaling factor needs to be calculated dynamically during inference with the official mxfp4 checkpoints, which may negatively impact perf. The configs in **bold** are the recommended configs for the official checkpoints.

| device | Activation | Weight | Supported moe_backend | MMA|
|----------|----------|----------|----------|----------|
| Hopper | **bf16** | mxfp4 | **TRITON**, CUTLASS | simulated mxfp4, HGMMA |
| Hopper | fp8 | mxfp4 | CUTLASS (not enabled) | simulated mxfp4, QGMMA |
| Blackwell | **mxfp8** | mxfp4 | **CUTLASS, TRTLLM** | UTCQMMA |
| Blackwell | fp8 | mxfp4 | CUTLASS, TRTLLM | UTCQMMA |
| Blackwell | fp8 | mxfp4 | TRITON (experimental) | NA |
| Blackwell | bf16 | mxfp4 | TRTLLM | simulated mxfp4, UTCHMMA |


| moe_backend | TP | EP | AlltoAll |
Expand Down
Loading
Loading