diff --git a/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cc b/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cc index 7eb0124..757611c 100644 --- a/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cc +++ b/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cc @@ -27,7 +27,7 @@ std::vector SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block } std::vector SdkBufferCheckUtil::GetBlocksHash( - const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, cudaStream_t stream) { + const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, GpuStream_t stream) { std::vector iov_h(max_iov_num); return GetBlocksHash(block_buffers, iovs_d, crcs_d, iov_h.data(), max_iov_num, stream); } @@ -37,7 +37,7 @@ std::vector SdkBufferCheckUtil::GetBlocksHash(const BlockBuffers &block uint32_t *crcs_d, IovDevice *iovs_h_to_save, size_t max_iov_num, - cudaStream_t stream) { + GpuStream_t stream) { size_t iov_num = block_buffers.front().iovs.size(); size_t iovs_size = 0; for (const auto &block_buffer : block_buffers) { @@ -74,7 +74,7 @@ std::vector SdkBufferCheckUtil::GetIovsCrc(const std::vector SdkBufferCheckUtil::GetIovsCrc(const std::vector &iovs_h, IovDevice *iovs_d, uint32_t *crcs_d, - cudaStream_t stream) { + GpuStream_t stream) { return GetIovsCrc(iovs_h.data(), iovs_h.size(), iovs_d, crcs_d, stream); } @@ -105,7 +105,7 @@ bool SdkBufferCheckPool::Init(size_t max_check_iov_num) { CHECK_CUDA_ERROR_RETURN( cudaMalloc(&cell.d_crcs, crcs_byte_size), false, "cudaMalloc [%zu] byte failed", crcs_byte_size); CHECK_CUDA_ERROR_RETURN( - cudaStreamCreateWithFlags(&cell.cuda_stream, cudaStreamNonBlocking), false, "cuda stream create failed"); + cudaStreamCreateWithFlags(&cell.gpu_stream, cudaStreamNonBlocking), false, "cuda stream create failed"); cell_queue_.push(&cell); } KVCM_LOG_INFO( diff --git a/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cu b/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cu index c9ff772..8d39475 100644 --- a/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cu +++ b/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.cu @@ -48,7 +48,7 @@ constexpr uint32_t kDefaultThreadsPerBlock = 512; } // namespace std::vector SdkBufferCheckUtil::GetIovsCrc( - const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, cudaStream_t stream) { + const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, GpuStream_t stream) { size_t cal_byte_size = std::min(min_cal_byte_size_, iovs_h_ptr->size / 2); if (cal_byte_size == 0) { return {}; diff --git a/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h b/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h index 91222b9..472983c 100644 --- a/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h +++ b/kv_cache_manager/client/src/internal/sdk/sdk_buffer_check_util.h @@ -6,10 +6,19 @@ #include #include "kv_cache_manager/client/include/common.h" + +#if defined(USING_CUDA) #include "kv_cache_manager/client/src/internal/sdk/cuda_util.h" +#endif namespace kv_cache_manager { +#if defined(USING_CUDA) +using GpuStream_t = cudaStream_t; +#else +using GpuStream_t = void *; +#endif + struct IovDevice { const void *base; size_t size; @@ -22,19 +31,19 @@ class SdkBufferCheckUtil { IovDevice *iovs_d, uint32_t *crcs_d, size_t max_iov_num, - cudaStream_t stream); + GpuStream_t stream); static std::vector GetBlocksHash(const BlockBuffers &block_buffers, IovDevice *iovs_d, uint32_t *crcs_d, IovDevice *iovs_h_to_save, size_t max_iov_num, - cudaStream_t stream); + GpuStream_t stream); static std::vector GetIovsCrc(const std::vector &iovs_h); static std::vector - GetIovsCrc(const std::vector &iovs_h, IovDevice *iovs_d, uint32_t *crcs_d, cudaStream_t stream); + GetIovsCrc(const std::vector &iovs_h, IovDevice *iovs_d, uint32_t *crcs_d, GpuStream_t stream); static std::vector - GetIovsCrc(const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, cudaStream_t stream); + GetIovsCrc(const IovDevice *iovs_h_ptr, size_t iovs_size, IovDevice *iovs_d, uint32_t *crcs_d, GpuStream_t stream); private: static size_t min_cal_byte_size_; @@ -51,7 +60,7 @@ class SdkBufferCheckPool { IovDevice *h_iovs = nullptr; IovDevice *d_iovs = nullptr; uint32_t *d_crcs = nullptr; - cudaStream_t cuda_stream = nullptr; + GpuStream_t gpu_stream = nullptr; }; class CellHandle { diff --git a/kv_cache_manager/client/src/internal/sdk/test/sdk_buffer_check_util_test.cc b/kv_cache_manager/client/src/internal/sdk/test/sdk_buffer_check_util_test.cc index 20da1c6..056795e 100644 --- a/kv_cache_manager/client/src/internal/sdk/test/sdk_buffer_check_util_test.cc +++ b/kv_cache_manager/client/src/internal/sdk/test/sdk_buffer_check_util_test.cc @@ -303,7 +303,7 @@ TEST_F(SdkBufferCheckUtilTest, TestSdkBufferCheckPool) { push_buffer(); auto handle = pool.GetCell(); auto block_hashs = SdkBufferCheckUtil::GetBlocksHash( - block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num, handle->cuda_stream); + block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num, handle->gpu_stream); ASSERT_EQ(2, block_hashs.size()); ASSERT_EQ(block_hashs[0], block_hashs[1]); #endif @@ -352,7 +352,7 @@ TEST_F(SdkBufferCheckUtilTest, TestSdkBufferCheckPoolMultiThread) { handle->d_crcs, handle->h_iovs, max_check_iov_num, - handle->cuda_stream); + handle->gpu_stream); ASSERT_EQ(1, block_hashs.size()); expect = block_hashs[0]; } @@ -364,7 +364,7 @@ TEST_F(SdkBufferCheckUtilTest, TestSdkBufferCheckPoolMultiThread) { handle->d_crcs, handle->h_iovs, max_check_iov_num, - handle->cuda_stream); + handle->gpu_stream); ASSERT_EQ(std::vector({expect}), block_hashs); }; for (int i = 0; i < 20; ++i) { diff --git a/kv_cache_manager/client/src/transfer_client_impl.cc b/kv_cache_manager/client/src/transfer_client_impl.cc index 18c9ac4..bed5028 100644 --- a/kv_cache_manager/client/src/transfer_client_impl.cc +++ b/kv_cache_manager/client/src/transfer_client_impl.cc @@ -138,7 +138,7 @@ ClientErrorCode TransferClientImpl::LoadKvCaches(const UriStrVec &uri_str_vec, if (need_print) { auto handle = sdk_buffer_check_pool_->GetCell(); block_hashs = SdkBufferCheckUtil::GetBlocksHash( - block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->cuda_stream); + block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->gpu_stream); } PrintBlockHashAndUri("get_", uri_str_vec, block_hashs, trace_info); } @@ -160,7 +160,7 @@ std::pair TransferClientImpl::SaveKvCaches(const Uri if (need_print) { auto handle = sdk_buffer_check_pool_->GetCell(); block_hashs = SdkBufferCheckUtil::GetBlocksHash( - block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->cuda_stream); + block_buffers, handle->d_iovs, handle->d_crcs, handle->h_iovs, max_check_iov_num_, handle->gpu_stream); } PrintBlockHashAndUri("put_", uri_str_vec, block_hashs, trace_info); }