Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block
}

std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(
const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, cudaStream_t stream) {
const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, GpuStream_t stream) {
std::vector<IovDevice> iov_h(max_iov_num);
return GetBlocksHash(block_buffers, iovs_d, crcs_d, iov_h.data(), max_iov_num, stream);
}
Expand All @@ -37,7 +37,7 @@ std::vector<int64_t> SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block
uint32_t *crcs_d,
IovDevice *iovs_h_to_save,
size_t max_iov_num,
cudaStream_t stream) {
GpuStream_t stream) {
size_t iov_num = block_buffers.front().iovs.size();
size_t iovs_size = 0;
for (const auto &block_buffer : block_buffers) {
Expand Down Expand Up @@ -74,7 +74,7 @@ std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(const std::vector<IovDevice
std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(const std::vector<IovDevice> &iovs_h,
IovDevice *iovs_d,
uint32_t *crcs_d,
cudaStream_t stream) {
GpuStream_t stream) {
return GetIovsCrc(iovs_h.data(), iovs_h.size(), iovs_d, crcs_d, stream);
}

Expand Down Expand Up @@ -105,7 +105,7 @@ bool SdkBufferCheckPool::Init(size_t max_check_iov_num) {
CHECK_CUDA_ERROR_RETURN(
cudaMalloc(&cell.d_crcs, crcs_byte_size), false, "cudaMalloc [%zu] byte failed", crcs_byte_size);
CHECK_CUDA_ERROR_RETURN(
cudaStreamCreateWithFlags(&cell.cuda_stream, cudaStreamNonBlocking), false, "cuda stream create failed");
cudaStreamCreateWithFlags(&cell.gpu_stream, cudaStreamNonBlocking), false, "cuda stream create failed");
cell_queue_.push(&cell);
}
KVCM_LOG_INFO(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ constexpr uint32_t kDefaultThreadsPerBlock = 512;
} // namespace

std::vector<uint32_t> SdkBufferCheckUtil::GetIovsCrc(
const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, cudaStream_t stream) {
const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, GpuStream_t stream) {
size_t cal_byte_size = std::min(min_cal_byte_size_, iovs_h_ptr->size / 2);
if (cal_byte_size == 0) {
return {};
Expand Down
19 changes: 14 additions & 5 deletions kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@
#include <vector>

#include "kv_cache_manager/client/include/common.h"

#if defined(USING_CUDA)
#include "kv_cache_manager/client/src/internal/sdk/cuda_util.h"
#endif

namespace kv_cache_manager {

#if defined(USING_CUDA)
using GpuStream_t = cudaStream_t;
#else
using GpuStream_t = void *;
#endif

struct IovDevice {
const void *base;
size_t size;
Expand All @@ -22,19 +31,19 @@ class SdkBufferCheckUtil {
IovDevice *iovs_d,
uint32_t *crcs_d,
size_t max_iov_num,
cudaStream_t stream);
GpuStream_t stream);
static std::vector<int64_t> GetBlocksHash(const BlockBuffers &block_buffers,
IovDevice *iovs_d,
uint32_t *crcs_d,
IovDevice *iovs_h_to_save,
size_t max_iov_num,
cudaStream_t stream);
GpuStream_t stream);

static std::vector<uint32_t> GetIovsCrc(const std::vector<IovDevice> &iovs_h);
static std::vector<uint32_t>
GetIovsCrc(const std::vector<IovDevice> &iovs_h, IovDevice *iovs_d, uint32_t *crcs_d, cudaStream_t stream);
GetIovsCrc(const std::vector<IovDevice> &iovs_h, IovDevice *iovs_d, uint32_t *crcs_d, GpuStream_t stream);
static std::vector<uint32_t>
GetIovsCrc(const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, cudaStream_t stream);
GetIovsCrc(const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, GpuStream_t stream);

private:
static size_t min_cal_byte_size_;
Expand All @@ -51,7 +60,7 @@ class SdkBufferCheckPool {
IovDevice *h_iovs = nullptr;
IovDevice *d_iovs = nullptr;
uint32_t *d_crcs = nullptr;
cudaStream_t cuda_stream = nullptr;
GpuStream_t gpu_stream = nullptr;
};

class CellHandle {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ TEST_F(SdkBufferCheckUtilTest, TestSdkBufferCheckPool) {
push_buffer();
auto handle = pool.GetCell();
auto block_hashs = SdkBufferCheckUtil::GetBlocksHash(
block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num, handle->cuda_stream);
block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num, handle->gpu_stream);
ASSERT_EQ(2, block_hashs.size());
ASSERT_EQ(block_hashs[0], block_hashs[1]);
#endif
Expand Down Expand Up @@ -352,7 +352,7 @@ TEST_F(SdkBufferCheckUtilTest, TestSdkBufferCheckPoolMultiThread) {
handle->d_crcs,
handle->h_iovs,
max_check_iov_num,
handle->cuda_stream);
handle->gpu_stream);
ASSERT_EQ(1, block_hashs.size());
expect = block_hashs[0];
}
Expand All @@ -364,7 +364,7 @@ TEST_F(SdkBufferCheckUtilTest, TestSdkBufferCheckPoolMultiThread) {
handle->d_crcs,
handle->h_iovs,
max_check_iov_num,
handle->cuda_stream);
handle->gpu_stream);
ASSERT_EQ(std::vector<int64_t>({expect}), block_hashs);
};
for (int i = 0; i < 20; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions kv_cache_manager/client/src/transfer_client_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ ClientErrorCode TransferClientImpl::LoadKvCaches(const UriStrVec &uri_str_vec,
if (need_print) {
auto handle = sdk_buffer_check_pool_->GetCell();
block_hashs = SdkBufferCheckUtil::GetBlocksHash(
block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->cuda_stream);
block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->gpu_stream);
}
PrintBlockHashAndUri("get_", uri_str_vec, block_hashs, trace_info);
}
Expand All @@ -160,7 +160,7 @@ std::pair<ClientErrorCode, UriStrVec> TransferClientImpl::SaveKvCaches(const Uri
if (need_print) {
auto handle = sdk_buffer_check_pool_->GetCell();
block_hashs = SdkBufferCheckUtil::GetBlocksHash(
block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->cuda_stream);
block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->gpu_stream);
}
PrintBlockHashAndUri("put_", uri_str_vec, block_hashs, trace_info);
}
Expand Down
Loading