Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions constraints.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# These vulnerabilities were inherited from the base image (pytorch:25.12-py3) and should be removed when the base image
# is updated.
# WAR against https://github.com/advisories/GHSA-38jv-5279-wg99
urllib3>=2.6.3
# WAR against https://github.com/advisories/GHSA-8rrh-rw8j-w5fx
wheel>=0.46.2
# WAR against https://github.com/advisories/GHSA-7gcm-g887-7qv7
protobuf>=6.33.5
# WAR against https://github.com/advisories/GHSA-6mq8-rvhq-8wgg
aiohttp>=3.13.3
# WAR against https://github.com/advisories/GHSA-qjxf-f2mg-c6mc
tornado>=6.5.5
# WAR against https://github.com/advisories/GHSA-3936-cmfr-pm3m
black>=26.3.1
13 changes: 13 additions & 0 deletions cpp/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1424,3 +1424,16 @@ TRTLLM_NAMESPACE_END
{ \
tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \
} while (0)

// Warn-only variant: log a warning on failure but do not throw or abort.
// Use for cleanup/secondary operations where a CUDA error is non-fatal (e.g. free on an error path).
#define TLLM_CUDA_CHECK_WARN(stat) \
do \
{ \
cudaError_t const _tllm_cuda_warn_err = (stat); \
if (TLLM_UNLIKELY(_tllm_cuda_warn_err != cudaSuccess)) \
{ \
TLLM_LOG_WARNING( \
"CUDA error in %s (%s:%d): %s", #stat, __FILE__, __LINE__, cudaGetErrorString(_tllm_cuda_warn_err)); \
} \
} while (0)
122 changes: 108 additions & 14 deletions cpp/tensorrt_llm/common/ncclUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,69 @@
#include <limits>
#include <stdexcept>

namespace
{

// RAII guard for cudaMalloc — frees the pointer on destruction, logging a warning on failure.
struct CudaMallocGuard
{
void* ptr{nullptr};

explicit CudaMallocGuard(void* p) noexcept
: ptr(p)
{
}

~CudaMallocGuard()
{
if (ptr)
{
TLLM_CUDA_CHECK_WARN(cudaFree(ptr));
}
}

void* release() noexcept
{
void* p = ptr;
ptr = nullptr;
return p;
}

CudaMallocGuard(CudaMallocGuard const&) = delete;
CudaMallocGuard& operator=(CudaMallocGuard const&) = delete;
};

// RAII guard for ncclMemAlloc — frees the pointer on destruction, logging a warning on failure.
struct NcclMemGuard
{
void* ptr{nullptr};

explicit NcclMemGuard(void* p) noexcept
: ptr(p)
{
}

~NcclMemGuard()
{
if (ptr)
{
TLLM_NCCL_CHECK_WARN(ncclMemFree(ptr));
}
}

void* release() noexcept
{
void* p = ptr;
ptr = nullptr;
return p;
}

NcclMemGuard(NcclMemGuard const&) = delete;
NcclMemGuard& operator=(NcclMemGuard const&) = delete;
};

} // namespace

namespace tensorrt_llm::common::nccl_util
{

Expand Down Expand Up @@ -403,28 +466,59 @@ bool NCCLWindowAllocator::isCommValid(ncclComm_t comm) const noexcept

NCCLWindowBuffer NCCLWindowAllocator::allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle)
{
NCCLWindowBuffer buffer;
buffer.handle = handle;

// Allocate device memory using ncclMemAlloc
ncclResult_t allocResult = ncclMemAlloc(&buffer.ptr, size);
if (allocResult != ncclSuccess)
// Step 1: Allocate symmetric memory (per-rank, non-collective — can fail asymmetrically).
void* ncclPtr = nullptr;
TLLM_NCCL_CHECK_WARN(ncclMemAlloc(&ncclPtr, size));
int const localAllocOk = (ncclPtr != nullptr) ? 1 : 0;
NcclMemGuard ncclGuard{ncclPtr}; // frees ncclPtr on any early return or exception

// Step 2: ncclCommWindowRegister is collective — if any rank skips it, all other ranks hang.
// Synchronize the per-rank alloc status using a small cudaMalloc flag (not ncclMemAlloc, so
// OOM on symmetric memory does not prevent us from allocating the flag).
int* rankSyncFlag = nullptr;
TLLM_CUDA_CHECK(cudaMalloc(&rankSyncFlag, sizeof(int)));
CudaMallocGuard flagGuard{rankSyncFlag}; // frees rankSyncFlag on any early return or exception

// Step 3: Populate flag, reduce with min across ranks (0 if any rank failed), then read back.
// H2D failure is non-fatal: warn and continue — device flag may be stale but the allreduce
// must still be reached by all ranks. allreduce and D2H failures are catastrophic (throw).
auto stream = at::cuda::getCurrentCUDAStream().stream();
TLLM_CUDA_CHECK_WARN(cudaMemcpy(rankSyncFlag, &localAllocOk, sizeof(int), cudaMemcpyHostToDevice));
TLLM_NCCL_CHECK(ncclAllReduce(rankSyncFlag, rankSyncFlag, 1, ncclInt32, ncclMin, comm, stream));
TLLM_CUDA_CHECK_WARN(cudaStreamSynchronize(stream));

int allAllocOk = 0;
TLLM_CUDA_CHECK(cudaMemcpy(&allAllocOk, rankSyncFlag, sizeof(int), cudaMemcpyDeviceToHost));
// flagGuard frees rankSyncFlag here at end of its scope

if (!allAllocOk)
{
TLLM_THROW("ncclMemAlloc failed with error: %d", allocResult);
if (localAllocOk)
{
TLLM_LOG_WARNING(
"[NCCLUtil] ncclMemAlloc failed on at least one other rank; "
"freeing local allocation (size=%zu) and aborting window registration on all ranks.",
size);
}
return NCCLWindowBuffer{}; // ncclGuard frees ncclPtr
}
buffer.size = size;

// Register the buffer with NCCL as a window
ncclResult_t regResult = ncclCommWindowRegister(comm, buffer.ptr, size, &buffer.window, NCCL_WIN_COLL_SYMMETRIC);
// Step 4: Register with NCCL as a window (collective — all ranks must reach this call).
// Failure here is non-fatal: warn and fall back to regular allreduce.
// ncclGuard frees ncclPtr on return.
ncclWindow_t window = nullptr;
ncclResult_t const regResult = ncclCommWindowRegister(comm, ncclPtr, size, &window, NCCL_WIN_COLL_SYMMETRIC);
TLLM_NCCL_CHECK_WARN(regResult);
if (regResult != ncclSuccess)
{
ncclMemFree(buffer.ptr);
TLLM_THROW("ncclCommWindowRegister failed with error: %d", regResult);
return NCCLWindowBuffer{};
}

// Step 5: Success — transfer ownership to the returned buffer.
ncclGuard.release();
NCCLWindowBuffer buffer{ncclPtr, handle, size, window};
TLLM_LOG_TRACE("[NCCLUtil] Allocated and registered NCCL window buffer: handle=%d, ptr=%p, size=%zu, window=%p",
handle, buffer.ptr, size, static_cast<void*>(buffer.window));

handle, buffer.ptr, buffer.size, static_cast<void*>(buffer.window));
return buffer;
}

