Skip to content

Commit 477c09f

Browse files
Add amrex::Gpu::freeAsync (#4804)
This PR adds the function `amrex::Gpu::freeAsync (Arena* arena, void* mem)` that can be used to free memory the next time the current GPU stream is synchronized. This is based on #4432 but with much reduced complexity from OMP. The interface is now opt-in and always available, instead of needing to be enabled using runtime parameters.
1 parent 367e101 commit 477c09f

File tree

5 files changed

+209
-68
lines changed

5 files changed

+209
-68
lines changed

Src/Base/AMReX_GpuDevice.H

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <cstdlib>
1717
#include <cstring>
1818
#include <memory>
19+
#include <mutex>
1920

2021
#define AMREX_GPU_MAX_STREAMS 8
2122

@@ -46,8 +47,24 @@ using gpuDeviceProp_t = cudaDeviceProp;
4647
}
4748
#endif
4849

50+
namespace amrex {
51+
class Arena;
52+
}
53+
4954
namespace amrex::Gpu {
5055

56+
#ifdef AMREX_USE_GPU
57+
class StreamManager {
58+
gpuStream_t m_stream;
59+
std::mutex m_mutex;
60+
Vector<std::pair<Arena*, void*>> m_free_wait_list;
61+
public:
62+
[[nodiscard]] gpuStream_t& get ();
63+
void sync ();
64+
void free_async (Arena* arena, void* mem);
65+
};
66+
#endif
67+
5168
class Device
5269
{
5370

@@ -57,14 +74,16 @@ public:
5774
static void Finalize ();
5875

5976
#if defined(AMREX_USE_GPU)
60-
static gpuStream_t gpuStream () noexcept { return gpu_stream[OpenMP::get_thread_num()]; }
77+
static gpuStream_t gpuStream () noexcept {
78+
return gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].get();
79+
}
6180
#ifdef AMREX_USE_CUDA
6281
/** for backward compatibility */
63-
static cudaStream_t cudaStream () noexcept { return gpu_stream[OpenMP::get_thread_num()]; }
82+
static cudaStream_t cudaStream () noexcept { return gpuStream(); }
6483
#endif
6584
#ifdef AMREX_USE_SYCL
66-
static sycl::queue& streamQueue () noexcept { return *(gpu_stream[OpenMP::get_thread_num()].queue); }
67-
static sycl::queue& streamQueue (int i) noexcept { return *(gpu_stream_pool[i].queue); }
85+
static sycl::queue& streamQueue () noexcept { return *(gpuStream().queue); }
86+
static sycl::queue& streamQueue (int i) noexcept { return *(gpu_stream_pool[i].get().queue); }
6887
#endif
6988
#endif
7089

@@ -104,6 +123,8 @@ public:
104123
*/
105124
static void streamSynchronizeAll () noexcept;
106125

126+
static void freeAsync (Arena* arena, void* mem) noexcept;
127+
107128
#if defined(__CUDACC__)
108129
/** Generic graph selection. These should be called by users. */
109130
static void startGraphRecording(bool first_iter, void* h_ptr, void* d_ptr, size_t sz);
@@ -196,10 +217,10 @@ private:
196217
static AMREX_EXPORT dim3 numThreadsMin;
197218
static AMREX_EXPORT dim3 numBlocksOverride, numThreadsOverride;
198219

199-
static AMREX_EXPORT Vector<gpuStream_t> gpu_stream_pool; // The size of this is max_gpu_stream
200-
// The non-owning gpu_stream is used to store the current stream that will be used.
201-
// gpu_stream is a vector so that it's thread safe to write to it.
202-
static AMREX_EXPORT Vector<gpuStream_t> gpu_stream; // The size of this is omp_max_threads
220+
static AMREX_EXPORT Vector<StreamManager> gpu_stream_pool; // The size of this is max_gpu_stream
221+
// The non-owning gpu_stream_index is used to store the current stream index that will be used.
222+
// gpu_stream_index is a vector so that it's thread safe to write to it.
223+
static AMREX_EXPORT Vector<int> gpu_stream_index; // The size of this is omp_max_threads
203224
static AMREX_EXPORT gpuDeviceProp_t device_prop;
204225
static AMREX_EXPORT int memory_pools_supported;
205226
static AMREX_EXPORT unsigned int max_blocks_per_launch;
@@ -208,6 +229,8 @@ private:
208229
static AMREX_EXPORT std::unique_ptr<sycl::context> sycl_context;
209230
static AMREX_EXPORT std::unique_ptr<sycl::device> sycl_device;
210231
#endif
232+
233+
friend StreamManager;
211234
#endif
212235
};
213236

