Skip to content

Commit e76e4f7

Browse files
committed
Integrate zero copy api
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
1 parent d77234d commit e76e4f7

File tree

7 files changed

+363
-41
lines changed

7 files changed

+363
-41
lines changed

cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,52 @@ void copyBatchBlockOffsets(ITensor& output, SizeType32 batchSize, std::vector<Bl
172172
}
173173
}
174174

175+
SizeType32 IndexMapper::addNewSequence(LlmRequest::RequestIdType requestId)
176+
{
177+
TLLM_CHECK(indexMap_.find(requestId) == indexMap_.end());
178+
auto iter = freeIndices_.begin();
179+
TLLM_CHECK_WITH_INFO(iter != freeIndices_.end(), "No free index found");
180+
auto index = *iter;
181+
freeIndices_.erase(iter);
182+
indexMap_[requestId] = index;
183+
return index;
184+
}
185+
186+
SizeType32 IndexMapper::getIndex(LlmRequest::RequestIdType requestId)
187+
{
188+
return indexMap_[requestId];
189+
}
190+
191+
void IndexMapper::removeSequence(LlmRequest::RequestIdType requestId)
192+
{
193+
auto iter = indexMap_.find(requestId);
194+
TLLM_CHECK(iter != indexMap_.end());
195+
auto index = iter->second;
196+
freeIndices_.insert(index);
197+
indexMap_.erase(iter);
198+
}
199+
200+
at::Tensor IndexMapper::getCopyIndex(
201+
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 numContext, SizeType32 beamWidth)
202+
{
203+
int numSeqs = numContext + beamWidth * (requestIds.size() - numContext);
204+
for (uint32_t i = 0, idx = 0; i < requestIds.size(); i++)
205+
{
206+
if (i < numContext)
207+
{
208+
copyIndex_[idx++] = indexMap_[requestIds[i]] * maxBeamWidth_;
209+
}
210+
else
211+
{
212+
for (uint32_t j = 0; j < beamWidth; j++)
213+
{
214+
copyIndex_[idx++] = indexMap_[requestIds[i]] * maxBeamWidth_ + j;
215+
}
216+
}
217+
}
218+
219+
auto options = at::TensorOptions().dtype(at::ScalarType::Int).pinned_memory(true);
220+
return at::from_blob(copyIndex_, numSeqs, options);
221+
}
222+
175223
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
#include "kvCacheManagerV2Utils.h"
22
#include "tensorrt_llm/common/assert.h"
33
#include "tensorrt_llm/common/cudaUtils.h"
4+
#include "tensorrt_llm/common/envUtils.h"
5+
#include "tensorrt_llm/common/memoryUtils.h"
46
#include <algorithm>
57
#include <array>
68
#include <cassert>
79
#include <cuda_runtime.h>
10+
#include <vector>
811

912
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
1013
{
1114
using Grain = uint4;
1215
constexpr uint32_t ctaSize = 128;
16+
constexpr uint32_t copyBlockCtaSize = 128;
17+
constexpr uint32_t copyBlocknbBufs = 2;
1318
constexpr uint32_t nbBufs = 4;
1419
constexpr uint32_t grainBytes = sizeof(Grain);
1520

@@ -162,4 +167,139 @@ CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes,
162167
return launchBatchedCopy(false, tasks, numBytes, stream);
163168
}
164169

