Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <deque>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>

namespace tensorrt_llm::batch_manager::kv_cache_manager
Expand Down Expand Up @@ -70,6 +71,8 @@ class KVCacheEventManager
// Add an event to mEventQueue
void enqueueEvent(executor::KVCacheEvent&& event);

void flushRemovedEvents(SizeType32 windowSize);

/// @brief Flag to terminate the worker
std::atomic<bool> mRun;
/// @brief Worker thread
Expand Down Expand Up @@ -99,6 +102,8 @@ class KVCacheEventManager
/// @brief An auto-incrementing event id counter
size_t mEventId;

std::unordered_map<SizeType32, std::optional<executor::KVCacheRemovedData>> mLatestRemovedEvents;

/// @brief Attention DP ranks and size
/// If set, we will exchange KV cache events and accumulate on rank 0
std::optional<SizeType32> mAttentionDpRank;
Expand Down
29 changes: 24 additions & 5 deletions cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional
: mRun{true}
, mMaxSize{maxKVEventEntries}
, mEventId{0}
, mLatestRemovedEvents{}
, mAttentionDpRank{attentionDpRank}
, mAttentionDpSize{attentionDpSize}
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
Expand Down Expand Up @@ -92,6 +93,8 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
return;
}

flushRemovedEvents(windowSize);

auto const parentBlock = blocks.front()->getPrevBlock();
auto const parent = (parentBlock != nullptr && parentBlock->getBlockId() >= 0)
? std::optional<size_t>(parentBlock->getHash())
Expand All @@ -110,16 +113,28 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks

void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize)
{
// We can only batch the removed block events if the same sliding window size is used.
if (!mEventQueue.empty() && mEventQueue.back().windowSize == windowSize
&& std::holds_alternative<tle::KVCacheRemovedData>(mEventQueue.back().data))
auto& latestRemovedEvent = mLatestRemovedEvents[windowSize];
if (latestRemovedEvent != std::nullopt)
{
std::get<tle::KVCacheRemovedData>(mEventQueue.back().data).blockHashes.push_back(block->getHash());
latestRemovedEvent->blockHashes.push_back(block->getHash());
}
else
{
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank});
latestRemovedEvent = tle::KVCacheRemovedData{{block->getHash()}};
}
}

void KVCacheEventManager::flushRemovedEvents(SizeType32 windowSize)
{
if (mLatestRemovedEvents.find(windowSize) != mLatestRemovedEvents.end())
{
auto latestRemovedEvent = mLatestRemovedEvents[windowSize];
if (latestRemovedEvent != std::nullopt)
{
enqueueEvent({mEventId++, *latestRemovedEvent, windowSize, mAttentionDpRank});
}
}
mLatestRemovedEvents[windowSize] = std::nullopt;
}

