Skip to content

Commit ddf9342

Browse files
authored
WarpXSolverVec: Remove duplicates (BLAST-WarpX#6415)
This appears necessary for PETSc to work.
1 parent 94bd1ab commit ddf9342

File tree

4 files changed

+149
-129
lines changed

4 files changed

+149
-129
lines changed

Source/FieldSolver/ImplicitSolvers/WarpXSolverDOF.H

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ class WarpX;
2424
struct WarpXSolverDOF
2525
{
2626
amrex::Vector<std::array<std::unique_ptr<amrex::iMultiFab>,3>> m_array;
27-
amrex::Vector<std::array<std::unique_ptr<amrex::iMultiFab>,3>> m_array_lhs;
2827
amrex::Vector<std::unique_ptr<amrex::iMultiFab>> m_scalar;
29-
amrex::Vector<std::unique_ptr<amrex::iMultiFab>> m_scalar_lhs;
3028

3129
warpx::fields::FieldType m_array_type = warpx::fields::FieldType::None;
3230
warpx::fields::FieldType m_scalar_type = warpx::fields::FieldType::None;
@@ -35,6 +33,9 @@ struct WarpXSolverDOF
3533
amrex::Long m_nDoFs_g = 0; /*!< Global nDOF */
3634

3735
void Define ( WarpX* const, int, const std::string&, const std::string&);
36+
37+
void fill_local_dof (amrex::iMultiFab& dof, amrex::iMultiFab const& mask);
38+
void fill_global_dof ();
3839
};
3940

4041
#endif

Source/FieldSolver/ImplicitSolvers/WarpXSolverDOF.cpp

Lines changed: 121 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
#include <ablastr/utils/SignalHandling.H>
1212
#include <ablastr/warn_manager/WarnManager.H>
1313

14+
#include <AMReX_Scan.H>
15+
1416
using warpx::fields::FieldType;
17+
using namespace amrex;
1518

1619
void WarpXSolverDOF::Define ( WarpX* const a_WarpX,
1720
const int a_num_amr_levels,
@@ -41,10 +44,7 @@ void WarpXSolverDOF::Define ( WarpX* const a_WarpX,
4144

4245
m_array.resize(a_num_amr_levels);
4346
m_scalar.resize(a_num_amr_levels);
44-
m_array_lhs.resize(a_num_amr_levels);
45-
m_scalar_lhs.resize(a_num_amr_levels);
4647

47-
amrex::Long offset = 0;
4848
m_nDoFs_l = 0;
4949

5050
// Define the 3D vector field data container
@@ -62,26 +62,9 @@ void WarpXSolverDOF::Define ( WarpX* const a_WarpX,
6262
this_array[n]->DistributionMap(),
6363
2*ncomp, // {local, global} for each comp
6464
this_array[n]->nGrowVect() );
65-
m_nDoFs_g += this_array[n]->boxArray().numPts()*ncomp;
66-
67-
m_array[lev][n]->setVal(-1.0);
68-
amrex::Long offset_mf = 0;
69-
for (amrex::MFIter mfi(*m_array[lev][n]); mfi.isValid(); ++mfi) {
70-
auto bx = mfi.tilebox();
71-
auto dof_arr = m_array[lev][n]->array(mfi);
72-
ParallelFor( bx, [=] AMREX_GPU_DEVICE (int i, int j, int k)
73-
{
74-
for (int v = 0; v < ncomp; v++) {
75-
dof_arr(i,j,k,2*v) = bx.index(amrex::IntVect(AMREX_D_DECL(i, j, k))) * ncomp
76-
+ v
77-
+ offset_mf
78-
+ offset;
79-
}
80-
});
81-
offset_mf += bx.numPts()*ncomp;
82-
}
83-
offset += offset_mf;
84-
m_nDoFs_l += offset_mf;
65+
66+
auto* mask = a_WarpX->getFieldDotMaskPointer(m_array_type, lev, ablastr::fields::Direction{n});
67+
fill_local_dof(*m_array[lev][n], *mask);
8568
}
8669
}
8770

@@ -101,106 +84,137 @@ void WarpXSolverDOF::Define ( WarpX* const a_WarpX,
10184
this_mf->DistributionMap(),
10285
2*ncomp, // {local, global} for each comp
10386
this_mf->nGrowVect() );
104-
m_nDoFs_g += this_mf->boxArray().numPts()*ncomp;
105-
106-
m_scalar[lev]->setVal(-1.0);
107-
amrex::Long offset_mf = 0;
108-
for (amrex::MFIter mfi(*m_scalar[lev]); mfi.isValid(); ++mfi) {
109-
auto bx = mfi.tilebox();
110-
auto dof_arr = m_scalar[lev]->array(mfi);
111-
ParallelFor( bx, [=] AMREX_GPU_DEVICE (int i, int j, int k)
112-
{
113-
for (int v = 0; v < ncomp; v++) {
114-
dof_arr(i,j,k,2*v) = bx.index(amrex::IntVect(AMREX_D_DECL(i, j, k))) * ncomp
115-
+ v
116-
+ offset_mf
117-
+ offset;
118-
}
119-
});
120-
offset_mf += bx.numPts()*ncomp;
121-
}
122-
offset += offset_mf;
123-
m_nDoFs_l += offset_mf;
87+
88+
auto* mask = a_WarpX->getFieldDotMaskPointer(m_scalar_type, lev, ablastr::fields::Direction{0});
89+
fill_local_dof(*m_scalar[lev], *mask);
12490
}
12591

