diff --git a/cpp/include/tensorrt_llm/batch_manager/cudaVmmArena.h b/cpp/include/tensorrt_llm/batch_manager/cudaVmmArena.h new file mode 100644 index 00000000000..a0d8ae6d43d --- /dev/null +++ b/cpp/include/tensorrt_llm/batch_manager/cudaVmmArena.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRTLLM_CUDAVMMARENA_H +#define TRTLLM_CUDAVMMARENA_H + +#include +#include +#include +#include +#include + +namespace tensorrt_llm::batch_manager::vmm { + +/// Exception thrown for CUDA driver API errors. +class CudaVmmError : public std::runtime_error { +public: + explicit CudaVmmError(const std::string& msg, CUresult result = CUDA_SUCCESS) + : std::runtime_error(msg), result_(result) {} + + CUresult result() const noexcept { return result_; } + +private: + CUresult result_; +}; + +/// Manages a contiguous virtual address range backed by physical CUDA memory pages +/// that can be added and removed at runtime using the CUDA Virtual Memory Management API. +/// +/// The arena reserves a fixed VA window of `max_size` bytes upfront, then commits +/// (maps) physical pages into it on demand in multiples of the device's allocation +/// granularity. All committed memory is accessible on the owning device with +/// read/write permissions. +/// +/// Typical usage: +/// CudaVmmArena arena(1ULL << 30, 0); // Reserve 1 GiB VA on device 0 +/// arena.grow(64 << 20); // Commit first 64 MiB +/// void* p = reinterpret_cast(arena.ptr()); +/// ... +/// arena.shrink(32 << 20); // Release upper 32 MiB back to OS +/// +/// Thread safety: not thread-safe; external synchronization is required. +class CudaVmmArena { +public: + /// Reserve `max_size` bytes of virtual address space on `device`. + /// `max_size` is rounded up to the device's allocation granularity. + /// No physical memory is allocated until grow() is called. + explicit CudaVmmArena(size_t max_size, int device = 0); + + ~CudaVmmArena(); + + // Non-copyable, non-movable: owns CUDA virtual/physical resources. + CudaVmmArena(const CudaVmmArena&) = delete; + CudaVmmArena& operator=(const CudaVmmArena&) = delete; + CudaVmmArena(CudaVmmArena&&) = delete; + CudaVmmArena& operator=(CudaVmmArena&&) = delete; + + // ----------------------------------------------------------------------- + // Resize operations + // ----------------------------------------------------------------------- + + /// Increase committed size to `new_size` by mapping additional physical pages. + /// `new_size` is rounded up to granularity. + /// Throws if new_size <= committed_size() or new_size > max_size(). + void grow(size_t new_size); + + /// Decrease committed size to `new_size` by unmapping and releasing tail pages. + /// `new_size` is rounded down to the nearest granularity boundary. + /// Throws if new_size >= committed_size(). + void shrink(size_t new_size); + + /// Convenience: call grow() or shrink() depending on `new_size`. + /// A no-op if new_size (after alignment) equals committed_size(). + void resize(size_t new_size); + + // ----------------------------------------------------------------------- + // Accessors + // ----------------------------------------------------------------------- + + /// Base device pointer of the reserved VA range. + /// Only bytes in [ptr(), ptr() + committed_size()) are valid to access. + CUdeviceptr ptr() const noexcept { return base_ptr_; } + + /// Number of bytes currently mapped to physical memory. + size_t committed_size() const noexcept { return committed_size_; } + + /// Total reserved virtual address range (>= max_size passed to constructor). + size_t max_size() const noexcept { return max_size_; } + + /// Physical allocation granularity in bytes for this device. + size_t granularity() const noexcept { return granularity_; } + + /// CUDA device index this arena was created for. + int device() const noexcept { return device_; } + +private: + // Allocate one granularity-sized physical handle, map it at `offset` into + // the reserved VA range, and grant read/write access. + void map_chunk(size_t offset); + + // Revoke access, unmap, and release the physical handle at slot `chunk_idx`. + void unmap_chunk(size_t chunk_idx); + + // Throw CudaVmmError if `res` is not CUDA_SUCCESS. + static void check(CUresult res, const char* where); + + // Round `n` up to the next multiple of `align` (which must be a power of 2). + static size_t align_up(size_t n, size_t align) noexcept { + return (n + align - 1) & ~(align - 1); + } + + // Round `n` down to the previous multiple of `align`. + static size_t align_down(size_t n, size_t align) noexcept { + return n & ~(align - 1); + } + + int device_; + size_t granularity_; ///< Minimum physical page granularity, bytes. + size_t max_size_; ///< Reserved VA range size (aligned up). + size_t committed_size_;///< Currently mapped byte count. + CUdeviceptr base_ptr_; ///< Start of the reserved VA range. + + /// One handle per committed granularity chunk, in order. + std::vector handles_; + + CUmemAllocationProp alloc_prop_; ///< Shared allocation properties. + CUmemAccessDesc access_desc_;///< Shared access descriptor. +}; + +} // namespace tensorrt_llm::batch_manager::vmm + +#endif // TRTLLM_CUDAVMMARENA_H diff --git a/cpp/tensorrt_llm/batch_manager/cudaVmmArena.cpp b/cpp/tensorrt_llm/batch_manager/cudaVmmArena.cpp new file mode 100644 index 00000000000..bc1b9e78f80 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/cudaVmmArena.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/batch_manager/cudaVmmArena.h" + +#include +#include + +namespace tensorrt_llm::batch_manager::vmm { + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +void CudaVmmArena::check(CUresult res, const char* where) { + if (res == CUDA_SUCCESS) return; + + const char* name = nullptr; + const char* desc = nullptr; + cuGetErrorName(res, &name); + cuGetErrorString(res, &desc); + + std::ostringstream oss; + oss << "CUDA VMM error in " << where << ": " + << (name ? name : "?") << " (" << res << ")" + << (desc ? std::string(" — ") + desc : std::string{}); + throw CudaVmmError(oss.str(), res); +} + +// --------------------------------------------------------------------------- +// Constructor / Destructor +// --------------------------------------------------------------------------- + +CudaVmmArena::CudaVmmArena(size_t max_size, int device) + : device_(device) + , granularity_(0) + , max_size_(0) + , committed_size_(0) + , base_ptr_(0) +{ + // Build allocation properties: pinned device memory on the selected GPU. + std::memset(&alloc_prop_, 0, sizeof(alloc_prop_)); + alloc_prop_.type = CU_MEM_ALLOCATION_TYPE_PINNED; + alloc_prop_.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + alloc_prop_.location.id = device_; + + // Query the minimum granularity required by this device/allocation type. + check(cuMemGetAllocationGranularity( + &granularity_, &alloc_prop_, + CU_MEM_ALLOC_GRANULARITY_MINIMUM), + "cuMemGetAllocationGranularity"); + + if (granularity_ == 0) + throw CudaVmmError("Device reported zero allocation granularity."); + + // Round requested max_size up to a granularity boundary. + max_size_ = align_up(max_size, granularity_); + if (max_size_ == 0) + throw CudaVmmError("max_size rounds to zero after granularity alignment."); + + // Reserve the virtual address range. No physical memory is allocated yet. + check(cuMemAddressReserve(&base_ptr_, max_size_, + /*alignment=*/0, /*hint=*/0, /*flags=*/0), + "cuMemAddressReserve"); + + // Pre-size the handle vector but leave all entries empty. + handles_.reserve(max_size_ / granularity_); + + // Build the access descriptor once; reused for every chunk. + std::memset(&access_desc_, 0, sizeof(access_desc_)); + access_desc_.location = alloc_prop_.location; + access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; +} + +CudaVmmArena::~CudaVmmArena() { + // Unmap and release all committed chunks in reverse order. + for (size_t i = handles_.size(); i-- > 0;) { + unmap_chunk(i); + } + handles_.clear(); + + // Release the virtual address reservation. + if (base_ptr_) { + cuMemAddressFree(base_ptr_, max_size_); + base_ptr_ = 0; + } +} + +// --------------------------------------------------------------------------- +// Private: map / unmap a single granularity-sized chunk +// --------------------------------------------------------------------------- + +void CudaVmmArena::map_chunk(size_t offset) { + CUmemGenericAllocationHandle handle{}; + + // Allocate one granularity-sized physical page. + check(cuMemCreate(&handle, granularity_, &alloc_prop_, /*flags=*/0), + "cuMemCreate"); + + // Map the physical page into our reserved VA range at `offset`. + CUresult res = cuMemMap(base_ptr_ + offset, granularity_, + /*offset into handle=*/0, handle, /*flags=*/0); + if (res != CUDA_SUCCESS) { + cuMemRelease(handle); // best-effort cleanup + check(res, "cuMemMap"); + } + + // Grant read/write access on the mapped range. + res = cuMemSetAccess(base_ptr_ + offset, granularity_, + &access_desc_, /*count=*/1); + if (res != CUDA_SUCCESS) { + cuMemUnmap(base_ptr_ + offset, granularity_); + cuMemRelease(handle); + check(res, "cuMemSetAccess"); + } + + handles_.push_back(handle); +} + +void CudaVmmArena::unmap_chunk(size_t chunk_idx) { + const size_t offset = chunk_idx * granularity_; + + // Revoke access before unmapping (required by the CUDA VMM spec). + CUmemAccessDesc no_access{}; + no_access.location = alloc_prop_.location; + no_access.flags = CU_MEM_ACCESS_FLAGS_PROT_NONE; + cuMemSetAccess(base_ptr_ + offset, granularity_, &no_access, 1); + + cuMemUnmap(base_ptr_ + offset, granularity_); + cuMemRelease(handles_[chunk_idx]); + handles_[chunk_idx] = CUmemGenericAllocationHandle{}; +} + +// --------------------------------------------------------------------------- +// Public: grow / shrink / resize +// --------------------------------------------------------------------------- + +void CudaVmmArena::grow(size_t new_size) { + const size_t aligned = align_up(new_size, granularity_); + + if (aligned == 0) + throw CudaVmmError("grow(): new_size rounds to zero."); + if (aligned > max_size_) + throw CudaVmmError("grow(): new_size exceeds the reserved VA range."); + if (aligned <= committed_size_) + throw CudaVmmError("grow(): new_size must be larger than current committed_size."); + + // Map chunks covering [committed_size_, aligned). + size_t offset = committed_size_; + while (offset < aligned) { + map_chunk(offset); // may throw; already-mapped chunks stay valid + offset += granularity_; + } + + committed_size_ = aligned; +} + +void CudaVmmArena::shrink(size_t new_size) { + // Round *down* so we never expose a partially-unmapped granule. + const size_t aligned = align_down(new_size, granularity_); + + if (aligned >= committed_size_) + throw CudaVmmError("shrink(): new_size must be smaller than current committed_size."); + + // Unmap chunks covering [aligned, committed_size_) in reverse order. + size_t offset = committed_size_; + while (offset > aligned) { + offset -= granularity_; + unmap_chunk(handles_.size() - 1); + handles_.pop_back(); + } + + committed_size_ = aligned; +} + +void CudaVmmArena::resize(size_t new_size) { + // Determine what the aligned target size would be without committing. + const size_t aligned_up = align_up(new_size, granularity_); + const size_t aligned_down = align_down(new_size, granularity_); + + if (aligned_up > committed_size_) { + grow(new_size); + } else if (aligned_down < committed_size_) { + shrink(new_size); + } + // else: already at the right size, nothing to do. +} + +} // namespace tensorrt_llm::batch_manager::vmm diff --git a/cpp/tensorrt_llm/batch_manager/example.cu b/cpp/tensorrt_llm/batch_manager/example.cu new file mode 100644 index 00000000000..53c0de17d34 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/example.cu @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// example.cu — demonstrates CudaVmmArena grow/shrink. +// +// Build: +// nvcc -std=c++17 -lcuda example.cu cuda_vmm_arena.cpp -o vmm_example +// +// Requirements: CUDA 10.2+, GPU with compute capability >= 7.0, +// driver support for cuMemAddressReserve. + +#include "tensorrt_llm/batch_manager/cudaVmmArena.h" + +#include + +using namespace tensorrt_llm::batch_manager::vmm; +#include + +// Simple kernel: write the index into each element. +__global__ void fill_kernel(int* data, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) data[i] = i; +} + +// Verify the fill on the host. +static bool verify(const int* host, int n) { + for (int i = 0; i < n; ++i) { + if (host[i] != i) { + std::printf("MISMATCH at index %d: got %d, expected %d\n", + i, host[i], i); + return false; + } + } + return true; +} + +int main() { + // Initialize the CUDA driver API. + if (cuInit(0) != CUDA_SUCCESS) { + std::printf("cuInit failed\n"); + return 1; + } + + CUdevice dev{}; + CUcontext ctx{}; + cuDeviceGet(&dev, 0); + cuCtxCreate(&ctx, nullptr, 0, dev); + + try { + constexpr size_t MB = 1ULL << 20; + constexpr size_t MAX = 512 * MB; // reserve 512 MiB VA + constexpr size_t STEP = 16 * MB; // commit in 16 MiB increments + + CudaVmmArena arena(MAX, /*device=*/0); + std::printf("Granularity : %zu bytes\n", arena.granularity()); + std::printf("Reserved VA : %zu bytes (%zu MiB)\n", + arena.max_size(), arena.max_size() / MB); + + // ---------------------------------------------------------------- + // Phase 1: grow to 32 MiB in two steps. + // ---------------------------------------------------------------- + arena.grow(STEP); + std::printf("\n[grow] committed = %zu MiB\n", + arena.committed_size() / MB); + + int n1 = static_cast(arena.committed_size() / sizeof(int)); + int* d_ptr = reinterpret_cast(arena.ptr()); + fill_kernel<<<(n1 + 255) / 256, 256>>>(d_ptr, n1); + cudaDeviceSynchronize(); + + arena.grow(2 * STEP); + std::printf("[grow] committed = %zu MiB\n", + arena.committed_size() / MB); + + int n2 = static_cast(arena.committed_size() / sizeof(int)); + fill_kernel<<<(n2 + 255) / 256, 256>>>(d_ptr, n2); + cudaDeviceSynchronize(); + + // Copy back and verify. + std::vector host(n2); + cudaMemcpy(host.data(), d_ptr, n2 * sizeof(int), cudaMemcpyDeviceToHost); + std::printf("Verify 32 MiB fill: %s\n", verify(host.data(), n2) ? "OK" : "FAIL"); + + // ---------------------------------------------------------------- + // Phase 2: shrink back to 16 MiB — verify original data still intact. + // ---------------------------------------------------------------- + arena.shrink(STEP); + std::printf("\n[shrink] committed = %zu MiB\n", + arena.committed_size() / MB); + + int n3 = static_cast(arena.committed_size() / sizeof(int)); + std::vector host2(n3); + cudaMemcpy(host2.data(), d_ptr, n3 * sizeof(int), cudaMemcpyDeviceToHost); + std::printf("Verify 16 MiB region still intact: %s\n", + verify(host2.data(), n3) ? "OK" : "FAIL"); + + // ---------------------------------------------------------------- + // Phase 3: use resize() to go back up. + // ---------------------------------------------------------------- + arena.resize(4 * STEP); + std::printf("\n[resize] committed = %zu MiB\n", + arena.committed_size() / MB); + + std::printf("\nAll phases completed successfully.\n"); + } catch (const CudaVmmError& e) { + std::printf("CudaVmmError: %s (CUresult=%d)\n", e.what(), e.result()); + return 1; + } + + cuCtxDestroy(ctx); + return 0; +} diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index ff79e4eb3a5..2b25179c435 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -38,6 +38,7 @@ #include #include #include +#include #include namespace tc = tensorrt_llm::common; @@ -1856,9 +1857,17 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence) { auto const requestId = sequence.getRequestId(); auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); + // Shared context blocks appear beamWidth times in the flat vector (once per beam), + // but must be pinned exactly once: unpinBlocksById decrements once per unique block ID. + // Without deduplication, a shared block with beamWidth=B gets pinned B times but + // unpinned once, leaking the block permanently. + std::unordered_set seen; for (auto& block : allocatedBlocks) { - block->incRefCount(); + if (seen.insert(block->getBlockId()).second) + { + block->incRefCount(); + } } } @@ -2581,18 +2590,17 @@ bool KVCacheManager::addSequence( void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { auto const requestId = llmRequest.mRequestId; - bool found = false; - { - // protect the mSequences - std::scoped_lock lock(mSequencesMtx); - found = mSequences.find(requestId) != mSequences.end(); - } - if (found) + // Hold the lock for the entire find-and-use operation to eliminate the TOCTOU race + // between the existence check and the subsequent access: removeSequence() could extract + // the node from mSequences between a two-step check-then-get, leaving a dangling reference. + // mBlockManager.storeContextBlocks() does not acquire mSequencesMtx, so no deadlock. + std::scoped_lock lock(mSequencesMtx); + auto const it = mSequences.find(requestId); + if (it != mSequences.end()) { - auto& sequence = getSequence(requestId); if (mEnableBlockReuse && !llmRequest.isDummyRequest()) { - mBlockManager.storeContextBlocks(sequence, llmRequest); + mBlockManager.storeContextBlocks(it->second, llmRequest); } } else diff --git a/cpp/tests/unit_tests/batch_manager/CMakeLists.txt b/cpp/tests/unit_tests/batch_manager/CMakeLists.txt index e07add91887..df3a90fc832 100644 --- a/cpp/tests/unit_tests/batch_manager/CMakeLists.txt +++ b/cpp/tests/unit_tests/batch_manager/CMakeLists.txt @@ -27,3 +27,4 @@ add_gtest(microBatchSchedulerTest microBatchSchedulerTest.cpp) add_gtest(peftCacheManagerTest peftCacheManagerTest.cpp) add_gtest(staticThreadPoolTest staticThreadPoolTest.cpp) add_gtest(rnnCacheFormatterTest rnnCacheFormatterTest.cpp) +add_gtest(testCudaVmmArena testCudaVmmArena.cu) diff --git a/cpp/tests/unit_tests/batch_manager/testCudaVmmArena.cu b/cpp/tests/unit_tests/batch_manager/testCudaVmmArena.cu new file mode 100644 index 00000000000..85eb3af7a32 --- /dev/null +++ b/cpp/tests/unit_tests/batch_manager/testCudaVmmArena.cu @@ -0,0 +1,494 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// test_cuda_vmm_arena.cu — Google Test suite for CudaVmmArena. +// +// Build (adjust paths to your GTest installation): +// nvcc -std=c++17 -lcuda \ +// -I/path/to/googletest/include \ +// -L/path/to/googletest/lib -lgtest -lgtest_main \ +// test_cuda_vmm_arena.cu ../cuda_vmm_arena.cpp \ +// -o test_vmm_arena +// +// Or with CMake, see the accompanying CMakeLists.txt. + +#include "tensorrt_llm/batch_manager/cudaVmmArena.h" + +#include + +using namespace tensorrt_llm::batch_manager::vmm; +#include // cudaMemcpy, cudaDeviceSynchronize +#include +#include +#include + +// --------------------------------------------------------------------------- +// Test fixture: initialises CUDA driver once per process, creates a CUDA +// context, and provides a fresh arena for each test case. +// --------------------------------------------------------------------------- + +/// Per-process CUDA driver + context setup. +class CudaEnv : public ::testing::Environment { +public: + void SetUp() override { + ASSERT_EQ(cuInit(0), CUDA_SUCCESS) << "cuInit failed"; + ASSERT_EQ(cuDeviceGet(&dev_, 0), CUDA_SUCCESS); + ASSERT_EQ(cuCtxCreate(&ctx_, nullptr, 0, dev_), CUDA_SUCCESS); + } + void TearDown() override { + if (ctx_) cuCtxDestroy(ctx_); + } + static CUdevice dev_; + static CUcontext ctx_; +}; +CUdevice CudaEnv::dev_{}; +CUcontext CudaEnv::ctx_{}; + +/// Base fixture used by all test cases. +class ArenaTest : public ::testing::Test { +protected: + // Default arena parameters. Each test that needs different values + // creates its own arena inline. + static constexpr size_t kMaxSize = 256ULL << 20; // 256 MiB reserved VA + + // Skip the test early if the device doesn't support VMM. + // GTEST_SKIP() requires a void context — SetUp() qualifies. + void SetUp() override { + CUmemAllocationProp prop{}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = 0; + size_t g = 0; + CUresult res = cuMemGetAllocationGranularity( + &g, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); + if (res == CUDA_ERROR_NOT_SUPPORTED || res == CUDA_ERROR_NO_DEVICE) + GTEST_SKIP() << "VMM not supported on this system (CUresult=" << res << ")"; + if (res != CUDA_SUCCESS || g == 0) + GTEST_SKIP() << "Could not query VMM granularity (CUresult=" << res << ")"; + } + + // Helper: build a fresh arena. Throws CudaVmmError on hard failure. + static CudaVmmArena* make_arena(size_t max_size = kMaxSize, + int device = 0) + { + return new CudaVmmArena(max_size, device); + } + + // Granularity helper (we need it before the arena exists in some tests). + static size_t query_granularity(int device = 0) { + CUmemAllocationProp prop{}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + size_t g = 0; + cuMemGetAllocationGranularity(&g, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); + return g; + } +}; + +// =========================================================================== +// Construction +// =========================================================================== + +TEST_F(ArenaTest, ConstructionSetsInitialState) { + std::unique_ptr a(make_arena(kMaxSize)); + + EXPECT_GT(a->granularity(), 0u); + EXPECT_GE(a->max_size(), kMaxSize); // rounded up >= requested + EXPECT_EQ(a->committed_size(), 0u); // nothing committed yet + EXPECT_NE(a->ptr(), 0u); // VA range was reserved + EXPECT_EQ(a->device(), 0); +} + +TEST_F(ArenaTest, MaxSizeAlignedUpToGranularity) { + const size_t g = query_granularity(); + if (g == 0) GTEST_SKIP() << "Cannot query granularity"; + + // Request a size that is deliberately not aligned. + const size_t unaligned = g + 1; + std::unique_ptr a(make_arena(unaligned)); + + EXPECT_EQ(a->max_size() % g, 0u) << "max_size must be a multiple of granularity"; + EXPECT_GE(a->max_size(), unaligned); +} + +TEST_F(ArenaTest, SmallestPossibleReservation) { + // A reservation of exactly one granule should succeed. + const size_t g = query_granularity(); + if (g == 0) GTEST_SKIP(); + + std::unique_ptr a(make_arena(g)); + EXPECT_EQ(a->max_size(), g); + EXPECT_EQ(a->committed_size(), 0u); +} + +// =========================================================================== +// grow() +// =========================================================================== + +TEST_F(ArenaTest, GrowCommitsSingleChunk) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + a->grow(g); + + EXPECT_EQ(a->committed_size(), g); +} + +TEST_F(ArenaTest, GrowCommitsMultipleChunks) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + a->grow(4 * g); + + EXPECT_EQ(a->committed_size(), 4 * g); +} + +TEST_F(ArenaTest, GrowRoundsUpToGranularity) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + // Request 1 byte — should round up to one granule. + a->grow(1); + + EXPECT_EQ(a->committed_size(), g); + EXPECT_EQ(a->committed_size() % g, 0u); +} + +TEST_F(ArenaTest, GrowIncrementallyAccumulatesCommitted) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + a->grow(g); + EXPECT_EQ(a->committed_size(), g); + + a->grow(2 * g); + EXPECT_EQ(a->committed_size(), 2 * g); + + a->grow(4 * g); + EXPECT_EQ(a->committed_size(), 4 * g); +} + +TEST_F(ArenaTest, GrowThrowsWhenNotLargerThanCommitted) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(2 * g); + + // Same size — must throw. + EXPECT_THROW(a->grow(2 * g), CudaVmmError); + // Smaller size — must throw. + EXPECT_THROW(a->grow(g), CudaVmmError); +} + +TEST_F(ArenaTest, GrowThrowsWhenExceedsMaxSize) { + const size_t g = query_granularity(); + if (g == 0) GTEST_SKIP(); + + // Reserve exactly one granule. + std::unique_ptr a(make_arena(g)); + + // Trying to grow beyond the reserved VA must throw. + EXPECT_THROW(a->grow(2 * g), CudaVmmError); +} + +// =========================================================================== +// shrink() +// =========================================================================== + +TEST_F(ArenaTest, ShrinkReleasesChunks) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(4 * g); + + a->shrink(2 * g); + + EXPECT_EQ(a->committed_size(), 2 * g); +} + +TEST_F(ArenaTest, ShrinkToZeroReleasesAll) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(4 * g); + + a->shrink(0); + + EXPECT_EQ(a->committed_size(), 0u); +} + +TEST_F(ArenaTest, ShrinkRoundsDownToGranularity) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(4 * g); + + // Request shrink to "2*g + 1" — rounds DOWN to 2*g. + a->shrink(2 * g + 1); + + EXPECT_EQ(a->committed_size(), 2 * g); +} + +TEST_F(ArenaTest, ShrinkThrowsWhenNotSmallerThanCommitted) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(2 * g); + + // Same size (after alignment) — must throw. + EXPECT_THROW(a->shrink(2 * g), CudaVmmError); + // Larger — must throw. + EXPECT_THROW(a->shrink(4 * g), CudaVmmError); +} + +TEST_F(ArenaTest, ShrinkThrowsOnUninitializedArena) { + std::unique_ptr a(make_arena()); + // committed_size == 0; any shrink target is >= committed_size. + EXPECT_THROW(a->shrink(0), CudaVmmError); +} + +// =========================================================================== +// resize() +// =========================================================================== + +TEST_F(ArenaTest, ResizeGrowsWhenLarger) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + a->resize(3 * g); + + EXPECT_EQ(a->committed_size(), 3 * g); +} + +TEST_F(ArenaTest, ResizeShrinksWhenSmaller) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(4 * g); + + a->resize(2 * g); + + EXPECT_EQ(a->committed_size(), 2 * g); +} + +TEST_F(ArenaTest, ResizeIsNoOpWhenAlreadyAtTarget) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(2 * g); + + // resize to the exact same aligned size — must not throw. + EXPECT_NO_THROW(a->resize(2 * g)); + EXPECT_EQ(a->committed_size(), 2 * g); +} + +// =========================================================================== +// grow → shrink → grow cycle +// =========================================================================== + +TEST_F(ArenaTest, GrowShrinkGrowCycle) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + a->grow(4 * g); + EXPECT_EQ(a->committed_size(), 4 * g); + + a->shrink(2 * g); + EXPECT_EQ(a->committed_size(), 2 * g); + + a->grow(6 * g); + EXPECT_EQ(a->committed_size(), 6 * g); + + a->shrink(0); + EXPECT_EQ(a->committed_size(), 0u); +} + +// =========================================================================== +// ptr() stability +// =========================================================================== + +TEST_F(ArenaTest, BasePointerDoesNotChangeAfterGrow) { + std::unique_ptr a(make_arena()); + const CUdeviceptr base = a->ptr(); + + a->grow(a->granularity()); + EXPECT_EQ(a->ptr(), base) << "grow() must not change the base pointer"; + + a->grow(2 * a->granularity()); + EXPECT_EQ(a->ptr(), base); +} + +TEST_F(ArenaTest, BasePointerDoesNotChangeAfterShrink) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(4 * g); + const CUdeviceptr base = a->ptr(); + + a->shrink(2 * g); + EXPECT_EQ(a->ptr(), base) << "shrink() must not change the base pointer"; +} + +// =========================================================================== +// Memory accessibility — write via kernel, read back to host +// =========================================================================== + +/// Device kernel: write sequential uint32_t values starting at `offset_elems`. +__global__ void write_seq_kernel(uint32_t* data, uint32_t n, uint32_t offset_elems) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) data[i] = offset_elems + i; +} + +/// Device kernel: verify sequential uint32_t values, store 1 on mismatch. +__global__ void verify_seq_kernel(const uint32_t* data, uint32_t n, + uint32_t offset_elems, int* mismatch_flag) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n && data[i] != offset_elems + i) + atomicExch(mismatch_flag, 1); +} + +static void launch_write(CUdeviceptr base, size_t bytes, uint32_t offset_elems = 0) { + auto* p = reinterpret_cast(base); + uint32_t n = static_cast(bytes / sizeof(uint32_t)); + write_seq_kernel<<<(n + 255) / 256, 256>>>(p, n, offset_elems); + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); +} + +static bool device_verify(CUdeviceptr base, size_t bytes, uint32_t offset_elems = 0) { + auto* p = reinterpret_cast(base); + uint32_t n = static_cast(bytes / sizeof(uint32_t)); + + int* d_flag{}; + cudaMalloc(&d_flag, sizeof(int)); + cudaMemset(d_flag, 0, sizeof(int)); + + verify_seq_kernel<<<(n + 255) / 256, 256>>>(p, n, offset_elems, d_flag); + cudaDeviceSynchronize(); + + int flag = 0; + cudaMemcpy(&flag, d_flag, sizeof(int), cudaMemcpyDeviceToHost); + cudaFree(d_flag); + return flag == 0; +} + +TEST_F(ArenaTest, CommittedMemoryIsWriteable) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(g); + + ASSERT_NO_THROW(launch_write(a->ptr(), g)); + EXPECT_TRUE(device_verify(a->ptr(), g)); +} + +TEST_F(ArenaTest, DataInRetainedChunksSurvivesShrink) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(4 * g); + + // Fill the entire 4-chunk region. + launch_write(a->ptr(), 4 * g); + + // Shrink to 2 chunks — the lower half must be intact. + a->shrink(2 * g); + + EXPECT_TRUE(device_verify(a->ptr(), 2 * g)) + << "Data in retained chunks should survive shrink()"; +} + +TEST_F(ArenaTest, NewChunksAfterRegrowAreWriteable) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + a->grow(2 * g); + launch_write(a->ptr(), 2 * g); + + a->shrink(g); + + // Regrow: the new chunk is fresh physical memory — just check it's writable. + a->grow(2 * g); + const size_t new_chunk_offset = g; + const uint32_t off_elems = static_cast(new_chunk_offset / sizeof(uint32_t)); + launch_write(a->ptr() + new_chunk_offset, g, off_elems); + EXPECT_TRUE(device_verify(a->ptr() + new_chunk_offset, g, off_elems)); +} + +TEST_F(ArenaTest, MultipleSequentialGrowsAllAccessible) { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + + // Grow in three steps. + a->grow(g); + a->grow(3 * g); + a->grow(6 * g); + + // Fill and verify the whole 6-chunk region in one pass. + launch_write(a->ptr(), 6 * g); + EXPECT_TRUE(device_verify(a->ptr(), 6 * g)); +} + +// =========================================================================== +// Destructor safety +// =========================================================================== + +TEST_F(ArenaTest, DestructorWithCommittedMemoryDoesNotLeak) { + // Simply constructing and immediately destroying a grown arena should not + // crash, assert, or leak CUDA resources. + { + std::unique_ptr a(make_arena()); + const size_t g = a->granularity(); + a->grow(4 * g); + // Destructor runs here. + } + SUCCEED(); // If we reach here, no crash occurred. +} + +TEST_F(ArenaTest, DestructorWithZeroCommittedMemoryDoesNotCrash) { + { + std::unique_ptr a(make_arena()); + // Never committed anything. + } + SUCCEED(); +} + +// =========================================================================== +// Multiple independent arenas +// =========================================================================== + +TEST_F(ArenaTest, TwoArenasAreIndependent) { + std::unique_ptr a(make_arena(64ULL << 20)); + std::unique_ptr b(make_arena(64ULL << 20)); + + const size_t g = a->granularity(); + a->grow(g); + b->grow(g); + + // Base pointers must differ. + EXPECT_NE(a->ptr(), b->ptr()); + + // Writes to one must not alias the other. + uint32_t n = static_cast(g / sizeof(uint32_t)); + write_seq_kernel<<<(n + 255) / 256, 256>>>( + reinterpret_cast(a->ptr()), n, /*offset=*/0); + write_seq_kernel<<<(n + 255) / 256, 256>>>( + reinterpret_cast(b->ptr()), n, /*offset=*/n); + cudaDeviceSynchronize(); + + EXPECT_TRUE(device_verify(a->ptr(), g, /*offset_elems=*/0)); + EXPECT_TRUE(device_verify(b->ptr(), g, /*offset_elems=*/n)); +} + +// =========================================================================== +// main +// =========================================================================== + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + ::testing::AddGlobalTestEnvironment(new CudaEnv()); + return RUN_ALL_TESTS(); +}