Skip to content

Commit c8b1d21

Browse files
Heuristic load balance based on number of particles and number of cells (#737)
* Issue #713: Load Balance w.r.t. number of cells + particles README: Update Badged to `master` (#725) * README: Update Badged to `master` Update the badges to check the `master` branch as development branch. * Docs: More dev->master updates - release workflow - contribution guide link Galilean PSATD with shift (#704) * Read Galilean velocity * Prepare structures for Galilean solver * Started implementing Galilean equations * Analytical limits for X1, X2, X3, X4 coefficients added * Slight changes added * Added Galilean position pusher * Scale galilean velocity * Remove unneeded Abort * Fix Galilean pusher * Allocate Theta2 array * Fix definition of coefficients * Increase guard cells for Galilean * Add guard cell in particle exchange * Type corrected * v_gal added to warpx_current_deposition * v_gal added to WarpXParticleContainer.H * Bug fixed - update particle x-position over one time step * Fix issues with merge from dev * Preparation for merging dev into galilean. * Adding galilean shift * Implemented galilean shift * Changed method's name from GalileanShift to ShiftGalileanBoundary * Added doxygen string for ShiftGalileanBoundary * Removed never used method LowerCornerWithCentering * Removed temporary comments * Removed dt as a variable from DepositCharge method and its dependencies * Converted tab to spaces * Removed EOL white space * Add documentation and automated tests * Fix compilation error * Add automated test * Update automated test * Removed temporary used galilean shift * Removed temporary used particle's push for Galilean PSATD * Removed unused statement * Remove EOL white space. * Added zero shift for LowerCorner in RZ geometry * Minor changes to Galilean implementation * Modifications for GPU * Fix typo Co-authored-by: Remi Lehe <[email protected]> [mini-PR] when a cufft error occurs, print a meaningful error message (#728) * added method to translate cufft errors * fixed style * bug fixing avoid duplicate tests and plot less often (#726) * avoid duplicate tests and plot less often * fix tests I broke when trying to save plotfiles doc install yt on Summit (#729) * doc install yt on Summit * eol Do not use local Redistribute for electrostatic solver (#731) Add Reset Random Seed Feature (#717) * Add ResetRandomSeed * Add doc * Modify and change location of the code. * Small fix * Try to fix an alert * Try to fix an alert * Modify based on suggestions * Use INT_MAX * Modify based on suggestions. * Modify based on suggestions. openPMD: warn if step is already written (#718) * making sure iterations are written at most once. * prints a warning when iteration is written more than once writting is not stopped * Fixed tabs * included <iostream> as requested by Axel Minor refactoring of space-charge calculation (#732) Improve clarity and documentation Minor Update GNUmakefile Update based on comments Update GNU Makefile Formatting Formatting Formatting Formatting Remove unneeded function Removed unneeded function Formatting Formatting Formatting Whitespace Minor Formatting Formatting Formatting Formatting Formatting Formatting Formatting whitespace Formatting Minor Formatting Remove unneeded template function Change import Minor Formatting Remove unused variable Formatting Update Source/WarpX.H Co-Authored-By: MaxThevenet <[email protected]> Update Source/Parallelization/WarpXRegrid.cpp Co-Authored-By: MaxThevenet <[email protected]> Remove `n_particles` and `n_cells` Update Source/WarpX.H Co-Authored-By: MaxThevenet <[email protected]> Revert clear costs in case of edge case Update to use new load_balance_api in AMReX Tabs Minor * minor: indentation in Source/WarpX.H Co-authored-by: MaxThevenet <[email protected]>
1 parent 04bd921 commit c8b1d21

File tree

12 files changed

+241
-50
lines changed

12 files changed

+241
-50
lines changed

Source/Diagnostics/FieldIO.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,11 @@ WarpX::AverageAndPackFields ( Vector<std::string>& varnames,
506506
amrex::Vector<MultiFab>& mf_avg, const int ngrow) const
507507
{
508508
// Count how many different fields should be written (ncomp)
509+
MultiFab* cost = WarpX::getCosts(0);
509510
int ncomp = fields_to_plot.size()
510511
+ static_cast<int>(plot_finepatch)*6
511512
+ static_cast<int>(plot_crsepatch)*6
512-
+ static_cast<int>(costs[0] != nullptr and plot_costs);
513+
+ static_cast<int>(cost != nullptr and plot_costs);
513514

514515
// Add in the RZ modes
515516
if (n_rz_azimuthal_modes > 1) {
@@ -723,13 +724,15 @@ WarpX::AverageAndPackFields ( Vector<std::string>& varnames,
723724
dcomp += 3;
724725
}
725726

726-
if (costs[0] != nullptr and plot_costs)
727+
if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
727728
{
728-
AverageAndPackScalarField( mf_avg[lev], *costs[lev], dcomp, ngrow );
729-
if(lev==0) varnames.push_back("costs");
730-
dcomp += 1;
729+
if (costs[0] != nullptr and plot_costs)
730+
{
731+
AverageAndPackScalarField( mf_avg[lev], *costs[lev], dcomp, ngrow );
732+
if(lev==0) varnames.push_back("costs");
733+
dcomp += 1;
734+
}
731735
}
732-
733736
BL_ASSERT(dcomp == ncomp);
734737
} // end loop over levels of refinement
735738

Source/Diagnostics/WarpXIO.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,11 @@ WarpX::WriteCheckPointFile() const
187187
pml[lev]->CheckPoint(amrex::MultiFabFileFullPrefix(lev, checkpointname, level_prefix, "pml"));
188188
}
189189

190-
if (costs[lev]) {
191-
VisMF::Write(*costs[lev],
192-
amrex::MultiFabFileFullPrefix(lev, checkpointname, level_prefix, "costs"));
190+
if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) {
191+
if (costs[lev]) {
192+
VisMF::Write(*costs[lev],
193+
amrex::MultiFabFileFullPrefix(lev, checkpointname, level_prefix, "costs"));
194+
}
193195
}
194196
}
195197

