Skip to content

Commit c5370f9

Browse files
committed
[SPH] allow stopping evolve until according to walltime
1 parent e8c0a9b commit c5370f9

4 files changed

Lines changed: 120 additions & 8 deletions

File tree

src/shammodels/sph/include/shammodels/sph/Model.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -956,8 +956,9 @@ namespace shammodels::sph {
956956
solver.print_timestep_logs();
957957
}
958958

959-
inline bool evolve_until(Tscal target_time, i32 niter_max) {
960-
return solver.evolve_until(target_time, niter_max);
959+
inline EvolveUntilResults evolve_until(
960+
Tscal target_time, i32 niter_max, f64 max_global_walltime = -1) {
961+
return solver.evolve_until(target_time, niter_max, max_global_walltime);
961962
}
962963

963964
private:

src/shammodels/sph/include/shammodels/sph/Solver.hpp

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "shamsys/legacy/log.hpp"
3232
#include "shamtree/TreeTraversalCache.hpp"
3333
#include <functional>
34+
#include <limits>
3435
#include <memory>
3536
#include <optional>
3637
#include <stdexcept>
@@ -51,6 +52,14 @@ namespace shammodels::sph {
5152
inline f64 tcompute_max() { return shamalgs::collective::allreduce_max(tcompute); }
5253
};
5354

55+
struct EvolveUntilResults {
56+
bool reach_target_time;
57+
bool reach_niter_max;
58+
bool reach_max_walltime;
59+
60+
i32 iter_count;
61+
};
62+
5463
/**
5564
* @brief The shamrock SPH model
5665
*
@@ -236,7 +245,16 @@ namespace shammodels::sph {
236245
return solver_config.get_dt_sph();
237246
}
238247

239-
inline bool evolve_until(Tscal target_time, i32 niter_max) {
248+
inline EvolveUntilResults evolve_until(
249+
Tscal target_time, i32 niter_max, f64 max_walltime = -1) {
250+
251+
f64 start_wall_time
252+
= shamalgs::collective::allreduce_max(shambase::details::get_wtime());
253+
254+
const bool walltime_limit_active = (max_walltime != -1);
255+
i32 next_walltime_check_iter
256+
= walltime_limit_active ? 1 : std::numeric_limits<i32>::max();
257+
240258
auto step = [&]() {
241259
Tscal dt = solver_config.get_dt_sph();
242260
Tscal t = solver_config.get_time();
@@ -260,13 +278,81 @@ namespace shammodels::sph {
260278

261279
if ((iter_count >= niter_max) && (niter_max != -1)) {
262280
logger::info_ln("SPH", "stopping evolve until because of niter =", iter_count);
263-
return false;
281+
return {
282+
.reach_target_time = false,
283+
.reach_niter_max = true,
284+
.reach_max_walltime = false,
285+
.iter_count = iter_count,
286+
};
287+
}
288+
289+
if (walltime_limit_active && iter_count >= next_walltime_check_iter) {
290+
f64 global_walltime
291+
= shamalgs::collective::allreduce_max(shambase::details::get_wtime());
292+
293+
if ((max_walltime != -1) && (global_walltime >= max_walltime)) {
294+
logger::info_ln(
295+
"SPH",
296+
shambase::format(
297+
"stopping evolve until because of "
298+
"max_walltime = {:.2f}s > {:.2f}s",
299+
global_walltime,
300+
max_walltime));
301+
return {
302+
.reach_target_time = false,
303+
.reach_niter_max = false,
304+
.reach_max_walltime = true,
305+
.iter_count = iter_count,
306+
};
307+
}
308+
309+
f64 sec_per_iter
310+
= (global_walltime - start_wall_time) / static_cast<f64>(iter_count);
311+
312+
auto get_remaining_iters = [&](f64 delta_walltime, f64 factor) -> i32 {
313+
if (sec_per_iter > 0) {
314+
f64 tmp = factor * delta_walltime / sec_per_iter;
315+
if (tmp > std::numeric_limits<i32>::max()) {
316+
return std::numeric_limits<i32>::max();
317+
}
318+
return static_cast<i32>(tmp);
319+
}
320+
return 1000; // default to 1000 iterations if sec_per_iter is 0
321+
};
322+
323+
i32 iters_to_next_check = std::numeric_limits<i32>::max();
324+
325+
i32 next_walltime_type = 0; // 0 for walltime, 1 for global walltime
326+
327+
if (max_walltime != -1) {
328+
i32 iters_to_limit
329+
= get_remaining_iters(max_walltime - global_walltime, 0.1);
330+
if (iters_to_limit < iters_to_next_check) {
331+
next_walltime_type = 1;
332+
iters_to_next_check = iters_to_limit;
333+
}
334+
}
335+
336+
next_walltime_check_iter = iter_count + std::max(1, iters_to_next_check);
337+
338+
logger::info_ln(
339+
"SPH",
340+
shambase::format(
341+
"next walltime check in {:.2f}s (niter = {}) global walltime = {:.2f}s",
342+
iters_to_next_check * sec_per_iter,
343+
iters_to_next_check,
344+
global_walltime));
264345
}
265346
}
266347

267348
print_timestep_logs();
268349

269-
return true;
350+
return {
351+
.reach_target_time = true,
352+
.reach_niter_max = false,
353+
.reach_max_walltime = false,
354+
.iter_count = iter_count,
355+
};
270356
}
271357
};
272358

src/shammodels/sph/src/pySPHModel.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,12 +734,13 @@ void add_instance(py::module &m, std::string name_config, std::string name_model
734734
.def("evolve_once", &T::evolve_once)
735735
.def(
736736
"evolve_until",
737-
[](T &self, f64 target_time, i32 niter_max) {
738-
return self.evolve_until(target_time, niter_max);
737+
[](T &self, f64 target_time, i32 niter_max, f64 max_walltime) {
738+
return self.evolve_until(target_time, niter_max, max_walltime);
739739
},
740740
py::arg("target_time"),
741741
py::kw_only(),
742-
py::arg("niter_max") = -1)
742+
py::arg("niter_max") = -1,
743+
py::arg("max_walltime") = -1)
743744
.def(
744745
"set_dt",
745746
[](T &self, f64 dt) {
@@ -1649,6 +1650,21 @@ ON_PYTHON_INIT {
16491650

16501651
py::module msph = m.def_submodule("model_sph", "Shamrock sph solver");
16511652

1653+
py::class_<EvolveUntilResults>(m, "EvolveUntilResults")
1654+
.def_readwrite("reach_target_time", &EvolveUntilResults::reach_target_time)
1655+
.def_readwrite("reach_niter_max", &EvolveUntilResults::reach_niter_max)
1656+
.def_readwrite("reach_max_walltime", &EvolveUntilResults::reach_max_walltime)
1657+
.def_readwrite("iter_count", &EvolveUntilResults::iter_count)
1658+
.def("__repr__", [](const EvolveUntilResults &self) {
1659+
return shambase::format(
1660+
"EvolveUntilResults(reach_target_time={}, reach_niter_max={}, "
1661+
"reach_max_walltime={}, iter_count={})",
1662+
self.reach_target_time,
1663+
self.reach_niter_max,
1664+
self.reach_max_walltime,
1665+
self.iter_count);
1666+
});
1667+
16521668
using namespace shammodels::sph;
16531669

16541670
add_instance<f64_3, shammath::M4>(msph, "SPHModel_f64_3_M4_SolverConfig", "SPHModel_f64_3_M4");

src/shampylib/src/pyShamsys.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "shambase/exception.hpp"
1717
#include "shambase/numeric_limits.hpp"
1818
#include "shambase/stacktrace.hpp"
19+
#include "shamalgs/collective/reduction.hpp"
1920
#include "shambindings/pybindaliases.hpp"
2021
#include "shampylib/pyNodeInstance.hpp"
2122
#include "shamrock/experimental_features.hpp"
@@ -156,6 +157,14 @@ ON_PYTHON_INIT {
156157
dump profiling data
157158
)pbdoc");
158159

160+
m.def("get_wtime", []() {
161+
return shambase::details::get_wtime();
162+
});
163+
164+
m.def("get_wtime_sync", []() {
165+
return shamalgs::collective::allreduce_max(shambase::details::get_wtime());
166+
});
167+
159168
py::module sys_module = m.def_submodule("sys", "system handling part of shamrock");
160169
sys_module.def("signal_handler", &shamsys::details::signal_callback_handler);
161170

0 commit comments

Comments
 (0)