@@ -3856,7 +3856,8 @@ void TranslateBetweenGrids(
38563856 EL_DEBUG_CSE;
38573857
38583858 /* Overview
3859-
3859+ We broadcast the size of A to all the ranks in B to make sure that
3860+ all ranks in B subgrid has the correct size of A.
38603861 Since we are using blocking communication, some care is required
38613862 to avoid deadlocks. Let's start with a naive algorithm for
38623863 [STAR,VC] matrices and optimize it in steps:
@@ -3883,21 +3884,53 @@ void TranslateBetweenGrids(
38833884 */
38843885
38853886 // Matrix dimensions
3886- const Int m = A.Height ();
3887- const Int n = A.Width ();
3887+ Int m = A.Height ();
3888+ Int n = A.Width ();
3889+ Int strideA = A.RowStride ();
3890+ Int ALDim = A.LDim ();
3891+
3892+ // Create A metadata
3893+ Int recvMetaData[4 ];
3894+ Int metaData[4 ];
3895+
3896+ SyncInfo<El::Device::CPU> syncGeneralMetaData = SyncInfo<El::Device::CPU>();
3897+ mpi::Comm const & viewingCommB = B.Grid ().ViewingComm ();
3898+
3899+ const bool inAGrid = A.Participating ();
3900+ const bool inBGrid = B.Participating ();
3901+
3902+ if (inAGrid)
3903+ {
3904+ metaData[0 ] = m;
3905+ metaData[1 ] = n;
3906+ metaData[2 ] = strideA;
3907+ metaData[3 ] = ALDim;
3908+ }
3909+ else
3910+ {
3911+ metaData[0 ] = 0 ;
3912+ metaData[1 ] = 0 ;
3913+ metaData[2 ] = 0 ;
3914+ metaData[3 ] = 0 ;
3915+ }
3916+ const std::vector<Int> sendMetaData (metaData, metaData + 4 );
3917+ mpi::AllReduce ( sendMetaData.data (), recvMetaData, 4 , mpi::MAX, viewingCommB, syncGeneralMetaData);
3918+ m = recvMetaData[0 ];
3919+ n = recvMetaData[1 ];
3920+ strideA = recvMetaData[2 ];
3921+ ALDim =recvMetaData[3 ];
3922+
3923+
38883924 B.Resize (m, n);
38893925 const Int nLocA = A.LocalWidth ();
38903926 const Int nLocB = B.LocalWidth ();
38913927
38923928 // Return immediately if there is no local data
3893- const bool inAGrid = A.Participating ();
3894- const bool inBGrid = B.Participating ();
38953929 if (!inAGrid && !inBGrid) {
38963930 return ;
38973931 }
38983932
38993933 // Compute the number of messages to send/recv
3900- const Int strideA = A.RowStride ();
39013934 const Int strideB = B.RowStride ();
39023935 const Int strideGCD = GCD (strideA, strideB);
39033936 const Int numSends = Min (strideB/strideGCD, nLocA);
@@ -3913,7 +3946,6 @@ void TranslateBetweenGrids(
39133946 // that we can match send/recv communicators. Since A's VC
39143947 // communicator is not necessarily defined on every process, we
39153948 // instead work with A's owning group.
3916- mpi::Comm const & viewingCommB = B.Grid ().ViewingComm ();
39173949 mpi::Group owningGroupA = A.Grid ().OwningGroup ();
39183950 const int sizeA = A.Grid ().Size ();
39193951 vector<int > viewingRanksA (sizeA), owningRanksA (sizeA);
@@ -3976,15 +4008,15 @@ void TranslateBetweenGrids(
39764008 // Copy data locally
39774009 copy::util::InterleaveMatrix (
39784010 m, messageWidth,
3979- A.LockedBuffer (0 ,jLocA), 1 , numSends*A. LDim () ,
4011+ A.LockedBuffer (0 ,jLocA), 1 , numSends*ALDim ,
39804012 B.Buffer (0 ,jLocB), 1 , numRecvs*B.LDim (),
39814013 syncInfo);
39824014 }
39834015 else if (viewingRank == sendViewingRank) {
39844016 // Send data to other rank
39854017 copy::util::InterleaveMatrix (
39864018 m, messageWidth,
3987- A.LockedBuffer (0 ,jLocA), 1 , numSends*A. LDim () ,
4019+ A.LockedBuffer (0 ,jLocA), 1 , numSends*ALDim ,
39884020 messageBuf.data (), 1 , m,
39894021 syncInfo);
39904022 mpi::Send (
0 commit comments