void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize)
Expand Down Expand Up @@ -151,6 +166,10 @@ std::deque<tle::KVCacheEvent> KVCacheEventManager::getEvents(std::optional<std::

void KVCacheEventManager::flush()
{
for (auto const& [windowSize, latestRemovedEvent] : mLatestRemovedEvents)
{
flushRemovedEvents(windowSize);
}
auto eventQueue = std::exchange(mEventQueue, {});

if (eventQueue.empty())
Expand Down
264 changes: 264 additions & 0 deletions cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6188,3 +6188,267 @@ TEST(KVCacheManagerReuseAccountingTest, MultipleRequestsWithSharedPrefix)
auto const remaining = kvCacheManager->getRemainingBlocksToCompletion(req1, onlyWindowSize);
EXPECT_EQ(remaining, (promptLength / tokensPerBlock) + (maxNewTokens / tokensPerBlock));
}

// All remove events for the same window size during a single iteration must be consolidated
// into a single KVCacheRemovedData (not emitted as separate events).
TEST_F(KVCacheManagerTest, KVCacheManagerEventRemovedBatchedWithinWindow)
{
auto constexpr numLayers = 2;
auto constexpr numHeads = 2;
auto constexpr sizePerHead = 16;
auto constexpr tokensPerBlock = 4;
// Tight pool of 4: seq0 and seq1 together use all 4 blocks, leaving none fresh for seq2.
// seq2 therefore must evict tree blocks to obtain its 4 needed blocks.
auto constexpr blocksInPrimaryPool = 4;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr maxNumSequences = 4;
auto constexpr maxAttentionWindow = 32;
auto constexpr beamWidth = 1;
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const stream = std::make_shared<tr::CudaStream>();
SizeType32 constexpr maxNewTokens{0};
tr::SamplingConfig const samplingConfig{beamWidth};
auto constexpr onboardBlocks = true;

auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
beamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, dtype, 0, stream,
maxAttentionWindow, true, onboardBlocks, CacheType::kSELF, std::nullopt,
std::make_unique<tlk::KVCacheEventManager>(1024));
kvCacheManager.allocatePools(false);
(void) getEvents(kvCacheManager);

// Seq0: stores blockA([0,1,2,3]) as a leaf in the radix tree.
auto inputTokens0 = std::make_shared<VecTokens>(VecTokens{0, 1, 2, 3, 4});
auto llmRequest0 = std::make_shared<LlmRequest>(0, maxNewTokens, inputTokens0, samplingConfig, true);
kvCacheManager.addSequence(0, inputTokens0->size(), beamWidth, llmRequest0);
kvCacheManager.storeContextBlocks(*llmRequest0);
(void) kvCacheManager.removeSequence(0, llmRequest0);

// Seq1: stores blockB([10,11,12,13]) as a separate leaf in the radix tree.
auto inputTokens1 = std::make_shared<VecTokens>(VecTokens{10, 11, 12, 13, 14});
auto llmRequest1 = std::make_shared<LlmRequest>(1, maxNewTokens, inputTokens1, samplingConfig, true);
kvCacheManager.addSequence(1, inputTokens1->size(), beamWidth, llmRequest1);
kvCacheManager.storeContextBlocks(*llmRequest1);
(void) kvCacheManager.removeSequence(1, llmRequest1);

(void) getEvents(kvCacheManager); // drain seq0/seq1 stored events

// Seq2 needs 4 blocks (15 tokens) with no radix tree match. All 4 pool blocks are in
// the free queue after seq0 and seq1 released them. Two of those 4 blocks (blockA and
// blockB) are leaves in the radix tree, so each call to freeChildren emits a remove
// event. Both removes accumulate into mLatestRemovedEvents[W] and are committed as
// one consolidated KVCacheRemovedData when flush() is called.
auto inputTokens2 = std::make_shared<VecTokens>(
VecTokens{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114});
auto llmRequest2 = std::make_shared<LlmRequest>(2, maxNewTokens, inputTokens2, samplingConfig, true);
kvCacheManager.addSequence(2, inputTokens2->size(), beamWidth, llmRequest2);

auto events = getEvents(kvCacheManager);

SizeType32 numRemovedEvents = 0;
SizeType32 numTotalRemovedHashes = 0;
for (auto const& event : events)
{
if (std::holds_alternative<tle::KVCacheRemovedData>(event.data))
{
++numRemovedEvents;
numTotalRemovedHashes
+= static_cast<SizeType32>(std::get<tle::KVCacheRemovedData>(event.data).blockHashes.size());
}
}

// blockA and blockB were both evicted from the same window in the same iteration.
// They must appear in exactly one consolidated Removed event, not two separate events.
EXPECT_EQ(numRemovedEvents, 1) << "Expected 1 consolidated Removed event for same-window evictions, got "
<< numRemovedEvents;
EXPECT_EQ(numTotalRemovedHashes, 2) << "Expected 2 hashes in the Removed event (blockA and blockB), got "
<< numTotalRemovedHashes;
}

