Skip to content

Commit 4a4e8b3

Browse files
Communications arena implementation (#3388)
## Summary Implement a communications arena for comm buffers to replace `the_fa_arena`. It creates a separate arena when GPU-aware MPI is used and `the_arena` is not managed. ## Additional background The motivation for this is a communication performance degradation that is observed for GPU-aware MPI with `amrex.the_arena_is_managed=0`. @WeiqunZhang has a hypothesis that this may be due to the need for frequent re-registering of comm buffer pointers when using the same device arena as the other compute data. Hence a separate arena in this case would alleviate this issue. `the_fa_arena` is eliminated in this PR and the communication buffer directly uses `the_comms_arena` to simplify the code. ## Performance tests The above stated performance degradation is particularly observed with the `GPU/CNS/Exec/Sod` code under `Tests` and is alleviated by using a separate comms arena as seen in the performance data below. `original` refers to the state before we made the change in #3362 related to `the_fa_arena` pointing to the device arena which allowed `amrex.the_arena_is_managed=1` with GPU-aware MPI without a significant performance hit. It is compared with the current development branch and the proposed comms arena implementation. The data pointing to the performance improvement from this PR is highlighted. ![Screenshot 2023-06-27 at 4 11 54 PM](https://github.com/AMReX-Codes/amrex/assets/18251677/ae16b822-0178-4679-a90f-255cad6c5451) In other tests such as the `ABecLaplacian` linear solve or the ERF code, using `amrex.the_arena_is_managed=0` did not show a significant performance hit and using this comms arena implementation did not harm the performance either. More comprehensive tests would be required to determine the effect on other codes and platforms. --------- Co-authored-by: Mukul Dave <[email protected]> Co-authored-by: Weiqun Zhang <[email protected]>
1 parent 0236a37 commit 4a4e8b3

File tree

7 files changed

+73
-33
lines changed

7 files changed

+73
-33
lines changed

Src/Base/AMReX_Arena.H

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Arena* The_Async_Arena ();
3131
Arena* The_Device_Arena ();
3232
Arena* The_Managed_Arena ();
3333
Arena* The_Pinned_Arena ();
34+
Arena* The_Comms_Arena ();
3435
Arena* The_Cpu_Arena ();
3536

3637
struct ArenaInfo

Src/Base/AMReX_Arena.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,18 @@ namespace {
3434
Arena* the_managed_arena = nullptr;
3535
Arena* the_pinned_arena = nullptr;
3636
Arena* the_cpu_arena = nullptr;
37+
Arena* the_comms_arena = nullptr;
3738

3839
Long the_arena_init_size = 0L;
3940
Long the_device_arena_init_size = 1024*1024*8;
4041
Long the_managed_arena_init_size = 1024*1024*8;
4142
Long the_pinned_arena_init_size = 1024*1024*8;
43+
Long the_comms_arena_init_size = 1024*1024*8;
4244
Long the_arena_release_threshold = std::numeric_limits<Long>::max();
4345
Long the_device_arena_release_threshold = std::numeric_limits<Long>::max();
4446
Long the_managed_arena_release_threshold = std::numeric_limits<Long>::max();
4547
Long the_pinned_arena_release_threshold = std::numeric_limits<Long>::max();
48+
Long the_comms_arena_release_threshold = std::numeric_limits<Long>::max();
4649
Long the_async_arena_release_threshold = std::numeric_limits<Long>::max();
4750
#ifdef AMREX_USE_HIP
4851
bool the_arena_is_managed = false; // xxxxx HIP FIX HERE
@@ -276,6 +279,7 @@ Arena::Initialize ()
276279
BL_ASSERT(the_managed_arena == nullptr || the_managed_arena == The_BArena());
277280
BL_ASSERT(the_pinned_arena == nullptr);
278281
BL_ASSERT(the_cpu_arena == nullptr || the_cpu_arena == The_BArena());
282+
BL_ASSERT(the_comms_arena == nullptr || the_comms_arena == The_BArena());
279283

280284
#ifdef AMREX_USE_GPU
281285
#ifdef AMREX_USE_SYCL
@@ -292,10 +296,12 @@ Arena::Initialize ()
292296
pp.queryAdd( "the_device_arena_init_size", the_device_arena_init_size);
293297
pp.queryAdd("the_managed_arena_init_size", the_managed_arena_init_size);
294298
pp.queryAdd( "the_pinned_arena_init_size", the_pinned_arena_init_size);
299+
pp.queryAdd( "the_comms_arena_init_size", the_comms_arena_init_size);
295300
pp.queryAdd( "the_arena_release_threshold" , the_arena_release_threshold);
296301
pp.queryAdd( "the_device_arena_release_threshold", the_device_arena_release_threshold);
297302
pp.queryAdd("the_managed_arena_release_threshold", the_managed_arena_release_threshold);
298303
pp.queryAdd( "the_pinned_arena_release_threshold", the_pinned_arena_release_threshold);
304+
pp.queryAdd("the_comms_arena_release_threshold", the_comms_arena_release_threshold);
299305
pp.queryAdd( "the_async_arena_release_threshold", the_async_arena_release_threshold);
300306
pp.queryAdd("the_arena_is_managed", the_arena_is_managed);
301307
pp.queryAdd("abort_on_out_of_gpu_memory", abort_on_out_of_gpu_memory);
@@ -361,6 +367,22 @@ Arena::Initialize ()
361367
(the_pinned_arena_release_threshold));
362368
the_pinned_arena->registerForProfiling("Pinned Memory");
363369

370+
#ifdef AMREX_USE_GPU
371+
if (ParallelDescriptor::UseGpuAwareMpi()) {
372+
if (!(the_arena->isDevice())) {
373+
the_comms_arena = the_device_arena;
374+
} else {
375+
the_comms_arena = new CArena(0, ArenaInfo{}.SetDeviceMemory().SetReleaseThreshold
376+
(the_comms_arena_release_threshold));
377+
the_comms_arena->registerForProfiling("Comms Memory");
378+
}
379+
} else {
380+
the_comms_arena = the_pinned_arena;
381+
}
382+
#else
383+
the_comms_arena = The_BArena();
384+
#endif
385+
364386
if (the_device_arena_init_size > 0 && the_device_arena != the_arena) {
365387
BL_PROFILE("The_Device_Arena::Initialize()");
366388
void *p = the_device_arena->alloc(the_device_arena_init_size);
@@ -379,6 +401,13 @@ Arena::Initialize ()
379401
the_pinned_arena->free(p);
380402
}
381403

404+
if (the_comms_arena_init_size > 0 && the_comms_arena != the_arena
405+
&& the_comms_arena != the_device_arena && the_comms_arena != the_pinned_arena) {
406+
BL_PROFILE("The_Comms_Arena::Initialize()");
407+
void *p = the_comms_arena->alloc(the_comms_arena_init_size);
408+
the_comms_arena->free(p);
409+
}
410+
382411
the_cpu_arena = The_BArena();
383412

384413
// Initialize the null arena
@@ -440,6 +469,13 @@ Arena::PrintUsage ()
440469
p->PrintUsage("The Pinned Arena");
441470
}
442471
}
472+
if (The_Comms_Arena() && The_Comms_Arena() != The_Device_Arena()
473+
&& The_Comms_Arena() != The_Pinned_Arena()) {
474+
auto* p = dynamic_cast<CArena*>(The_Comms_Arena());
475+
if (p) {
476+
p->PrintUsage("The Comms Arena");
477+
}
478+
}
443479
}
444480

