1+ #include " kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h"
2+
3+ #include < algorithm>
4+ #include < cassert>
5+
6+ #include " kv_cache_manager/common/env_util.h"
7+ #include " kv_cache_manager/common/hash_util.h"
8+
9+ namespace kv_cache_manager {
10+
11+ std::vector<int64_t > SdkBufferCheckUtil::GetBlocksHash (const BlockBuffers &block_buffers) {
12+ std::vector<IovDevice> iov_h;
13+ size_t iov_num = block_buffers.front ().iovs .size ();
14+ iov_h.reserve (iov_num * block_buffers.size ());
15+ for (const auto &block_buffer : block_buffers) {
16+ for (const auto &raw_iov : block_buffer.iovs ) {
17+ iov_h.push_back ({raw_iov.base , raw_iov.size });
18+ }
19+ }
20+ auto crcs = GetIovsCrc (iov_h);
21+ std::vector<int64_t > result;
22+ result.reserve (block_buffers.size ());
23+ for (size_t offset = 0 ; offset < crcs.size (); offset += iov_num) {
24+ result.push_back (HashUtil::HashIntArray (&crcs[offset], &crcs[offset + iov_num], 0 ));
25+ }
26+ return result;
27+ }
28+
29+ std::vector<int64_t > SdkBufferCheckUtil::GetBlocksHash (
30+ const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, cudaStream_t stream) {
31+ std::vector<IovDevice> iov_h (max_iov_num);
32+ return GetBlocksHash (block_buffers, iovs_d, crcs_d, iov_h.data (), max_iov_num, stream);
33+ }
34+
35+ std::vector<int64_t > SdkBufferCheckUtil::GetBlocksHash (const BlockBuffers &block_buffers,
36+ IovDevice *iovs_d,
37+ uint32_t *crcs_d,
38+ IovDevice *iovs_h_to_save,
39+ size_t max_iov_num,
40+ cudaStream_t stream) {
41+ size_t iov_num = block_buffers.front ().iovs .size ();
42+ size_t iovs_size = 0 ;
43+ for (const auto &block_buffer : block_buffers) {
44+ assert (iov_num == block_buffer.iovs .size ());
45+ if (iovs_size + block_buffer.iovs .size () > max_iov_num) {
46+ break ;
47+ }
48+ for (const auto &raw_iov : block_buffer.iovs ) {
49+ iovs_h_to_save[iovs_size].base = raw_iov.base ;
50+ iovs_h_to_save[iovs_size].size = raw_iov.size ;
51+ iovs_size++;
52+ }
53+ }
54+ auto crcs = GetIovsCrc (iovs_h_to_save, iovs_size, iovs_d, crcs_d, stream);
55+ std::vector<int64_t > result;
56+ result.reserve (iovs_size / iov_num);
57+ for (size_t offset = 0 ; offset < crcs.size (); offset += iov_num) {
58+ result.push_back (HashUtil::HashIntArray (&crcs[offset], &crcs[offset + iov_num], 0 ));
59+ }
60+ return result;
61+ }
62+
63+ std::vector<uint32_t > SdkBufferCheckUtil::GetIovsCrc (const std::vector<IovDevice> &iovs_h) {
64+ IovDevice *iovs_d = nullptr ;
65+ uint32_t *crcs_d = nullptr ;
66+ CHECK_CUDA_ERROR_RETURN (cudaMalloc (&iovs_d, sizeof (IovDevice) * iovs_h.size ()), {}, " cudaMalloc fail" );
67+ CHECK_CUDA_ERROR_RETURN (cudaMalloc (&crcs_d, sizeof (uint32_t ) * iovs_h.size ()), {}, " cudaMalloc fail" );
68+ auto crcs = GetIovsCrc (iovs_h, iovs_d, crcs_d, nullptr );
69+ CHECK_CUDA_ERROR_RETURN (cudaFree (iovs_d), {}, " cudaMalloc fail" );
70+ CHECK_CUDA_ERROR_RETURN (cudaFree (crcs_d), {}, " cudaMalloc fail" );
71+ return crcs;
72+ }
73+
74+ std::vector<uint32_t > SdkBufferCheckUtil::GetIovsCrc (const std::vector<IovDevice> &iovs_h,
75+ IovDevice *iovs_d,
76+ uint32_t *crcs_d,
77+ cudaStream_t stream) {
78+ return GetIovsCrc (iovs_h.data (), iovs_h.size (), iovs_d, crcs_d, stream);
79+ }
80+
81+ SdkBufferCheckPool::SdkBufferCheckPool (size_t cell_num) { cells_.resize (cell_num); }
82+
83+ SdkBufferCheckPool::~SdkBufferCheckPool () {
84+ for (const auto &cell : cells_) {
85+ if (cell.h_iovs ) {
86+ CHECK_CUDA_ERROR (cudaFreeHost (cell.h_iovs ), " cuda free iovs_h_mem[%p] failed" , cell.h_iovs );
87+ }
88+ if (cell.d_iovs ) {
89+ CHECK_CUDA_ERROR (cudaFree (cell.d_iovs ), " cuda free d_iovs[%p] failed" , cell.d_iovs );
90+ }
91+ if (cell.d_iovs ) {
92+ CHECK_CUDA_ERROR (cudaFree (cell.d_crcs ), " cuda free d_crcs[%p] failed" , cell.d_crcs );
93+ }
94+ }
95+ }
96+
97+ bool SdkBufferCheckPool::Init (size_t max_check_iov_num) {
98+ size_t iovs_byte_size = max_check_iov_num * sizeof (IovDevice);
99+ size_t crcs_byte_size = max_check_iov_num * sizeof (uint32_t );
100+ for (auto &cell : cells_) {
101+ CHECK_CUDA_ERROR_RETURN (
102+ cudaMallocHost (&cell.h_iovs , iovs_byte_size), false , " cudaMallocHost [%zu] bytes failed" , iovs_byte_size);
103+ CHECK_CUDA_ERROR_RETURN (
104+ cudaMalloc (&cell.d_iovs , iovs_byte_size), false , " cudaMalloc [%zu] byte failed" , iovs_byte_size);
105+ CHECK_CUDA_ERROR_RETURN (
106+ cudaMalloc (&cell.d_crcs , crcs_byte_size), false , " cudaMalloc [%zu] byte failed" , crcs_byte_size);
107+ CHECK_CUDA_ERROR_RETURN (
108+ cudaStreamCreateWithFlags (&cell.cuda_stream , cudaStreamNonBlocking), false , " cuda stream create failed" );
109+ cell_queue_.push (&cell);
110+ }
111+ return true ;
112+ }
113+
114+ SdkBufferCheckPool::CellHandle::~CellHandle () {
115+ if (pool_) {
116+ pool_->PutCell (cell_);
117+ }
118+ }
119+
120+ SdkBufferCheckPool::CellHandle SdkBufferCheckPool::GetCell () {
121+ std::unique_lock lock (mutex_);
122+ cv_.wait (lock, [this ] { return !cell_queue_.empty (); });
123+ Cell *cell = cell_queue_.front ();
124+ cell_queue_.pop ();
125+ return CellHandle (this , cell);
126+ }
127+
128+ void SdkBufferCheckPool::PutCell (Cell *cell) {
129+ std::unique_lock lock (mutex_);
130+ cell_queue_.push (cell);
131+ cv_.notify_one ();
132+ }
133+
134+ } // namespace kv_cache_manager
0 commit comments