// When evictions and a store happen for the same window in the same iteration, the Removed
// event must appear before the Stored event. This is the ordering guarantee provided by
// enqueueStoredEvent calling flushRemovedEvents before appending the Stored event.
TEST_F(KVCacheManagerTest, KVCacheManagerEventRemovedOrderedBeforeStore)
{
auto constexpr numLayers = 2;
auto constexpr numHeads = 2;
auto constexpr sizePerHead = 16;
auto constexpr tokensPerBlock = 4;
auto constexpr blocksInPrimaryPool = 8;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr maxNumSequences = 4;
auto constexpr maxAttentionWindow = 32;
auto constexpr beamWidth = 1;
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const stream = std::make_shared<tr::CudaStream>();
SizeType32 constexpr maxNewTokens{0};
tr::SamplingConfig const samplingConfig{beamWidth};
auto constexpr onboardBlocks = true;
tle::RetentionPriority constexpr lowPriority = 0;
tle::RetentionPriority constexpr highPriority = 80;

auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
beamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, dtype, 0, stream,
maxAttentionWindow, true, onboardBlocks, CacheType::kSELF, std::nullopt,
std::make_unique<tlk::KVCacheEventManager>(1024));
kvCacheManager.allocatePools(false);
(void) getEvents(kvCacheManager);

// Seq0: store root → block0(lowPrio) → block1(highPrio) in the radix tree.
auto inputTokens0 = std::make_shared<VecTokens>(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8});
auto llmRequest0 = std::make_shared<LlmRequest>(0, maxNewTokens, inputTokens0, samplingConfig, true);
llmRequest0->setKvCacheRetentionConfig(
KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, 4, lowPriority),
KvCacheRetentionConfig::TokenRangeRetentionConfig(4, std::nullopt, highPriority)},
highPriority));
kvCacheManager.addSequence(0, inputTokens0->size(), beamWidth, llmRequest0);
kvCacheManager.storeContextBlocks(*llmRequest0);
(void) kvCacheManager.removeSequence(0, llmRequest0);
(void) getEvents(kvCacheManager); // drain

// Seq1 with different tokens.
// addSequence: evicts seq0's block0 (and its descendant block1) — removes buffered, not yet emitted.
// storeContextBlocks: calls flushRemovedEvents(W) first, committing the buffered removes,
// then appends the Stored event for seq1's new blocks.
auto inputTokens1 = std::make_shared<VecTokens>(VecTokens{100, 101, 102, 103, 104, 105, 106, 107, 108});
auto llmRequest1 = std::make_shared<LlmRequest>(1, maxNewTokens, inputTokens1, samplingConfig, true);
kvCacheManager.addSequence(1, inputTokens1->size(), beamWidth, llmRequest1);
kvCacheManager.storeContextBlocks(*llmRequest1);

auto events = getEvents(kvCacheManager);

// Find the positions of the first Removed and first Stored events.
std::optional<SizeType32> removedPos;
std::optional<SizeType32> storedPos;
SizeType32 pos = 0;
for (auto const& event : events)
{
if (!removedPos && std::holds_alternative<tle::KVCacheRemovedData>(event.data))
{
removedPos = pos;
}
if (!storedPos && std::holds_alternative<tle::KVCacheStoredData>(event.data))
{
storedPos = pos;
}
++pos;
}

ASSERT_TRUE(removedPos.has_value()) << "Expected at least one Removed event";
ASSERT_TRUE(storedPos.has_value()) << "Expected at least one Stored event";

EXPECT_LT(*removedPos, *storedPos)
<< "Removed event (pos=" << *removedPos << ") must precede Stored event (pos=" << *storedPos
<< ") for the same window. enqueueStoredEvent must flush pending removes before appending the store.";
}