170+
// dst_tensor[:, :num_seqs, 0] = src_tensor[:, copy_idx]
171+
// dst_tensor[:, :num_seqs, 1] = dst_tensor[:, :num_seqs, 0] + 1
172+
template <bool COPY_V_IDX = true>
173+
__global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict__ srcPtr,
174+
SizeType32* __restrict__ dstPtr, SizeType32 const maxNumSequences, SizeType32 numBlocksPerSeq,
175+
SizeType32 const* __restrict__ copyIndex)
176+
{
177+
constexpr uint32_t kvFactor = 2;
178+
constexpr auto elemPerAccess = sizeof(PackedInt) / sizeof(SizeType32);
179+
180+
__shared__ PackedInt data[copyBlocknbBufs][copyBlockCtaSize];
181+
182+
auto const iterPerSeq = divUp(numBlocksPerSeq * sizeof(SizeType32), sizeof(PackedInt) * copyBlockCtaSize);
183+
auto const tid = threadIdx.x;
184+
auto const poolIdx = blockIdx.x;
185+
auto const seqIdx = blockIdx.y;
186+
auto const seqDimStride = kvFactor * numBlocksPerSeq;
187+
uint32_t const srcIdxBeg = tid * elemPerAccess + (poolIdx * maxNumSequences + copyIndex[seqIdx]) * seqDimStride;
188+
uint32_t const dstIdxKBeg = tid * elemPerAccess + (poolIdx * maxNumSequences + seqIdx) * seqDimStride;
189+
uint32_t const dstIdxVBeg = dstIdxKBeg + numBlocksPerSeq;
190+
191+
uint32_t const srcIdxEnd = (poolIdx * maxNumSequences + copyIndex[seqIdx]) * seqDimStride + numBlocksPerSeq;
192+
193+
for (uint32_t i = 0; i < iterPerSeq + copyBlocknbBufs; i++)
194+
{
195+
uint32_t const idxBuf = i % copyBlocknbBufs;
196+
if (i >= copyBlocknbBufs)
197+
{
198+
uint32_t const stIter = i - copyBlocknbBufs;
199+
assert(idxBuf == (stIter % copyBlocknbBufs));
200+
auto const offset = copyBlockCtaSize * stIter * elemPerAccess;
201+
SizeType32 const srcIdx = srcIdxBeg + offset;
202+
SizeType32 const dstIdxK = dstIdxKBeg + offset;
203+
SizeType32 const dstIdxV = dstIdxVBeg + offset;
204+
PackedInt const& src = data[idxBuf][tid];
205+
PackedInt& dstK = *reinterpret_cast<PackedInt*>(dstPtr + dstIdxK);
206+
PackedInt& dstV = *reinterpret_cast<PackedInt*>(dstPtr + dstIdxV);
207+
asm volatile("cp.async.wait_group %0;\n" ::"n"(copyBlocknbBufs - 1) : "memory");
208+
if (srcIdx < srcIdxEnd)
209+
{
210+
dstK = src;
211+
if (COPY_V_IDX)
212+
{
213+
#pragma unroll
214+
for (uint32_t j = 0; j < elemPerAccess; j++)
215+
{
216+
dstV.unpacked[j] = src.unpacked[j] + 1;
217+
}
218+
}
219+
}
220+
}
221+
uint32_t const ldIter = i;
222+
PackedInt* const dst = &data[idxBuf][tid];
223+
uint32_t const srcIdx = srcIdxBeg + copyBlockCtaSize * ldIter * elemPerAccess;
224+
PackedInt const* const src = reinterpret_cast<PackedInt const*>(srcPtr + srcIdx);
225+
if (srcIdx < srcIdxEnd)
226+
{
227+
uint32_t const size = sizeof(PackedInt);
228+
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)),
229+
"l"(src), "n"(size), "r"(size)
230+
: "memory");
231+
}
232+
asm volatile("cp.async.commit_group;\n" : : : "memory");
233+
}
234+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
235+
cudaTriggerProgrammaticLaunchCompletion();
236+
#endif
237+
}
238+
239+
// Host-side launcher
240+
void copyBatchBlockOffsetsToDevice(
241+
ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept
242+
{
243+
using namespace tensorrt_llm::runtime;
244+
245+
auto const* srcPtr = bufferCast<tk::KVCacheIndex::UnderlyingType const>(input);
246+
auto* dstPtr = bufferCast<tk::KVCacheIndex::UnderlyingType>(
247+
output); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq]
248+
auto const* copyIndexPtr = bufferCast<SizeType32 const>(copyIndex);
249+
auto const& srcShape = input.getShape();
250+
auto const& dstShape = output.getShape();
251+
auto const& copyIndexShape = copyIndex.getShape();
252+
253+
TLLM_CHECK(srcShape.nbDims == 4); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq]
254+
TLLM_CHECK(dstShape.nbDims == 4); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq]
255+
256+
SizeType32 numPools = srcShape.d[0];
257+
SizeType32 maxNumSequences = srcShape.d[1];
258+
SizeType32 numBlocksPerSeq = srcShape.d[3];
259+
SizeType32 numSeqs = copyIndexShape.d[0];
260+
261+
if (numSeqs == 0)
262+
{
263+
return;
264+
}
265+
266+
TLLM_CHECK_WITH_INFO((numBlocksPerSeq * sizeof(SizeType32)) % sizeof(PackedInt) == 0,
267+
"Not implemented case: numBlocksPerSeq * sizeof(SizeType32) = %zu must be a multiple of %zu.",
268+
static_cast<size_t>(numBlocksPerSeq * sizeof(SizeType32)), static_cast<size_t>(sizeof(PackedInt)));
269+
270+
dim3 gridDim(numPools, numSeqs, 1);
271+
dim3 blockDim(copyBlockCtaSize);
272+
273+
if (copyVIdx)
274+
{
275+
copyBatchBlockOffsetsToDeviceKernel<true>
276+
<<<gridDim, blockDim, 0, stream>>>(srcPtr, dstPtr, maxNumSequences, numBlocksPerSeq, copyIndexPtr);
277+
}
278+
else
279+
{
280+
copyBatchBlockOffsetsToDeviceKernel<false>
281+
<<<gridDim, blockDim, 0, stream>>>(srcPtr, dstPtr, maxNumSequences, numBlocksPerSeq, copyIndexPtr);
282+
}
283+
}
284+
285+
IndexMapper::IndexMapper(SizeType32 maxBatchSize, SizeType32 maxBeamWidth)
286+
: maxBatchSize_(maxBatchSize)
287+
, maxBeamWidth_(maxBeamWidth)
288+
{
289+
indexMap_.reserve(maxBatchSize);
290+
for (SizeType32 i = 0; i < maxBatchSize; i++)
291+
{
292+
freeIndices_.insert(i);
293+
}
294+
// Allocate copyIndex_ memory as pinned (page-locked) host memory
295+
TLLM_CUDA_CHECK(cudaMallocHost(&copyIndex_, maxBatchSize * maxBeamWidth * sizeof(SizeType32)));
296+
}
297+
298+
IndexMapper::~IndexMapper()
299+
{
300+
indexMap_.clear();
301+
freeIndices_.clear();
302+
TLLM_CUDA_CHECK(cudaFreeHost(copyIndex_));
303+
}
304+
165305
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
#pragma once
1919