Expand Down
16 changes: 16 additions & 0 deletions cpp/tensorrt_llm/common/ncclUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"

#if ENABLE_MULTI_DEVICE
#include <ATen/cuda/CUDAContext.h>
Expand All @@ -43,6 +44,21 @@

#if ENABLE_MULTI_DEVICE

// TLLM_NCCL_CHECK (throw on failure) is provided by multiDeviceUtils.h.

// Warn-only variant: log a warning on NCCL failure but do not throw or abort.
// Use for cleanup/secondary operations where an NCCL error is non-fatal (e.g. ncclMemFree on an error path).
#define TLLM_NCCL_CHECK_WARN(cmd) \
do \
{ \
ncclResult_t const _tllm_nccl_warn_r = (cmd); \
if (TLLM_UNLIKELY(_tllm_nccl_warn_r != ncclSuccess)) \
{ \
TLLM_LOG_WARNING( \
"NCCL error in %s (%s:%d): %s", #cmd, __FILE__, __LINE__, ncclGetErrorString(_tllm_nccl_warn_r)); \
} \
} while (0)

TRTLLM_NAMESPACE_BEGIN

namespace common::nccl_util
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -375,6 +375,84 @@ struct SuffixAutomaton
return SAOptional<LookupResult>();
}

/**
* @brief Find the longest suffix of an external token sequence that appears
* as a substring in this SA's text, then return its continuation position.
*
* Uses the standard longest-common-substring algorithm: process suffix tokens
* in forward order through the SA, following suffix links on mismatch.
*
* Time complexity: O(suffixLen) amortized. Each token either advances the
* match or triggers suffix link fallbacks. Since matchedLen increases at most
* suffixLen times and never goes below 0, total suffix link hops is bounded
* by suffixLen.
*
* @param suffix Pointer to the suffix tokens (forward order: oldest to newest)
* @param suffixLen Number of tokens in the suffix
* @return Optional LookupResult with continuation position and match length
*/
SA_CUDA_CALLABLE SAOptional<LookupResult> lookupWithSuffix(Token const* suffix, int suffixLen) const
{
if (mStates.empty() || suffixLen <= 0)
{
return SAOptional<LookupResult>();
}

NodeIndex state = NodeIndex(0);
int matchedLen = 0;

for (int i = 0; i < suffixLen; i++)
{
Token token = suffix[i];

while (state != NodeIndex(0) && mStates.at(state, token) == nullptr)
{
state = *mStates.at(state).link;
matchedLen = mStates.at(state).len;
}

NodeIndex const* nextPtr = mStates.at(state, token);
if (nextPtr != nullptr)
{
state = *nextPtr;
matchedLen++;
}
}

if (matchedLen == 0 || state == NodeIndex(0))
{
return SAOptional<LookupResult>();
}

while (state != NodeIndex(0))
{
auto& nodeData = mStates.at(state);
SAOptional<TextIndex> posOpt = nodeData.pos;

if (posOpt.hasValue())
{
TextIndex pos = *posOpt;
if (+pos + 1 < +mTokens.size())
{
LookupResult result;
result.pos = TextIndex(+pos + 1);
result.len = matchedLen;
return SAOptional<LookupResult>(result);
}
}

auto linkOpt = nodeData.link;
if (!linkOpt.hasValue())
{
break;
}
state = *linkOpt;
matchedLen = mStates.at(state).len;
}

return SAOptional<LookupResult>();
}

SA_CUDA_CALLABLE void getDraftTokens(Token::ValueType* buf, int bufLen, TextIndex startPos) const
{
int availableLen = +mTokens.size() - +startPos;
Expand Down
Loading
Loading