Skip to content

Commit b6b3c5e

Browse files
committed
DOP853Integrator now runs quickly without hanging, but still inaccurate.
1 parent 6d1288c commit b6b3c5e

File tree

1 file changed

+174
-89
lines changed

1 file changed

+174
-89
lines changed

include/integrators/ode/dop853.hpp

Lines changed: 174 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <cmath>
55
#include <array>
66
#include <algorithm>
7+
#include <fstream>
8+
#include <chrono>
79

810
namespace diffeq::integrators::ode {
911

@@ -47,58 +49,84 @@ class DOP853Integrator : public AdaptiveIntegrator<S, T> {
4749
}
4850

4951
time_type adaptive_step(state_type& state, time_type dt) override {
52+
// Debug log setup (append mode, allow env override for test isolation)
53+
static std::string log_name = [](){
54+
const char* env = std::getenv("DOP853_DEBUG_LOG");
55+
return env ? std::string(env) : std::string("dop853_debug.log");
56+
}();
57+
static std::ofstream debug_log(log_name, std::ios::app);
58+
static constexpr int MAX_STEPS = 1000000; // safety limit
59+
static constexpr double MAX_SECONDS = 30.0; // timeout in seconds
60+
int step_count = 0;
61+
auto start_time = std::chrono::steady_clock::now();
62+
5063
time_type h_abs = std::abs(dt);
51-
time_type min_step = 10 * std::abs(std::nextafter(this->current_time_,
52-
this->current_time_ + dt) - this->current_time_);
53-
64+
time_type min_step = 10 * std::abs(std::nextafter(this->current_time_, this->current_time_ + dt) - this->current_time_);
65+
5466
if (h_abs > this->dt_max_) {
5567
h_abs = this->dt_max_;
5668
} else if (h_abs < std::max(min_step, this->dt_min_)) {
5769
h_abs = std::max(min_step, this->dt_min_);
5870
}
59-
71+
6072
bool step_accepted = false;
6173
bool step_rejected = false;
6274
time_type actual_dt = 0;
63-
75+
6476
while (!step_accepted) {
77+
++step_count;
78+
auto now = std::chrono::steady_clock::now();
79+
double elapsed = std::chrono::duration<double>(now - start_time).count();
80+
if (step_count > MAX_STEPS || elapsed > MAX_SECONDS) {
81+
debug_log << "[TIMEOUT] Step limit or time exceeded. Aborting.\n";
82+
throw std::runtime_error("DOP853 adaptive_step timeout or too many steps");
83+
}
84+
6585
time_type error_norm = rk_step(state, state, h_abs);
66-
86+
debug_log << "step: " << step_count << ", t: " << this->current_time_ << ", h: " << h_abs << ", error_norm: " << error_norm << ", accepted: " << step_accepted << ", rejected: " << step_rejected << std::endl;
87+
88+
if (std::isnan(error_norm) || std::isinf(error_norm)) {
89+
debug_log << "[ERROR] error_norm is NaN or Inf. Aborting.\n";
90+
throw std::runtime_error("DOP853 error_norm is NaN or Inf");
91+
}
92+
6793
if (error_norm < 1.0) {
6894
step_accepted = true;
6995
actual_dt = h_abs;
7096
this->advance_time(h_abs);
71-
97+
7298
if (!step_rejected) {
7399
// Suggest next step size
74-
time_type factor = std::min(max_factor_,
75-
safety_ * std::pow(error_norm, error_exponent_));
100+
time_type factor = std::min(max_factor_, safety_ * std::pow(std::max(error_norm, static_cast<time_type>(1e-10)), error_exponent_));
76101
h_abs = std::min(this->dt_max_, h_abs * factor);
77102
}
78103
} else {
79104
step_rejected = true;
80-
time_type factor = std::max(min_factor_,
81-
safety_ * std::pow(error_norm, error_exponent_));
105+
time_type factor = std::max(min_factor_, safety_ * std::pow(std::max(error_norm, static_cast<time_type>(1e-10)), error_exponent_));
82106
h_abs = std::max(this->dt_min_, h_abs * factor);
83107
}
84108
}
85-
109+
110+
debug_log << "[COMPLETE] t: " << this->current_time_ << ", dt: " << actual_dt << std::endl;
111+
debug_log.flush();
86112
return actual_dt;
87113
}
88114

89115
private:
90116
std::array<time_type, N_STAGES + 1> C_;
91117
std::array<std::array<time_type, N_STAGES>, N_STAGES> A_;
92118
std::array<time_type, N_STAGES> B_;
119+
std::array<time_type, N_STAGES> ER_; // Error estimation weights (Fortran er1, er6, ...)
120+
time_type BHH1_, BHH2_, BHH3_; // Error estimation weights (Fortran bhh1, bhh2, bhh3)
93121
std::array<time_type, N_STAGES + 1> E3_; // 3rd order error estimate
94122
std::array<time_type, N_STAGES + 1> E5_; // 5th order error estimate
95-
123+
96124
// Adaptive step control parameters
97125
time_type safety_;
98126
time_type min_factor_;
99127
time_type max_factor_;
100128
static constexpr time_type error_exponent_ = -1.0 / 8.0; // -1/(order+1) for 7th order error estimator
101-
129+
102130
void initialize_coefficients() {
103131
// C coefficients (times for stages)
104132
C_[0] = static_cast<time_type>(0.0);
@@ -126,92 +154,149 @@ class DOP853Integrator : public AdaptiveIntegrator<S, T> {
126154
}
127155
}
128156

129-
// Fill in some key coefficients (simplified)
130-
A_[1][0] = static_cast<time_type>(0.526001519587677318785587544488e-01);
131-
A_[2][0] = static_cast<time_type>(0.197250569845378994544595329183e-01);
132-
A_[2][1] = static_cast<time_type>(0.591751709536137983633785987549e-01);
133-
// ... (many more coefficients in full implementation)
134-
135-
// B coefficients for final solution (simplified)
136-
B_[0] = static_cast<time_type>(0.0295532805322554043052460699239e+00);
137-
B_[1] = static_cast<time_type>(0.0);
138-
B_[2] = static_cast<time_type>(0.0);
139-
B_[3] = static_cast<time_type>(0.0);
140-
B_[4] = static_cast<time_type>(0.0681942582430981642978628893916e+00);
141-
// ... (more coefficients)
142-
143-
// Error estimation coefficients (simplified)
157+
// Fill in some key coefficients (Fortran DOP853, partial, user should complete for production)
158+
A_[1][0] = static_cast<time_type>(5.26001519587677318785587544488e-2); // a21
159+
A_[2][0] = static_cast<time_type>(1.97250569845378994544595329183e-2); // a31
160+
A_[2][1] = static_cast<time_type>(5.91751709536136983633785987549e-2); // a32
161+
A_[3][0] = static_cast<time_type>(2.95875854768068491816892993775e-2); // a41
162+
A_[3][2] = static_cast<time_type>(8.87627564304205475450678981324e-2); // a43
163+
A_[4][0] = static_cast<time_type>(2.41365134159266685502369798665e-1); // a51
164+
A_[4][2] = static_cast<time_type>(-8.84549479328286085344864962717e-1); // a53
165+
A_[4][3] = static_cast<time_type>(9.24834003261792003115737966543e-1); // a54
166+
A_[5][0] = static_cast<time_type>(3.7037037037037037037037037037e-2); // a61
167+
A_[5][3] = static_cast<time_type>(1.70828608729473871279604482173e-1); // a64
168+
A_[5][4] = static_cast<time_type>(1.25467687566822425016691814123e-1); // a65
169+
A_[6][0] = static_cast<time_type>(3.7109375e-2); // a71
170+
A_[6][3] = static_cast<time_type>(1.70252211019544039314978060272e-1); // a74
171+
A_[6][4] = static_cast<time_type>(6.02165389804559606850219397283e-2); // a75
172+
A_[6][5] = static_cast<time_type>(-1.7578125e-2); // a76
173+
A_[7][0] = static_cast<time_type>(3.70920001185047927108779319836e-2); // a81
174+
A_[7][3] = static_cast<time_type>(1.70383925712239993810214054705e-1); // a84
175+
A_[7][4] = static_cast<time_type>(1.07262030446373284651809199168e-1); // a85
176+
A_[7][5] = static_cast<time_type>(-1.53194377486244017527936158236e-2); // a86
177+
A_[7][6] = static_cast<time_type>(8.27378916381402288758473766002e-3); // a87
178+
// ... (fill in all A_[8][*], A_[9][*], A_[10][*], A_[11][*] as needed)
179+
180+
// B coefficients for final solution (Fortran b1, b6, b7, ...)
181+
B_[0] = static_cast<time_type>(5.42937341165687622380535766363e-2); // b1
182+
B_[5] = static_cast<time_type>(4.45031289275240888144113950566e0); // b6
183+
B_[6] = static_cast<time_type>(1.89151789931450038304281599044e0); // b7
184+
B_[7] = static_cast<time_type>(-5.8012039600105847814672114227e0); // b8
185+
B_[8] = static_cast<time_type>(3.1116436695781989440891606237e-1); // b9
186+
B_[9] = static_cast<time_type>(-1.52160949662516078556178806805e-1); // b10
187+
B_[10] = static_cast<time_type>(2.01365400804030348374776537501e-1); // b11
188+
B_[11] = static_cast<time_type>(4.47106157277725905176885569043e-2); // b12
189+
190+
// Error estimation weights (Fortran er1, er6, ...)
191+
ER_.fill(static_cast<time_type>(0.0));
192+
ER_[0] = static_cast<time_type>(0.1312004499419488073250102996e-1); // er1
193+
ER_[5] = static_cast<time_type>(-0.1225156446376204440720569753e+01); // er6
194+
ER_[6] = static_cast<time_type>(-0.4957589496572501915214079952e+00); // er7
195+
ER_[7] = static_cast<time_type>(0.1664377182454986536961530415e+01); // er8
196+
ER_[8] = static_cast<time_type>(-0.3503288487499736816886487290e+00); // er9
197+
ER_[9] = static_cast<time_type>(0.3341791187130174790297318841e+00); // er10
198+
ER_[10] = static_cast<time_type>(0.8192320648511571246570742613e-01); // er11
199+
ER_[11] = static_cast<time_type>(-0.2235530786388629525884427845e-01); // er12
200+
201+
// Error estimation weights (Fortran bhh1, bhh2, bhh3)
202+
BHH1_ = static_cast<time_type>(0.244094488188976377952755905512e+00);
203+
BHH2_ = static_cast<time_type>(0.733846688281611857341361741547e+00);
204+
BHH3_ = static_cast<time_type>(0.220588235294117647058823529412e-01);
205+
206+
// Error estimation coefficients (simplified, legacy)
144207
for (int i = 0; i <= N_STAGES; ++i) {
145208
E3_[i] = E5_[i] = static_cast<time_type>(0.0);
146209
}
147210
E3_[0] = static_cast<time_type>(1e-6); // Simplified error estimate
148211
E5_[0] = static_cast<time_type>(1e-8); // Simplified error estimate
149212
}
150213

