7
7
!
8
8
module mpas_atm_boundaries
9
9
10
+ #ifdef MPAS_OPENACC
11
+ use mpas_timer, only: mpas_timer_start, mpas_timer_stop
12
+ #define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X)
13
+ #define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X)
14
+ #else
15
+ #define MPAS_ACC_TIMER_START(X)
16
+ #define MPAS_ACC_TIMER_STOP(X)
17
+ #endif
18
+
10
19
use mpas_derived_types, only : mpas_pool_type, mpas_clock_type, block_type, mpas_time_type, mpas_timeInterval_type, MPAS_NOW, &
11
20
MPAS_STREAM_LATEST_BEFORE, MPAS_STREAM_EARLIEST_STRICTLY_AFTER, &
12
21
MPAS_streamManager_type
@@ -363,7 +372,7 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
363
372
real (kind= RKIND), dimension (vertDim,horizDim+1 ) :: return_state
364
373
365
374
type (mpas_pool_type), pointer :: lbc
366
- integer , pointer :: idx
375
+ integer , pointer :: idx_ptr
367
376
real (kind= RKIND), dimension (:,:), pointer :: tend
368
377
real (kind= RKIND), dimension (:,:), pointer :: state
369
378
real (kind= RKIND), dimension (:,:,:), pointer :: tend_scalars
@@ -374,6 +383,7 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
374
383
real (kind= RKIND) :: dt
375
384
integer :: err_level
376
385
integer :: ierr
386
+ integer :: i,j,idx
377
387
378
388
379
389
currTime = mpas_get_clock_time(clock, MPAS_NOW, ierr)
@@ -410,13 +420,49 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
410
420
! query the field as a scalar constituent
411
421
!
412
422
if (associated(tend) .and. associated(state)) then
413
- return_state(:,:) = state(:,:) - dt * tend(:,:)
423
+ MPAS_ACC_TIMER_START(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
424
+ !$acc enter data create(return_state) &
425
+ !$acc copyin(tend, state)
426
+ MPAS_ACC_TIMER_STOP(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
427
+
428
+ !$acc parallel default(present)
429
+ !$acc loop gang vector collapse(2 )
430
+ do i= 1 , horizDim+1
431
+ do j= 1 , vertDim
432
+ return_state(j,i) = state(j,i) - dt * tend(j,i)
433
+ end do
434
+ end do
435
+ !$acc end parallel
436
+
437
+ MPAS_ACC_TIMER_START(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
438
+ !$acc exit data copyout(return_state) &
439
+ !$acc delete(tend, state)
440
+ MPAS_ACC_TIMER_STOP(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
414
441
else
415
442
call mpas_pool_get_array(lbc, ' lbc_scalars' , tend_scalars, 1 )
416
443
call mpas_pool_get_array(lbc, ' lbc_scalars' , state_scalars, 2 )
417
- call mpas_pool_get_dimension(lbc, ' index_' // trim (field), idx)
444
+ call mpas_pool_get_dimension(lbc, ' index_' // trim (field), idx_ptr)
445
+
446
+ idx= idx_ptr ! Avoid non- array pointer for OpenACC
418
447
419
- return_state(:,:) = state_scalars(idx,:,:) - dt * tend_scalars(idx,:,:)
448
+ MPAS_ACC_TIMER_START(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
449
+ !$acc enter data create(return_state) &
450
+ !$acc copyin(tend_scalars, state_scalars)
451
+ MPAS_ACC_TIMER_STOP(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
452
+
453
+ !$acc parallel default(present)
454
+ !$acc loop gang vector collapse(2 )
455
+ do i= 1 , horizDim+1
456
+ do j= 1 , vertDim
457
+ return_state(j,i) = state_scalars(idx,j,i) - dt * tend_scalars(idx,j,i)
458
+ end do
459
+ end do
460
+ !$acc end parallel
461
+
462
+ MPAS_ACC_TIMER_START(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
463
+ !$acc exit data copyout(return_state) &
464
+ !$acc delete(tend_scalars, state_scalars)
465
+ MPAS_ACC_TIMER_STOP(' mpas_atm_get_bdy_state_2d [ACC_data_xfer]' )
420
466
end if
421
467
422
468
end function mpas_atm_get_bdy_state_2d
@@ -476,6 +522,7 @@ function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, fi
476
522
real (kind= RKIND) :: dt
477
523
integer :: err_level
478
524
integer :: ierr
525
+ integer :: i,j,k
479
526
480
527
481
528
currTime = mpas_get_clock_time(clock, MPAS_NOW, ierr)
@@ -496,7 +543,26 @@ function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, fi
496
543
call mpas_pool_get_array(lbc, ' lbc_' // trim (field), tend, 1 )
497
544
call mpas_pool_get_array(lbc, ' lbc_' // trim (field), state, 2 )
498
545
499
- return_state(:,:,:) = state(:,:,:) - dt * tend(:,:,:)
546
+ MPAS_ACC_TIMER_START(' mpas_atm_get_bdy_state_3d [ACC_data_xfer]' )
547
+ !$acc enter data create(return_state) &
548
+ !$acc copyin(tend, state)
549
+ MPAS_ACC_TIMER_STOP(' mpas_atm_get_bdy_state_3d [ACC_data_xfer]' )
550
+
551
+ !$acc parallel default(present)
552
+ !$acc loop gang vector collapse(3 )
553
+ do i= 1 , horizDim+1
554
+ do j= 1 , vertDim
555
+ do k= 1 , innerDim
556
+ return_state(k,j,i) = state(k,j,i) - dt * tend(k,j,i)
557
+ end do
558
+ end do
559
+ end do
560
+ !$acc end parallel
561
+
562
+ MPAS_ACC_TIMER_START(' mpas_atm_get_bdy_state_3d [ACC_data_xfer]' )
563
+ !$acc exit data copyout(return_state) &
564
+ !$acc delete(tend, state)
565
+ MPAS_ACC_TIMER_STOP(' mpas_atm_get_bdy_state_3d [ACC_data_xfer]' )
500
566
501
567
end function mpas_atm_get_bdy_state_3d
502
568
0 commit comments