// A store event for window W2 must not flush pending remove events for a different window W1.
// Removes for W1 must only be committed when a store for W1 occurs or when flush() is called.
// This verifies per-window isolation in the lazy-batching remove event logic.
TEST_F(KVCacheManagerTest, KVCacheManagerEventStoreForDifferentWindowDoesNotFlushPendingRemoves)
{
// Two windows: wFull (non-SWA, equal to maxSequenceLength) and wSWA (SWA, smaller).
// storeContextBlocks skips SWA windows, so it only emits a Stored event for wFull.
// This means wSWA removes are never flushed by the wFull store — they stay buffered
// until flush() at end of iteration.
//
// Expected event order: [Removed(wFull), Stored(wFull), Removed(wSWA)]
// Removed(wFull) — flushed by wFull's own storeContextBlocks call
// Stored(wFull) — emitted by storeContextBlocks for wFull
// Removed(wSWA) — only flushed by the iteration-end flush(), AFTER storeContextBlocks
//
// If isolation were broken (wFull store flushes ALL windows' removes), the order
// would be [Removed(wSWA), Removed(wFull), Stored(wFull)] — Stored(wFull) would
// appear after Removed(wSWA), violating the per-window ordering guarantee.
auto constexpr numLayers = 2;
auto constexpr numHeads = 2;
auto constexpr sizePerHead = 16;
auto constexpr tokensPerBlock = 4;
// Tight pool: seq0 uses 3 out of 4 blocks, leaving only 1 fresh block. seq1 therefore
// has to evict seq0's cached tree blocks to obtain the 3 it needs.
auto constexpr blocksInPrimaryPool = 4;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr maxNumSequences = 4;
auto constexpr beamWidth = 1;
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const stream = std::make_shared<tr::CudaStream>();
SizeType32 constexpr maxNewTokens{0};
tr::SamplingConfig const samplingConfig{beamWidth};
auto constexpr onboardBlocks = true;

auto constexpr wSWA = tokensPerBlock * 2; // 8 tokens — SWA (< maxSequenceLength)
auto constexpr wFull = tokensPerBlock * 4; // 16 tokens — full attention = maxSequenceLength
auto constexpr maxSequenceLength = wFull;

auto const blocksPerWindow = BlocksPerWindow{
{wSWA, {blocksInPrimaryPool, blocksInSecondaryPool}}, {wFull, {blocksInPrimaryPool, blocksInSecondaryPool}}};
KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
beamWidth, std::vector<BlockManager::SizeType32>{wSWA, wFull}, std::nullopt, dtype, 0, stream,
maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt,
std::make_unique<tlk::KVCacheEventManager>(1024));
kvCacheManager.allocatePools(false);
(void) getEvents(kvCacheManager);

// Seq0: 9 tokens → 3 blocks per window. storeContextBlocks stores 2 full blocks in wFull
// (skips wSWA). removeSequence stores 2 full blocks in wSWA as well (releaseBlocks covers
// all windows). After release, each window's free queue is [block3_fresh, block2, block1, block0],
// with block0 and block1 in the respective radix trees.
auto inputTokens0 = std::make_shared<VecTokens>(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8});
auto llmRequest0 = std::make_shared<LlmRequest>(0, maxNewTokens, inputTokens0, samplingConfig, true);
kvCacheManager.addSequence(0, inputTokens0->size(), beamWidth, llmRequest0);
kvCacheManager.storeContextBlocks(*llmRequest0);
(void) kvCacheManager.removeSequence(0, llmRequest0);
(void) getEvents(kvCacheManager); // drain

