Skip to content

Remove need to sync Gpu stream before deallocating memory #4432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: development
Choose a base branch
from
1 change: 1 addition & 0 deletions Src/Base/AMReX_Arena.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ Arena::allocate_system (std::size_t nbytes) // NOLINT(readability-make-member-fu
{
std::size_t free_mem_avail = Gpu::Device::freeMemAvailable();
if (nbytes >= free_mem_avail) {
Gpu::streamSynchronizeAll(); // this could cause some memory to be freed
free_mem_avail += freeUnused_protected(); // For CArena, mutex has already acquired
if (abort_on_out_of_gpu_memory && nbytes >= free_mem_avail) {
amrex::Abort("Out of gpu memory. Free: " + std::to_string(free_mem_avail)
Expand Down
2 changes: 2 additions & 0 deletions Src/Base/AMReX_CArena.H
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ public:
*/
void free (void* vp) final;

void free_now (void* vp);

std::size_t freeUnused () final;

/**
Expand Down
18 changes: 18 additions & 0 deletions Src/Base/AMReX_CArena.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <AMReX_CArena.H>
#include <AMReX_BLassert.H>
#include <AMReX_Gpu.H>
#include <AMReX_GpuDevice.H>
#include <AMReX_ParallelReduce.H>

#include <utility>
Expand Down Expand Up @@ -265,6 +266,23 @@ CArena::free (void* vp)
return;
}

if (this->isDeviceAccessible()) {
Gpu::Device::freeAfterSync(this, vp);
} else {
free_now(vp);
}
}

void
CArena::free_now (void* vp)
{
if (vp == nullptr) {
//
// Allow calls with NULL as allowed by C++ delete.
//
return;
}

std::lock_guard<std::mutex> lock(carena_mutex);

//
Expand Down
63 changes: 55 additions & 8 deletions Src/Base/AMReX_GpuDevice.H
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstdlib>
#include <cstring>
#include <memory>
#include <mutex>

#define AMREX_GPU_MAX_STREAMS 8

Expand Down Expand Up @@ -46,8 +47,28 @@ using gpuDeviceProp_t = cudaDeviceProp;
}
#endif

namespace amrex {
class CArena;
}

namespace amrex::Gpu {

#ifdef AMREX_USE_GPU
class StreamManager {
gpuStream_t m_stream;
std::uint64_t m_stream_op_id = 0;
std::uint64_t m_last_sync = 0;
Vector<std::pair<CArena*, void*>> m_free_wait_list;
std::mutex m_mutex;
public:
[[nodiscard]] gpuStream_t get ();
[[nodiscard]] gpuStream_t& internal_get ();
void sync ();
void internal_after_sync ();
void stream_free (CArena* arena, void* mem);
};
#endif

class Device
{

Expand All @@ -57,17 +78,32 @@ public:
static void Finalize ();

#if defined(AMREX_USE_GPU)
static gpuStream_t gpuStream () noexcept { return gpu_stream[OpenMP::get_thread_num()]; }
static gpuStream_t gpuStream () noexcept {
return gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].get();
}
#ifdef AMREX_USE_CUDA
/** for backward compatibility */
static cudaStream_t cudaStream () noexcept { return gpu_stream[OpenMP::get_thread_num()]; }
static cudaStream_t cudaStream () noexcept {
return gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].get();
}
#endif
#ifdef AMREX_USE_SYCL
static sycl::queue& streamQueue () noexcept { return *(gpu_stream[OpenMP::get_thread_num()].queue); }
static sycl::queue& streamQueue (int i) noexcept { return *(gpu_stream_pool[i].queue); }
static sycl::queue& streamQueue () noexcept {
return *(gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].get().queue);
}
static sycl::queue& streamQueue (int i) noexcept {
return *(gpu_stream_pool[i].get().queue);
}
#endif
#endif

static void freeAfterSync (CArena* arena, void* mem) noexcept {
amrex::ignore_unused(arena, mem);
#ifdef AMREX_USE_CUDA
gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].stream_free(arena, mem);
#endif
}

static int numGpuStreams () noexcept {
return inSingleStreamRegion() ? 1 : max_gpu_streams;
}
Expand Down Expand Up @@ -104,6 +140,16 @@ public:
*/
static void streamSynchronizeAll () noexcept;

#ifdef AMREX_USE_GPU
/**
* Halt execution of code until the current AMReX GPU stream has finished processing all
* previously requested tasks. Unlike streamSynchronize which avoids redundant
* synchronizations when being called multiple times in a row,
* this function always causes the GPU stream to be synchronized
*/
static void actualStreamSynchronize (gpuStream_t stream) noexcept;
#endif

#if defined(__CUDACC__)
/** Generic graph selection. These should be called by users. */
static void startGraphRecording(bool first_iter, void* h_ptr, void* d_ptr, size_t sz);
Expand Down Expand Up @@ -196,10 +242,11 @@ private:
static AMREX_EXPORT dim3 numThreadsMin;
static AMREX_EXPORT dim3 numBlocksOverride, numThreadsOverride;

static AMREX_EXPORT Vector<gpuStream_t> gpu_stream_pool; // The size of this is max_gpu_stream
// The non-owning gpu_stream is used to store the current stream that will be used.
// gpu_stream is a vector so that it's thread safe to write to it.
static AMREX_EXPORT Vector<gpuStream_t> gpu_stream; // The size of this is omp_max_threads
static AMREX_EXPORT Vector<StreamManager> gpu_stream_pool; // The size of this is max_gpu_stream
// The non-owning gpu_stream_index is used to store the current stream index that will be used.
// gpu_stream_index is a vector so that it's thread safe to write to it.
static AMREX_EXPORT Vector<int> gpu_stream_index; // The size of this is omp_max_threads

static AMREX_EXPORT gpuDeviceProp_t device_prop;
static AMREX_EXPORT int memory_pools_supported;
static AMREX_EXPORT unsigned int max_blocks_per_launch;
Expand Down
Loading
Loading