Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
317 changes: 317 additions & 0 deletions .claude/skills/ad-pipeline-failure-pr/SKILL.md

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions .github/workflows/model-registry-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: Model Registry Check

on:
pull_request:
types: [opened, edited, synchronize, reopened]
paths:
- examples/auto_deploy/model_registry/models.yaml

jobs:
validate-model-registry:
name: Validate AutoDeploy Model Registry
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6

- uses: actions/setup-python@v6
with:
python-version: "3.12"
cache: "pip"

- name: Install validator dependency
run: python3 -m pip install PyYAML

- name: Validate model registry
run: python3 scripts/check_model_registry.py
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,14 @@ repos:
files: ".*/auto_deploy/.*"
- repo: local
hooks:
- id: model-registry-check
name: Validate AutoDeploy model registry
entry: python scripts/check_model_registry.py
language: python
additional_dependencies:
- PyYAML
files: ^examples/auto_deploy/model_registry/models\.yaml$
pass_filenames: false
- id: test lists format
name: Check for tabs and multiple spaces in test_lists txt files
entry: ./scripts/format_test_list.py
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/fetch_content.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
{
"name": "cutlass",
"git_repository": "https://github.com/NVIDIA/cutlass",
"git_tag": "v4.3.0",
"git_tag": "v4.4.1",
"git_shallow": true,
"source_subdir": "dont-add-this-project-with-add-subdirectory"
},
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class CacheTransceiver : public BaseCacheTransceiver
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
std::optional<executor::CacheTransceiverConfig> mCacheTransceiverConfig;
std::vector<std::unique_ptr<kv_cache_manager::CacheTransBufferManager>> mCacheTransBufferManagers;
std::vector<kv_cache_manager::CacheTransBufferManager*> mCacheTransBufferManagerPtrs;
std::vector<BaseTransBufferManager*> mCacheTransBufferManagerPtrs;

rnn_state_manager::RnnStateManager* mRnnStateManager{nullptr};
// TODO(shreyasm): update this to use same container as kv by using base trans buffers instead
Expand Down
9 changes: 9 additions & 0 deletions cpp/include/tensorrt_llm/executor/cacheCommunicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "tensorrt_llm/executor/serialization.h"
#include <atomic>
#include <cstdint>
#include <optional>
#include <vector>

namespace tensorrt_llm::executor::kv_cache
Expand Down Expand Up @@ -63,6 +65,13 @@ class Connection
{
return false;
}

virtual void activateBuffer(uint8_t /*kind*/) const {}

[[nodiscard]] virtual std::optional<size_t> getPreAssignedBufferId(uint8_t /*kind*/) const
{
return std::nullopt;
}
};

class ConnectionManager
Expand Down
10 changes: 10 additions & 0 deletions cpp/tensorrt_llm/batch_manager/baseTransBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <atomic>
#include <condition_variable>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <mutex>
#include <optional>
Expand All @@ -38,6 +39,13 @@ class FabricMemory;
namespace tensorrt_llm::batch_manager
{

enum class BufferKind : uint8_t
{
kKV = 0,
kKV_INDEXER = 1,
kRNN = 2
};

/// @brief Base class for cache transfer buffer management.
/// Handles buffer pool allocation, index assignment, and slicing.
/// Derived classes provide cache-specific size calculations.
Expand All @@ -46,6 +54,8 @@ class BaseTransBufferManager
public:
virtual ~BaseTransBufferManager() = default;

[[nodiscard]] virtual BufferKind getBufferKind() const = 0;

/// @brief Assign a buffer index for sending.
/// @return Assigned buffer index, or nullopt if using dynamic buffers.
std::optional<int> assignBufferIndexForSend();
Expand Down
15 changes: 7 additions & 8 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,9 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
"bufferCoverTargetNum:%d pickUpConnections.size():%ld",
bufferTargetNum, targetNum, peerDuplicateHeadFactor, targetInfo.mDupHeadFactor, bufferCoverTargetNum,
pickUpConnections.size());
auto* agentConnnecion
auto const* agentConnection
= dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[pickUpConnections[0]]);
if (agentConnnecion != nullptr)
if (agentConnection != nullptr)
{
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == bufferTargetNum, "Agent need all buffer pre-allocated");
TLLM_CHECK(onlyUseDynamicBuffer == false);
Expand Down Expand Up @@ -792,12 +792,11 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess

TLLM_CHECK(blockNum > 0);

