Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,9 @@ class KVCacheBlockPool

// When true, pool tensor is laid out as {numLayers, numBlocks, kvFactor, blockSize}
// instead of the default {numBlocks, numLayers, kvFactor, blockSize}.
// Used for recurrent state (linear attention) pools.
// This path is currently disabled (always false) but the supporting code
// is kept in place so layer-first layout can be re-enabled for recurrent
// state (linear attention) pools in the future. See allocatePools().
bool layerFirstLayout;

KVCacheBlockPool(SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead,
Expand Down
18 changes: 14 additions & 4 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ class BlockRange
int32_t indexFromEnd, SizeType32 windowSize)
{
int32_t const numBlocksToCollect = indexFromEnd + 1;
TLLM_CHECK_WITH_INFO(numBlocksToCollect >= 0,
"fromReuseTree: numBlocksToCollect=%d (indexFromEnd=%d) would underflow reserve() to a huge size_t",
numBlocksToCollect, indexFromEnd);

std::vector<SizeType32> blockIds;
blockIds.reserve(numBlocksToCollect);
blockIds.reserve(static_cast<size_t>(numBlocksToCollect));
for (int32_t i = 0; i < numBlocksToCollect; ++i)
{
TLLM_CHECK_WITH_INFO(
Expand Down Expand Up @@ -316,9 +319,16 @@ class BlockIterator
{
BlockPtr const& block = mRange->mCacheManager->getBlockManager().getBlockById(
mRange->mBlockIds.at(mIdx), mRange->mWindowSize);
TLLM_CHECK_WITH_INFO(block->isPrimary(), "cache transceiver only supports primary blocks");
auto const blockOffset = block->getMemoryPoolBlockIndex();
mCurrent = runtime::ITensor::slice(mRange->mPool, blockOffset, 1);
if (block->isPlaceholder())
{
mCurrent = nullptr;
}
else
{
TLLM_CHECK_WITH_INFO(block->isPrimary(), "cache transceiver only supports primary blocks");
auto const blockOffset = block->getMemoryPoolBlockIndex();
mCurrent = runtime::ITensor::slice(mRange->mPool, blockOffset, 1);
}
}
else
{
Expand Down
13 changes: 12 additions & 1 deletion cpp/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,18 @@ inline bool doCheckError(cudaStream_t stream)

inline void syncAndCheck(cudaStream_t stream, char const* const file, int const line)
{
if (doCheckError(stream))
bool doCheck = false;
try
{
doCheck = doCheckError(stream);
}
catch (TllmException& e)
{
TLLM_THROW(
fmtstr("[TensorRT-LLM][ERROR] Failed to determine if CUDA stream is capturing. Original caller: %s:%d.",
file, line));
}
if (doCheck)
{
cudaStreamSynchronize(stream);
check(cudaGetLastError(), "cudaGetLastError", file, line);
Expand Down
1,061 changes: 606 additions & 455 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
BlockKey const& lastBlockKey, SizeType32 indexFromEnd, bool recvSideHasCP = false, SizeType32 ppSize = 1);

// Emit an INFO-level log line listing the block ids in the given BlockRange. Placeholder blocks
// (negative ids) are surfaced with a [P] tag so sender/receiver logs can be visually compared
// to confirm the placeholder layout is symmetric across executors.
void logBlockIds(char const* direction, LlmRequest const& llmRequest, SizeType32 selfIdx, BlockRange const& blockRange);

using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
using Connection = tensorrt_llm::executor::kv_cache::Connection;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
Expand Down
19 changes: 19 additions & 0 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <future>
#include <map>
#include <memory>
#include <sstream>
#include <unordered_map>

namespace tensorrt_llm::batch_manager
Expand Down Expand Up @@ -841,6 +842,24 @@ class CacheReceiver::Impl
int32_t requestedBlockSize = requestedBlockRange.getBlockIdsPerWindow().begin()->second.size();
TLLM_CHECK_WITH_INFO(requestedBlockSize > 0, "requestedBlockSize must be > 0");
int32_t indexFromEnd = requestedBlockSize - 1;
{
// Log the full requested block id list (placeholders have negative ids) so sender
// and receiver logs can be compared side-by-side to confirm layout symmetry.
std::ostringstream idsOss;
idsOss << '[';
auto const& idList = requestedBlockRange.getBlockIdsPerWindow().begin()->second;
for (size_t k = 0; k < idList.size(); ++k)
{
if (k > 0)
{
idsOss << ", ";
}
idsOss << idList[k];
}
idsOss << ']';
TLLM_LOG_INFO("[disagg-sendRequestInfo] req=%lu requestedBlockSize=%d indexFromEnd=%d blockIds=%s",
requestId, requestedBlockSize, indexFromEnd, idsOss.str().c_str());
}

requestInfo = RequestInfo(requestId, mSelfState, indexFromEnd, lastBlockKey);
}
Expand Down
12 changes: 8 additions & 4 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,10 +1015,13 @@ void WindowBlockManager::allocatePools(bool useUvm)
poolDtype = nvinfer1::DataType::kUINT8;
}

nvinfer1::Dims cacheShape = isRecurrentState()
// Linear-attention (recurrent state) pools could optionally use a
// layer-first layout. That path is currently disabled; toggle the
// flag below (e.g. back to `isRecurrentState()`) to re-enable it.
pool.layerFirstLayout = false;
nvinfer1::Dims cacheShape = pool.layerFirstLayout
? ITensor::makeShape({pool.numLayers, mNumPrimaryBlocks, mKVFactor, blockSize})
: ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
pool.layerFirstLayout = isRecurrentState();

TLLM_LOG_DEBUG(
"[%s] Allocating primary pool with %d blocks for %d layers with %d kv heads, shape={%d, %d, %d, %d}%s",
Expand All @@ -1031,7 +1034,7 @@ void WindowBlockManager::allocatePools(bool useUvm)
pool.primaryPtr = mBufferManager.gpuSync(cacheShape, poolDtype);
if (mNumSecondaryBlocks > 0)
{
nvinfer1::Dims cacheShapeOffload = isRecurrentState()
nvinfer1::Dims cacheShapeOffload = pool.layerFirstLayout
? ITensor::makeShape({pool.numLayers, mNumSecondaryBlocks, mKVFactor, blockSize})
: ITensor::makeShape({mNumSecondaryBlocks, pool.numLayers, mKVFactor, blockSize});
TLLM_LOG_DEBUG("[%s] Allocating secondary pool with %d blocks for %d layers with %d kv heads",
Expand Down Expand Up @@ -1203,7 +1206,8 @@ void WindowBlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims
if (pool.layerFirstLayout)
{
// Layer-first layout: {numLayers, numBlocks, kvFactor, blockSize}
// Flat index: layerIdx * numBlocks * kvFactor + blockIdx * kvFactor + fieldIdx
// Flat index: layerIdx * numBlocks * kvFactor + blockIdx * kvFactor + fieldIdx.
// Currently dead because allocatePools forces layerFirstLayout = false.
return tk::KVCacheIndex{common::flat_index3(
layerIdx, block->getMemoryPoolBlockIndex(), fieldIdx, mNumPrimaryBlocks, mKVFactor)};
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
auto const& pool = pools[poolIdx];

// For layer-first layout pools, block data is non-contiguous across layers.
// Copy each layer's block data separately.
// Copy each layer's block data separately. Currently dead because
// allocatePools forces layerFirstLayout = false.
if (pool.layerFirstLayout)
{
auto srcPool = src->isPrimary() ? pool.primaryPtr : pool.secondaryPtr;
Expand Down
19 changes: 19 additions & 0 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "mlaCacheFormatter.h"
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"

#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/dataType.h"
Expand Down Expand Up @@ -162,6 +163,16 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
auto const ppSize = selfConfig.getParallelConfig().mPipelineParallelism;
auto blockRange
= getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd, recvSideHasCP, ppSize);
// logBlockIds("send", llmRequest, selfIdx, blockRange);
// MLA does not currently combine with Mamba / linear-attention, so placeholder blocks
// (negative ids) must not appear here. Reject explicitly rather than silently mishandling.
for (auto const& [windowSize, blockIds] : blockRange.getBlockIdsPerWindow())
{
for (auto const id : blockIds)
{
TLLM_CHECK_WITH_INFO(id >= 0, "MLACacheFormatter does not support placeholder blocks");
}
}
auto const& windowSizes = blockRange.getWindowSizes();
TLLM_CHECK_WITH_INFO(
static_cast<int>(windowSizes.size()) == numPools, "window sizes should be the same as numPools");
Expand Down Expand Up @@ -416,6 +427,14 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
auto const srcPpSize = destConfig.getParallelConfig().mPipelineParallelism;
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse(),
destConfig.getEnablePartialReuse(), recvSideHasCP, srcPpSize);
// logBlockIds("recv", llmRequest, selfIdx, blockRange);
for (auto const& [windowSize, blockIds] : blockRange.getBlockIdsPerWindow())
{
for (auto const id : blockIds)
{
TLLM_CHECK_WITH_INFO(id >= 0, "MLACacheFormatter does not support placeholder blocks");
}
}
auto const numPools = mCacheManager->getBlockManager().getNumPools(
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
auto const& windowSizes = blockRange.getWindowSizes();
Expand Down
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
? false
: reinterpret_cast<bool*>(params.has_initial_state_ptr)[batch_id];

int cache_index;
int64_t cache_index;
if constexpr (kHasConvStateIndices)
{
cache_index = reinterpret_cast<int*>(params.cache_indices_ptr)[batch_id];
Expand Down Expand Up @@ -379,10 +379,11 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kerne
if (channel_id >= params.dim)
return;

assert(batch_id < params.batch);
input_t* x
= reinterpret_cast<input_t*>(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride;

int conv_state_batch_coord;
int64_t conv_state_batch_coord;
if constexpr (kHasConvStateIndices)
{
conv_state_batch_coord = params.conv_state_indices_ptr[batch_id];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
if (self.isPoolLayerFirst(layer_idx))
{
// Layer-first layout: pool[pool_layer_idx, :]
// Layer-first layout: pool[pool_layer_idx, :].
// Currently dead because allocatePools forces layerFirstLayout = false.
return pool.index({pool_layer_idx});
}
// Standard layout: pool[:, pool_layer_idx]
Expand Down
5 changes: 4 additions & 1 deletion cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7637,7 +7637,10 @@ void testKVCacheManagerLinearAttention_BlockCopying(

char* poolBaseAddr
= reinterpret_cast<char*>(kvCacheManager.getBlockManager().getRecurrentStatesPool().primaryPtr->data());
// memory layout of the pool: [numLayers, blocksInPrimaryPool, 1 (kvFactor), sizePerBlock]
// memory layout of the pool: [blocksInPrimaryPool, numLayers, 1 (kvFactor), sizePerBlock]
// (layer-first layout is currently disabled; see allocatePools()).
// setOffsets writes blockIdx * numLayers, and allRecurrentStatesBytes is the per-layer stride,
// so (blockOffset * allRecurrentStatesBytes) points to the start of layer 0 of block blockIdx.
size_t const strideBlockId = linearAttentionMetadata.allRecurrentStatesBytes;
std::unique_ptr<char[]> hostBuffer(new char[strideBlockId]);

Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def plan(
self.kv_cache_block_offsets = kv_cache_block_offsets
self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers
self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping
print(f"host_kv_cache_pool_mapping: {host_kv_cache_pool_mapping}")
self.workspace = workspace
self.cache_indirection = cache_indirection
self.kv_scale_orig_quant = kv_scale_orig_quant if kv_scales_sf_inv is None else kv_scales_sf_inv
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,9 @@ def get_num_attention_layers(
use_disagg = is_disagg or os.environ.get('TRTLLM_USE_CPP_MAMBA',
'0') == '1'
use_reuse = kv_cache_config is not None and kv_cache_config.enable_block_reuse
use_spec = spec_config is not None
spec_config is not None

use_v1_mamba_manager = use_disagg or use_spec
use_v1_mamba_manager = False
if is_hybrid_linear(
self.pretrained_config) and use_v1_mamba_manager and use_reuse:
logger.warning(
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
from .guided_decoder import GuidedDecoder
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .llm_request import ExecutorResponse
from .mamba_cache_manager import BaseMambaCacheManager, MambaHybridCacheManager
from .mamba_cache_manager import (MambaHybridCacheManager,
MixedMambaHybridCacheManager)
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
from .resource_manager import (KVCacheManager, KVCacheManagerV2,
Expand Down Expand Up @@ -1405,7 +1406,7 @@ def create_py_executor_instance(

# For hybrid models, this has both impl and mamba_impl
mamba_cache_manager = None
if isinstance(kv_cache_manager, BaseMambaCacheManager):
if isinstance(kv_cache_manager, MixedMambaHybridCacheManager):
mamba_cache_manager = kv_cache_manager

kv_cache_transceiver = create_kv_cache_transceiver(
Expand Down
Loading
Loading