Skip to content

Commit 335f529

Browse files
WeiqunZhangatmyers
andauthored
Delay gpu stream sync in MFIter (#4493)
Previously we performed gpu stream sync during MFIter::Initialize in case that kernels on non-default streams inside MFIter loop might depend on the result of kernels on the default amrex stream launched before the MFIter loop. Now we delay the sync if needed until the end of the first iteration. --------- Co-authored-by: Andrew Myers <[email protected]>
1 parent 2d47d39 commit 335f529

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

Src/Base/AMReX_MFIter.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,12 @@ MFIter::Initialize ()
282282
"Nested or multiple active MFIters is not supported by default. This can be changed by calling MFIter::allowMultipleMFIters(true)".);
283283
}
284284

285-
#ifdef AMREX_USE_GPU
286-
if (device_sync) {
287-
#ifdef AMREX_USE_OMP
285+
#if defined(AMREX_USE_GPU) && defined(AMREX_USE_OMP)
286+
if (Gpu::inLaunchRegion() && device_sync && (streams > 1)
287+
&& (OpenMP::get_num_threads() > 1))
288+
{ // If there are multiple gpu streams and multiple omp threads, we need
289+
// to sync here. Otherwise, the sync will be delayed.
288290
#pragma omp single
289-
#endif
290291
Gpu::streamSynchronize();
291292
}
292293
#endif
@@ -534,6 +535,17 @@ MFIter::operator++ () noexcept
534535

535536
#ifdef AMREX_USE_GPU
536537
if (Gpu::inLaunchRegion()) {
538+
if (device_sync && (streams > 1) && (OpenMP::get_num_threads() == 1)
539+
&& (currentIndex == 1) && isValid())
540+
{
541+
// Because omp num threads is 1, gpu stream sync has not
542+
// been called in Initialize. We need to sync stream 0
543+
// before launching kernels on stream 1, because the user
544+
// might have launched kernels (such as memcpyAsync) on
545+
// stream 0 before this MFIter and the kernels on stream 1
546+
// might depend on it.
547+
Gpu::streamSynchronize();
548+
}
537549
Gpu::Device::setStreamIndex(currentIndex%streams);
538550
AMREX_GPU_ERROR_CHECK();
539551
#ifdef AMREX_DEBUG

0 commit comments

Comments
 (0)