3131#include " shamsys/legacy/log.hpp"
3232#include " shamtree/TreeTraversalCache.hpp"
3333#include < functional>
34+ #include < limits>
3435#include < memory>
3536#include < optional>
3637#include < stdexcept>
@@ -51,6 +52,14 @@ namespace shammodels::sph {
5152 inline f64 tcompute_max () { return shamalgs::collective::allreduce_max (tcompute); }
5253 };
5354
55+ struct EvolveUntilResults {
56+ bool reach_target_time;
57+ bool reach_niter_max;
58+ bool reach_max_walltime;
59+
60+ i32 iter_count;
61+ };
62+
5463 /* *
5564 * @brief The shamrock SPH model
5665 *
@@ -236,7 +245,16 @@ namespace shammodels::sph {
236245 return solver_config.get_dt_sph ();
237246 }
238247
239- inline bool evolve_until (Tscal target_time, i32 niter_max) {
248+ inline EvolveUntilResults evolve_until (
249+ Tscal target_time, i32 niter_max, f64 max_walltime = -1 ) {
250+
251+ f64 start_wall_time
252+ = shamalgs::collective::allreduce_max (shambase::details::get_wtime ());
253+
254+ const bool walltime_limit_active = (max_walltime != -1 );
255+ i32 next_walltime_check_iter
256+ = walltime_limit_active ? 1 : std::numeric_limits<i32 >::max ();
257+
240258 auto step = [&]() {
241259 Tscal dt = solver_config.get_dt_sph ();
242260 Tscal t = solver_config.get_time ();
@@ -260,13 +278,81 @@ namespace shammodels::sph {
260278
261279 if ((iter_count >= niter_max) && (niter_max != -1 )) {
262280 logger::info_ln (" SPH" , " stopping evolve until because of niter =" , iter_count);
263- return false ;
281+ return {
282+ .reach_target_time = false ,
283+ .reach_niter_max = true ,
284+ .reach_max_walltime = false ,
285+ .iter_count = iter_count,
286+ };
287+ }
288+
289+ if (walltime_limit_active && iter_count >= next_walltime_check_iter) {
290+ f64 global_walltime
291+ = shamalgs::collective::allreduce_max (shambase::details::get_wtime ());
292+
293+ if ((max_walltime != -1 ) && (global_walltime >= max_walltime)) {
294+ logger::info_ln (
295+ " SPH" ,
296+ shambase::format (
297+ " stopping evolve until because of "
298+ " max_walltime = {:.2f}s > {:.2f}s" ,
299+ global_walltime,
300+ max_walltime));
301+ return {
302+ .reach_target_time = false ,
303+ .reach_niter_max = false ,
304+ .reach_max_walltime = true ,
305+ .iter_count = iter_count,
306+ };
307+ }
308+
309+ f64 sec_per_iter
310+ = (global_walltime - start_wall_time) / static_cast <f64 >(iter_count);
311+
312+ auto get_remaining_iters = [&](f64 delta_walltime, f64 factor) -> i32 {
313+ if (sec_per_iter > 0 ) {
314+ f64 tmp = factor * delta_walltime / sec_per_iter;
315+ if (tmp > std::numeric_limits<i32 >::max ()) {
316+ return std::numeric_limits<i32 >::max ();
317+ }
318+ return static_cast <i32 >(tmp);
319+ }
320+ return 1000 ; // default to 1000 iterations if sec_per_iter is 0
321+ };
322+
323+ i32 iters_to_next_check = std::numeric_limits<i32 >::max ();
324+
325+ i32 next_walltime_type = 0 ; // 0 for walltime, 1 for global walltime
326+
327+ if (max_walltime != -1 ) {
328+ i32 iters_to_limit
329+ = get_remaining_iters (max_walltime - global_walltime, 0.1 );
330+ if (iters_to_limit < iters_to_next_check) {
331+ next_walltime_type = 1 ;
332+ iters_to_next_check = iters_to_limit;
333+ }
334+ }
335+
336+ next_walltime_check_iter = iter_count + std::max (1 , iters_to_next_check);
337+
338+ logger::info_ln (
339+ " SPH" ,
340+ shambase::format (
341+ " next walltime check in {:.2f}s (niter = {}) global walltime = {:.2f}s" ,
342+ iters_to_next_check * sec_per_iter,
343+ iters_to_next_check,
344+ global_walltime));
264345 }
265346 }
266347
267348 print_timestep_logs ();
268349
269- return true ;
350+ return {
351+ .reach_target_time = true ,
352+ .reach_niter_max = false ,
353+ .reach_max_walltime = false ,
354+ .iter_count = iter_count,
355+ };
270356 }
271357 };
272358
0 commit comments