Skip to content

Commit 71957e9

Browse files
authored
feat: use MPI collectives for mpp_scatter and mpp_gather (#1655)
1 parent c062d8c commit 71957e9

File tree

11 files changed

+423
-194
lines changed

11 files changed

+423
-194
lines changed

mpp/include/mpp_comm.inc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@
384384
#undef MPP_GATHER_PELIST_3D_
385385
#define MPP_GATHER_PELIST_2D_ mpp_gather_pelist_logical_2d
386386
#define MPP_GATHER_PELIST_3D_ mpp_gather_pelist_logical_3d
387+
#undef MPI_TYPE_
388+
#define MPI_TYPE_ MPI_LOGICAL
387389
#include <mpp_gather.fh>
388390
389391
#undef MPP_GATHER_1D_
@@ -396,6 +398,8 @@
396398
#undef MPP_GATHER_PELIST_3D_
397399
#define MPP_GATHER_PELIST_2D_ mpp_gather_pelist_int4_2d
398400
#define MPP_GATHER_PELIST_3D_ mpp_gather_pelist_int4_3d
401+
#undef MPI_TYPE_
402+
#define MPI_TYPE_ MPI_INTEGER4
399403
#include <mpp_gather.fh>
400404
401405
@@ -409,6 +413,8 @@
409413
#undef MPP_GATHER_PELIST_3D_
410414
#define MPP_GATHER_PELIST_2D_ mpp_gather_pelist_int8_2d
411415
#define MPP_GATHER_PELIST_3D_ mpp_gather_pelist_int8_3d
416+
#undef MPI_TYPE_
417+
#define MPI_TYPE_ MPI_INTEGER8
412418
#include <mpp_gather.fh>
413419
414420
@@ -422,6 +428,8 @@
422428
#undef MPP_GATHER_PELIST_3D_
423429
#define MPP_GATHER_PELIST_2D_ mpp_gather_pelist_real4_2d
424430
#define MPP_GATHER_PELIST_3D_ mpp_gather_pelist_real4_3d
431+
#undef MPI_TYPE_
432+
#define MPI_TYPE_ MPI_REAL4
425433
#include <mpp_gather.fh>
426434
427435
#undef MPP_GATHER_1D_
@@ -434,6 +442,8 @@
434442
#undef MPP_GATHER_PELIST_3D_
435443
#define MPP_GATHER_PELIST_2D_ mpp_gather_pelist_real8_2d
436444
#define MPP_GATHER_PELIST_3D_ mpp_gather_pelist_real8_3d
445+
#undef MPI_TYPE_
446+
#define MPI_TYPE_ MPI_REAL8
437447
#include <mpp_gather.fh>
438448
439449
!#################################################
@@ -443,6 +453,8 @@
443453
#define MPP_TYPE_ integer(i4_kind)
444454
#define MPP_SCATTER_PELIST_2D_ mpp_scatter_pelist_int4_2d
445455
#define MPP_SCATTER_PELIST_3D_ mpp_scatter_pelist_int4_3d
456+
#undef MPI_TYPE_
457+
#define MPI_TYPE_ MPI_INTEGER4
446458
#include <mpp_scatter.fh>
447459
448460
#undef MPP_SCATTER_PELIST_2D_
@@ -451,6 +463,8 @@
451463
#define MPP_TYPE_ integer(i8_kind)
452464
#define MPP_SCATTER_PELIST_2D_ mpp_scatter_pelist_int8_2d
453465
#define MPP_SCATTER_PELIST_3D_ mpp_scatter_pelist_int8_3d
466+
#undef MPI_TYPE_
467+
#define MPI_TYPE_ MPI_INTEGER8
454468
#include <mpp_scatter.fh>
455469
456470
#undef MPP_SCATTER_PELIST_2D_
@@ -459,6 +473,8 @@
459473
#define MPP_TYPE_ real(r4_kind)
460474
#define MPP_SCATTER_PELIST_2D_ mpp_scatter_pelist_real4_2d
461475
#define MPP_SCATTER_PELIST_3D_ mpp_scatter_pelist_real4_3d
476+
#undef MPI_TYPE_
477+
#define MPI_TYPE_ MPI_REAL4
462478
#include <mpp_scatter.fh>
463479
464480
#undef MPP_SCATTER_PELIST_2D_
@@ -467,5 +483,7 @@
467483
#define MPP_TYPE_ real(r8_kind)
468484
#define MPP_SCATTER_PELIST_2D_ mpp_scatter_pelist_real8_2d
469485
#define MPP_SCATTER_PELIST_3D_ mpp_scatter_pelist_real8_3d
486+
#undef MPI_TYPE_
487+
#define MPI_TYPE_ MPI_REAL8
470488
#include <mpp_scatter.fh>
471489
!> @}