20+
#include "tensorrt_llm/batch_manager/llmRequest.h"
2021
#include "tensorrt_llm/kernels/kvCacheIndex.h"
2122
#include "tensorrt_llm/runtime/iBuffer.h"
2223
#include "tensorrt_llm/runtime/iTensor.h"
24+
#include <ATen/ATen.h>
2325
#include <cstdint>
2426
#include <cuda.h>
27+
#include <set>
2528
#include <vector>
2629

2730
namespace tk = tensorrt_llm::kernels;
@@ -51,6 +54,37 @@ struct BlockIndices
5154
SizeType32 length;
5255
};
5356

57+
using PackedInt = union
58+
{
59+
int4 packed;
60+
SizeType32 unpacked[4];
61+
};
62+
63+
class IndexMapper
64+
{
65+
public:
66+
IndexMapper(SizeType32 maxBatchSize, SizeType32 maxBeamWidth);
67+
68+
~IndexMapper();
69+
70+
SizeType32 addNewSequence(LlmRequest::RequestIdType requestId);
71+
72+
SizeType32 getIndex(LlmRequest::RequestIdType requestId);
73+
74+
void removeSequence(LlmRequest::RequestIdType requestId);
75+
76+
at::Tensor getCopyIndex(
77+
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 numContext, SizeType32 beamWidth);
78+
79+
private:
80+
std::unordered_map<LlmRequest::RequestIdType, SizeType32> indexMap_;
81+
std::set<SizeType32> freeIndices_;
82+
SizeType32* copyIndex_;
83+
SizeType32 currentIndex_;
84+
SizeType32 maxBatchSize_;
85+
SizeType32 maxBeamWidth_;
86+
};
87+
5488
CUresult copyDiskToDisk(
5589
std::vector<Task<DiskAddress, DiskAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
5690
CUresult copyDiskToHost(
@@ -69,4 +103,7 @@ CUresult copyDeviceToDevice(
69103
void copyBatchBlockOffsets(ITensor& output, SizeType32 batchSize, std::vector<BlockIndices> const& batchBlockIndices,
70104
SizeType32 numPools, SizeType32 offset) noexcept;
71105

106+
void copyBatchBlockOffsetsToDevice(
107+
ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept;
108+
72109
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module)
7676
.def_rw("addr", &BlockIndices::addr)
7777
.def_rw("length", &BlockIndices::length);
7878

79+
nb::class_<IndexMapper>(module, "IndexMapper")
80+
.def(nb::init<SizeType32, SizeType32>(), nb::arg("max_batch_size"), nb::arg("max_beam_width"))
81+
.def("add_new_sequence", &IndexMapper::addNewSequence)
82+
.def("get_index", &IndexMapper::getIndex)
83+
.def("remove_sequence", &IndexMapper::removeSequence)
84+
.def("get_copy_index", &IndexMapper::getCopyIndex);
85+
7986
// Bind copy functions
8087
module.def(
8188
"copy_disk_to_disk",
@@ -137,6 +144,22 @@ void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module)
137144
},
138145
nb::arg("output"), nb::arg("batch_size"), nb::arg("batch_block_indices"), nb::arg("num_pools"),
139146
nb::arg("offset"), nb::call_guard<nb::gil_scoped_release>(), "Copy batch block indices to output tensor");
147+
148+
module.def(
149+
"copy_batch_block_offsets_to_device",
150+
[](at::Tensor input, at::Tensor output, at::Tensor copyIndex, bool copyVIdx, uintptr_t stream)
151+
{
152+
auto _input = from_torch(input);
153+
auto _output = from_torch(output);
154+
auto _copyIndex = from_torch(copyIndex);
155+
TLLM_CHECK_WITH_INFO(_input.has_value(), "Invalid input tensor.");
156+
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
157+
TLLM_CHECK_WITH_INFO(_copyIndex.has_value(), "Invalid copy index tensor.");
158+
copyBatchBlockOffsetsToDevice(*(_input.value()), *(_output.value()), *(_copyIndex.value()), copyVIdx,
159+
reinterpret_cast<CUstream>(stream));
160+
},
161+
nb::arg("input"), nb::arg("output"), nb::arg("copy_index"), nb::arg("copy_v_idx"), nb::arg("stream"),
162+
nb::call_guard<nb::gil_scoped_release>(), "Copy batch block indices to device");
140163
}
141164

