Skip to content

Commit a3735ed

Browse files
authored
Merge branch 'main' into dev-bench-moe
2 parents 06e6568 + 09449d4 commit a3735ed

3 files changed

Lines changed: 119 additions & 22 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ namespace kvc = tensorrt_llm::executor::kv_cache;
2424

2525
#pragma once
2626

27+
namespace tensorrt_llm::testing
28+
{
29+
class KVCacheTransferManagerTestAccess;
30+
} // namespace tensorrt_llm::testing
31+
2732
namespace tensorrt_llm::batch_manager::kv_cache_manager
2833
{
2934

@@ -76,10 +81,15 @@ class KVCacheTransferManager
7681
[[nodiscard]] KvCacheTransferStats getAndResetTransferStats();
7782

7883
private:
84+
friend class ::tensorrt_llm::testing::KVCacheTransferManagerTestAccess;
85+
7986
//! \brief Get pointer to pool specified by cache block.
8087
static tr::ITensor::SharedPtr computeBlockPointer(
8188
BlockPtr const& block, std::vector<KVCacheBlockPool> const& pools, size_t poolIdx);
8289

90+
//! \brief Get pool-qualified index for pending transfer tracking.
91+
[[nodiscard]] static kernels::KVCacheIndex::UnderlyingType getPendingTransferIndex(BlockPtr const& block);
92+
8393
/*!
8494
* \brief The key method that copies the src block to the dst block.
8595
*
@@ -107,8 +117,8 @@ class KVCacheTransferManager
107117
runtime::BufferManager mOnboardManager;
108118
runtime::BufferManager mOffloadManager;
109119

110-
// Track reads and writes for blocks. Note that it is the memory pool index that
111-
// identifies the raw memory blocks involved in I/O, not the block Id.
120+
// Track reads and writes for blocks. Note that it is the pool-qualified memory pool index
121+
// that identifies the raw memory blocks involved in I/O, not the block Id.
112122
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
113123
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
114124
// Reference to parent loopback agent

cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ tr::ITensor::SharedPtr KVCacheTransferManager::computeBlockPointer(
9797
return blockTensor;
9898
}
9999

100+
tk::KVCacheIndex::UnderlyingType KVCacheTransferManager::getPendingTransferIndex(BlockPtr const& block)
101+
{
102+
auto const blockOffset = block->getMemoryPoolBlockIndex();
103+
return block->isPrimary() ? blockOffset : blockOffset | tk::KVCacheIndex::kSecondaryPoolFlag;
104+
}
105+
100106
void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
101107
std::vector<KVCacheBlockPool> const& pools, bool isOffload, int numTokensToCopy, executor::KvCacheTransferMode mode,
102108
std::string const& directory)
@@ -252,10 +258,9 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
252258
}
253259

254260
//
255-
// Note about recording events to wait for cudaMempyAsync calls between blocks:
256-
// The memory copy involves raw memory blocks, which are pointed to by the
257-
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
258-
// as the raw memory block identifier. Using getBlockId() when recording events is wrong.
261+
// Note about recording events to wait for cudaMemcpyAsync calls between blocks:
262+
// The memory copy involves raw memory blocks, which are identified by the pool-qualified
263+
// memory pool block index. Using getBlockId() when recording events is wrong.
259264
// getBlockId() returns the logical block id, which has nothing to do with the raw memory
260265
// block pointers involved in a cudaMemcpy.
261266
//
@@ -289,22 +294,25 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr co
289294
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
290295
std::string const& directory)
291296
{
297+
auto const offloadedBlockIndex = getPendingTransferIndex(offloadedBlock);
298+
auto const blockIndex = getPendingTransferIndex(block);
299+
292300
// Wait for any pending writes before reading from offloadedBlock
293-
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
301+
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlockIndex);
294302
if (offloadedBlockPendingWriteItr != mPendingWrites.end())
295303
{
296304
mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second);
297305
// Don't erase, we are not changing state of offloadedBlock
298306
}
299307
// Wait for any pending reads before overwriting block
300-
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
308+
auto blockPendingReadItr = mPendingReads.find(blockIndex);
301309
if (blockPendingReadItr != mPendingReads.end())
302310
{
303311
mOnboardManager.getStream().wait(blockPendingReadItr->second);
304312
mPendingReads.erase(blockPendingReadItr);
305313
}
306314
// Wait for any pending writes before overwriting block
307-
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
315+
auto blockPendingWriteItr = mPendingWrites.find(blockIndex);
308316
if (blockPendingWriteItr != mPendingWrites.end())
309317
{
310318
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
@@ -330,33 +338,36 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr co
330338
}
331339

332340
// Record new pending read from offloadedBlock
333-
mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
334-
mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
341+
mPendingReads[offloadedBlockIndex] = tr::CudaEvent();
342+
mOnboardManager.getStream().record(mPendingReads[offloadedBlockIndex]);
335343
// Record new pending write to block
336-
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
337-
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
344+
mPendingWrites[blockIndex] = tr::CudaEvent();
345+
mOnboardManager.getStream().record(mPendingWrites[blockIndex]);
338346
}
339347

340348
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
341349
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
342350
std::string const& directory)
343351
{
352+
auto const blockIndex = getPendingTransferIndex(block);
353+
auto const offloadBlockIndex = getPendingTransferIndex(offloadBlock);
354+
344355
// Wait for any pending writes before reading from block
345-
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
356+
auto blockPendingWriteItr = mPendingWrites.find(blockIndex);
346357
if (blockPendingWriteItr != mPendingWrites.end())
347358
{
348359
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
349360
// Don't erase, we are not changing state of block
350361
}
351362
// Wait for any pending reads before overwriting offloadBlock
352-
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
363+
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlockIndex);
353364
if (offloadBlockPendingReadItr != mPendingReads.end())
354365
{
355366
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
356367
mPendingReads.erase(offloadBlockPendingReadItr);
357368
}
358369
// Wait for any pending writes before overwriting offloadBlock
359-
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
370+
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlockIndex);
360371
if (offloadBlockPendingWriteItr != mPendingWrites.end())
361372
{
362373
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
@@ -373,11 +384,11 @@ void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offl
373384
}
374385

375386
// Record new pending read from block
376-
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
377-
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
387+
mPendingReads[blockIndex] = tr::CudaEvent();
388+
mOffloadManager.getStream().record(mPendingReads[blockIndex]);
378389
// Record new pending write to offloadBlock
379-
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
380-
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
390+
mPendingWrites[offloadBlockIndex] = tr::CudaEvent();
391+
mOffloadManager.getStream().record(mPendingWrites[offloadBlockIndex]);
381392
}
382393

383394
void KVCacheTransferManager::syncWithBufferManager()
@@ -390,7 +401,7 @@ void KVCacheTransferManager::syncWithBufferManager()
390401
mBufferManager.getStream().record(readyForOnboardEvent);
391402
mOnboardManager.getStream().wait(readyForOnboardEvent);
392403

393-
// Once we synchronize, clear our list of pending thransfers.
404+
// Once we synchronize, clear our list of pending transfers.
394405
mPendingReads.clear();
395406
mPendingWrites.clear();
396407
}
@@ -405,7 +416,7 @@ void KVCacheTransferManager::syncTransfers()
405416
mOnboardManager.getStream().record(onboardEvent);
406417
mBufferManager.getStream().wait(onboardEvent);
407418

408-
// Once we synchronize, clear our list of pending thransfers.
419+
// Once we synchronize, clear our list of pending transfers.
409420
mPendingReads.clear();
410421
mPendingWrites.clear();
411422
}

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,39 @@ namespace tle = tensorrt_llm::executor;
6262
namespace tr = tensorrt_llm::runtime;
6363
namespace fs = std::filesystem;
6464

65+
namespace tensorrt_llm::testing
66+
{
67+
namespace tbk = batch_manager::kv_cache_manager;
68+
69+
class KVCacheTransferManagerTestAccess
70+
{
71+
public:
72+
[[nodiscard]] static std::size_t pendingReadCount(tbk::KVCacheTransferManager const& transferManager)
73+
{
74+
return transferManager.mPendingReads.size();
75+
}
76+
77+
[[nodiscard]] static std::size_t pendingWriteCount(tbk::KVCacheTransferManager const& transferManager)
78+
{
79+
return transferManager.mPendingWrites.size();
80+
}
81+
82+
[[nodiscard]] static bool hasPendingReadForBlock(
83+
tbk::KVCacheTransferManager const& transferManager, tbk::BlockPtr const& block)
84+
{
85+
return transferManager.mPendingReads.find(tbk::KVCacheTransferManager::getPendingTransferIndex(block))
86+
!= transferManager.mPendingReads.end();
87+
}
88+
89+
[[nodiscard]] static bool hasPendingWriteForBlock(
90+
tbk::KVCacheTransferManager const& transferManager, tbk::BlockPtr const& block)
91+
{
92+
return transferManager.mPendingWrites.find(tbk::KVCacheTransferManager::getPendingTransferIndex(block))
93+
!= transferManager.mPendingWrites.end();
94+
}
95+
};
96+
} // namespace tensorrt_llm::testing
97+
6598
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
6699

67100
using ParamType = bool;
@@ -5013,6 +5046,49 @@ TEST_F(KVCacheManagerTest, KVCacheTransferManagerConcurrencyTest)
50135046
}
50145047
}
50155048

5049+
TEST_F(KVCacheManagerTest, KVCacheTransferManagerPendingTransfersDistinguishPrimaryAndSecondarySlots)
5050+
{
5051+
using tensorrt_llm::testing::KVCacheTransferManagerTestAccess;
5052+
5053+
auto constexpr kBlockSize = 1024;
5054+
auto constexpr kNumSlotsPerPool = 2;
5055+
5056+
auto bufferManager = tr::BufferManager(std::make_shared<tr::CudaStream>());
5057+
auto transferManager = KVCacheTransferManager(bufferManager);
5058+
5059+
auto pool = KVCacheBlockPool(0, 2, 0, 0, 0);
5060+
pool.primaryPtr
5061+
= bufferManager.gpu(tr::ITensor::makeShape({kNumSlotsPerPool, kBlockSize}), nvinfer1::DataType::kFLOAT);
5062+
bufferManager.setZero(*pool.primaryPtr);
5063+
5064+
pool.secondaryPtr
5065+
= tr::BufferManager::pinned(tr::ITensor::makeShape({kNumSlotsPerPool, kBlockSize}), nvinfer1::DataType::kFLOAT);
5066+
5067+
auto primarySlot0 = std::make_shared<KVCacheBlock>(0, tk::KVCacheIndex(0, false));
5068+
auto primarySlot1 = std::make_shared<KVCacheBlock>(1, tk::KVCacheIndex(1, false));
5069+
auto secondarySlot0 = std::make_shared<KVCacheBlock>(2, tk::KVCacheIndex(0, true));
5070+
auto secondarySlot1 = std::make_shared<KVCacheBlock>(3, tk::KVCacheIndex(1, true));
5071+
5072+
transferManager.offload(primarySlot0, secondarySlot1, {pool});
5073+
5074+
EXPECT_EQ(KVCacheTransferManagerTestAccess::pendingReadCount(transferManager), 1);
5075+
EXPECT_EQ(KVCacheTransferManagerTestAccess::pendingWriteCount(transferManager), 1);
5076+
EXPECT_TRUE(KVCacheTransferManagerTestAccess::hasPendingReadForBlock(transferManager, primarySlot0));
5077+
EXPECT_TRUE(KVCacheTransferManagerTestAccess::hasPendingWriteForBlock(transferManager, secondarySlot1));
5078+
5079+
transferManager.onboard(secondarySlot0, primarySlot1, {pool});
5080+
5081+
EXPECT_EQ(KVCacheTransferManagerTestAccess::pendingReadCount(transferManager), 2);
5082+
EXPECT_EQ(KVCacheTransferManagerTestAccess::pendingWriteCount(transferManager), 2);
5083+
EXPECT_TRUE(KVCacheTransferManagerTestAccess::hasPendingReadForBlock(transferManager, primarySlot0));
5084+
EXPECT_TRUE(KVCacheTransferManagerTestAccess::hasPendingReadForBlock(transferManager, secondarySlot0));
5085+
EXPECT_TRUE(KVCacheTransferManagerTestAccess::hasPendingWriteForBlock(transferManager, primarySlot1));
5086+
EXPECT_TRUE(KVCacheTransferManagerTestAccess::hasPendingWriteForBlock(transferManager, secondarySlot1));
5087+
5088+
transferManager.syncTransfers();
5089+
bufferManager.getStream().synchronize();
5090+
}
5091+
50165092
TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerSinkTokenLengthTest)
50175093
{
50185094
// TODO: Support sink attention and add coverage

0 commit comments

Comments
 (0)