445481
void
@@ -485,6 +521,13 @@ Arena::PrintUsageToFiles (const std::string& filename, const std::string& messag
485521
p->PrintUsage(ofs, "The Pinned Arena", " ");
486522
}
487523
}
524+
if (The_Comms_Arena() && The_Comms_Arena() != The_Device_Arena()
525+
&& The_Comms_Arena() != The_Pinned_Arena()) {
526+
auto* p = dynamic_cast<CArena*>(The_Comms_Arena());
527+
if (p) {
528+
p->PrintUsage(ofs, "The Comms Arena", " ");
529+
}
530+
}
488531

489532
ofs << "\n";
490533
}
@@ -509,6 +552,13 @@ Arena::Finalize ()
509552
// MultiFab mf(...); // this should be scoped in { ... }
510553
// amrex::Finalize();
511554
// mf cannot be used now, but it can at least be freed without a segfault
555+
if (!dynamic_cast<BArena*>(the_comms_arena)) {
556+
if (the_comms_arena != the_device_arena && the_comms_arena != the_pinned_arena) {
557+
delete the_comms_arena;
558+
}
559+
the_comms_arena = nullptr;
560+
}
561+
512562
if (!dynamic_cast<BArena*>(the_device_arena)) {
513563
if (the_device_arena != the_arena) {
514564
delete the_device_arena;
@@ -600,4 +650,14 @@ The_Cpu_Arena ()
600650
}
601651
}
602652

