@@ -97,6 +97,12 @@ tr::ITensor::SharedPtr KVCacheTransferManager::computeBlockPointer(
9797 return blockTensor;
9898}
9999
100+ tk::KVCacheIndex::UnderlyingType KVCacheTransferManager::getPendingTransferIndex (BlockPtr const & block)
101+ {
102+ auto const blockOffset = block->getMemoryPoolBlockIndex ();
103+ return block->isPrimary () ? blockOffset : blockOffset | tk::KVCacheIndex::kSecondaryPoolFlag ;
104+ }
105+
100106void KVCacheTransferManager::copyBlock (BlockPtr const & src, BlockPtr const & dst,
101107 std::vector<KVCacheBlockPool> const & pools, bool isOffload, int numTokensToCopy, executor::KvCacheTransferMode mode,
102108 std::string const & directory)
@@ -252,10 +258,9 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
252258}
253259
254260//
255- // Note about recording events to wait for cudaMempyAsync calls between blocks:
256- // The memory copy involves raw memory blocks, which are pointed to by the
257- // memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
258- // as the raw memory block identifier. Using getBlockId() when recording events is wrong.
261+ // Note about recording events to wait for cudaMemcpyAsync calls between blocks:
262+ // The memory copy involves raw memory blocks, which are identified by the pool-qualified
263+ // memory pool block index. Using getBlockId() when recording events is wrong.
259264// getBlockId() returns the logical block id, which has nothing to do with the raw memory
260265// block pointers involved in a cudaMemcpy.
261266//
@@ -289,22 +294,25 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr co
289294 std::vector<KVCacheBlockPool> const & pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
290295 std::string const & directory)
291296{
297+ auto const offloadedBlockIndex = getPendingTransferIndex (offloadedBlock);
298+ auto const blockIndex = getPendingTransferIndex (block);
299+
292300 // Wait for any pending writes before reading from offloadedBlock
293- auto offloadedBlockPendingWriteItr = mPendingWrites .find (offloadedBlock-> getMemoryPoolBlockIndex () );
301+ auto offloadedBlockPendingWriteItr = mPendingWrites .find (offloadedBlockIndex );
294302 if (offloadedBlockPendingWriteItr != mPendingWrites .end ())
295303 {
296304 mOnboardManager .getStream ().wait (offloadedBlockPendingWriteItr->second );
297305 // Don't erase, we are not changing state of offloadedBlock
298306 }
299307 // Wait for any pending reads before overwriting block
300- auto blockPendingReadItr = mPendingReads .find (block-> getMemoryPoolBlockIndex () );
308+ auto blockPendingReadItr = mPendingReads .find (blockIndex );
301309 if (blockPendingReadItr != mPendingReads .end ())
302310 {
303311 mOnboardManager .getStream ().wait (blockPendingReadItr->second );
304312 mPendingReads .erase (blockPendingReadItr);
305313 }
306314 // Wait for any pending writes before overwriting block
307- auto blockPendingWriteItr = mPendingWrites .find (block-> getMemoryPoolBlockIndex () );
315+ auto blockPendingWriteItr = mPendingWrites .find (blockIndex );
308316 if (blockPendingWriteItr != mPendingWrites .end ())
309317 {
310318 mOnboardManager .getStream ().wait (blockPendingWriteItr->second );
@@ -330,33 +338,36 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr co
330338 }
331339
332340 // Record new pending read from offloadedBlock
333- mPendingReads [offloadedBlock-> getMemoryPoolBlockIndex () ] = tr::CudaEvent ();
334- mOnboardManager .getStream ().record (mPendingReads [offloadedBlock-> getMemoryPoolBlockIndex () ]);
341+ mPendingReads [offloadedBlockIndex ] = tr::CudaEvent ();
342+ mOnboardManager .getStream ().record (mPendingReads [offloadedBlockIndex ]);
335343 // Record new pending write to block
336- mPendingWrites [block-> getMemoryPoolBlockIndex () ] = tr::CudaEvent ();
337- mOnboardManager .getStream ().record (mPendingWrites [block-> getMemoryPoolBlockIndex () ]);
344+ mPendingWrites [blockIndex ] = tr::CudaEvent ();
345+ mOnboardManager .getStream ().record (mPendingWrites [blockIndex ]);
338346}
339347
340348void KVCacheTransferManager::offload (BlockPtr const & block, BlockPtr const & offloadBlock,
341349 std::vector<KVCacheBlockPool> const & pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
342350 std::string const & directory)
343351{
352+ auto const blockIndex = getPendingTransferIndex (block);
353+ auto const offloadBlockIndex = getPendingTransferIndex (offloadBlock);
354+
344355 // Wait for any pending writes before reading from block
345- auto blockPendingWriteItr = mPendingWrites .find (block-> getMemoryPoolBlockIndex () );
356+ auto blockPendingWriteItr = mPendingWrites .find (blockIndex );
346357 if (blockPendingWriteItr != mPendingWrites .end ())
347358 {
348359 mOffloadManager .getStream ().wait (blockPendingWriteItr->second );
349360 // Don't erase, we are not changing state of block
350361 }
351362 // Wait for any pending reads before overwriting offloadBlock
352- auto offloadBlockPendingReadItr = mPendingReads .find (offloadBlock-> getMemoryPoolBlockIndex () );
363+ auto offloadBlockPendingReadItr = mPendingReads .find (offloadBlockIndex );
353364 if (offloadBlockPendingReadItr != mPendingReads .end ())
354365 {
355366 mOffloadManager .getStream ().wait (offloadBlockPendingReadItr->second );
356367 mPendingReads .erase (offloadBlockPendingReadItr);
357368 }
358369 // Wait for any pending writes before overwriting offloadBlock
359- auto offloadBlockPendingWriteItr = mPendingWrites .find (offloadBlock-> getMemoryPoolBlockIndex () );
370+ auto offloadBlockPendingWriteItr = mPendingWrites .find (offloadBlockIndex );
360371 if (offloadBlockPendingWriteItr != mPendingWrites .end ())
361372 {
362373 mOffloadManager .getStream ().wait (offloadBlockPendingWriteItr->second );
@@ -373,11 +384,11 @@ void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offl
373384 }
374385
375386 // Record new pending read from block
376- mPendingReads [block-> getMemoryPoolBlockIndex () ] = tr::CudaEvent ();
377- mOffloadManager .getStream ().record (mPendingReads [block-> getMemoryPoolBlockIndex () ]);
387+ mPendingReads [blockIndex ] = tr::CudaEvent ();
388+ mOffloadManager .getStream ().record (mPendingReads [blockIndex ]);
378389 // Record new pending write to offloadBlock
379- mPendingWrites [offloadBlock-> getMemoryPoolBlockIndex () ] = tr::CudaEvent ();
380- mOffloadManager .getStream ().record (mPendingWrites [offloadBlock-> getMemoryPoolBlockIndex () ]);
390+ mPendingWrites [offloadBlockIndex ] = tr::CudaEvent ();
391+ mOffloadManager .getStream ().record (mPendingWrites [offloadBlockIndex ]);
381392}
382393
383394void KVCacheTransferManager::syncWithBufferManager ()
@@ -390,7 +401,7 @@ void KVCacheTransferManager::syncWithBufferManager()
390401 mBufferManager .getStream ().record (readyForOnboardEvent);
391402 mOnboardManager .getStream ().wait (readyForOnboardEvent);
392403
393- // Once we synchronize, clear our list of pending thransfers .
404+ // Once we synchronize, clear our list of pending transfers .
394405 mPendingReads .clear ();
395406 mPendingWrites .clear ();
396407}
@@ -405,7 +416,7 @@ void KVCacheTransferManager::syncTransfers()
405416 mOnboardManager .getStream ().record (onboardEvent);
406417 mBufferManager .getStream ().wait (onboardEvent);
407418
408- // Once we synchronize, clear our list of pending thransfers .
419+ // Once we synchronize, clear our list of pending transfers .
409420 mPendingReads .clear ();
410421 mPendingWrites .clear ();
411422}
0 commit comments