// Seq1 with different tokens (9 tokens → 3 blocks per window).
// addSequence for each window: gets block3 (fresh, no event), block2 (not in tree, no event),
// then block1 (in tree as leaf) → freeChildren(block1) → Removed(block1) buffered for that window.
// storeContextBlocks:
// wSWA: skipped (SWA) — wSWA removes stay buffered
// wFull: flushRemovedEvents(wFull) → Removed(wFull) committed; Stored(wFull) committed
// flush(): flushRemovedEvents(wSWA) → Removed(wSWA) committed
auto inputTokens1 = std::make_shared<VecTokens>(VecTokens{100, 101, 102, 103, 104, 105, 106, 107, 108});
auto llmRequest1 = std::make_shared<LlmRequest>(1, maxNewTokens, inputTokens1, samplingConfig, true);
kvCacheManager.addSequence(1, inputTokens1->size(), beamWidth, llmRequest1);
kvCacheManager.storeContextBlocks(*llmRequest1);

auto events = getEvents(kvCacheManager);

// Find the position of the first Removed and Stored event for each window.
std::optional<SizeType32> removedSWAPos, storedFullPos, removedFullPos;
SizeType32 pos = 0;
for (auto const& event : events)
{
if (std::holds_alternative<tle::KVCacheRemovedData>(event.data))
{
if (event.windowSize == wSWA && !removedSWAPos)
removedSWAPos = pos;
if (event.windowSize == wFull && !removedFullPos)
removedFullPos = pos;
}
else if (std::holds_alternative<tle::KVCacheStoredData>(event.data))
{
if (event.windowSize == wFull && !storedFullPos)
{
storedFullPos = pos;
}
}
++pos;
}

ASSERT_TRUE(removedSWAPos.has_value()) << "Expected Removed event for wSWA";
ASSERT_TRUE(removedFullPos.has_value()) << "Expected Removed event for wFull";
ASSERT_TRUE(storedFullPos.has_value()) << "Expected Stored event for wFull";

// Within wFull, removes must precede stores.
EXPECT_LT(*removedFullPos, *storedFullPos) << "Removed(wFull) must precede Stored(wFull)";

// The wFull store must NOT have flushed wSWA's pending removes prematurely.
// Correct isolation: Stored(wFull) appears before Removed(wSWA).
// Broken isolation: Removed(wSWA) appears before Stored(wFull).
EXPECT_LT(*storedFullPos, *removedSWAPos)
<< "Stored(wFull) (pos=" << *storedFullPos << ") must precede Removed(wSWA) (pos=" << *removedSWAPos
<< "). The wFull store must not prematurely flush pending removes for wSWA.";
}
4 changes: 4 additions & 0 deletions examples/auto_deploy/nano_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ kv_cache_config:
transforms:
detect_sharding:
allreduce_strategy: SYMM_MEM
# NOTE: add 'tp' to sharding dims only for high-throughput runs
# For low-latency, keep mamba and attention replicated
sharding_dims: ['ep', 'bmm']
# NOTE: sharding_source applies only to TP sharding
sharding_source: ['manual']
manual_config:
head_dim: 128
tp_plan:
Expand Down
6 changes: 6 additions & 0 deletions examples/auto_deploy/super_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ kv_cache_config:
transforms:
detect_sharding:
allreduce_strategy: SYMM_MEM
# NOTE: add 'tp' to sharding dims only for high-throughput runs
# For low-latency, keep mamba and attention replicated
sharding_dims: ['ep', 'bmm']
# NOTE: sharding_source applies only to TP sharding
sharding_source: ['manual']
manual_config:
head_dim: 128
tp_plan:
Expand Down Expand Up @@ -46,3 +50,5 @@ transforms:
enabled: true
insert_cached_ssm_attention:
backend: flashinfer_ssm
fuse_nvfp4_moe:
backend: trtllm_gen
5 changes: 4 additions & 1 deletion examples/visual_gen/serve/benchmark_visual_gen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
#
# Requirements:
# pip install git+https://github.com/huggingface/diffusers.git
# pip install av
#
# Optional (for MP4/H.264 video output):
# apt-get install ffmpeg # or: conda install ffmpeg
# Without ffmpeg, videos are saved as AVI/MJPEG using a pure-Python encoder.

set -euo pipefail

Expand Down
Loading
Loading