Skip to content

Commit 4d7cc96

Browse files
committed
Merge branch 'atmosphere/acc_atm_get_bdy_states' into develop (PR #1279)
This merge adds OpenACC directives and data movement to the mpas_atm_get_bdy_state_2d and mpas_atm_get_bdy_state_3d functions. Timing information for the OpenACC data transfers in these routines is captured in the log file by new timers: - mpas_atm_get_bdy_state_2d [ACC_data_xfer]. - mpas_atm_get_bdy_state_3d [ACC_data_xfer]. * atmosphere/acc_atm_get_bdy_states: Enforce correct OpenACC data in mpas_atm_get_bdy_state_{2d,3d} functions Add OpenACC data movement to mpas_atm_get_bdy_state_{2d,3d} functions Initial OpenACC port of the mpas_atm_get_bdy_state{2d,3d} functions
2 parents d1a4a94 + 947343f commit 4d7cc96

File tree

1 file changed

+71
-5
lines changed

1 file changed

+71
-5
lines changed

src/core_atmosphere/dynamics/mpas_atm_boundaries.F

+71-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
!
88
module mpas_atm_boundaries
99

10+
#ifdef MPAS_OPENACC
11+
use mpas_timer, only: mpas_timer_start, mpas_timer_stop
12+
#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X)
13+
#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X)
14+
#else
15+
#define MPAS_ACC_TIMER_START(X)
16+
#define MPAS_ACC_TIMER_STOP(X)
17+
#endif
18+
1019
use mpas_derived_types, only : mpas_pool_type, mpas_clock_type, block_type, mpas_time_type, mpas_timeInterval_type, MPAS_NOW, &
1120
MPAS_STREAM_LATEST_BEFORE, MPAS_STREAM_EARLIEST_STRICTLY_AFTER, &
1221
MPAS_streamManager_type
@@ -363,7 +372,7 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
363372
real (kind=RKIND), dimension(vertDim,horizDim+1) :: return_state
364373

365374
type (mpas_pool_type), pointer :: lbc
366-
integer, pointer :: idx
375+
integer, pointer :: idx_ptr
367376
real (kind=RKIND), dimension(:,:), pointer :: tend
368377
real (kind=RKIND), dimension(:,:), pointer :: state
369378
real (kind=RKIND), dimension(:,:,:), pointer :: tend_scalars
@@ -374,6 +383,7 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
374383
real (kind=RKIND) :: dt
375384
integer :: err_level
376385
integer :: ierr
386+
integer :: i,j,idx
377387

378388

379389
currTime = mpas_get_clock_time(clock, MPAS_NOW, ierr)
@@ -410,13 +420,49 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
410420
! query the field as a scalar constituent
411421
!
412422
if (associated(tend) .and. associated(state)) then
413-
return_state(:,:) = state(:,:) - dt * tend(:,:)
423+
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
424+
!$acc enter data create(return_state) &
425+
!$acc copyin(tend, state)
426+
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
427+
428+
!$acc parallel default(present)
429+
!$acc loop gang vector collapse(2)
430+
do i=1, horizDim+1
431+
do j=1, vertDim
432+
return_state(j,i) = state(j,i) - dt * tend(j,i)
433+
end do
434+
end do
435+
!$acc end parallel
436+
437+
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
438+
!$acc exit data copyout(return_state) &
439+
!$acc delete(tend, state)
440+
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
414441
else
415442
call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1)
416443
call mpas_pool_get_array(lbc, 'lbc_scalars', state_scalars, 2)
417-
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx)
444+
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr)
445+
446+
idx=idx_ptr ! Avoid non-array pointer for OpenACC
418447

419-
return_state(:,:) = state_scalars(idx,:,:) - dt * tend_scalars(idx,:,:)
448+
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
449+
!$acc enter data create(return_state) &
450+
!$acc copyin(tend_scalars, state_scalars)
451+
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
452+
453+
!$acc parallel default(present)
454+
!$acc loop gang vector collapse(2)
455+
do i=1, horizDim+1
456+
do j=1, vertDim
457+
return_state(j,i) = state_scalars(idx,j,i) - dt * tend_scalars(idx,j,i)
458+
end do
459+
end do
460+
!$acc end parallel
461+
462+
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
463+
!$acc exit data copyout(return_state) &
464+
!$acc delete(tend_scalars, state_scalars)
465+
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
420466
end if
421467

422468
end function mpas_atm_get_bdy_state_2d
@@ -476,6 +522,7 @@ function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, fi
476522
real (kind=RKIND) :: dt
477523
integer :: err_level
478524
integer :: ierr
525+
integer :: i,j,k
479526

480527

481528
currTime = mpas_get_clock_time(clock, MPAS_NOW, ierr)
@@ -496,7 +543,26 @@ function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, fi
496543
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1)
497544
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), state, 2)
498545

499-
return_state(:,:,:) = state(:,:,:) - dt * tend(:,:,:)
546+
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
547+
!$acc enter data create(return_state) &
548+
!$acc copyin(tend, state)
549+
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
550+
551+
!$acc parallel default(present)
552+
!$acc loop gang vector collapse(3)
553+
do i=1, horizDim+1
554+
do j=1, vertDim
555+
do k=1, innerDim
556+
return_state(k,j,i) = state(k,j,i) - dt * tend(k,j,i)
557+
end do
558+
end do
559+
end do
560+
!$acc end parallel
561+
562+
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
563+
!$acc exit data copyout(return_state) &
564+
!$acc delete(tend, state)
565+
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
500566

501567
end function mpas_atm_get_bdy_state_3d
502568

0 commit comments

Comments
 (0)