mpp/include/mpp_comm_mpi.inc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,21 @@ end subroutine mpp_exit
411411
#define MPP_BROADCAST_4D_ mpp_broadcast_real8_4d
412412
#undef MPP_BROADCAST_5D_
413413
#define MPP_BROADCAST_5D_ mpp_broadcast_real8_5d
414+
#undef MPP_SCATTERV_
415+
#define MPP_SCATTERV_ mpp_scatterv_real8
416+
#undef MPP_GATHER_
417+
#define MPP_GATHER_ mpp_gather_real8
418+
#undef MPP_GATHERV_
419+
#define MPP_GATHERV_ mpp_gatherv_real8
414420
#undef MPP_TYPE_
415421
#define MPP_TYPE_ real(r8_kind)
416422
#undef MPP_TYPE_BYTELEN_
417423
#define MPP_TYPE_BYTELEN_ 8
418424
#undef MPI_TYPE_
419425
#define MPI_TYPE_ MPI_REAL8
426+
427+
428+
420429
#include <mpp_transmit_mpi.fh>
421430

422431
#ifdef OVERLOAD_C8
@@ -468,6 +477,12 @@ end subroutine mpp_exit
468477
#define MPP_BROADCAST_4D_ mpp_broadcast_cmplx8_4d
469478
#undef MPP_BROADCAST_5D_
470479
#define MPP_BROADCAST_5D_ mpp_broadcast_cmplx8_5d
480+
#undef MPP_SCATTERV_
481+
#define MPP_SCATTERV_ mpp_scatterv_complx8
482+
#undef MPP_GATHER_
483+
#define MPP_GATHER_ mpp_gather_complx8
484+
#undef MPP_GATHERV_
485+
#define MPP_GATHERV_ mpp_gatherv_complx8
471486
#undef MPP_TYPE_
472487
#define MPP_TYPE_ complex(c8_kind)
473488
#undef MPP_TYPE_BYTELEN_
@@ -525,6 +540,12 @@ end subroutine mpp_exit
525540
#define MPP_BROADCAST_4D_ mpp_broadcast_real4_4d
526541
#undef MPP_BROADCAST_5D_
527542
#define MPP_BROADCAST_5D_ mpp_broadcast_real4_5d
543+
#undef MPP_SCATTERV_
544+
#define MPP_SCATTERV_ mpp_scatterv_real4
545+
#undef MPP_GATHER_
546+
#define MPP_GATHER_ mpp_gather_real4
547+
#undef MPP_GATHERV_
548+
#define MPP_GATHERV_ mpp_gatherv_real4
528549
#undef MPP_TYPE_
529550
#define MPP_TYPE_ real(r4_kind)
530551
#undef MPP_TYPE_BYTELEN_
@@ -582,6 +603,12 @@ end subroutine mpp_exit
582603
#define MPP_BROADCAST_4D_ mpp_broadcast_cmplx4_4d
583604
#undef MPP_BROADCAST_5D_
584605
#define MPP_BROADCAST_5D_ mpp_broadcast_cmplx4_5d
606+
#undef MPP_SCATTERV_
607+
#define MPP_SCATTERV_ mpp_scatterv_cmplx4
608+
#undef MPP_GATHER_
609+
#define MPP_GATHER_ mpp_gather_cmplx4
610+
#undef MPP_GATHERV_
611+
#define MPP_GATHERV_ mpp_gatherv_cmplx4
585612
#undef MPP_TYPE_
586613
#define MPP_TYPE_ complex(c4_kind)
587614
#undef MPP_TYPE_BYTELEN_
@@ -641,6 +668,12 @@ end subroutine mpp_exit
641668
#define MPP_BROADCAST_4D_ mpp_broadcast_int8_4d
642669
#undef MPP_BROADCAST_5D_
643670
#define MPP_BROADCAST_5D_ mpp_broadcast_int8_5d
671+
#undef MPP_SCATTERV_
672+
#define MPP_SCATTERV_ mpp_scatterv_int8
673+
#undef MPP_GATHER_
674+
#define MPP_GATHER_ mpp_gather_int8
675+
#undef MPP_GATHERV_
676+
#define MPP_GATHERV_ mpp_gatherv_int8
644677
#undef MPP_TYPE_
645678
#define MPP_TYPE_ integer(i8_kind)
646679
#undef MPP_TYPE_BYTELEN_
@@ -697,6 +730,12 @@ end subroutine mpp_exit
697730
#define MPP_BROADCAST_4D_ mpp_broadcast_int4_4d
698731
#undef MPP_BROADCAST_5D_
699732
#define MPP_BROADCAST_5D_ mpp_broadcast_int4_5d
733+
#undef MPP_SCATTERV_
734+
#define MPP_SCATTERV_ mpp_scatterv_int4
735+
#undef MPP_GATHER_
736+
#define MPP_GATHER_ mpp_gather_int4
737+
#undef MPP_GATHERV_
738+
#define MPP_GATHERV_ mpp_gatherv_int4
700739
#undef MPP_TYPE_
701740
#define MPP_TYPE_ integer(i4_kind)
702741
#undef MPP_TYPE_BYTELEN_
@@ -755,6 +794,12 @@ end subroutine mpp_exit
755794
#define MPP_BROADCAST_4D_ mpp_broadcast_logical8_4d
756795
#undef MPP_BROADCAST_5D_
757796
#define MPP_BROADCAST_5D_ mpp_broadcast_logical8_5d
797+
#undef MPP_SCATTERV_
798+
#define MPP_SCATTERV_ mpp_scatterv_logical8
799+
#undef MPP_GATHER_
800+
#define MPP_GATHER_ mpp_gather_logical8
801+
#undef MPP_GATHERV_
802+
#define MPP_GATHERV_ mpp_gatherv_logical8
758803
#undef MPP_TYPE_
759804
#define MPP_TYPE_ logical(l8_kind)
760805
#undef MPP_TYPE_BYTELEN_
@@ -811,6 +856,12 @@ end subroutine mpp_exit
811856
#define MPP_BROADCAST_4D_ mpp_broadcast_logical4_4d
812857
#undef MPP_BROADCAST_5D_
813858
#define MPP_BROADCAST_5D_ mpp_broadcast_logical4_5d
859+
#undef MPP_SCATTERV_
860+
#define MPP_SCATTERV_ mpp_scatterv_logical4
861+
#undef MPP_GATHER_
862+
#define MPP_GATHER_ mpp_gather_logical4
863+
#undef MPP_GATHERV_
864+
#define MPP_GATHERV_ mpp_gatherv_logical4
814865
#undef MPP_TYPE_
815866
#define MPP_TYPE_ logical(l4_kind)
816867
#undef MPP_TYPE_BYTELEN_

