|
1 | 1 | #include "kvCacheManagerV2Utils.h" |
2 | 2 | #include "tensorrt_llm/common/assert.h" |
3 | 3 | #include "tensorrt_llm/common/cudaUtils.h" |
| 4 | +#include "tensorrt_llm/common/envUtils.h" |
| 5 | +#include "tensorrt_llm/common/memoryUtils.h" |
4 | 6 | #include <algorithm> |
5 | 7 | #include <array> |
6 | 8 | #include <cassert> |
7 | 9 | #include <cuda_runtime.h> |
| 10 | +#include <vector> |
8 | 11 |
|
9 | 12 | namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 |
10 | 13 | { |
11 | 14 | using Grain = uint4; |
12 | 15 | constexpr uint32_t ctaSize = 128; |
| 16 | +constexpr uint32_t copyBlockCtaSize = 128; |
| 17 | +constexpr uint32_t copyBlocknbBufs = 2; |
13 | 18 | constexpr uint32_t nbBufs = 4; |
14 | 19 | constexpr uint32_t grainBytes = sizeof(Grain); |
15 | 20 |
|
@@ -162,4 +167,139 @@ CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, |
162 | 167 | return launchBatchedCopy(false, tasks, numBytes, stream); |
163 | 168 | } |
164 | 169 |
|
| 170 | +// dst_tensor[:, :num_seqs, 0] = src_tensor[:, copy_idx] |
| 171 | +// dst_tensor[:, :num_seqs, 1] = dst_tensor[:, :num_seqs, 0] + 1 |
| 172 | +template <bool COPY_V_IDX = true> |
| 173 | +__global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict__ srcPtr, |
| 174 | + SizeType32* __restrict__ dstPtr, SizeType32 const maxNumSequences, SizeType32 numBlocksPerSeq, |
| 175 | + SizeType32 const* __restrict__ copyIndex) |
| 176 | +{ |
| 177 | + constexpr uint32_t kvFactor = 2; |
| 178 | + constexpr auto elemPerAccess = sizeof(PackedInt) / sizeof(SizeType32); |
| 179 | + |
| 180 | + __shared__ PackedInt data[copyBlocknbBufs][copyBlockCtaSize]; |
| 181 | + |
| 182 | + auto const iterPerSeq = divUp(numBlocksPerSeq * sizeof(SizeType32), sizeof(PackedInt) * copyBlockCtaSize); |
| 183 | + auto const tid = threadIdx.x; |
| 184 | + auto const poolIdx = blockIdx.x; |
| 185 | + auto const seqIdx = blockIdx.y; |
| 186 | + auto const seqDimStride = kvFactor * numBlocksPerSeq; |
| 187 | + uint32_t const srcIdxBeg = tid * elemPerAccess + (poolIdx * maxNumSequences + copyIndex[seqIdx]) * seqDimStride; |
| 188 | + uint32_t const dstIdxKBeg = tid * elemPerAccess + (poolIdx * maxNumSequences + seqIdx) * seqDimStride; |
| 189 | + uint32_t const dstIdxVBeg = dstIdxKBeg + numBlocksPerSeq; |
| 190 | + |
| 191 | + uint32_t const srcIdxEnd = (poolIdx * maxNumSequences + copyIndex[seqIdx]) * seqDimStride + numBlocksPerSeq; |
| 192 | + |
| 193 | + for (uint32_t i = 0; i < iterPerSeq + copyBlocknbBufs; i++) |
| 194 | + { |
| 195 | + uint32_t const idxBuf = i % copyBlocknbBufs; |
| 196 | + if (i >= copyBlocknbBufs) |
| 197 | + { |
| 198 | + uint32_t const stIter = i - copyBlocknbBufs; |
| 199 | + assert(idxBuf == (stIter % copyBlocknbBufs)); |
| 200 | + auto const offset = copyBlockCtaSize * stIter * elemPerAccess; |
| 201 | + SizeType32 const srcIdx = srcIdxBeg + offset; |
| 202 | + SizeType32 const dstIdxK = dstIdxKBeg + offset; |
| 203 | + SizeType32 const dstIdxV = dstIdxVBeg + offset; |
| 204 | + PackedInt const& src = data[idxBuf][tid]; |
| 205 | + PackedInt& dstK = *reinterpret_cast<PackedInt*>(dstPtr + dstIdxK); |
| 206 | + PackedInt& dstV = *reinterpret_cast<PackedInt*>(dstPtr + dstIdxV); |
| 207 | + asm volatile("cp.async.wait_group %0;\n" ::"n"(copyBlocknbBufs - 1) : "memory"); |
| 208 | + if (srcIdx < srcIdxEnd) |
| 209 | + { |
| 210 | + dstK = src; |
| 211 | + if (COPY_V_IDX) |
| 212 | + { |
| 213 | +#pragma unroll |
| 214 | + for (uint32_t j = 0; j < elemPerAccess; j++) |
| 215 | + { |
| 216 | + dstV.unpacked[j] = src.unpacked[j] + 1; |
| 217 | + } |
| 218 | + } |
| 219 | + } |
| 220 | + } |
| 221 | + uint32_t const ldIter = i; |
| 222 | + PackedInt* const dst = &data[idxBuf][tid]; |
| 223 | + uint32_t const srcIdx = srcIdxBeg + copyBlockCtaSize * ldIter * elemPerAccess; |
| 224 | + PackedInt const* const src = reinterpret_cast<PackedInt const*>(srcPtr + srcIdx); |
| 225 | + if (srcIdx < srcIdxEnd) |
| 226 | + { |
| 227 | + uint32_t const size = sizeof(PackedInt); |
| 228 | + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)), |
| 229 | + "l"(src), "n"(size), "r"(size) |
| 230 | + : "memory"); |
| 231 | + } |
| 232 | + asm volatile("cp.async.commit_group;\n" : : : "memory"); |
| 233 | + } |
| 234 | +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) |
| 235 | + cudaTriggerProgrammaticLaunchCompletion(); |
| 236 | +#endif |
| 237 | +} |
| 238 | + |
| 239 | +// Host-side launcher |
| 240 | +void copyBatchBlockOffsetsToDevice( |
| 241 | + ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept |
| 242 | +{ |
| 243 | + using namespace tensorrt_llm::runtime; |
| 244 | + |
| 245 | + auto const* srcPtr = bufferCast<tk::KVCacheIndex::UnderlyingType const>(input); |
| 246 | + auto* dstPtr = bufferCast<tk::KVCacheIndex::UnderlyingType>( |
| 247 | + output); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq] |
| 248 | + auto const* copyIndexPtr = bufferCast<SizeType32 const>(copyIndex); |
| 249 | + auto const& srcShape = input.getShape(); |
| 250 | + auto const& dstShape = output.getShape(); |
| 251 | + auto const& copyIndexShape = copyIndex.getShape(); |
| 252 | + |
| 253 | + TLLM_CHECK(srcShape.nbDims == 4); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq] |
| 254 | + TLLM_CHECK(dstShape.nbDims == 4); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq] |
| 255 | + |
| 256 | + SizeType32 numPools = srcShape.d[0]; |
| 257 | + SizeType32 maxNumSequences = srcShape.d[1]; |
| 258 | + SizeType32 numBlocksPerSeq = srcShape.d[3]; |
| 259 | + SizeType32 numSeqs = copyIndexShape.d[0]; |
| 260 | + |
| 261 | + if (numSeqs == 0) |
| 262 | + { |
| 263 | + return; |
| 264 | + } |
| 265 | + |
| 266 | + TLLM_CHECK_WITH_INFO((numBlocksPerSeq * sizeof(SizeType32)) % sizeof(PackedInt) == 0, |
| 267 | + "Not implemented case: numBlocksPerSeq * sizeof(SizeType32) = %zu must be a multiple of %zu.", |
| 268 | + static_cast<size_t>(numBlocksPerSeq * sizeof(SizeType32)), static_cast<size_t>(sizeof(PackedInt))); |
| 269 | + |
| 270 | + dim3 gridDim(numPools, numSeqs, 1); |
| 271 | + dim3 blockDim(copyBlockCtaSize); |
| 272 | + |
| 273 | + if (copyVIdx) |
| 274 | + { |
| 275 | + copyBatchBlockOffsetsToDeviceKernel<true> |
| 276 | + <<<gridDim, blockDim, 0, stream>>>(srcPtr, dstPtr, maxNumSequences, numBlocksPerSeq, copyIndexPtr); |
| 277 | + } |
| 278 | + else |
| 279 | + { |
| 280 | + copyBatchBlockOffsetsToDeviceKernel<false> |
| 281 | + <<<gridDim, blockDim, 0, stream>>>(srcPtr, dstPtr, maxNumSequences, numBlocksPerSeq, copyIndexPtr); |
| 282 | + } |
| 283 | +} |
| 284 | + |
| 285 | +IndexMapper::IndexMapper(SizeType32 maxBatchSize, SizeType32 maxBeamWidth) |
| 286 | + : maxBatchSize_(maxBatchSize) |
| 287 | + , maxBeamWidth_(maxBeamWidth) |
| 288 | +{ |
| 289 | + indexMap_.reserve(maxBatchSize); |
| 290 | + for (SizeType32 i = 0; i < maxBatchSize; i++) |
| 291 | + { |
| 292 | + freeIndices_.insert(i); |
| 293 | + } |
| 294 | + // Allocate copyIndex_ memory as pinned (page-locked) host memory |
| 295 | + TLLM_CUDA_CHECK(cudaMallocHost(©Index_, maxBatchSize * maxBeamWidth * sizeof(SizeType32))); |
| 296 | +} |
| 297 | + |
| 298 | +IndexMapper::~IndexMapper() |
| 299 | +{ |
| 300 | + indexMap_.clear(); |
| 301 | + freeIndices_.clear(); |
| 302 | + TLLM_CUDA_CHECK(cudaFreeHost(copyIndex_)); |
| 303 | +} |
| 304 | + |
165 | 305 | } // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 |
0 commit comments