Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/cudaVmmArena.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TRTLLM_CUDAVMMARENA_H
#define TRTLLM_CUDAVMMARENA_H
Comment on lines +1 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Align the new header prologue with the repo conventions.

The new file does not use the standard TensorRT-LLM Apache/SPDX prologue, and the guard should follow TRTLLM_<FILENAME_IN_CAPS>_H—for this file, TRTLLM_CUDA_VMM_ARENA_H.

As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified" and "Use a preprocessor guard in C++ header files with the format TRTLLM_<FILENAME_IN_CAPS>_H."

Also applies to: 145-145

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/batch_manager/cudaVmmArena.h` around lines 1 - 18,
The header prologue and include guard in cudaVmmArena.h do not follow repo
conventions; replace the current file header with the standard TensorRT-LLM
Apache/SPDX prologue including the correct NVIDIA copyright line with the latest
modification year, and change the include guard macro from TRTLLM_CUDAVMMARENA_H
to the required format TRTLLM_CUDA_VMM_ARENA_H so the guard matches
TRTLLM_<FILENAME_IN_CAPS>_H and aligns with project guidelines.


#include <cuda.h>
#include <cstddef>
#include <vector>
#include <stdexcept>
#include <string>

namespace tensorrt_llm::batch_manager::vmm {

/// Exception thrown for CUDA driver API errors.
class CudaVmmError : public std::runtime_error {
public:
explicit CudaVmmError(const std::string& msg, CUresult result = CUDA_SUCCESS)
: std::runtime_error(msg), result_(result) {}

CUresult result() const noexcept { return result_; }

private:
CUresult result_;
};

/// Manages a contiguous virtual address range backed by physical CUDA memory pages
/// that can be added and removed at runtime using the CUDA Virtual Memory Management API.
///
/// The arena reserves a fixed VA window of `max_size` bytes upfront, then commits
/// (maps) physical pages into it on demand in multiples of the device's allocation
/// granularity. All committed memory is accessible on the owning device with
/// read/write permissions.
///
/// Typical usage:
/// CudaVmmArena arena(1ULL << 30, 0); // Reserve 1 GiB VA on device 0
/// arena.grow(64 << 20); // Commit first 64 MiB
/// void* p = reinterpret_cast<void*>(arena.ptr());
/// ...
/// arena.shrink(32 << 20); // Release upper 32 MiB back to OS
///
/// Thread safety: not thread-safe; external synchronization is required.
class CudaVmmArena {
public:
/// Reserve `max_size` bytes of virtual address space on `device`.
/// `max_size` is rounded up to the device's allocation granularity.
/// No physical memory is allocated until grow() is called.
explicit CudaVmmArena(size_t max_size, int device = 0);

~CudaVmmArena();

// Non-copyable, non-movable: owns CUDA virtual/physical resources.
CudaVmmArena(const CudaVmmArena&) = delete;
CudaVmmArena& operator=(const CudaVmmArena&) = delete;
CudaVmmArena(CudaVmmArena&&) = delete;
CudaVmmArena& operator=(CudaVmmArena&&) = delete;

// -----------------------------------------------------------------------
// Resize operations
// -----------------------------------------------------------------------

/// Increase committed size to `new_size` by mapping additional physical pages.
/// `new_size` is rounded up to granularity.
/// Throws if new_size <= committed_size() or new_size > max_size().
void grow(size_t new_size);

/// Decrease committed size to `new_size` by unmapping and releasing tail pages.
/// `new_size` is rounded down to the nearest granularity boundary.
/// Throws if new_size >= committed_size().
void shrink(size_t new_size);

/// Convenience: call grow() or shrink() depending on `new_size`.
/// A no-op if new_size (after alignment) equals committed_size().
void resize(size_t new_size);

// -----------------------------------------------------------------------
// Accessors
// -----------------------------------------------------------------------

/// Base device pointer of the reserved VA range.
/// Only bytes in [ptr(), ptr() + committed_size()) are valid to access.
CUdeviceptr ptr() const noexcept { return base_ptr_; }

/// Number of bytes currently mapped to physical memory.
size_t committed_size() const noexcept { return committed_size_; }

/// Total reserved virtual address range (>= max_size passed to constructor).
size_t max_size() const noexcept { return max_size_; }

/// Physical allocation granularity in bytes for this device.
size_t granularity() const noexcept { return granularity_; }

/// CUDA device index this arena was created for.
int device() const noexcept { return device_; }

private:
// Allocate one granularity-sized physical handle, map it at `offset` into
// the reserved VA range, and grant read/write access.
void map_chunk(size_t offset);

// Revoke access, unmap, and release the physical handle at slot `chunk_idx`.
void unmap_chunk(size_t chunk_idx);

// Throw CudaVmmError if `res` is not CUDA_SUCCESS.
static void check(CUresult res, const char* where);

// Round `n` up to the next multiple of `align` (which must be a power of 2).
static size_t align_up(size_t n, size_t align) noexcept {
return (n + align - 1) & ~(align - 1);
}

// Round `n` down to the previous multiple of `align`.
static size_t align_down(size_t n, size_t align) noexcept {
return n & ~(align - 1);
}

int device_;
size_t granularity_; ///< Minimum physical page granularity, bytes.
size_t max_size_; ///< Reserved VA range size (aligned up).
size_t committed_size_;///< Currently mapped byte count.
CUdeviceptr base_ptr_; ///< Start of the reserved VA range.

/// One handle per committed granularity chunk, in order.
std::vector<CUmemGenericAllocationHandle> handles_;

CUmemAllocationProp alloc_prop_; ///< Shared allocation properties.
CUmemAccessDesc access_desc_;///< Shared access descriptor.
};

} // namespace tensorrt_llm::batch_manager::vmm

#endif // TRTLLM_CUDAVMMARENA_H
202 changes: 202 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cudaVmmArena.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/batch_manager/cudaVmmArena.h"

#include <cstring>
#include <sstream>

namespace tensorrt_llm::batch_manager::vmm {

// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------

void CudaVmmArena::check(CUresult res, const char* where) {
if (res == CUDA_SUCCESS) return;

const char* name = nullptr;
const char* desc = nullptr;
cuGetErrorName(res, &name);
cuGetErrorString(res, &desc);

std::ostringstream oss;
oss << "CUDA VMM error in " << where << ": "
<< (name ? name : "?") << " (" << res << ")"
<< (desc ? std::string(" — ") + desc : std::string{});
throw CudaVmmError(oss.str(), res);
}

// ---------------------------------------------------------------------------
// Constructor / Destructor
// ---------------------------------------------------------------------------

CudaVmmArena::CudaVmmArena(size_t max_size, int device)
: device_(device)
, granularity_(0)
, max_size_(0)
, committed_size_(0)
, base_ptr_(0)
{
// Build allocation properties: pinned device memory on the selected GPU.
std::memset(&alloc_prop_, 0, sizeof(alloc_prop_));
alloc_prop_.type = CU_MEM_ALLOCATION_TYPE_PINNED;
alloc_prop_.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
alloc_prop_.location.id = device_;

// Query the minimum granularity required by this device/allocation type.
check(cuMemGetAllocationGranularity(
&granularity_, &alloc_prop_,
CU_MEM_ALLOC_GRANULARITY_MINIMUM),
"cuMemGetAllocationGranularity");

if (granularity_ == 0)
throw CudaVmmError("Device reported zero allocation granularity.");

// Round requested max_size up to a granularity boundary.
max_size_ = align_up(max_size, granularity_);
if (max_size_ == 0)
throw CudaVmmError("max_size rounds to zero after granularity alignment.");

// Reserve the virtual address range. No physical memory is allocated yet.
check(cuMemAddressReserve(&base_ptr_, max_size_,
/*alignment=*/0, /*hint=*/0, /*flags=*/0),
"cuMemAddressReserve");

// Pre-size the handle vector but leave all entries empty.
handles_.reserve(max_size_ / granularity_);

// Build the access descriptor once; reused for every chunk.
std::memset(&access_desc_, 0, sizeof(access_desc_));
access_desc_.location = alloc_prop_.location;
access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
}

CudaVmmArena::~CudaVmmArena() {
// Unmap and release all committed chunks in reverse order.
for (size_t i = handles_.size(); i-- > 0;) {
unmap_chunk(i);
}
handles_.clear();

// Release the virtual address reservation.
if (base_ptr_) {
cuMemAddressFree(base_ptr_, max_size_);
base_ptr_ = 0;
}
}

// ---------------------------------------------------------------------------
// Private: map / unmap a single granularity-sized chunk
// ---------------------------------------------------------------------------

void CudaVmmArena::map_chunk(size_t offset) {
CUmemGenericAllocationHandle handle{};

// Allocate one granularity-sized physical page.
check(cuMemCreate(&handle, granularity_, &alloc_prop_, /*flags=*/0),
"cuMemCreate");

// Map the physical page into our reserved VA range at `offset`.
CUresult res = cuMemMap(base_ptr_ + offset, granularity_,
/*offset into handle=*/0, handle, /*flags=*/0);
if (res != CUDA_SUCCESS) {
cuMemRelease(handle); // best-effort cleanup
check(res, "cuMemMap");
}

// Grant read/write access on the mapped range.
res = cuMemSetAccess(base_ptr_ + offset, granularity_,
&access_desc_, /*count=*/1);
if (res != CUDA_SUCCESS) {
cuMemUnmap(base_ptr_ + offset, granularity_);
cuMemRelease(handle);
check(res, "cuMemSetAccess");
}

handles_.push_back(handle);
}

void CudaVmmArena::unmap_chunk(size_t chunk_idx) {
const size_t offset = chunk_idx * granularity_;

// Revoke access before unmapping (required by the CUDA VMM spec).
CUmemAccessDesc no_access{};
no_access.location = alloc_prop_.location;
no_access.flags = CU_MEM_ACCESS_FLAGS_PROT_NONE;
cuMemSetAccess(base_ptr_ + offset, granularity_, &no_access, 1);

cuMemUnmap(base_ptr_ + offset, granularity_);
cuMemRelease(handles_[chunk_idx]);
handles_[chunk_idx] = CUmemGenericAllocationHandle{};
}
Comment on lines +133 to +145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't hide CUDA VMM failures on the shrink path.

unmap_chunk() ignores the results of cuMemSetAccess, cuMemUnmap, and cuMemRelease, yet shrink() pops the handle and decrements committed_size_ as if the tail chunk was fully released. A checked path is needed here; otherwise the arena can report success while leaking or leaving pages mapped.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/batch_manager/cudaVmmArena.cpp` around lines 133 - 145,
unmap_chunk currently ignores failures from cuMemSetAccess, cuMemUnmap, and
cuMemRelease, causing shrink() to pop handles_ and decrement committed_size_
even if unmapping/releasing fails; update unmap_chunk (and the shrink path that
calls it) to check return codes from cuMemSetAccess, cuMemUnmap, and
cuMemRelease, and on any failure: stop mutating state (do not clear
handles_[chunk_idx] or decrement committed_size_), propagate or return the error
to the caller so shrink can abort and preserve the tail chunk, and only clear
handles_[chunk_idx] and adjust committed_size_ after all CUDA calls succeed; use
the existing symbols (unmap_chunk, shrink, cuMemSetAccess, cuMemUnmap,
cuMemRelease, handles_, committed_size_, base_ptr_, granularity_) to locate and
implement the checks and error propagation.