151-
// Core RK step implementation following scipy's rk_step function
214+
// DOP853 12-stage Runge-Kutta step, Fortran-aligned
152215
time_type rk_step(const state_type& y, state_type& y_new, time_type h) {
153-
// Simplified implementation using RK4 with better error estimation
154-
// A full DOP853 would implement all 13 stages
155-
156-
state_type k1 = StateCreator<state_type>::create(y);
157-
state_type k2 = StateCreator<state_type>::create(y);
158-
state_type k3 = StateCreator<state_type>::create(y);
159-
state_type k4 = StateCreator<state_type>::create(y);
160-
state_type temp = StateCreator<state_type>::create(y);
161-
216+
constexpr int N = N_STAGES;
217+
std::vector<state_type> k(N, StateCreator<state_type>::create(y));
218+
state_type y1 = StateCreator<state_type>::create(y);
162219
time_type t = this->current_time_;
163-
164-
// k1 = f(t, y)
165-
this->sys_(t, y, k1);
166-
167-
// k2 = f(t + h/2, y + h*k1/2)
168-
for (std::size_t i = 0; i < y.size(); ++i) {
169-
auto y_it = y.begin();
170-
auto k1_it = k1.begin();
171-
auto temp_it = temp.begin();
172-
temp_it[i] = y_it[i] + h * k1_it[i] / static_cast<time_type>(2);
173-
}
174-
this->sys_(t + h / static_cast<time_type>(2), temp, k2);
175-
176-
// k3 = f(t + h/2, y + h*k2/2)
177-
for (std::size_t i = 0; i < y.size(); ++i) {
178-
auto y_it = y.begin();
179-
auto k2_it = k2.begin();
180-
auto temp_it = temp.begin();
181-
temp_it[i] = y_it[i] + h * k2_it[i] / static_cast<time_type>(2);
182-
}
183-
this->sys_(t + h / static_cast<time_type>(2), temp, k3);
184-
185-
// k4 = f(t + h, y + h*k3)
186-
for (std::size_t i = 0; i < y.size(); ++i) {
187-
auto y_it = y.begin();
188-
auto k3_it = k3.begin();
189-
auto temp_it = temp.begin();
190-
temp_it[i] = y_it[i] + h * k3_it[i];
191-
}
192-
this->sys_(t + h, temp, k4);
193-
194-
// Final solution
195-
for (std::size_t i = 0; i < y.size(); ++i) {
196-
auto y_it = y.begin();
197-
auto k1_it = k1.begin();
198-
auto k2_it = k2.begin();
199-
auto k3_it = k3.begin();
200-
auto k4_it = k4.begin();
201-
auto y_new_it = y_new.begin();
202-
203-
y_new_it[i] = y_it[i] + h * (k1_it[i] + static_cast<time_type>(2) * k2_it[i] +
204-
static_cast<time_type>(2) * k3_it[i] + k4_it[i]) / static_cast<time_type>(6);
220+
221+
// Stage 1: k1 = f(t, y)
222+
this->sys_(t, y, k[0]);
223+
224+
// Stage 2: y1 = y + h*a21*k1
225+
for (size_t i = 0; i < y.size(); ++i)
226+
y1[i] = y[i] + h * A_[1][0] * k[0][i];
227+
this->sys_(t + C_[1] * h, y1, k[1]);
228+
229+
// Stage 3: y1 = y + h*(a31*k1 + a32*k2)
230+
for (size_t i = 0; i < y.size(); ++i)
231+
y1[i] = y[i] + h * (A_[2][0] * k[0][i] + A_[2][1] * k[1][i]);
232+
this->sys_(t + C_[2] * h, y1, k[2]);
233+
234+
// Stage 4: y1 = y + h*(a41*k1 + a43*k3)
235+
for (size_t i = 0; i < y.size(); ++i)
236+
y1[i] = y[i] + h * (A_[3][0] * k[0][i] + A_[3][2] * k[2][i]);
237+
this->sys_(t + C_[3] * h, y1, k[3]);
238+
239+
// Stage 5: y1 = y + h*(a51*k1 + a53*k3 + a54*k4)
240+
for (size_t i = 0; i < y.size(); ++i)
241+
y1[i] = y[i] + h * (A_[4][0] * k[0][i] + A_[4][2] * k[2][i] + A_[4][3] * k[3][i]);
242+
this->sys_(t + C_[4] * h, y1, k[4]);
243+
244+
// Stage 6: y1 = y + h*(a61*k1 + a64*k4 + a65*k5)
245+
for (size_t i = 0; i < y.size(); ++i)
246+
y1[i] = y[i] + h * (A_[5][0] * k[0][i] + A_[5][3] * k[3][i] + A_[5][4] * k[4][i]);
247+
this->sys_(t + C_[5] * h, y1, k[5]);
248+
249+
// Stage 7: y1 = y + h*(a71*k1 + a74*k4 + a75*k5 + a76*k6)
250+
for (size_t i = 0; i < y.size(); ++i)
251+
y1[i] = y[i] + h * (A_[6][0] * k[0][i] + A_[6][3] * k[3][i] + A_[6][4] * k[4][i] + A_[6][5] * k[5][i]);
252+
this->sys_(t + C_[6] * h, y1, k[6]);
253+
254+
// Stage 8: y1 = y + h*(a81*k1 + a84*k4 + a85*k5 + a86*k6 + a87*k7)
255+
for (size_t i = 0; i < y.size(); ++i)
256+
y1[i] = y[i] + h * (A_[7][0] * k[0][i] + A_[7][3] * k[3][i] + A_[7][4] * k[4][i] + A_[7][5] * k[5][i] + A_[7][6] * k[6][i]);
257+
this->sys_(t + C_[7] * h, y1, k[7]);
258+
259+
// Stage 9: y1 = y + h*(a91*k1 + a94*k4 + a95*k5 + a96*k6 + a97*k7 + a98*k8)
260+
for (size_t i = 0; i < y.size(); ++i)
261+
y1[i] = y[i] + h * (A_[8][0] * k[0][i] + A_[8][3] * k[3][i] + A_[8][4] * k[4][i] + A_[8][5] * k[5][i] + A_[8][6] * k[6][i] + A_[8][7] * k[7][i]);
262+
this->sys_(t + C_[8] * h, y1, k[8]);
263+
264+
// Stage 10: y1 = y + h*(a101*k1 + a104*k4 + a105*k5 + a106*k6 + a107*k7 + a108*k8 + a109*k9)
265+
for (size_t i = 0; i < y.size(); ++i)
266+
y1[i] = y[i] + h * (A_[9][0] * k[0][i] + A_[9][3] * k[3][i] + A_[9][4] * k[4][i] + A_[9][5] * k[5][i] + A_[9][6] * k[6][i] + A_[9][7] * k[7][i] + A_[9][8] * k[8][i]);
267+
this->sys_(t + C_[9] * h, y1, k[9]);
268+
269+
// Stage 11: y1 = y + h*(a111*k1 + a114*k4 + a115*k5 + a116*k6 + a117*k7 + a118*k8 + a119*k9 + a1110*k10)
270+
for (size_t i = 0; i < y.size(); ++i)
271+
y1[i] = y[i] + h * (A_[10][0] * k[0][i] + A_[10][3] * k[3][i] + A_[10][4] * k[4][i] + A_[10][5] * k[5][i] + A_[10][6] * k[6][i] + A_[10][7] * k[7][i] + A_[10][8] * k[8][i] + A_[10][9] * k[9][i]);
272+
this->sys_(t + C_[10] * h, y1, k[10]);
273+
274+
// Stage 12: y1 = y + h*(a121*k1 + a124*k4 + a125*k5 + a126*k6 + a127*k7 + a128*k8 + a129*k9 + a1210*k10 + a1211*k11)
275+
for (size_t i = 0; i < y.size(); ++i)
276+
y1[i] = y[i] + h * (A_[11][0] * k[0][i] + A_[11][3] * k[3][i] + A_[11][4] * k[4][i] + A_[11][5] * k[5][i] + A_[11][6] * k[6][i] + A_[11][7] * k[7][i] + A_[11][8] * k[8][i] + A_[11][9] * k[9][i] + A_[11][10] * k[10][i]);
277+
this->sys_(t + C_[11] * h, y1, k[11]);
278+
279+
// 8th order solution (main step)
280+
for (size_t i = 0; i < y.size(); ++i) {
281+
y_new[i] = y[i]
282+
+ h * (B_[0] * k[0][i] + B_[5] * k[5][i] + B_[6] * k[6][i] + B_[7] * k[7][i] + B_[8] * k[8][i] + B_[9] * k[9][i] + B_[10] * k[10][i] + B_[11] * k[11][i]);
205283
}
206-
207-
// Error estimation (simplified)
208-
state_type error = StateCreator<state_type>::create(y);
209-
for (std::size_t i = 0; i < error.size(); ++i) {
210-
auto error_it = error.begin();
211-
error_it[i] = h * static_cast<time_type>(1e-8); // Simplified error estimate
284+
285+
// Error estimation (Fortran-aligned)
286+
time_type err = 0, err2 = 0;
287+
for (size_t i = 0; i < y.size(); ++i) {
288+
time_type sk = this->atol_ + this->rtol_ * std::max(std::abs(y[i]), std::abs(y_new[i]));
289+
// First error component
290+
time_type erri1 = y_new[i] - (BHH1_ * k[0][i] + BHH2_ * k[8][i] + BHH3_ * k[11][i]);
291+
err2 += (erri1 / sk) * (erri1 / sk);
292+
// Second error component
293+
time_type erri2 = ER_[0] * k[0][i] + ER_[5] * k[5][i] + ER_[6] * k[6][i] + ER_[7] * k[7][i] + ER_[8] * k[8][i] + ER_[9] * k[9][i] + ER_[10] * k[10][i] + ER_[11] * k[11][i];
294+
err += (erri2 / sk) * (erri2 / sk);
212295
}
213-
214-
return this->error_norm(error, y_new);
296+
time_type deno = err + 0.01 * err2;
297+
if (deno <= 0.0) deno = 1.0;
298+
err = std::abs(h) * std::sqrt(err / (y.size() * deno));
299+
return err;
215300
}
216301
};
217302

0 commit comments

Comments
 (0)