Skip to content

Commit 4abb4bc

Browse files
authored
[client] fix check buffer not thread safe, add SdkBufferCheckPool (#9)
1 parent 61f046e commit 4abb4bc

File tree

8 files changed

+303
-112
lines changed

8 files changed

+303
-112
lines changed

kv_cache_manager/client/include/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ enum [[nodiscard]] ClientErrorCode : int32_t{
2828
ER_TRANSFERCLIENT_INIT_ERROR = 13,
2929
ER_MANAGERCLIENT_INIT_ERROR = 14,
3030
ER_CLIENT_NOT_EXISTS = 15,
31+
ER_INIT_CHECK_BUFFER_ERROR = 16,
3132

3233
// service status code
3334
ER_SERVICE_NO_STATUS = 50,

kv_cache_manager/client/src/internal/sdk/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ cc_library(
138138
cuda_library(
139139
name = "sdk_buffer_check_util",
140140
srcs = [
141+
"sdk_buffer_check_util.cc",
141142
"sdk_buffer_check_util.cu",
142143
],
143144
hdrs = [
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#include "kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h"
2+
3+
#include <algorithm>
4+
#include <cassert>
5+
6+
#include "kv_cache_manager/common/env_util.h"
7+
#include "kv_cache_manager/common/hash_util.h"
8+
9+
namespace kv_cache_manager {
10+
11+
std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block_buffers) {
12+
std::vector<IovDevice> iov_h;
13+
size_t iov_num = block_buffers.front().iovs.size();
14+
iov_h.reserve(iov_num * block_buffers.size());
15+
for (const auto &block_buffer : block_buffers) {
16+
for (const auto &raw_iov : block_buffer.iovs) {
17+
iov_h.push_back({raw_iov.base, raw_iov.size});
18+
}
19+
}
20+
auto crcs = GetIovsCrc(iov_h);
21+
std::vector<int64_t> result;
22+
result.reserve(block_buffers.size());
23+
for (size_t offset = 0; offset < crcs.size(); offset += iov_num) {
24+
result.push_back(HashUtil::HashIntArray(&crcs[offset], &crcs[offset + iov_num], 0));
25+
}
26+
return result;
27+
}
28+
29+
std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(
30+
const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, cudaStream_t stream) {
31+
std::vector<IovDevice> iov_h(max_iov_num);
32+
return GetBlocksHash(block_buffers, iovs_d, crcs_d, iov_h.data(), max_iov_num, stream);
33+
}
34+
35+
std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block_buffers,
36+
IovDevice *iovs_d,
37+
uint32_t *crcs_d,
38+
IovDevice *iovs_h_to_save,
39+
size_t max_iov_num,
40+
cudaStream_t stream) {
41+
size_t iov_num = block_buffers.front().iovs.size();
42+
size_t iovs_size = 0;
43+
for (const auto &block_buffer : block_buffers) {
44+
assert(iov_num == block_buffer.iovs.size());
45+
if (iovs_size + block_buffer.iovs.size() > max_iov_num) {
46+
break;
47+
}
48+
for (const auto &raw_iov : block_buffer.iovs) {
49+
iovs_h_to_save[iovs_size].base = raw_iov.base;
50+
iovs_h_to_save[iovs_size].size = raw_iov.size;
51+
iovs_size++;
52+
}
53+
}
54+
auto crcs = GetIovsCrc(iovs_h_to_save, iovs_size, iovs_d, crcs_d, stream);
55+
std::vector<int64_t> result;
56+
result.reserve(iovs_size / iov_num);
57+
for (size_t offset = 0; offset < crcs.size(); offset += iov_num) {
58+
result.push_back(HashUtil::HashIntArray(&crcs[offset], &crcs[offset + iov_num], 0));
59+
}
60+
return result;
61+
}
62+
63+
std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(const std::vector<IovDevice> &iovs_h) {
64+
IovDevice *iovs_d = nullptr;
65+
uint32_t *crcs_d = nullptr;
66+
CHECK_CUDA_ERROR_RETURN(cudaMalloc(&iovs_d, sizeof(IovDevice) * iovs_h.size()), {}, "cudaMalloc fail");
67+
CHECK_CUDA_ERROR_RETURN(cudaMalloc(&crcs_d, sizeof(uint32_t) * iovs_h.size()), {}, "cudaMalloc fail");
68+
auto crcs = GetIovsCrc(iovs_h, iovs_d, crcs_d, nullptr);
69+
CHECK_CUDA_ERROR_RETURN(cudaFree(iovs_d), {}, "cudaMalloc fail");
70+
CHECK_CUDA_ERROR_RETURN(cudaFree(crcs_d), {}, "cudaMalloc fail");
71+
return crcs;
72+
}
73+
74+
std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(const std::vector<IovDevice> &iovs_h,
75+
IovDevice *iovs_d,
76+
uint32_t *crcs_d,
77+
cudaStream_t stream) {
78+
return GetIovsCrc(iovs_h.data(), iovs_h.size(), iovs_d, crcs_d, stream);
79+
}
80+
81+
SdkBufferCheckPool::SdkBufferCheckPool(size_t cell_num) { cells_.resize(cell_num); }
82+
83+
SdkBufferCheckPool::~SdkBufferCheckPool() {
84+
for (const auto &cell : cells_) {
85+
if (cell.h_iovs) {
86+
CHECK_CUDA_ERROR(cudaFreeHost(cell.h_iovs), "cuda free iovs_h_mem[%p] failed", cell.h_iovs);
87+
}
88+
if (cell.d_iovs) {
89+
CHECK_CUDA_ERROR(cudaFree(cell.d_iovs), "cuda free d_iovs[%p] failed", cell.d_iovs);
90+
}
91+
if (cell.d_iovs) {
92+
CHECK_CUDA_ERROR(cudaFree(cell.d_crcs), "cuda free d_crcs[%p] failed", cell.d_crcs);
93+
}
94+
}
95+
}
96+
97+
bool SdkBufferCheckPool::Init(size_t max_check_iov_num) {
98+
size_t iovs_byte_size = max_check_iov_num * sizeof(IovDevice);
99+
size_t crcs_byte_size = max_check_iov_num * sizeof(uint32_t);
100+
for (auto &cell : cells_) {
101+
CHECK_CUDA_ERROR_RETURN(
102+
cudaMallocHost(&cell.h_iovs, iovs_byte_size), false, "cudaMallocHost [%zu] bytes failed", iovs_byte_size);
103+
CHECK_CUDA_ERROR_RETURN(
104+
cudaMalloc(&cell.d_iovs, iovs_byte_size), false, "cudaMalloc [%zu] byte failed", iovs_byte_size);
105+
CHECK_CUDA_ERROR_RETURN(
106+
cudaMalloc(&cell.d_crcs, crcs_byte_size), false, "cudaMalloc [%zu] byte failed", crcs_byte_size);
107+
CHECK_CUDA_ERROR_RETURN(
108+
cudaStreamCreateWithFlags(&cell.cuda_stream, cudaStreamNonBlocking), false, "cuda stream create failed");
109+
cell_queue_.push(&cell);
110+
}
111+
return true;
112+
}
113+
114+
SdkBufferCheckPool::CellHandle::~CellHandle() {
115+
if (pool_) {
116+
pool_->PutCell(cell_);
117+
}
118+
}
119+
120+
SdkBufferCheckPool::CellHandle SdkBufferCheckPool::GetCell() {
121+
std::unique_lock lock(mutex_);
122+
cv_.wait(lock, [this] { return !cell_queue_.empty(); });
123+
Cell *cell = cell_queue_.front();
124+
cell_queue_.pop();
125+
return CellHandle(this, cell);
126+
}
127+
128+
void SdkBufferCheckPool::PutCell(Cell *cell) {
129+
std::unique_lock lock(mutex_);
130+
cell_queue_.push(cell);
131+
cv_.notify_one();
132+
}
133+
134+
} // namespace kv_cache_manager

kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cu

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
#include <algorithm>
2-
#include <cassert>
3-
41
#include "kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h"
52
#include "kv_cache_manager/common/env_util.h"
63
#include "kv_cache_manager/common/hash_util.h"
@@ -50,76 +47,6 @@ constexpr uint32_t kDefaultThreadsPerBlock = 512;
5047

5148
} // namespace
5249

53-
std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block_buffers) {
54-
std::vector<IovDevice> iov_h;
55-
size_t iov_num = block_buffers.front().iovs.size();
56-
iov_h.reserve(iov_num * block_buffers.size());
57-
for (const auto &block_buffer : block_buffers) {
58-
for (const auto &raw_iov : block_buffer.iovs) {
59-
iov_h.push_back({raw_iov.base, raw_iov.size});
60-
}
61-
}
62-
auto crcs = GetIovsCrc(iov_h);
63-
std::vector<int64_t> result;
64-
result.reserve(block_buffers.size());
65-
for (size_t offset = 0; offset < crcs.size(); offset += iov_num) {
66-
result.push_back(HashUtil::HashIntArray(&crcs[offset], &crcs[offset + iov_num], 0));
67-
}
68-
return result;
69-
}
70-
71-
std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(
72-
const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, cudaStream_t stream) {
73-
std::vector<IovDevice> iov_h(max_iov_num);
74-
return GetBlocksHash(block_buffers, iovs_d, crcs_d, iov_h.data(), max_iov_num, stream);
75-
}
76-
77-
std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block_buffers,
78-
IovDevice *iovs_d,
79-
uint32_t *crcs_d,
80-
IovDevice *iovs_h_to_save,
81-
size_t max_iov_num,
82-
cudaStream_t stream) {
83-
size_t iov_num = block_buffers.front().iovs.size();
84-
size_t iovs_size = 0;
85-
for (const auto &block_buffer : block_buffers) {
86-
assert(iov_num == block_buffer.iovs.size());
87-
if (iovs_size + block_buffer.iovs.size() > max_iov_num) {
88-
break;
89-
}
90-
for (const auto &raw_iov : block_buffer.iovs) {
91-
iovs_h_to_save[iovs_size].base = raw_iov.base;
92-
iovs_h_to_save[iovs_size].size = raw_iov.size;
93-
iovs_size++;
94-
}
95-
}
96-
auto crcs = GetIovsCrc(iovs_h_to_save, iovs_size, iovs_d, crcs_d, stream);
97-
std::vector<int64_t> result;
98-
result.reserve(iovs_size / iov_num);
99-
for (size_t offset = 0; offset < crcs.size(); offset += iov_num) {
100-
result.push_back(HashUtil::HashIntArray(&crcs[offset], &crcs[offset + iov_num], 0));
101-
}
102-
return result;
103-
}
104-
105-
std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(const std::vector<IovDevice> &iovs_h) {
106-
IovDevice *iovs_d = nullptr;
107-
uint32_t *crcs_d = nullptr;
108-
CHECK_CUDA_ERROR_RETURN(cudaMalloc(&iovs_d, sizeof(IovDevice) * iovs_h.size()), {}, "cudaMalloc fail");
109-
CHECK_CUDA_ERROR_RETURN(cudaMalloc(&crcs_d, sizeof(uint32_t) * iovs_h.size()), {}, "cudaMalloc fail");
110-
auto crcs = GetIovsCrc(iovs_h, iovs_d, crcs_d, nullptr);
111-
CHECK_CUDA_ERROR_RETURN(cudaFree(iovs_d), {}, "cudaMalloc fail");
112-
CHECK_CUDA_ERROR_RETURN(cudaFree(crcs_d), {}, "cudaMalloc fail");
113-
return crcs;
114-
}
115-
116-
std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(const std::vector<IovDevice> &iovs_h,
117-
IovDevice *iovs_d,
118-
uint32_t *crcs_d,
119-
cudaStream_t stream) {
120-
return GetIovsCrc(iovs_h.data(), iovs_h.size(), iovs_d, crcs_d, stream);
121-
}
122-
12350
std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(
12451
const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, cudaStream_t stream) {
12552
size_t cal_byte_size = std::min(min_cal_byte_size_, iovs_h_ptr->size / 2);

kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <condition_variable>
4+
#include <mutex>
5+
#include <queue>
36
#include <vector>
47

58
#include "kv_cache_manager/client/include/common.h"
@@ -37,4 +40,49 @@ class SdkBufferCheckUtil {
3740
static size_t min_cal_byte_size_;
3841
};
3942

43+
class SdkBufferCheckPool {
44+
static constexpr size_t kDefaultCellNum = 4;
45+
46+
public:
47+
explicit SdkBufferCheckPool(size_t cell_num = kDefaultCellNum);
48+
~SdkBufferCheckPool();
49+
50+
struct Cell {
51+
IovDevice *h_iovs = nullptr;
52+
IovDevice *d_iovs = nullptr;
53+
uint32_t *d_crcs = nullptr;
54+
cudaStream_t cuda_stream = nullptr;
55+
};
56+
57+
class CellHandle {
58+
public:
59+
CellHandle(SdkBufferCheckPool *pool, Cell *cell) : pool_(pool), cell_(cell) {}
60+
CellHandle(const CellHandle &) = delete;
61+
CellHandle(CellHandle &&other) : pool_(std::move(other.pool_)), cell_(std::move(other.cell_)) {
62+
other.pool_ = nullptr;
63+
other.cell_ = nullptr;
64+
}
65+
~CellHandle();
66+
Cell *operator->() { return cell_; }
67+
Cell &operator*() { return *cell_; }
68+
explicit operator bool() const { return cell_ != nullptr; }
69+
70+
private:
71+
SdkBufferCheckPool *pool_;
72+
Cell *cell_;
73+
};
74+
75+
bool Init(size_t max_check_iov_num);
76+
CellHandle GetCell();
77+
78+
private:
79+
friend class CellHandle;
80+
void PutCell(Cell *cell);
81+
82+
std::mutex mutex_;
83+
std::condition_variable cv_;
84+
std::queue<Cell *> cell_queue_;
85+
std::vector<Cell> cells_;
86+
};
87+
4088
}; // namespace kv_cache_manager

0 commit comments

Comments
 (0)