@@ -245,6 +268,21 @@ streamSynchronizeAll () noexcept
245268
Device::streamSynchronizeAll();
246269
}
247270

271+
/** Deallocate memory belonging to an arena asynchronously.
272+
* Memory deallocated in this way is held in a pool and will not be reused until
273+
* the next amrex::Gpu::streamSynchronize(). GPU kernels that were already launched on the
274+
* currently active stream can still continue to use the memory after this function is called.
275+
* There is no need to use this function for CPU-only memory or with The_Async_Arena.
276+
*
277+
* \param[in] arena the arena the memory belongs to
278+
* \param[in] mem pointer to the memory to be freed
279+
*/
280+
inline void
281+
freeAsync (Arena* arena, void* mem) noexcept
282+
{
283+
Device::freeAsync(arena, mem);
284+
}
285+
248286
#ifdef AMREX_USE_GPU
249287

250288
inline void

Src/Base/AMReX_GpuDevice.cpp

Lines changed: 105 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
#include <AMReX_Arena.H>
23
#include <AMReX_GpuDevice.H>
34
#include <AMReX_GpuLaunch.H>
45
#include <AMReX_Machine.H>
@@ -97,10 +98,10 @@ dim3 Device::numThreadsOverride = dim3(0, 0, 0);
9798
dim3 Device::numBlocksOverride = dim3(0, 0, 0);
9899
unsigned int Device::max_blocks_per_launch = 2560;
99100

100-
Vector<gpuStream_t> Device::gpu_stream_pool;
101-
Vector<gpuStream_t> Device::gpu_stream;
102-
gpuDeviceProp_t Device::device_prop;
103-
int Device::memory_pools_supported = 0;
101+
Vector<StreamManager> Device::gpu_stream_pool;
102+
Vector<int> Device::gpu_stream_index;
103+
gpuDeviceProp_t Device::device_prop;
104+
int Device::memory_pools_supported = 0;
104105

105106
constexpr int Device::warp_size;
106107

@@ -141,6 +142,64 @@ namespace {
141142
}
142143
}
143144

145+
[[nodiscard]] gpuStream_t&
146+
StreamManager::get () {
147+
return m_stream;
148+
}
149+
150+
void
151+
StreamManager::sync () {
152+
decltype(m_free_wait_list) new_empty_wait_list{};
153+
154+
{
155+
// lock mutex before accessing and modifying member variables
156+
std::lock_guard<std::mutex> lock(m_mutex);
157+
m_free_wait_list.swap(new_empty_wait_list);
158+
}
159+
// unlock mutex before stream sync and memory free
160+
// to avoid deadlocks from the CArena mutex
161+
162+
// actual stream sync
163+
#ifdef AMREX_USE_SYCL
164+
try {
165+
m_stream.queue->wait_and_throw();
166+
} catch (sycl::exception const& ex) {
167+
amrex::Abort(std::string("streamSynchronize: ")+ex.what()+"!!!!!");
168+
}
169+
#else
170+
AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL(hipStreamSynchronize(m_stream));,
171+
AMREX_CUDA_SAFE_CALL(cudaStreamSynchronize(m_stream)); )
172+
#endif
173+
174+
// synconizing the stream may have taken a long time and
175+
// there may be new kernels launched already, so we free memory
176+
// according to the state from before the stream was synced
177+
178+
for (auto [arena, mem] : new_empty_wait_list) {
179+
arena->free(mem);
180+
}
181+
}
182+
183+
void
184+
StreamManager::free_async (Arena* arena, void* mem) {
185+
if (arena->isDeviceAccessible()) {
186+
std::size_t free_wait_list_size = 0;
187+
{
188+
// lock mutex before accessing and modifying member variables
189+
std::lock_guard<std::mutex> lock(m_mutex);
190+
m_free_wait_list.emplace_back(arena, mem);
191+
free_wait_list_size = m_free_wait_list.size();
192+
}
193+
// Limit the number of memory allocations in m_free_wait_list
194+
// in case the stream is never synchronized
195+
if (free_wait_list_size > 100) {
196+
sync();
197+
}
198+
} else {
199+
arena->free(mem);
200+
}
201+
}
202+
144203
#endif
145204

