Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kv_connectors/llmd_fs_backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ make image-fs-backend-push IMAGE_TAG_BASE=<your-base-container-registry> FS_BACK
- `STORAGE_CONNECTOR_DEBUG`: legacy flag — setting to `1` enables debug-level logging (equivalent to `STORAGE_LOG_LEVEL=debug`)
- `USE_KERNEL_COPY_WRITE` : enable GPU-kernel-based writes using GPU SMs (default 0 - uses DMA copy).
- `USE_KERNEL_COPY_READ`: enable GPU-kernel-based reads using GPU SMs (default 0 - uses DMA copy).
- `USE_BATCH_MEMCPY_WRITE`: submit all per-(block, layer) copies in one `cudaMemcpyBatchAsync` call on writes (default 1, requires CUDA 12.8+; set to 0 to fall back to the per-call DMA loop).
- `USE_BATCH_MEMCPY_READ`: same as above for reads (default 1).

## Example vLLM YAML

Expand Down
119 changes: 117 additions & 2 deletions kv_connectors/llmd_fs_backend/csrc/storage/tensor_copier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,24 @@ TensorCopier::TensorCopier(std::vector<torch::Tensor>& tensors,
// Env flags
m_use_kernel_copy_read = get_env_flag("USE_KERNEL_COPY_READ", false);
m_use_kernel_copy_write = get_env_flag("USE_KERNEL_COPY_WRITE", false);
// Batched DMA is the default fast path on CUDA 12.8+; the per-call
// cudaMemcpyAsync loop remains as a fallback when these flags are
// explicitly set to 0 (older toolkits, debugging, A/B comparison).
// cudaMemcpyBatchAsync was introduced in CUDA 12.8 — default off below that.
#if CUDA_VERSION >= 12080
constexpr bool kBatchDefault = true;
#else
constexpr bool kBatchDefault = false;
#endif
m_use_batch_memcpy_read =
get_env_flag("USE_BATCH_MEMCPY_READ", kBatchDefault);
m_use_batch_memcpy_write =
get_env_flag("USE_BATCH_MEMCPY_WRITE", kBatchDefault);
FS_LOG_INFO("TensorCopier: use_kernel_copy_read="
<< m_use_kernel_copy_read
<< ", use_kernel_copy_write=" << m_use_kernel_copy_write
<< ", use_batch_memcpy_read=" << m_use_batch_memcpy_read
<< ", use_batch_memcpy_write=" << m_use_batch_memcpy_write
<< ", m_gpu_blocks_per_file=" << m_gpu_blocks_per_file);
}

Expand Down Expand Up @@ -96,12 +111,112 @@ void TensorCopier::copy_blocks_via_cuda_memcpy(
}
}