@@ -382,13 +384,15 @@ WarpX::InitFromCheckpoint ()
382384
}
383385
}
384386

385-
if (costs[lev]) {
386-
const auto& cost_mf_name =
387+
if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) {
388+
if (costs[lev]) {
389+
const auto& cost_mf_name =
387390
amrex::MultiFabFileFullPrefix(lev, restart_chkfile, level_prefix, "costs");
388-
if (VisMF::Exist(cost_mf_name)) {
389-
VisMF::Read(*costs[lev], cost_mf_name);
390-
} else {
391-
costs[lev]->setVal(0.0);
391+
if (VisMF::Exist(cost_mf_name)) {
392+
VisMF::Read(*costs[lev], cost_mf_name);
393+
} else {
394+
costs[lev]->setVal(0.0);
395+
}
392396
}
393397
}
394398
}

Source/Evolve/WarpXEvolveEM.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ WarpX::EvolveEM (int numsteps)
3838
static int last_check_file_step = 0;
3939
static int last_insitu_step = 0;
4040

41-
if (do_compute_max_step_from_zmax){
41+
if (do_compute_max_step_from_zmax) {
4242
computeMaxStepBoostAccelerator(geom[0]);
4343
}
4444

@@ -61,25 +61,36 @@ WarpX::EvolveEM (int numsteps)
6161
if (warpx_py_beforestep) warpx_py_beforestep();
6262
#endif
6363

64-
if (costs[0] != nullptr)
65-
{
64+
MultiFab* cost = WarpX::getCosts(0);
65+
amrex::Vector<amrex::Real>* cost_heuristic = WarpX::getCostsHeuristic(0);
66+
if (cost != nullptr || cost_heuristic != nullptr) {
6667
#ifdef WARPX_USE_PSATD
6768
amrex::Abort("LoadBalance for PSATD: TODO");
6869
#endif
69-
7070
if (step > 0 && (step+1) % load_balance_int == 0)
7171
{
7272
LoadBalance();
7373
// Reset the costs to 0
74-
for (int lev = 0; lev <= finest_level; ++lev) {
75-
costs[lev]->setVal(0.0);
74+
for (int lev = 0; lev <= finest_level; ++lev)
75+
{
76+
if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
77+
{
78+
costs[lev]->setVal(0.0);
79+
} else if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Heuristic)
80+
{
81+
costs_heuristic[lev]->assign((*costs_heuristic[lev]).size(), 0.0);
82+
}
7683
}
7784
}
78-
79-
for (int lev = 0; lev <= finest_level; ++lev) {
80-
// Perform running average of the costs
81-
// (Giving more importance to most recent costs)
82-
(*costs[lev].get()).mult( (1. - 2./load_balance_int) );
85+
for (int lev = 0; lev <= finest_level; ++lev)
86+
{
87+
MultiFab* cost = WarpX::getCosts(lev);
88+
if (cost)
89+
{
90+
// Perform running average of the costs
91+
// (Giving more importance to most recent costs)
92+
cost->mult( (1. - 2./load_balance_int) );
93+
}
8394
}
8495
}
8596

@@ -92,13 +103,13 @@ WarpX::EvolveEM (int numsteps)
92103
FillBoundaryB(guard_cells.ng_alloc_EB, guard_cells.ng_Extra);
93104
UpdateAuxilaryData();
94105
// on first step, push p by -0.5*dt
95-
for (int lev = 0; lev <= finest_level; ++lev) {
106+
for (int lev = 0; lev <= finest_level; ++lev)
107+
{
96108
mypc->PushP(lev, -0.5*dt[lev],
97109
*Efield_aux[lev][0],*Efield_aux[lev][1],*Efield_aux[lev][2],
98110
*Bfield_aux[lev][0],*Bfield_aux[lev][1],*Bfield_aux[lev][2]);
99111
}
100112
is_synchronized = false;
101-
102113
} else {
103114
// Beyond one step, we have E^{n} and B^{n}.
104115
// Particles have p^{n-1/2} and x^{n}.

Source/FieldSolver/WarpXPushFieldsEM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ WarpX::EvolveE (int lev, PatchType patch_type, amrex::Real a_dt)
262262
F = F_cp[lev].get();
263263
}
264264

265-
MultiFab* cost = costs[lev].get();
265+
MultiFab* cost = WarpX::getCosts(lev);
266266
const IntVect& rr = (lev > 0) ? refRatio(lev-1) : IntVect::TheUnitVector();
267267

268268
// xmin is only used by the kernel for cylindrical geometry,

Source/FieldSolver/WarpX_QED_Field_Pushers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ WarpX::Hybrid_QED_Push (int lev, PatchType patch_type, Real a_dt)
8383
Bz = Bfield_cp[lev][2].get();
8484
}
8585

86-
MultiFab* cost = costs[lev].get();
86+
MultiFab* cost = WarpX::getCosts(lev);
8787
const IntVect& rr = (lev > 0) ? refRatio(lev-1) : IntVect::TheUnitVector();
8888

8989
// xmin is only used by the kernel for cylindrical geometry,

Source/Initialization/WarpXInitData.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "Filter/NCIGodfreyFilter.H"
1313
#include "Parser/GpuParser.H"
1414
#include "Utils/WarpXUtil.H"
15+
#include "Utils/WarpXAlgorithmSelection.H"
1516

1617
#include <AMReX_ParallelDescriptor.H>
1718
#include <AMReX_ParmParse.H>
@@ -420,8 +421,16 @@ WarpX::InitLevelData (int lev, Real /*time*/)
420421
rho_cp[lev]->setVal(0.0);
421422
}
422423

423-
if (costs[lev]) {
424-
costs[lev]->setVal(0.0);
424+
if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) {
425+
if (costs[lev]) {
426+
costs[lev]->setVal(0.0);
427+
}
428+
} else if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Heuristic) {
429+
if (costs_heuristic[lev]) {
430+
std::fill((*costs_heuristic[lev]).begin(),
431+
(*costs_heuristic[lev]).end(),
432+
0.0);
433+
}
425434
}
426435
}
427436