12692
}
12793

128-
auto nDoFs_g = m_nDoFs_l;
129-
amrex::ParallelDescriptor::ReduceLongSum(&nDoFs_g,1);
130-
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(
131-
m_nDoFs_g == nDoFs_g,
132-
"WarpXSolverDOF::Define(): something has gone wrong in DoF counting");
133-
134-
auto num_procs = amrex::ParallelDescriptor::NProcs();
135-
auto my_proc = amrex::ParallelDescriptor::MyProc();
136-
amrex::Vector<int> dof_proc_arr(num_procs,0);
137-
dof_proc_arr[my_proc] = m_nDoFs_l;
138-
amrex::ParallelDescriptor::ReduceIntSum(dof_proc_arr.data(), num_procs);
94+
fill_global_dof();
13995

140-
int offset_global = 0;
141-
for (int i = 0; i < my_proc; i++) { offset_global += dof_proc_arr[i]; }
142-
143-
if (m_array_type != FieldType::None) {
144-
for (int lev = 0; lev < a_num_amr_levels; ++lev) {
145-
const ablastr::fields::VectorField this_array = a_WarpX->m_fields.get_alldirs(a_vector_type_name, lev);
146-
for (int n = 0; n < 3; n++) {
147-
auto ncomp = this_array[n]->nComp();
148-
for (amrex::MFIter mfi(*m_array[lev][n]); mfi.isValid(); ++mfi) {
149-
auto bx = mfi.tilebox();
150-
auto dof_arr = m_array[lev][n]->array(mfi);
151-
ParallelFor( bx, [=] AMREX_GPU_DEVICE (int i, int j, int k)
152-
{
153-
for (int v = 0; v < ncomp; v++) {
154-
dof_arr(i,j,k,2*v+1) = dof_arr(i,j,k,2*v) + offset_global;
155-
}
156-
});
96+
for (int lev = 0; lev < a_num_amr_levels; ++lev) {
97+
for (int n = 0; n < 3; n++) {
98+
if (auto* dof = m_array[lev][n].get()) {
99+
for (int comp = 1; comp < dof->nComp(); comp += 2) { // Only call this on global id
100+
dof->FillBoundaryAndSync(comp, 1, dof->nGrowVect(), a_WarpX->Geom(lev).periodicity());
157101
}
158102
}
159103
}
104+
if (auto* dof = m_scalar[lev].get()) {
105+
for (int comp = 1; comp < dof->nComp(); comp += 2) { // Only call this on global id
106+
dof->FillBoundaryAndSync(comp, 1, dof->nGrowVect(), a_WarpX->Geom(lev).periodicity());
107+
}
108+
}
160109
}
161-
if (m_scalar_type != FieldType::None) {
162-
for (int lev = 0; lev < a_num_amr_levels; ++lev) {
163-
const amrex::MultiFab* this_mf = a_WarpX->m_fields.get(a_scalar_type_name,lev);
164-
auto ncomp = this_mf->nComp();
165-
for (amrex::MFIter mfi(*m_scalar[lev]); mfi.isValid(); ++mfi) {
166-
auto bx = mfi.tilebox();
167-
auto dof_arr = m_scalar[lev]->array(mfi);
168-
ParallelFor( bx, [=] AMREX_GPU_DEVICE (int i, int j, int k)
169-
{
170-
for (int v = 0; v < ncomp; v++) {
171-
dof_arr(i,j,k,2*v+1) = dof_arr(i,j,k,2*v) + offset_global;
110+
111+
amrex::Print() << "Defined DOF object for linear solves (total DOFs = " << m_nDoFs_g << ").\n";
112+
}
113+
114+
void WarpXSolverDOF::fill_local_dof (iMultiFab& dof, iMultiFab const& mask)
115+
{
116+
int ncomp = dof.nComp() / 2; // /2 because both local and global ids are stored in dof
117+
118+
AMREX_ALWAYS_ASSERT(dof.boxArray().numPts()*ncomp < static_cast<Long>(std::numeric_limits<int>::max()));
119+
120+
dof.setVal(std::numeric_limits<int>::lowest());
121+
122+
#ifdef AMREX_USE_MPI
123+
int nprocs = ParallelDescriptor::NProcs();
124+
#endif
125+
126+
for (MFIter mfi(dof); mfi.isValid(); ++mfi) {
127+
Box const& vbx = mfi.validbox();
128+
int npts = vbx.numPts();
129+
BoxIndexer boxindex(vbx);
130+
auto const& m = mask.const_array(mfi);
131+
auto const& d = dof.array(mfi);
132+
auto start_id = m_nDoFs_l;
133+
auto ndofs = Scan::PrefixSum<int>(
134+
npts,
135+
[=] AMREX_GPU_DEVICE (int offset) -> int
136+
{
137+
auto [i,j,k] = boxindex(offset);
138+
return m(i,j,k) ? 1 : 0;
139+
},
140+
[=] AMREX_GPU_DEVICE (int offset, int ps)
141+
{
142+
auto [i,j,k] = boxindex(offset);
143+
if (m(i,j,k)) {
144+
d(i,j,k,0) = ps + start_id;
145+
#ifdef AMREX_USE_MPI
146+
if (nprocs == 1)
147+
#endif
148+
{
149+
d(i,j,k,1) = ps + start_id;
172150
}
173-
});
174-
}
151+
}
152+
},
153+
Scan::Type::exclusive, Scan::retSum);
154+
if (ncomp > 1) {
155+
ParallelFor(vbx, ncomp-1, [=] AMREX_GPU_DEVICE (int i, int j, int k, int n)
156+
{
157+
if (m(i,j,k)) {
158+
d(i,j,k,2*(n+1)) = d(i,j,k,0) + ndofs*(n+1);
159+
#ifdef AMREX_USE_MPI
160+
if (nprocs == 1)
161+
#endif
162+
{
163+
d(i,j,k,2*(n+1)+1) = d(i,j,k,0) + ndofs*(n+1);
164+
}
165+
}
166+
});
175167
}
168+
m_nDoFs_l += Long(ndofs)*ncomp;
176169
}
170+
}
177171

178-
if (m_array_type != FieldType::None) {
179-
for (int lev = 0; lev < a_num_amr_levels; ++lev) {
180-
const auto& geom = a_WarpX->Geom(lev);
181-
for (int n = 0; n < 3; n++) {
182-
m_array_lhs[lev][n] = std::make_unique<amrex::iMultiFab>(m_array[lev][n]->boxArray(),
183-
m_array[lev][n]->DistributionMap(),
184-
m_array[lev][n]->nComp(),
185-
0 );
186-
amrex::iMultiFab::Copy(*m_array_lhs[lev][n], *m_array[lev][n], 0, 0, m_array[lev][n]->nComp(), 0);
187-
m_array[lev][n]->FillBoundary(geom.periodicity());
188-
// do NOT call FillBoundary() on m_array_lhs
172+
void WarpXSolverDOF::fill_global_dof ()
173+
{
174+
#ifndef AMREX_USE_MPI
175+
m_nDoFs_g = m_nDoFs_l;
176+
#else
177+
int nprocs = ParallelDescriptor::NProcs();
178+
if (nprocs == 1) {
179+
m_nDoFs_g = m_nDoFs_l;
180+
} else {
181+
Vector<Long> ndofs_allprocs(nprocs);
182+
MPI_Allgather(&m_nDoFs_l, 1, ParallelDescriptor::Mpi_typemap<Long>::type(),
183+
ndofs_allprocs.data(), 1, ParallelDescriptor::Mpi_typemap<Long>::type(),
184+
ParallelDescriptor::Communicator());
185+
Long proc_begin = 0;
186+
int myproc = ParallelDescriptor::MyProc();
187+
m_nDoFs_g = 0;
188+
for (int iproc = 0; iproc < nprocs; ++iproc) {
189+
if (iproc < myproc) {
190+
proc_begin += ndofs_allprocs[iproc];
189191
}
192+
m_nDoFs_g += ndofs_allprocs[iproc];
190193
}
191-
}
192-
if (m_scalar_type != FieldType::None) {
193-
for (int lev = 0; lev < a_num_amr_levels; ++lev) {
194-
m_scalar_lhs[lev] = std::make_unique<amrex::iMultiFab>(m_scalar[lev]->boxArray(),
195-
m_scalar[lev]->DistributionMap(),
196-
m_scalar[lev]->nComp(),
197-
0 );
198-
amrex::iMultiFab::Copy(*m_scalar_lhs[lev], *m_scalar[lev], 0, 0, m_scalar[lev]->nComp(), 0);
199-
const auto& geom = a_WarpX->Geom(lev);
200-
m_scalar[lev]->FillBoundary(geom.periodicity());
201-
// do NOT call FillBoundary() on m_scalar_lhs
194+
for (auto& x : m_array) {
195+
for (auto& y : x) {
196+
if (y) {
197+
auto const& dof = y->arrays();
198+
auto ncomp = y->nComp() / 2;
199+
ParallelFor(*y, IntVect(0), ncomp, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k, int n)
200+
{
201+
dof[b](i,j,k,2*n+1) = dof[b](i,j,k,2*n) + int(proc_begin);
202+
});
203+
}
204+
}
205+
}
206+
for (auto& x : m_scalar) {
207+
if (x) {
208+
auto const& dof = x->arrays();
209+
auto ncomp = x->nComp() / 2;
210+
ParallelFor(*x, IntVect(0), ncomp, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k, int n)
211+
{
212+
dof[b](i,j,k,2*n+1) = dof[b](i,j,k,2*n) + int(proc_begin);
213+
});
214+
}
202215
}
216+
Gpu::streamSynchronize();
203217
}
218+
#endif
204219

205-
amrex::Print() << "Defined DOF object for linear solves (total DOFs = " << m_nDoFs_g << ").\n";
206220
}

Source/FieldSolver/ImplicitSolvers/WarpXSolverVec.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,13 @@ void WarpXSolverVec::copyFrom ( const amrex::Real* const a_arr)
160160
{
161161
for (int v = 0; v < ncomp; v++) {
162162
int dof = dof_arr(i,j,k,2*v); // local
163-
data_arr(i,j,k,v) = a_arr[dof];
163+
if (dof >= 0) {
164+
data_arr(i,j,k,v) = a_arr[dof];
165+
}
164166
}
165167
});
166168
}
169+
m_array_vec[lev][n]->FillBoundaryAndSync(m_WarpX->Geom(lev).periodicity());
167170
}
168171
}
169172
if (m_scalar_type != FieldType::None) {
@@ -176,10 +179,13 @@ void WarpXSolverVec::copyFrom ( const amrex::Real* const a_arr)
176179
{
177180
for (int v = 0; v < ncomp; v++) {
178181
int dof = dof_arr(i,j,k,2*v); // local
179-
data_arr(i,j,k,v) = a_arr[dof];
182+
if (dof >= 0) {
183+
data_arr(i,j,k,v) = a_arr[dof];
184+
}
180185
}
181186
});
182187
}
188+
m_scalar_vec[lev]->FillBoundaryAndSync(m_WarpX->Geom(lev).periodicity());
183189
}
184190
}
185191
}
@@ -205,7 +211,9 @@ void WarpXSolverVec::copyTo ( amrex::Real* const a_arr) const
205211
{
206212
for (int v = 0; v < ncomp; v++) {
207213
int dof = dof_arr(i,j,k,2*v); // local
208-
a_arr[dof] = data_arr(i,j,k,v);
214+
if (dof >= 0) {
215+
a_arr[dof] = data_arr(i,j,k,v);
216+
}
209217
}
210218
});
211219
}
@@ -221,7 +229,9 @@ void WarpXSolverVec::copyTo ( amrex::Real* const a_arr) const
221229
{
222230
for (int v = 0; v < ncomp; v++) {
223231
int dof = dof_arr(i,j,k,2*v); // local
224-
a_arr[dof] = data_arr(i,j,k,v);
232+
if (dof >= 0) {
233+
a_arr[dof] = data_arr(i,j,k,v);
234+
}
225235
}
226236
});
227237
}

0 commit comments

Comments
 (0)