// Main transfer function - dispatches to kernel or memcpy path
// Batched DMA path: one cudaMemcpyBatchAsync covers all per-(block, layer)
// copies for the blocks in this file (num_blocks * num_tensors).
// The batch executes in stream order; ordering within the batch is unspecified.
void TensorCopier::copy_blocks_via_batch_memcpy(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's relevant, but HIP (AMD) also supports batch memcpy out of the box, and might be worth adding as well.

uint8_t* cpu_base,
const std::vector<int64_t>& block_ids_list,
bool is_store) {
const size_t num_tensors = m_gpu_tensors.size();
const size_t num_blocks = block_ids_list.size();
const size_t total = num_tensors * num_blocks;
if (total == 0) return;

// Thread-local scratch arrays — avoid alloc on every transfer. Capacities
// grow monotonically with the largest call we've seen.
thread_local std::vector<void*> dsts;
thread_local std::vector<void*> srcs;
thread_local std::vector<size_t> sizes;
dsts.resize(total);
srcs.resize(total);
sizes.resize(total);

// Compute CPU block offset, Each block in CPU memory stores all layers
// sequentially: [layer0_data, layer1_data, ..., layerN_data]
uint8_t* cpu_blk_ptr = cpu_base + (m_gpu_blocks_per_file - num_blocks) *
num_tensors * m_tensor_block_size;

// Build one (dst, src, size) descriptor per (block, layer) copy.
size_t idx = 0;
for (size_t bi = 0; bi < num_blocks; ++bi) {
int64_t gpu_block_idx = block_ids_list[bi];
for (const auto& tensor : m_gpu_tensors) {
uint8_t* gpu_blk_ptr = reinterpret_cast<uint8_t*>(tensor.data_ptr()) +
gpu_block_idx * m_tensor_block_size;
if (is_store) {
srcs[idx] = gpu_blk_ptr;
dsts[idx] = cpu_blk_ptr;
} else {
srcs[idx] = cpu_blk_ptr;
dsts[idx] = gpu_blk_ptr;
}
sizes[idx] = m_tensor_block_size;
cpu_blk_ptr += m_tensor_block_size;
++idx;
}
}

#if CUDA_VERSION >= 12080
// Set attributes with srcAccessOrder=ANY (cudaMemcpySrcAccessOrderAny)
// for malloc'd host staging buffer. Same as vLLM's cuda_mem_ops.py.
// static (not thread_local): never mutated, no per-thread duplication needed.
// Not const: CUDA's C API takes non-const pointers.
static cudaMemcpyAttributes attrs = [] {
cudaMemcpyAttributes a{};
a.srcAccessOrder = cudaMemcpySrcAccessOrderAny;
return a;
}();
static size_t attrs_idx = 0;

// Get current CUDA stream
const auto stream = at::cuda::getCurrentCUDAStream();

// CUDA 13 dropped the failIdx out-param; CUDA 12.8/12.9 still requires it.
#if CUDA_VERSION >= 13000
cudaError_t err = cudaMemcpyBatchAsync(dsts.data(),
srcs.data(),
sizes.data(),
total,
&attrs,
&attrs_idx,
/*numAttrs=*/1,
stream.stream());
#else
static thread_local size_t fail_idx;
cudaError_t err = cudaMemcpyBatchAsync(dsts.data(),
srcs.data(),
sizes.data(),
total,
&attrs,
&attrs_idx,
/*numAttrs=*/1,
&fail_idx,
stream.stream());
#endif
TORCH_CHECK(err == cudaSuccess,
Comment thread
kfirtoledo marked this conversation as resolved.
"cudaMemcpyBatchAsync failed err=",
cudaGetErrorString(err));
#else
// CUDA < 12.8: cudaMemcpyBatchAsync is not available — fall back.
copy_blocks_via_cuda_memcpy(cpu_base, block_ids_list, is_store);
#endif
}

// Dispatches to one of three paths (priority: batch > kernel > memcpy):
// - batch memcpy: one cudaMemcpyBatchAsync (CUDA 12.8+) for all
// per-(block, layer) copies in this file.
// - kernel copy: custom CUDA kernel doing the copies.
// - memcpy loop: one cudaMemcpyAsync per (block, layer) (fallback).
void TensorCopier::copy_blocks(uint8_t* cpu_base,
const std::vector<int64_t>& block_ids_list,
bool is_store) {
bool use_batch =
is_store ? m_use_batch_memcpy_write : m_use_batch_memcpy_read;
bool use_kernel = is_store ? m_use_kernel_copy_write : m_use_kernel_copy_read;
if (use_kernel) {
if (use_batch) {
copy_blocks_via_batch_memcpy(cpu_base, block_ids_list, is_store);
} else if (use_kernel) {
copy_blocks_via_kernels(cpu_base, block_ids_list, is_store);
} else {
copy_blocks_via_cuda_memcpy(cpu_base, block_ids_list, is_store);
Expand Down
10 changes: 10 additions & 0 deletions kv_connectors/llmd_fs_backend/csrc/storage/tensor_copier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class TensorCopier {
bool m_use_kernel_copy_write;
// Use kernel-based copy for get operations
bool m_use_kernel_copy_read;
// Use cudaMemcpyBatchAsync (CUDA 12.8+) for put operations
bool m_use_batch_memcpy_write;
// Use cudaMemcpyBatchAsync (CUDA 12.8+) for get operations
bool m_use_batch_memcpy_read;

// Performs block transfers using cudaMemcpyAsync (DMA-based copy)
void copy_blocks_via_cuda_memcpy(uint8_t* cpu_base,
Expand All @@ -57,4 +61,10 @@ class TensorCopier {
void copy_blocks_via_kernels(uint8_t* cpu_base,
const std::vector<int64_t>& block_ids_list,
bool is_store);

// Single cudaMemcpyBatchAsync call (CUDA 12.8+) submitting all
// (block, layer) copies — removes per-call dispatch overhead.
void copy_blocks_via_batch_memcpy(uint8_t* cpu_base,
const std::vector<int64_t>& block_ids_list,
bool is_store);
};
Loading