mpp/include/mpp_comm_nocomm.inc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ end subroutine mpp_exit
311311
#define MPP_BROADCAST_4D_ mpp_broadcast_real8_4d
312312
#undef MPP_BROADCAST_5D_
313313
#define MPP_BROADCAST_5D_ mpp_broadcast_real8_5d
314+
#undef MPP_SCATTERV_
315+
#define MPP_SCATTERV_ mpp_scatterv_real8
316+
#undef MPP_GATHER_
317+
#define MPP_GATHER_ mpp_gather_real8
318+
#undef MPP_GATHERV_
319+
#define MPP_GATHERV_ mpp_gatherv_real8
314320
#undef MPP_TYPE_
315321
#define MPP_TYPE_ real(r8_kind)
316322
#undef MPP_TYPE_BYTELEN_
@@ -368,6 +374,12 @@ end subroutine mpp_exit
368374
#define MPP_BROADCAST_4D_ mpp_broadcast_cmplx8_4d
369375
#undef MPP_BROADCAST_5D_
370376
#define MPP_BROADCAST_5D_ mpp_broadcast_cmplx8_5d
377+
#undef MPP_SCATTERV_
378+
#define MPP_SCATTERV_ mpp_scatterv_cmplx8
379+
#undef MPP_GATHER_
380+
#define MPP_GATHER_ mpp_gather_cmplx8
381+
#undef MPP_GATHERV_
382+
#define MPP_GATHERV_ mpp_gatherv_cmplx8
371383
#undef MPP_TYPE_
372384
#define MPP_TYPE_ complex(c8_kind)
373385
#undef MPP_TYPE_BYTELEN_
@@ -425,6 +437,12 @@ end subroutine mpp_exit
425437
#define MPP_BROADCAST_4D_ mpp_broadcast_real4_4d
426438
#undef MPP_BROADCAST_5D_
427439
#define MPP_BROADCAST_5D_ mpp_broadcast_real4_5d
440+
#undef MPP_SCATTERV_
441+
#define MPP_SCATTERV_ mpp_scatterv_real4
442+
#undef MPP_GATHER_
443+
#define MPP_GATHER_ mpp_gather_real4
444+
#undef MPP_GATHERV_
445+
#define MPP_GATHERV_ mpp_gatherv_real4
428446
#undef MPP_TYPE_
429447
#define MPP_TYPE_ real(r4_kind)
430448
#undef MPP_TYPE_BYTELEN_
@@ -482,6 +500,12 @@ end subroutine mpp_exit
482500
#define MPP_BROADCAST_4D_ mpp_broadcast_cmplx4_4d
483501
#undef MPP_BROADCAST_5D_
484502
#define MPP_BROADCAST_5D_ mpp_broadcast_cmplx4_5d
503+
#undef MPP_SCATTERV_
504+
#define MPP_SCATTERV_ mpp_scatterv_cmplx4
505+
#undef MPP_GATHER_
506+
#define MPP_GATHER_ mpp_gather_cmplx4
507+
#undef MPP_GATHERV_
508+
#define MPP_GATHERV_ mpp_gatherv_cmplx4
485509
#undef MPP_TYPE_
486510
#define MPP_TYPE_ complex(c4_kind)
487511
#undef MPP_TYPE_BYTELEN_
@@ -541,6 +565,12 @@ end subroutine mpp_exit
541565
#define MPP_BROADCAST_4D_ mpp_broadcast_int8_4d
542566
#undef MPP_BROADCAST_5D_
543567
#define MPP_BROADCAST_5D_ mpp_broadcast_int8_5d
568+
#undef MPP_SCATTERV_
569+
#define MPP_SCATTERV_ mpp_scatterv_int8
570+
#undef MPP_GATHER_
571+
#define MPP_GATHER_ mpp_gather_int8
572+
#undef MPP_GATHERV_
573+
#define MPP_GATHERV_ mpp_gatherv_int8
544574
#undef MPP_TYPE_
545575
#define MPP_TYPE_ integer(i8_kind)
546576
#undef MPP_TYPE_BYTELEN_
@@ -597,6 +627,12 @@ end subroutine mpp_exit
597627
#define MPP_BROADCAST_4D_ mpp_broadcast_int4_4d
598628
#undef MPP_BROADCAST_5D_
599629
#define MPP_BROADCAST_5D_ mpp_broadcast_int4_5d
630+
#undef MPP_SCATTERV_
631+
#define MPP_SCATTERV_ mpp_scatterv_int4
632+
#undef MPP_GATHER_
633+
#define MPP_GATHER_ mpp_gather_int4
634+
#undef MPP_GATHERV_
635+
#define MPP_GATHERV_ mpp_gatherv_int4
600636
#undef MPP_TYPE_
601637
#define MPP_TYPE_ integer(i4_kind)
602638
#undef MPP_TYPE_BYTELEN_
@@ -655,6 +691,12 @@ end subroutine mpp_exit
655691
#define MPP_BROADCAST_4D_ mpp_broadcast_logical8_4d
656692
#undef MPP_BROADCAST_5D_
657693
#define MPP_BROADCAST_5D_ mpp_broadcast_logical8_5d
694+
#undef MPP_SCATTERV_
695+
#define MPP_SCATTERV_ mpp_scatterv_logical8
696+
#undef MPP_GATHER_
697+
#define MPP_GATHER_ mpp_gather_logical8
698+
#undef MPP_GATHERV_
699+
#define MPP_GATHERV_ mpp_gatherv_logical8
658700
#undef MPP_TYPE_
659701
#define MPP_TYPE_ logical(l8_kind)
660702
#undef MPP_TYPE_BYTELEN_
@@ -711,6 +753,12 @@ end subroutine mpp_exit
711753
#define MPP_BROADCAST_4D_ mpp_broadcast_logical4_4d
712754
#undef MPP_BROADCAST_5D_
713755
#define MPP_BROADCAST_5D_ mpp_broadcast_logical4_5d
756+
#undef MPP_SCATTERV_
757+
#define MPP_SCATTERV_ mpp_scatterv_logical4
758+
#undef MPP_GATHER_
759+
#define MPP_GATHER_ mpp_gather_logical4
760+
#undef MPP_GATHERV_
761+
#define MPP_GATHERV_ mpp_gatherv_logical4
714762
#undef MPP_TYPE_
715763
#define MPP_TYPE_ logical(l4_kind)
716764
#undef MPP_TYPE_BYTELEN_