146205
void
@@ -384,24 +443,25 @@ void
384443
Device::Finalize ()
385444
{
386445
#ifdef AMREX_USE_GPU
446+
streamSynchronizeAll();
387447
Device::profilerStop();
388448

389449
#ifdef AMREX_USE_SYCL
390450
for (auto& s : gpu_stream_pool) {
391-
delete s.queue;
392-
s.queue = nullptr;
451+
delete s.get().queue;
452+
s.get().queue = nullptr;
393453
}
394454
sycl_context.reset();
395455
sycl_device.reset();
396456
#else
397457
for (int i = 0; i < max_gpu_streams; ++i)
398458
{
399-
AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL( hipStreamDestroy(gpu_stream_pool[i]));,
400-
AMREX_CUDA_SAFE_CALL(cudaStreamDestroy(gpu_stream_pool[i])); );
459+
AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL( hipStreamDestroy(gpu_stream_pool[i].get()));,
460+
AMREX_CUDA_SAFE_CALL(cudaStreamDestroy(gpu_stream_pool[i].get())); );
401461
}
402462
#endif
403463

404-
gpu_stream.clear();
464+
gpu_stream_index.clear();
405465

406466
#ifdef AMREX_USE_ACC
407467
amrex_finalize_acc();
@@ -417,7 +477,10 @@ Device::initialize_gpu (bool minimal)
417477

418478
#ifdef AMREX_USE_GPU
419479

420-
gpu_stream_pool.resize(max_gpu_streams);
480+
if (gpu_stream_pool.size() != max_gpu_streams) {
481+
// no copy/move constructor for std::mutex
482+
gpu_stream_pool = Vector<StreamManager>(max_gpu_streams);
483+
}
421484

422485
#ifdef AMREX_USE_HIP
423486

@@ -430,7 +493,7 @@ Device::initialize_gpu (bool minimal)
430493
// AMD devices do not support shared cache banking.
431494

432495
for (int i = 0; i < max_gpu_streams; ++i) {
433-
AMREX_HIP_SAFE_CALL(hipStreamCreate(&gpu_stream_pool[i]));
496+
AMREX_HIP_SAFE_CALL(hipStreamCreate(&gpu_stream_pool[i].get()));
434497
}
435498

436499
#ifdef AMREX_GPU_STREAM_ALLOC_SUPPORT
@@ -458,9 +521,9 @@ Device::initialize_gpu (bool minimal)
458521
#endif
459522

460523
for (int i = 0; i < max_gpu_streams; ++i) {
461-
AMREX_CUDA_SAFE_CALL(cudaStreamCreate(&gpu_stream_pool[i]));
524+
AMREX_CUDA_SAFE_CALL(cudaStreamCreate(&gpu_stream_pool[i].get()));
462525
#ifdef AMREX_USE_ACC
463-
acc_set_cuda_stream(i, gpu_stream_pool[i]);
526+
acc_set_cuda_stream(i, gpu_stream_pool[i].get());
464527
#endif
465528
}
466529

@@ -473,7 +536,7 @@ Device::initialize_gpu (bool minimal)
473536
sycl_device = std::make_unique<sycl::device>(gpu_devices[device_id]);
474537
sycl_context = std::make_unique<sycl::context>(*sycl_device, amrex_sycl_error_handler);
475538
for (int i = 0; i < max_gpu_streams; ++i) {
476-
gpu_stream_pool[i].queue = new sycl::queue(*sycl_context, *sycl_device,
539+
gpu_stream_pool[i].get().queue = new sycl::queue(*sycl_context, *sycl_device,
477540
sycl::property_list{sycl::property::queue::in_order{}});
478541
}
479542
}
@@ -556,7 +619,7 @@ Device::initialize_gpu (bool minimal)
556619
}
557620
#endif
558621

559-
gpu_stream.resize(OpenMP::get_max_threads(), gpu_stream_pool[0]);
622+
gpu_stream_index.resize(OpenMP::get_max_threads(), 0);
560623

561624
ParmParse pp("device");
562625