142165
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -725,11 +725,8 @@ def _post_init_with_buffers(self, buffers) -> None:
725725
dtype=torch.int32,
726726
capture_graph=capture_graph,
727727
)
728-
self.host_kv_cache_block_offsets = torch.empty_like(
729-
self.kv_cache_block_offsets,
730-
device='cpu',
731-
pin_memory=True,
732-
)
728+
self.host_kv_cache_block_offsets = self.kv_cache_manager.host_kv_cache_block_offsets
729+
assert self.host_kv_cache_block_offsets.shape == self.kv_cache_block_offsets.shape, f"host_kv_cache_block_offsets and kv_cache_block_offsets should have the same shape, but got {self.host_kv_cache_block_offsets.shape} and {self.kv_cache_block_offsets.shape}"
733730
self.block_ids_per_seq = None
734731
self.kv_block_ids_per_seq = None
735732
if self.enable_flash_mla:
@@ -861,16 +858,8 @@ def prepare(self) -> None:
861858
if self.kv_cache_manager is not None:
862859
# Copy blocks for all context requests
863860
self.kv_cache_manager.copy_batch_block_offsets(
864-
self.host_kv_cache_block_offsets,
865-
self.request_ids[:self.num_contexts], 1, 0)
866-
# Copy blocks for all generation requests
867-
self.kv_cache_manager.copy_batch_block_offsets(
868-
self.host_kv_cache_block_offsets,
869-
self.request_ids[self.num_contexts:], self.beam_width,
870-
self.num_contexts)
871-
self.kv_cache_block_offsets[:, :self.num_seqs].copy_(
872-
self.host_kv_cache_block_offsets[:, :self.num_seqs],
873-
non_blocking=True)
861+
self.kv_cache_block_offsets, self.request_ids, self.beam_width,
862+
self.num_contexts, self.num_generations)
874863

875864
error_message = (
876865
f"The max KV cache length of input sequences ({self.kv_lens[:self.num_seqs].max()}) "

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,9 @@ def _executor_loop_overlap(self):
14791479
iter_stats=iter_stats,
14801480
ctx_transmission_reqs=ctx_transmission_reqs)
14811481

1482+
else:
1483+
self.previous_batch = None
1484+
14821485
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
14831486
self._check_kv_transfer_timeout()
14841487
self._terminate_disagg_ctx_finished_requests()

0 commit comments

Comments
 (0)