auto* agentConnnecion
= dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[pickUpConnections[0]]);
if (agentConnnecion != nullptr)
auto preAssignedKvId
= connections[pickUpConnections[0]]->getPreAssignedBufferId(static_cast<uint8_t>(BufferKind::kKV));
if (preAssignedKvId.has_value())
{
cacheBufferId = agentConnnecion->getCacheBufferId();
TLLM_CHECK(cacheBufferId.has_value());
cacheBufferId = static_cast<int>(*preAssignedKvId);
}
else
{
Expand All @@ -811,7 +810,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
bufferCoverTargetNum = bufferCoverTargetNumtmp;
remainNoCoverTargetNum = targetNum > bufferCoverTargetNum ? targetNum - bufferCoverTargetNum : 0;

if (agentConnnecion != nullptr)
if (preAssignedKvId.has_value())
{
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == targetNum, "Agent need buffer pre-allocated");
TLLM_CHECK(onlyUseDynamicBuffer == false);
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ CacheTransBufferManager::CacheTransBufferManager(
: cacheManager->getPrimaryPool(0)->getDataType(),
maxNumTokens)
, mCacheManager{cacheManager}
, mTransferIndexerKCache{transferIndexerKCache}
{
// TODO: FP4 dataSize
TLLM_CHECK(mCacheManager);
Expand Down
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,18 @@ class CacheTransBufferManager : public BaseTransBufferManager
return mCacheManager;
}

[[nodiscard]] BufferKind getBufferKind() const override
{
return mTransferIndexerKCache ? BufferKind::kKV_INDEXER : BufferKind::kKV;
}

private:
/// @brief Compute transfer buffer size from KV cache configuration.
static size_t computeTransferBufferSize(KVCacheManager::BaseKVCacheManager* cacheManager,
std::optional<size_t> maxNumTokens, bool transferIndexerKCache);

KVCacheManager::BaseKVCacheManager* mCacheManager;
bool mTransferIndexerKCache;
};

} // namespace tensorrt_llm::batch_manager::kv_cache_manager
42 changes: 26 additions & 16 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,26 +185,13 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
mCacheTransBufferManagers.push_back(
std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens, true));
}
mCacheTransBufferManagerPtrs.clear();
mCacheTransBufferManagerPtrs.reserve(mCacheTransBufferManagers.size());
for (auto& manager : mCacheTransBufferManagers)
{
mCacheTransBufferManagerPtrs.push_back(manager.get());
}

// RNN specific setup
if (mRnnStateManager != nullptr)
{
TLLM_LOG_DEBUG("Setting up RNN cache transfer components.");
TLLM_CHECK(!rnnLayerNumPerPP.empty());

if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL
|| backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
{
TLLM_LOG_ERROR("RNN cache transfer is not supported for NIXL and MOONCAKE yet");
return;
}

mRnnCacheTransBufferManager
= std::make_unique<rnn_state_manager::RnnCacheTransBufferManager>(mRnnStateManager, maxNumTokens);

Expand All @@ -218,6 +205,17 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
TLLM_LOG_INFO("RNN cache transfer components initialized.");
}

mCacheTransBufferManagerPtrs.clear();
mCacheTransBufferManagerPtrs.reserve(mCacheTransBufferManagers.size() + (mRnnCacheTransBufferManager ? 1 : 0));
for (auto& manager : mCacheTransBufferManagers)
{
mCacheTransBufferManagerPtrs.push_back(manager.get());
}
if (mRnnCacheTransBufferManager)
{
mCacheTransBufferManagerPtrs.push_back(mRnnCacheTransBufferManager.get());
}

if (backendType.value() == executor::CacheTransceiverConfig::BackendType::UCX)
{
std::lock_guard<std::mutex> lock(mDllMutex);
Expand All @@ -239,14 +237,18 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
{
auto rnnState
= mCacheState->hasRnnConfig() ? std::make_optional(mCacheState->getRnnCacheState()) : std::nullopt;
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
mCacheTransBufferManagerPtrs, *mCacheState, "nixl", rnnState);
TLLM_LOG_INFO("NIXL Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
{
auto rnnState
= mCacheState->hasRnnConfig() ? std::make_optional(mCacheState->getRnnCacheState()) : std::nullopt;
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake", rnnState);
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
Expand All @@ -261,7 +263,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
}

auto makeFormatter = [cacheManager, isMLA, this]()
{ return createCacheFormatter(cacheManager, mCacheTransBufferManagerPtrs, isMLA); };
{
std::vector<kv_cache_manager::CacheTransBufferManager*> kvBufferPtrs;
kvBufferPtrs.reserve(mCacheTransBufferManagers.size());
for (auto& mgr : mCacheTransBufferManagers)
{
kvBufferPtrs.push_back(mgr.get());
}
return createCacheFormatter(cacheManager, kvBufferPtrs, isMLA);
};

auto makeRnnFormatter = [this]() -> std::unique_ptr<RnnCacheFormatter>
{
Expand Down
8 changes: 8 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheTransferLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "tensorrt_llm/batch_manager/rnnCacheFormatter.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"

#include <algorithm>
Expand Down Expand Up @@ -95,6 +96,13 @@ void CacheTransferLayer::format(TransferSession& session) const
mKvFormatter->format(session);
if (mRnnFormatter)
{
for (auto const* conn : session.getConnections())
{
if (conn != nullptr)
{
conn->activateBuffer(static_cast<uint8_t>(BufferKind::kRNN));
}
}
mRnnFormatter->format(session);
}
}
Expand Down
Loading
Loading