Source/Parallelization/WarpXRegrid.cpp

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* License: BSD-3-Clause-LBNL
88
*/
99
#include <WarpX.H>
10+
#include <WarpXAlgorithmSelection.H>
1011
#include <AMReX_BLProfiler.H>
1112

1213
using namespace amrex;
@@ -17,6 +18,18 @@ WarpX::LoadBalance ()
1718
WARPX_PROFILE_REGION("LoadBalance");
1819
WARPX_PROFILE("WarpX::LoadBalance()");
1920

21+
if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
22+
{
23+
LoadBalanceTimers();
24+
} else if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Heuristic)
25+
{
26+
LoadBalanceHeuristic();
27+
}
28+
}
29+
30+
void
31+
WarpX::LoadBalanceTimers ()
32+
{
2033
AMREX_ALWAYS_ASSERT(costs[0] != nullptr);
2134

2235
const int nLevels = finestLevel();
@@ -30,7 +43,36 @@ WarpX::LoadBalance ()
3043
: DistributionMapping::makeKnapSack(*costs[lev], nmax);
3144
RemakeLevel(lev, t_new[lev], boxArray(lev), newdm);
3245
}
46+
mypc->Redistribute();
47+
}
48+
49+
void
50+
WarpX::LoadBalanceHeuristic ()
51+
{
52+
AMREX_ALWAYS_ASSERT(costs_heuristic[0] != nullptr);
53+
WarpX::ComputeCostsHeuristic();
54+
55+
const int nLevels = finestLevel();
56+
for (int lev = 0; lev <= nLevels; ++lev)
57+
{
58+
#ifdef AMREX_USE_MPI
59+
// Parallel reduce the costs_heurisitc
60+
amrex::Vector<Real>::iterator it = (*costs_heuristic[lev]).begin();
61+
amrex::Real* itAddr = &(*it);
62+
ParallelAllReduce::Sum(itAddr,
63+
costs_heuristic[lev]->size(),
64+
ParallelContext::CommunicatorSub());
65+
#endif
66+
const amrex::Real nboxes = costs_heuristic[lev]->size();
67+
const amrex::Real nprocs = ParallelContext::NProcsSub();
68+
const int nmax = static_cast<int>(std::ceil(nboxes/nprocs*load_balance_knapsack_factor));
3369

70+
const DistributionMapping newdm = (load_balance_with_sfc)
71+
? DistributionMapping::makeSFC(*costs_heuristic[lev], boxArray(lev), false)
72+
: DistributionMapping::makeKnapSack(*costs_heuristic[lev], nmax);
73+
74+
RemakeLevel(lev, t_new[lev], boxArray(lev), newdm);
75+
}
3476
mypc->Redistribute();
3577
}
3678