653+
Arena*
654+
The_Comms_Arena ()
655+
{
656+
if (the_comms_arena) {
657+
return the_comms_arena;
658+
} else {
659+
return The_Null_Arena();
660+
}
603661
}
662+
663+
}

Src/Base/AMReX_FabArray.H

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ struct MFInfo {
144144
struct TheFaArenaDeleter {
145145
using pointer = char*;
146146
void operator()(pointer p) const noexcept {
147-
The_FA_Arena()->free(p);
147+
The_Comms_Arena()->free(p);
148148
}
149149
};
150150
using TheFaArenaPointer = std::unique_ptr<char, TheFaArenaDeleter>;

Src/Base/AMReX_FabArrayBase.H

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,6 @@ bool CheckRcvStats (Vector<MPI_Status>& recv_stats, const Vector<std::size_t>& r
721721

722722
std::ostream& operator<< (std::ostream& os, const FabArrayBase::BDKey& id);
723723

724-
Arena* The_FA_Arena ();
725-
726724
}
727725

728726
#endif

Src/Base/AMReX_FabArrayBase.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ std::vector<std::string> FabArrayBase::m_region_tag;
8888

8989
namespace
9090
{
91-
Arena* the_fa_arena = nullptr;
9291
bool initialized = false;
9392
}
9493

@@ -123,16 +122,6 @@ FabArrayBase::Initialize ()
123122
MaxComp = 1;
124123
}
125124

126-
#ifdef AMREX_USE_GPU
127-
if (ParallelDescriptor::UseGpuAwareMpi()) {
128-
the_fa_arena = The_Device_Arena();
129-
} else {
130-
the_fa_arena = The_Pinned_Arena();
131-
}
132-
#else
133-
the_fa_arena = The_Cpu_Arena();
134-
#endif
135-
136125
amrex::ExecOnFinalize(FabArrayBase::Finalize);
137126

138127
#ifdef AMREX_MEM_PROFILING
@@ -159,12 +148,6 @@ FabArrayBase::Initialize ()
159148
#endif
160149
}
161150

162-
Arena*
163-
The_FA_Arena ()
164-
{
165-
return the_fa_arena;
166-
}
167-
168151
FabArrayBase::FabArrayBase (const BoxArray& bxs,
169152
const DistributionMapping& dm,
170153
int nvar,
@@ -2245,8 +2228,6 @@ FabArrayBase::Finalize ()
22452228

22462229
m_FA_stats = FabArrayStats();
22472230

2248-
the_fa_arena = nullptr;
2249-
22502231
initialized = false;
22512232
}
22522233

Src/Base/AMReX_FabArrayCommI.H

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ FabArray<FAB>::FillBoundary_finish ()
228228

229229
if (fbd->the_recv_data)
230230
{
231-
amrex::The_FA_Arena()->free(fbd->the_recv_data);
231+
amrex::The_Comms_Arena()->free(fbd->the_recv_data);
232232
fbd->the_recv_data = nullptr;
233233
}
234234
}
@@ -237,7 +237,7 @@ FabArray<FAB>::FillBoundary_finish ()
237237
if (N_snds > 0) {
238238
Vector<MPI_Status> stats(fbd->send_reqs.size());
239239
ParallelDescriptor::Waitall(fbd->send_reqs, stats);
240-
amrex::The_FA_Arena()->free(fbd->the_send_data);
240+
amrex::The_Comms_Arena()->free(fbd->the_send_data);
241241
fbd->the_send_data = nullptr;
242242
}
243243

@@ -548,7 +548,7 @@ FabArray<FAB>::ParallelCopy_finish ()
548548

