1616#include < cstdlib>
1717#include < cstring>
1818#include < memory>
19+ #include < mutex>
1920
2021#define AMREX_GPU_MAX_STREAMS 8
2122
@@ -46,8 +47,24 @@ using gpuDeviceProp_t = cudaDeviceProp;
4647}
4748#endif
4849
50+ namespace amrex {
51+ class Arena ;
52+ }
53+
4954namespace amrex ::Gpu {
5055
56+ #ifdef AMREX_USE_GPU
57+ class StreamManager {
58+ gpuStream_t m_stream;
59+ std::mutex m_mutex;
60+ Vector<std::pair<Arena*, void *>> m_free_wait_list;
61+ public:
62+ [[nodiscard]] gpuStream_t& get ();
63+ void sync ();
64+ void free_async (Arena* arena, void * mem);
65+ };
66+ #endif
67+
5168class Device
5269{
5370
@@ -57,14 +74,16 @@ public:
5774 static void Finalize ();
5875
5976#if defined(AMREX_USE_GPU)
60- static gpuStream_t gpuStream () noexcept { return gpu_stream[OpenMP::get_thread_num ()]; }
77+ static gpuStream_t gpuStream () noexcept {
78+ return gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num ()]].get ();
79+ }
6180#ifdef AMREX_USE_CUDA
6281 /* * for backward compatibility */
63- static cudaStream_t cudaStream () noexcept { return gpu_stream[ OpenMP::get_thread_num ()] ; }
82+ static cudaStream_t cudaStream () noexcept { return gpuStream () ; }
6483#endif
6584#ifdef AMREX_USE_SYCL
66- static sycl::queue& streamQueue () noexcept { return *(gpu_stream[ OpenMP::get_thread_num ()] .queue ); }
67- static sycl::queue& streamQueue (int i) noexcept { return *(gpu_stream_pool[i].queue ); }
85+ static sycl::queue& streamQueue () noexcept { return *(gpuStream () .queue ); }
86+ static sycl::queue& streamQueue (int i) noexcept { return *(gpu_stream_pool[i].get (). queue ); }
6887#endif
6988#endif
7089
@@ -104,6 +123,8 @@ public:
104123 */
105124 static void streamSynchronizeAll () noexcept ;
106125
126+ static void freeAsync (Arena* arena, void * mem) noexcept ;
127+
107128#if defined(__CUDACC__)
108129 /* * Generic graph selection. These should be called by users. */
109130 static void startGraphRecording (bool first_iter, void * h_ptr, void * d_ptr, size_t sz);
@@ -196,10 +217,10 @@ private:
196217 static AMREX_EXPORT dim3 numThreadsMin;
197218 static AMREX_EXPORT dim3 numBlocksOverride, numThreadsOverride;
198219
199- static AMREX_EXPORT Vector<gpuStream_t > gpu_stream_pool; // The size of this is max_gpu_stream
200- // The non-owning gpu_stream is used to store the current stream that will be used.
201- // gpu_stream is a vector so that it's thread safe to write to it.
202- static AMREX_EXPORT Vector<gpuStream_t> gpu_stream ; // The size of this is omp_max_threads
220+ static AMREX_EXPORT Vector<StreamManager > gpu_stream_pool; // The size of this is max_gpu_stream
221+ // The non-owning gpu_stream_index is used to store the current stream index that will be used.
222+ // gpu_stream_index is a vector so that it's thread safe to write to it.
223+ static AMREX_EXPORT Vector<int > gpu_stream_index ; // The size of this is omp_max_threads
203224 static AMREX_EXPORT gpuDeviceProp_t device_prop;
204225 static AMREX_EXPORT int memory_pools_supported;
205226 static AMREX_EXPORT unsigned int max_blocks_per_launch;
@@ -208,6 +229,8 @@ private:
208229 static AMREX_EXPORT std::unique_ptr<sycl::context> sycl_context;
209230 static AMREX_EXPORT std::unique_ptr<sycl::device> sycl_device;
210231#endif
232+
233+ friend StreamManager;
211234#endif
212235};
213236
@@ -245,6 +268,21 @@ streamSynchronizeAll () noexcept
245268 Device::streamSynchronizeAll ();
246269}
247270
271+ /* * Deallocate memory belonging to an arena asynchronously.
272+ * Memory deallocated in this way is held in a pool and will not be reused until
273+ * the next amrex::Gpu::streamSynchronize(). GPU kernels that were already launched on the
274+ * currently active stream can still continue to use the memory after this function is called.
275+ * There is no need to use this function for CPU-only memory or with The_Async_Arena.
276+ *
277+ * \param[in] arena the arena the memory belongs to
278+ * \param[in] mem pointer to the memory to be freed
279+ */
280+ inline void
281+ freeAsync (Arena* arena, void * mem) noexcept
282+ {
283+ Device::freeAsync (arena, mem);
284+ }
285+
248286#ifdef AMREX_USE_GPU
249287
250288inline void
0 commit comments