@@ -47,7 +89,6 @@ WarpX::RemakeLevel (int lev, Real /*time*/, const BoxArray& ba, const Distributi
4789
#endif // WARPX_DO_ELECTROSTATIC
4890

4991
// Fine patch
50-
5192
const auto& period = Geom(lev).periodicity();
5293
for (int idim=0; idim < 3; ++idim)
5394
{
@@ -98,7 +139,6 @@ WarpX::RemakeLevel (int lev, Real /*time*/, const BoxArray& ba, const Distributi
98139
}
99140

100141
// Aux patch
101-
102142
if (lev == 0 && Bfield_aux[0][0]->ixType() == Bfield_fp[0][0]->ixType())
103143
{
104144
for (int idim = 0; idim < 3; ++idim) {
@@ -223,15 +263,57 @@ WarpX::RemakeLevel (int lev, Real /*time*/, const BoxArray& ba, const Distributi
223263
}
224264
}
225265

226-
if (costs[lev] != nullptr) {
227-
costs[lev].reset(new MultiFab(costs[lev]->boxArray(), dm, 1, 0));
228-
costs[lev]->setVal(0.0);
266+
if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
267+
{
268+
if (costs[lev] != nullptr)
269+
{
270+
costs[lev].reset(new MultiFab(costs[lev]->boxArray(), dm, 1, 0));
271+
costs[lev]->setVal(0.0);
272+
}
273+
} else if (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Heuristic)
274+
{
275+
if (costs_heuristic[lev] != nullptr)
276+
{
277+
costs_heuristic[lev].reset(new amrex::Vector<Real>);
278+
const int nboxes = Efield_fp[lev][0].get()->size();
279+
costs_heuristic[lev]->resize(nboxes, 0.0); // Initializes to 0.0?
280+
}
229281
}
230282

231283
SetDistributionMap(lev, dm);
232-
}
233-
else
284+
285+
} else
234286
{
235287
amrex::Abort("RemakeLevel: to be implemented");
236288
}
237289
}
290+
291+
void
292+
WarpX::ComputeCostsHeuristic ()
293+
{
294+
for (int lev = 0; lev <= finest_level; ++lev)
295+
{
296+
auto & mypc = WarpX::GetInstance().GetPartContainer();
297+
auto nSpecies = mypc.nSpecies();
298+
299+
// Species loop
300+
for (int i_s = 0; i_s < nSpecies; ++i_s)
301+
{
302+
auto & myspc = mypc.GetParticleContainer(i_s);
303+
304+
// Particle loop
305+
for (WarpXParIter pti(myspc, lev); pti.isValid(); ++pti)
306+
{
307+
(*costs_heuristic[lev])[pti.index()] += costs_heuristic_particles_wt*pti.numParticles();
308+
}
309+
}
310+
311+
//Cell loop
312+
MultiFab* Ex = Efield_fp[lev][0].get();
313+
for (MFIter mfi(*Ex, false); mfi.isValid(); ++mfi)
314+
{
315+
const Box& gbx = mfi.growntilebox();
316+
(*costs_heuristic[lev])[mfi.index()] += costs_heuristic_cells_wt*gbx.numPts();
317+
}
318+
} // for (int lev ...)
319+
} // WarpX::ComputeCostsHeuristic

Source/Particles/PhysicalParticleContainer.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,7 @@ void PhysicalParticleContainer::InitData()
156156
// Init ionization module here instead of in the PhysicalParticleContainer
157157
// constructor because dt is required
158158
if (do_field_ionization) {InitIonizationModule();}
159-
160159
AddParticles(0); // Note - add on level 0
161-
162160
Redistribute(); // We then redistribute
163161
}
164162