@@ -626,8 +689,13 @@ int Device::numDevicePartners () noexcept
626689
int
627690
Device::streamIndex (gpuStream_t s) noexcept
628691
{
629-
auto it = std::find(std::begin(gpu_stream_pool), std::end(gpu_stream_pool), s);
630-
return static_cast<int>(std::distance(std::begin(gpu_stream_pool), it));
692+
const int N = gpu_stream_pool.size();
693+
for (int i = 0; i < N ; ++i) {
694+
if (gpu_stream_pool[i].get() == s) {
695+
return i;
696+
}
697+
}
698+
return N;
631699
}
632700
#endif
633701

@@ -636,7 +704,7 @@ Device::setStreamIndex (int idx) noexcept
636704
{
637705
amrex::ignore_unused(idx);
638706
#ifdef AMREX_USE_GPU
639-
gpu_stream[OpenMP::get_thread_num()] = gpu_stream_pool[idx % max_gpu_streams];
707+
gpu_stream_index[OpenMP::get_thread_num()] = idx % max_gpu_streams;
640708
#ifdef AMREX_USE_ACC
641709
amrex_set_acc_stream(idx % max_gpu_streams);
642710
#endif
@@ -647,16 +715,16 @@ Device::setStreamIndex (int idx) noexcept
647715
gpuStream_t
648716
Device::resetStream () noexcept
649717
{
650-
gpuStream_t r = gpu_stream[OpenMP::get_thread_num()];
651-
gpu_stream[OpenMP::get_thread_num()] = gpu_stream_pool[0];
718+
gpuStream_t r = gpuStream();
719+
gpu_stream_index[OpenMP::get_thread_num()] = 0;
652720
return r;
653721
}
654722

655723
gpuStream_t
656724
Device::setStream (gpuStream_t s) noexcept
657725
{
658-
gpuStream_t r = gpu_stream[OpenMP::get_thread_num()];
659-
gpu_stream[OpenMP::get_thread_num()] = s;
726+
gpuStream_t r = gpuStream();
727+
gpu_stream_index[OpenMP::get_thread_num()] = streamIndex(s);
660728
return r;
661729
}
662730
#endif
@@ -665,9 +733,9 @@ void
665733
Device::synchronize () noexcept
666734
{
667735
#ifdef AMREX_USE_SYCL
668-
for (auto const& s : gpu_stream_pool) {
736+
for (auto& s : gpu_stream_pool) {
669737
try {
670-
s.queue->wait_and_throw();
738+
s.get().queue->wait_and_throw();
671739
} catch (sycl::exception const& ex) {
672740
amrex::Abort(std::string("synchronize: ")+ex.what()+"!!!!!");
673741
}
@@ -681,31 +749,28 @@ Device::synchronize () noexcept
681749
void
682750
Device::streamSynchronize () noexcept
683751
{
684-
#ifdef AMREX_USE_SYCL
685-
auto& q = streamQueue();
686-
try {
687-
q.wait_and_throw();
688-
} catch (sycl::exception const& ex) {
689-
amrex::Abort(std::string("streamSynchronize: ")+ex.what()+"!!!!!");
690-
}
691-
#else
692-
AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL(hipStreamSynchronize(gpuStream()));,
693-
AMREX_CUDA_SAFE_CALL(cudaStreamSynchronize(gpuStream())); )
752+
#ifdef AMREX_USE_GPU
753+
gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].sync();
694754
#endif
695755
}
696756

697757
void
698758
Device::streamSynchronizeAll () noexcept
699759
{
700760
#ifdef AMREX_USE_GPU
701-
#ifdef AMREX_USE_SYCL
702-
Device::synchronize();
703-
#else
704-
for (auto const& s : gpu_stream_pool) {
705-
AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL(hipStreamSynchronize(s));,
706-
AMREX_CUDA_SAFE_CALL(cudaStreamSynchronize(s)); )
761+
for (auto& s : gpu_stream_pool) {
762+
s.sync();
707763
}
708764
#endif
765+
}
766+
767+
void
768+
Device::freeAsync (Arena* arena, void* mem) noexcept
769+
{
770+
#ifdef AMREX_USE_GPU
771+
gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].free_async(arena, mem);
772+
#else
773+
arena->free(mem);
709774
#endif
710775
}
711776

0 commit comments

Comments
 (0)