forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcudaVmmArena.cpp
More file actions
202 lines (164 loc) · 7.11 KB
/
cudaVmmArena.cpp
File metadata and controls
202 lines (164 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/cudaVmmArena.h"
#include <cstring>
#include <sstream>
namespace tensorrt_llm::batch_manager::vmm {
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
void CudaVmmArena::check(CUresult res, const char* where) {
if (res == CUDA_SUCCESS) return;
const char* name = nullptr;
const char* desc = nullptr;
cuGetErrorName(res, &name);
cuGetErrorString(res, &desc);
std::ostringstream oss;
oss << "CUDA VMM error in " << where << ": "
<< (name ? name : "?") << " (" << res << ")"
<< (desc ? std::string(" — ") + desc : std::string{});
throw CudaVmmError(oss.str(), res);
}
// ---------------------------------------------------------------------------
// Constructor / Destructor
// ---------------------------------------------------------------------------
CudaVmmArena::CudaVmmArena(size_t max_size, int device)
: device_(device)
, granularity_(0)
, max_size_(0)
, committed_size_(0)
, base_ptr_(0)
{
// Build allocation properties: pinned device memory on the selected GPU.
std::memset(&alloc_prop_, 0, sizeof(alloc_prop_));
alloc_prop_.type = CU_MEM_ALLOCATION_TYPE_PINNED;
alloc_prop_.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
alloc_prop_.location.id = device_;
// Query the minimum granularity required by this device/allocation type.
check(cuMemGetAllocationGranularity(
&granularity_, &alloc_prop_,
CU_MEM_ALLOC_GRANULARITY_MINIMUM),
"cuMemGetAllocationGranularity");
if (granularity_ == 0)
throw CudaVmmError("Device reported zero allocation granularity.");
// Round requested max_size up to a granularity boundary.
max_size_ = align_up(max_size, granularity_);
if (max_size_ == 0)
throw CudaVmmError("max_size rounds to zero after granularity alignment.");
// Reserve the virtual address range. No physical memory is allocated yet.
check(cuMemAddressReserve(&base_ptr_, max_size_,
/*alignment=*/0, /*hint=*/0, /*flags=*/0),
"cuMemAddressReserve");
// Pre-size the handle vector but leave all entries empty.
handles_.reserve(max_size_ / granularity_);
// Build the access descriptor once; reused for every chunk.
std::memset(&access_desc_, 0, sizeof(access_desc_));
access_desc_.location = alloc_prop_.location;
access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
}
CudaVmmArena::~CudaVmmArena() {
// Unmap and release all committed chunks in reverse order.
for (size_t i = handles_.size(); i-- > 0;) {
unmap_chunk(i);
}
handles_.clear();
// Release the virtual address reservation.
if (base_ptr_) {
cuMemAddressFree(base_ptr_, max_size_);
base_ptr_ = 0;
}
}
// ---------------------------------------------------------------------------
// Private: map / unmap a single granularity-sized chunk
// ---------------------------------------------------------------------------
void CudaVmmArena::map_chunk(size_t offset) {
CUmemGenericAllocationHandle handle{};
// Allocate one granularity-sized physical page.
check(cuMemCreate(&handle, granularity_, &alloc_prop_, /*flags=*/0),
"cuMemCreate");
// Map the physical page into our reserved VA range at `offset`.
CUresult res = cuMemMap(base_ptr_ + offset, granularity_,
/*offset into handle=*/0, handle, /*flags=*/0);
if (res != CUDA_SUCCESS) {
cuMemRelease(handle); // best-effort cleanup
check(res, "cuMemMap");
}
// Grant read/write access on the mapped range.
res = cuMemSetAccess(base_ptr_ + offset, granularity_,
&access_desc_, /*count=*/1);
if (res != CUDA_SUCCESS) {
cuMemUnmap(base_ptr_ + offset, granularity_);
cuMemRelease(handle);
check(res, "cuMemSetAccess");
}
handles_.push_back(handle);
}
void CudaVmmArena::unmap_chunk(size_t chunk_idx) {
const size_t offset = chunk_idx * granularity_;
// Revoke access before unmapping (required by the CUDA VMM spec).
CUmemAccessDesc no_access{};
no_access.location = alloc_prop_.location;
no_access.flags = CU_MEM_ACCESS_FLAGS_PROT_NONE;
cuMemSetAccess(base_ptr_ + offset, granularity_, &no_access, 1);
cuMemUnmap(base_ptr_ + offset, granularity_);
cuMemRelease(handles_[chunk_idx]);
handles_[chunk_idx] = CUmemGenericAllocationHandle{};
}
// ---------------------------------------------------------------------------
// Public: grow / shrink / resize
// ---------------------------------------------------------------------------
void CudaVmmArena::grow(size_t new_size) {
const size_t aligned = align_up(new_size, granularity_);
if (aligned == 0)
throw CudaVmmError("grow(): new_size rounds to zero.");
if (aligned > max_size_)
throw CudaVmmError("grow(): new_size exceeds the reserved VA range.");
if (aligned <= committed_size_)
throw CudaVmmError("grow(): new_size must be larger than current committed_size.");
// Map chunks covering [committed_size_, aligned).
size_t offset = committed_size_;
while (offset < aligned) {
map_chunk(offset); // may throw; already-mapped chunks stay valid
offset += granularity_;
}
committed_size_ = aligned;
}
void CudaVmmArena::shrink(size_t new_size) {
// Round *down* so we never expose a partially-unmapped granule.
const size_t aligned = align_down(new_size, granularity_);
if (aligned >= committed_size_)
throw CudaVmmError("shrink(): new_size must be smaller than current committed_size.");
// Unmap chunks covering [aligned, committed_size_) in reverse order.
size_t offset = committed_size_;
while (offset > aligned) {
offset -= granularity_;
unmap_chunk(handles_.size() - 1);
handles_.pop_back();
}
committed_size_ = aligned;
}
void CudaVmmArena::resize(size_t new_size) {
// Determine what the aligned target size would be without committing.
const size_t aligned_up = align_up(new_size, granularity_);
const size_t aligned_down = align_down(new_size, granularity_);
if (aligned_up > committed_size_) {
grow(new_size);
} else if (aligned_down < committed_size_) {
shrink(new_size);
}
// else: already at the right size, nothing to do.
}
} // namespace tensorrt_llm::batch_manager::vmm