Skip to content

Commit 2f7f309

Browse files
authored
Add metadata support to TranslateBetweenGrid for Star VC (elemental#151)
1 parent bf3cd78 commit 2f7f309

File tree

1 file changed

+41
-9
lines changed

1 file changed

+41
-9
lines changed

include/El/blas_like/level1/Copy/TranslateBetweenGrids.hpp

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3856,7 +3856,8 @@ void TranslateBetweenGrids(
38563856
EL_DEBUG_CSE;
38573857

38583858
/* Overview
3859-
3859+
We broadcast the size of A to all the ranks in B to make sure that
3860+
all ranks in B subgrid has the correct size of A.
38603861
Since we are using blocking communication, some care is required
38613862
to avoid deadlocks. Let's start with a naive algorithm for
38623863
[STAR,VC] matrices and optimize it in steps:
@@ -3883,21 +3884,53 @@ void TranslateBetweenGrids(
38833884
*/
38843885

38853886
// Matrix dimensions
3886-
const Int m = A.Height();
3887-
const Int n = A.Width();
3887+
Int m = A.Height();
3888+
Int n = A.Width();
3889+
Int strideA = A.RowStride();
3890+
Int ALDim = A.LDim();
3891+
3892+
// Create A metadata
3893+
Int recvMetaData[4];
3894+
Int metaData[4];
3895+
3896+
SyncInfo<El::Device::CPU> syncGeneralMetaData = SyncInfo<El::Device::CPU>();
3897+
mpi::Comm const& viewingCommB = B.Grid().ViewingComm();
3898+
3899+
const bool inAGrid = A.Participating();
3900+
const bool inBGrid = B.Participating();
3901+
3902+
if(inAGrid)
3903+
{
3904+
metaData[0] = m;
3905+
metaData[1] = n;
3906+
metaData[2] = strideA;
3907+
metaData[3] = ALDim;
3908+
}
3909+
else
3910+
{
3911+
metaData[0] = 0;
3912+
metaData[1] = 0;
3913+
metaData[2] = 0;
3914+
metaData[3] = 0;
3915+
}
3916+
const std::vector<Int> sendMetaData (metaData, metaData + 4);
3917+
mpi::AllReduce( sendMetaData.data(), recvMetaData, 4, mpi::MAX, viewingCommB, syncGeneralMetaData);
3918+
m = recvMetaData[0];
3919+
n = recvMetaData[1];
3920+
strideA = recvMetaData[2];
3921+
ALDim =recvMetaData[3];
3922+
3923+
38883924
B.Resize(m, n);
38893925
const Int nLocA = A.LocalWidth();
38903926
const Int nLocB = B.LocalWidth();
38913927

38923928
// Return immediately if there is no local data
3893-
const bool inAGrid = A.Participating();
3894-
const bool inBGrid = B.Participating();
38953929
if (!inAGrid && !inBGrid) {
38963930
return;
38973931
}
38983932

38993933
// Compute the number of messages to send/recv
3900-
const Int strideA = A.RowStride();
39013934
const Int strideB = B.RowStride();
39023935
const Int strideGCD = GCD(strideA, strideB);
39033936
const Int numSends = Min(strideB/strideGCD, nLocA);
@@ -3913,7 +3946,6 @@ void TranslateBetweenGrids(
39133946
// that we can match send/recv communicators. Since A's VC
39143947
// communicator is not necessarily defined on every process, we
39153948
// instead work with A's owning group.
3916-
mpi::Comm const& viewingCommB = B.Grid().ViewingComm();
39173949
mpi::Group owningGroupA = A.Grid().OwningGroup();
39183950
const int sizeA = A.Grid().Size();
39193951
vector<int> viewingRanksA(sizeA), owningRanksA(sizeA);
@@ -3976,15 +4008,15 @@ void TranslateBetweenGrids(
39764008
// Copy data locally
39774009
copy::util::InterleaveMatrix(
39784010
m, messageWidth,
3979-
A.LockedBuffer(0,jLocA), 1, numSends*A.LDim(),
4011+
A.LockedBuffer(0,jLocA), 1, numSends*ALDim,
39804012
B.Buffer(0,jLocB), 1, numRecvs*B.LDim(),
39814013
syncInfo);
39824014
}
39834015
else if (viewingRank == sendViewingRank) {
39844016
// Send data to other rank
39854017
copy::util::InterleaveMatrix(
39864018
m, messageWidth,
3987-
A.LockedBuffer(0,jLocA), 1, numSends*A.LDim(),
4019+
A.LockedBuffer(0,jLocA), 1, numSends*ALDim,
39884020
messageBuf.data(), 1, m,
39894021
syncInfo);
39904022
mpi::Send(

0 commit comments

Comments
 (0)