549549
if (pcd->the_recv_data)
550550
{
551-
amrex::The_FA_Arena()->free(pcd->the_recv_data);
551+
amrex::The_Comms_Arena()->free(pcd->the_recv_data);
552552
pcd->the_recv_data = nullptr;
553553
}
554554
}
@@ -558,7 +558,7 @@ FabArray<FAB>::ParallelCopy_finish ()
558558
Vector<MPI_Status> stats(pcd->send_reqs.size());
559559
ParallelDescriptor::Waitall(pcd->send_reqs, stats);
560560
}
561-
amrex::The_FA_Arena()->free(pcd->the_send_data);
561+
amrex::The_Comms_Arena()->free(pcd->the_send_data);
562562
pcd->the_send_data = nullptr;
563563
}
564564

@@ -685,7 +685,7 @@ FabArray<FAB>::PrepareSendBuffers (const MapOfCopyComTagContainers& SndTags,
685685

686686
if (total_volume > 0)
687687
{
688-
the_send_data = static_cast<char*>(amrex::The_FA_Arena()->alloc(total_volume));
688+
the_send_data = static_cast<char*>(amrex::The_Comms_Arena()->alloc(total_volume));
689689
for (int i = 0, N = static_cast<int>(send_size.size()); i < N; ++i) {
690690
send_data[i] = the_send_data + offset[i];
691691
}
@@ -783,7 +783,7 @@ FabArray<FAB>::PostRcvs (const MapOfCopyComTagContainers& RcvTags,
783783
}
784784
else
785785
{
786-
the_recv_data = static_cast<char*>(amrex::The_FA_Arena()->alloc(TotalRcvsVolume));
786+
the_recv_data = static_cast<char*>(amrex::The_Comms_Arena()->alloc(TotalRcvsVolume));
787787

788788
for (int i = 0; i < nrecv; ++i)
789789
{
@@ -1004,7 +1004,7 @@ FillBoundary (Vector<MF*> const& mf, Vector<int> const& scomp,
10041004
recv_size.push_back(nbytes);
10051005
}
10061006
1007-
the_recv_data = static_cast<char*>(amrex::The_FA_Arena()->alloc(TotalRcvsVolume));
1007+
the_recv_data = static_cast<char*>(amrex::The_Comms_Arena()->alloc(TotalRcvsVolume));
10081008
10091009
int k = 0;
10101010
for (int i = 0; i < nrecv; ++i) {
@@ -1077,7 +1077,7 @@ FillBoundary (Vector<MF*> const& mf, Vector<int> const& scomp,
10771077
send_size.push_back(nbytes);
10781078
}
10791079
1080-
the_send_data = static_cast<char*>(amrex::The_FA_Arena()->alloc(TotalSndsVolume));
1080+
the_send_data = static_cast<char*>(amrex::The_Comms_Arena()->alloc(TotalSndsVolume));
10811081
int k = 0;
10821082
for (int i = 0; i < nsend; ++i) {
10831083
send_data[i] = the_send_data + offset[i];
@@ -1113,13 +1113,13 @@ FillBoundary (Vector<MF*> const& mf, Vector<int> const& scomp,
11131113
11141114
detail::fbv_copy(recv_tags);
11151115
1116-
amrex::The_FA_Arena()->free(the_recv_data);
1116+
amrex::The_Comms_Arena()->free(the_recv_data);
11171117
}
11181118
11191119
if (N_snds > 0) {
11201120
Vector<MPI_Status> stats(send_reqs.size());
11211121
ParallelDescriptor::Waitall(send_reqs, stats);
1122-
amrex::The_FA_Arena()->free(the_send_data);
1122+
amrex::The_Comms_Arena()->free(the_send_data);
11231123
}
11241124
11251125
#endif // #ifdef AMREX_USE_MPI

Src/Base/AMReX_NonLocalBC.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ void PrepareCommBuffers(CommData& comm,
9595
}
9696
else
9797
{
98-
comm.the_data.reset(static_cast<char*>(amrex::The_FA_Arena()->alloc(total_volume)));
98+
comm.the_data.reset(static_cast<char*>(amrex::The_Comms_Arena()->alloc(total_volume)));
9999
for (int i = 0; i < N_comms; ++i) {
100100
comm.data[i] = comm.the_data.get() + comm.offset[i];
101101
}

0 commit comments

Comments
 (0)