@@ -1135,6 +1133,7 @@ PhysicalParticleContainer::Evolve (int lev,
11351133
BL_ASSERT(OnSameGrids(lev,jx));
11361134

11371135
MultiFab* cost = WarpX::getCosts(lev);
1136+
11381137
const iMultiFab* current_masks = WarpX::CurrentBufferMasks(lev);
11391138
const iMultiFab* gather_masks = WarpX::GatherBufferMasks(lev);
11401139

Source/Utils/WarpXAlgorithmSelection.H

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ struct GatheringAlgo {
4747
};
4848
};
4949

50+
/** Strategy to compute weights for use in load balance.
51+
*/
52+
struct LoadBalanceCostsUpdateAlgo {
53+
enum {
54+
Timers = 0, //!< load balance according to in-code timer-based weights (i.e., with `costs`)
55+
Heuristic = 1 /**< load balance according to weights computed from number of cells
56+
and number of particles per box (i.e., with `costs_heuristic`)*/
57+
};
58+
};
59+
5060
int
5161
GetAlgorithmInteger( amrex::ParmParse& pp, const char* pp_search_key );
5262

Source/Utils/WarpXAlgorithmSelection.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ const std::map<std::string, int> gathering_algo_to_int = {
5252
{"default", GatheringAlgo::EnergyConserving }
5353
};
5454

55+
const std::map<std::string, int> load_balance_costs_update_algo_to_int = {
56+
{"timers", LoadBalanceCostsUpdateAlgo::Timers },
57+
{"heuristic", LoadBalanceCostsUpdateAlgo::Heuristic },
58+
{"default", LoadBalanceCostsUpdateAlgo::Timers }
59+
};
60+
5561

5662
int
5763
GetAlgorithmInteger( amrex::ParmParse& pp, const char* pp_search_key ){
@@ -74,13 +80,15 @@ GetAlgorithmInteger( amrex::ParmParse& pp, const char* pp_search_key ){
7480
algo_to_int = charge_deposition_algo_to_int;
7581
} else if (0 == std::strcmp(pp_search_key, "field_gathering")) {
7682
algo_to_int = gathering_algo_to_int;
83+
} else if (0 == std::strcmp(pp_search_key, "load_balance_costs_update")) {
84+
algo_to_int = load_balance_costs_update_algo_to_int;
7785
} else {
7886
std::string pp_search_string = pp_search_key;
7987
amrex::Abort("Unknown algorithm type: " + pp_search_string);
8088
}
8189

8290
// Check if the user-input is a valid key for the dictionary
83-
if (algo_to_int.count(algo) == 0){
91+
if (algo_to_int.count(algo) == 0) {
8492
// Not a valid key ; print error message
8593
std::string pp_search_string = pp_search_key;
8694
std::string error_message = "Invalid string for algo." + pp_search_string

0 commit comments

Comments
 (0)