// ---------------------------------------------------------------------------
// Public: grow / shrink / resize
// ---------------------------------------------------------------------------

void CudaVmmArena::grow(size_t new_size) {
const size_t aligned = align_up(new_size, granularity_);

if (aligned == 0)
throw CudaVmmError("grow(): new_size rounds to zero.");
if (aligned > max_size_)
throw CudaVmmError("grow(): new_size exceeds the reserved VA range.");
if (aligned <= committed_size_)
throw CudaVmmError("grow(): new_size must be larger than current committed_size.");

// Map chunks covering [committed_size_, aligned).
size_t offset = committed_size_;
while (offset < aligned) {
map_chunk(offset); // may throw; already-mapped chunks stay valid
offset += granularity_;
}

committed_size_ = aligned;
Comment on lines +151 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

grow() leaves the arena state inconsistent after a mid-grow failure.

If one map_chunk() succeeds and a later one throws, handles_ already contains extra mappings but committed_size_ is still the old value. The next grow() or shrink() call will then reason from stale state. Roll back the chunks mapped in this call, or advance committed_size_ in lockstep with each successful mapping.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/batch_manager/cudaVmmArena.cpp` around lines 151 - 168,
CudaVmmArena::grow currently leaves handles_ and committed_size_ out of sync if
map_chunk throws partway; modify grow to maintain invariants by either advancing
committed_size_ in lockstep after each successful map_chunk call (set
committed_size_ = offset + granularity_ or similar inside the loop) or wrap the
map loop in a try/catch that on exception iterates over the newly-added handles_
for this invocation to unmap/close them and restore committed_size_ to its
pre-grow value before rethrowing; refer to CudaVmmArena::grow, map_chunk,
committed_size_, and handles_ when making the fix.

}

void CudaVmmArena::shrink(size_t new_size) {
// Round *down* so we never expose a partially-unmapped granule.
const size_t aligned = align_down(new_size, granularity_);

if (aligned >= committed_size_)
throw CudaVmmError("shrink(): new_size must be smaller than current committed_size.");

// Unmap chunks covering [aligned, committed_size_) in reverse order.
size_t offset = committed_size_;
while (offset > aligned) {
offset -= granularity_;
unmap_chunk(handles_.size() - 1);
handles_.pop_back();
}

committed_size_ = aligned;
}

void CudaVmmArena::resize(size_t new_size) {
// Determine what the aligned target size would be without committing.
const size_t aligned_up = align_up(new_size, granularity_);
const size_t aligned_down = align_down(new_size, granularity_);

if (aligned_up > committed_size_) {
grow(new_size);
} else if (aligned_down < committed_size_) {
shrink(new_size);
}
// else: already at the right size, nothing to do.
}

} // namespace tensorrt_llm::batch_manager::vmm
Loading
Loading