mpp/include/mpp_gather.fh

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@
2020

2121
!> @addtogroup mpp_mod
2222
!> @{
23-
subroutine MPP_GATHER_1D_(sbuf, rbuf,pelist)
23+
subroutine MPP_GATHER_1D_(sbuf, rbuf, pelist)
2424
! JWD: Did not create mpp_gather_2d because have no requirement for it
2525
! JWD: See mpp_gather_2dv below
2626
MPP_TYPE_, dimension(:), intent(in) :: sbuf
2727
MPP_TYPE_, dimension(:), intent(inout) :: rbuf
2828
integer, dimension(:), intent(in), optional :: pelist(:)
2929

30-
integer :: cnt, l, nproc, op_root
30+
integer :: cnt, l, nproc, op_root, ierr
3131
integer, allocatable :: pelist2(:)
3232

33+
if( .NOT.module_is_initialized ) call mpp_error( FATAL, 'MPP_GATHER_1D_: You must first call mpp_init.' )
3334

34-
! If pelist is provided, the first position must be
35-
! the operation root
35+
! If pelist is provided, the first position must be the operation root, w.r.t. new comm, op_root = 0
3636
if(PRESENT(pelist))then
37+
if(.not.ANY(mpp_pe().eq.pelist(:))) return
3738
nproc = size(pelist)
3839
allocate(pelist2(nproc))
3940
pelist2 = pelist
@@ -48,17 +49,8 @@ subroutine MPP_GATHER_1D_(sbuf, rbuf,pelist)
4849
if(size(rbuf(:)) < cnt*nproc) call mpp_error(FATAL, &
4950
"MPP_GATHER_1D_: size(rbuf) must be at least npes*size(sbuf) ")
5051

51-
!--- pre-post receiving
52-
if(pe == op_root) then
53-
rbuf(1:cnt) = sbuf
54-
do l = 2, nproc
55-
call mpp_recv(rbuf((l-1)*cnt+1), glen=cnt, from_pe=pelist2(l), block=.FALSE., tag=COMM_TAG_1 )
56-
enddo
57-
else
58-
call mpp_send(sbuf(1), plen=cnt, to_pe=op_root, tag=COMM_TAG_1)
59-
endif
52+
call mpp_gather( sbuf, rbuf, size(sbuf), op_root, pelist2, ierr )
6053

61-
call mpp_sync_self(check=EVENT_RECV)
6254
call mpp_sync_self()
6355
deallocate(pelist2)
6456
end subroutine MPP_GATHER_1D_
@@ -70,8 +62,9 @@ subroutine MPP_GATHER_1DV_(sbuf, ssize, rbuf, rsize, pelist)
7062
integer, dimension(:), intent(in) :: rsize
7163
integer, dimension(:), intent(in), optional :: pelist(:)
7264

73-
integer :: l, nproc, pos, op_root
74-
integer, allocatable :: pelist2(:)
65+
integer :: l, nproc, op_root, ierr
66+
integer, dimension(:), allocatable :: displs
67+
integer, dimension(:), allocatable :: pelist2
7568

7669
! If pelist is provided, the first position must be
7770
! the operation root
@@ -82,30 +75,24 @@ subroutine MPP_GATHER_1DV_(sbuf, ssize, rbuf, rsize, pelist)
8275
else
8376
nproc = mpp_npes()
8477
allocate(pelist2(nproc))
85-
pelist2 = (/ (l, l=0+root_pe, nproc-1+root_pe) /)
78+
pelist2 = (/ (l, l=root_pe, nproc-1+root_pe) /)
8679
endif
8780
op_root = pelist2(1)
8881

82+
if(pe .eq. op_root) then
83+
allocate(displs(nproc))
8984

90-
!--- pre-post receiving
91-
if (pe .eq. op_root) then
92-
pos = 1
93-
do l = 1,nproc ! include op_root to simplify logic
94-
if (rsize(l) == 0) then
95-
cycle ! avoid ranks with no data
96-
endif
97-
call mpp_recv(rbuf(pos),glen=rsize(l),from_pe=pelist2(l), &
98-
block=.FALSE.,tag=COMM_TAG_2)
99-
pos = pos + rsize(l)
100-
enddo
101-
endif
102-
if (ssize .gt. 0) then
103-
call mpp_send(sbuf(1),plen=ssize,to_pe=op_root,tag=COMM_TAG_2) !avoid ranks with no data
85+
displs(1) = 0
86+
do l = 2, nproc
87+
displs(l) = displs(l-1) + rsize(l-1)
88+
enddo
10489
endif
10590

106-
call mpp_sync_self(check=EVENT_RECV)
91+
call mpp_gather( sbuf, ssize, rbuf, rsize, displs, op_root, pelist2, ierr )
92+
10793
call mpp_sync_self()
10894
deallocate(pelist2)
95+
if(pe .eq. op_root) deallocate(displs)
10996
end subroutine MPP_GATHER_1DV_
11097

11198

0 commit comments

Comments
 (0)