Skip to content

Commit 8f88f29

Browse files
committed
remove the TimeType from template <system_state State, can_be_time TimeType>, which can be inferred from State
1 parent abec6be commit 8f88f29

File tree

15 files changed

+309
-326
lines changed

15 files changed

+309
-326
lines changed

include/core/abstract_integrator.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
#include "concepts.hpp"
66

77
// Abstract integrator base class
8-
template<system_state S, can_be_time T = double>
8+
template<system_state S>
99
class AbstractIntegrator {
1010
public:
1111
using state_type = S;
12-
using time_type = T;
12+
using time_type = typename S::value_type;
1313
using value_type = typename S::value_type;
1414
using system_function = std::function<void(time_type, const state_type&, state_type&)>;
1515

1616
explicit AbstractIntegrator(system_function sys)
17-
: sys_(std::move(sys)), current_time_(T{0}) {}
17+
: sys_(std::move(sys)), current_time_(time_type{0}) {}
1818

1919
virtual ~AbstractIntegrator() = default;
2020

include/core/adaptive_integrator.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
#include <core/state_creator.hpp>
1414

1515
// Abstract adaptive integrator with error control
16-
template<system_state S, can_be_time T = double>
17-
class AdaptiveIntegrator : public AbstractIntegrator<S, T> {
16+
template<system_state S>
17+
class AdaptiveIntegrator : public AbstractIntegrator<S> {
1818
public:
19-
using base_type = AbstractIntegrator<S, T>;
19+
using base_type = AbstractIntegrator<S>;
2020
using state_type = typename base_type::state_type;
2121
using time_type = typename base_type::time_type;
2222
using value_type = typename base_type::value_type;

include/core/concepts.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
#include <iterator>
55
#include <string>
66

7-
// Time type concept - supports both integer and floating point time
8-
template<typename T>
9-
concept can_be_time = std::is_arithmetic_v<T>;
10-
11-
// 状态概念 - 支持向量、矩阵、多维张量等类型
7+
// State concept - supports vectors, matrices, multi-dimensional tensors, etc.
128
template<typename T>
139
concept system_state = requires(T state) {
1410
typename T::value_type;

include/integrators/ode/bdf.hpp

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include <core/concepts.hpp>
23
#include <core/adaptive_integrator.hpp>
34
#include <core/state_creator.hpp>
45
#include <vector>
@@ -20,10 +21,10 @@ namespace diffeq::integrators::ode {
2021
* Adaptive: Yes
2122
* Stiff: Excellent
2223
*/
23-
template<system_state S, can_be_time T = double>
24-
class BDFIntegrator : public AdaptiveIntegrator<S, T> {
24+
template<system_state S>
25+
class BDFIntegrator : public AdaptiveIntegrator<S> {
2526
public:
26-
using base_type = AdaptiveIntegrator<S, T>;
27+
using base_type = AdaptiveIntegrator<S>;
2728
using state_type = typename base_type::state_type;
2829
using time_type = typename base_type::time_type;
2930
using value_type = typename base_type::value_type;
@@ -198,93 +199,125 @@ class BDFIntegrator : public AdaptiveIntegrator<S, T> {
198199
residual_it[i] += alpha_coeffs_[current_order_][j] * y_hist_it[i];
199200
}
200201

201-
// Subtract beta*h*f term
202+
// Subtract the f term
202203
residual_it[i] -= beta_coeffs_[current_order_] * dt * f_new_it[i];
203204
}
204205

205206
// Check convergence
206-
time_type residual_norm = static_cast<time_type>(0);
207+
time_type residual_norm = 0;
207208
for (std::size_t i = 0; i < residual.size(); ++i) {
208209
auto residual_it = residual.begin();
209210
residual_norm += residual_it[i] * residual_it[i];
210211
}
211212
residual_norm = std::sqrt(residual_norm);
212213

213214
if (residual_norm < newton_tolerance_) {
214-
// Converged - calculate error estimate using lower order method
215+
// Newton iteration converged
215216
calculate_error_estimate(y_new, error, dt);
216217
return true;
217218
}
218219

219-
// Update y_new using simplified Newton update
220-
// This is a simplified approach - a full implementation would compute the Jacobian
220+
// Update y_new using simplified Newton step
221+
// For simplicity, we use a diagonal approximation of the Jacobian
221222
for (std::size_t i = 0; i < y_new.size(); ++i) {
222223
auto y_new_it = y_new.begin();
223224
auto residual_it = residual.begin();
225+
auto f_new_it = f_new.begin();
226+
227+
// Simplified Newton update: y_new -= residual / (alpha[0] - beta*h*df/dy)
228+
// We approximate df/dy using finite differences
229+
time_type df_dy = estimate_jacobian_diagonal(i, y_new, dt);
230+
time_type denominator = alpha_coeffs_[current_order_][0] - beta_coeffs_[current_order_] * dt * df_dy;
224231

225-
y_new_it[i] = y_new_it[i] - residual_it[i] / alpha_coeffs_[current_order_][0];
232+
if (std::abs(denominator) > newton_tolerance_) {
233+
y_new_it[i] -= residual_it[i] / denominator;
234+
}
226235
}
227236
}
228237

229-
return false; // Newton iteration failed to converge
238+
// Newton iteration failed to converge
239+
return false;
240+
}
241+
242+
time_type estimate_jacobian_diagonal(std::size_t i, const state_type& y, time_type dt) {
243+
// Estimate diagonal element of Jacobian using finite differences
244+
time_type epsilon = static_cast<time_type>(1e-8);
245+
state_type y_pert = StateCreator<state_type>::create(y);
246+
state_type f_orig = StateCreator<state_type>::create(y);
247+
state_type f_pert = StateCreator<state_type>::create(y);
248+
249+
// Evaluate f at original point
250+
this->sys_(this->current_time_ + dt, y, f_orig);
251+
252+
// Perturb y[i] and evaluate f
253+
y_pert = y;
254+
auto y_pert_it = y_pert.begin();
255+
y_pert_it[i] += epsilon;
256+
this->sys_(this->current_time_ + dt, y_pert, f_pert);
257+
258+
// Estimate ∂f_i/∂y_i
259+
auto f_orig_it = f_orig.begin();
260+
auto f_pert_it = f_pert.begin();
261+
return (f_pert_it[i] - f_orig_it[i]) / epsilon;
230262
}
231263

232264
void calculate_error_estimate(const state_type& y_new, state_type& error, time_type dt) {
233-
// Improved error estimate using difference between current and lower order methods
234-
if (current_order_ > 1 && y_history_.size() >= static_cast<size_t>(current_order_)) {
235-
// Use difference between current order and order-1 solution
236-
state_type y_lower = StateCreator<state_type>::create(y_new);
265+
// Simple error estimate based on the difference between current and previous order solutions
266+
if (y_history_.size() >= 2 && current_order_ > 1) {
267+
// Use the difference between current order and previous order as error estimate
268+
state_type y_prev_order = StateCreator<state_type>::create(y_new);
237269

238-
// Calculate order-1 solution (backward Euler)
270+
// Reconstruct solution using previous order
239271
for (std::size_t i = 0; i < y_new.size(); ++i) {
240-
auto y_lower_it = y_lower.begin();
241-
auto y_hist_it = y_history_[0].begin();
242-
y_lower_it[i] = y_hist_it[i];
243-
}
244-
245-
// Simple backward Euler step
246-
state_type f_lower = StateCreator<state_type>::create(y_lower);
247-
this->sys_(this->current_time_ + dt, y_lower, f_lower);
248-
249-
for (std::size_t i = 0; i < y_new.size(); ++i) {
250-
auto y_lower_it = y_lower.begin();
251-
auto f_lower_it = f_lower.begin();
252-
y_lower_it[i] += dt * f_lower_it[i];
253-
}
254-
255-
// Error estimate is the difference
256-
for (std::size_t i = 0; i < error.size(); ++i) {
257-
auto error_it = error.begin();
272+
auto y_prev_order_it = y_prev_order.begin();
258273
auto y_new_it = y_new.begin();
259-
auto y_lower_it = y_lower.begin();
260-
error_it[i] = y_new_it[i] - y_lower_it[i];
274+
275+
y_prev_order_it[i] = alpha_coeffs_[current_order_ - 1][0] * y_new_it[i];
276+
277+
// Add history terms for previous order
278+
for (int j = 1; j < current_order_ && j < static_cast<int>(y_history_.size()); ++j) {
279+
auto y_hist_it = y_history_[j].begin();
280+
y_prev_order_it[i] += alpha_coeffs_[current_order_ - 1][j] * y_hist_it[i];
281+
}
282+
283+
// Calculate error as difference
284+
auto error_it = error.begin();
285+
error_it[i] = std::abs(y_new_it[i] - y_prev_order_it[i]);
261286
}
262287
} else {
263288
// Fallback error estimate
264-
for (std::size_t i = 0; i < error.size(); ++i) {
289+
for (std::size_t i = 0; i < y_new.size(); ++i) {
265290
auto error_it = error.begin();
266-
error_it[i] = dt * static_cast<time_type>(1e-6);
291+
error_it[i] = static_cast<time_type>(1e-6);
267292
}
268293
}
269294
}
270295

271296
time_type fallback_step(state_type& state, time_type dt) {
272-
// Fallback to simple backward Euler when BDF fails
273-
// This ensures the integrator doesn't crash on very stiff problems
297+
// Fallback to backward Euler with very small step
298+
time_type small_dt = std::min(dt, static_cast<time_type>(1e-6));
274299

275-
time_type actual_dt = std::min(dt, static_cast<time_type>(1e-6)); // Very small step
300+
state_type y_new = StateCreator<state_type>::create(state);
301+
state_type f_new = StateCreator<state_type>::create(state);
276302

277-
state_type f = StateCreator<state_type>::create(state);
278-
this->sys_(this->current_time_ + actual_dt, state, f);
279-
280-
for (std::size_t i = 0; i < state.size(); ++i) {
281-
auto state_it = state.begin();
282-
auto f_it = f.begin();
283-
state_it[i] += actual_dt * f_it[i];
303+
// Simple backward Euler iteration
304+
y_new = state;
305+
for (int iter = 0; iter < 5; ++iter) {
306+
this->sys_(this->current_time_ + small_dt, y_new, f_new);
307+
308+
for (std::size_t i = 0; i < state.size(); ++i) {
309+
auto y_new_it = y_new.begin();
310+
auto f_new_it = f_new.begin();
311+
auto state_it = state.begin();
312+
313+
y_new_it[i] = state_it[i] + small_dt * f_new_it[i];
314+
}
284315
}
285316

286-
this->advance_time(actual_dt);
287-
return actual_dt * static_cast<time_type>(2.0); // Suggest doubling for next step
317+
state = y_new;
318+
this->advance_time(small_dt);
319+
320+
return small_dt;
288321
}
289322
};
290323

include/integrators/ode/dop853.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
namespace diffeq::integrators::ode {
99

10-
template<system_state S, can_be_time T>
10+
template<system_state S>
1111
class DOP853Integrator;
1212

13-
template<system_state S, can_be_time T>
13+
template<system_state S>
1414
class DOP853DenseOutputHelper {
1515
public:
16-
using value_type = typename DOP853Integrator<S, T>::value_type;
16+
using value_type = typename DOP853Integrator<S>::value_type;
1717

1818
// Dense output for DOP853: ported from Fortran CONTD8
1919
// CON: continuous output coefficients, size 8*nd
@@ -47,12 +47,12 @@ class DOP853DenseOutputHelper {
4747
* Eighth-order method with embedded 5th and 3rd order error estimation.
4848
* Reference: Hairer, Norsett, Wanner, "Solving Ordinary Differential Equations I"
4949
*/
50-
template<system_state S, can_be_time T = double>
51-
class DOP853Integrator : public AdaptiveIntegrator<S, T> {
50+
template<system_state S>
51+
class DOP853Integrator : public AdaptiveIntegrator<S> {
5252

5353

5454
public:
55-
using base_type = AdaptiveIntegrator<S, T>;
55+
using base_type = AdaptiveIntegrator<S>;
5656
using state_type = typename base_type::state_type;
5757
using time_type = typename base_type::time_type;
5858
using value_type = typename base_type::value_type;

include/integrators/ode/euler.hpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
namespace diffeq::integrators::ode {
77

88
/**
9-
* @brief Simple Euler integrator: y_{n+1} = y_n + h * f(t_n, y_n)
9+
* @brief Forward Euler integrator
1010
*
11-
* First-order explicit method for ODEs.
12-
* Simple but not very accurate - mainly for educational purposes.
11+
* First-order explicit method. Simple but often unstable.
1312
*
1413
* Order: 1
15-
* Stability: Conditionally stable
14+
* Stability: Poor for stiff problems
15+
* Usage: Educational purposes, simple problems
1616
*/
17-
template<system_state S, can_be_time T = double>
18-
class EulerIntegrator : public AbstractIntegrator<S, T> {
17+
template<system_state S>
18+
class EulerIntegrator : public AbstractIntegrator<S> {
1919
public:
20-
using base_type = AbstractIntegrator<S, T>;
20+
using base_type = AbstractIntegrator<S>;
2121
using state_type = typename base_type::state_type;
2222
using time_type = typename base_type::time_type;
2323
using value_type = typename base_type::value_type;
@@ -28,17 +28,16 @@ class EulerIntegrator : public AbstractIntegrator<S, T> {
2828

2929
void step(state_type& state, time_type dt) override {
3030
// Create temporary state for derivative
31-
state_type dydt = StateCreator<state_type>::create(state);
31+
state_type derivative = StateCreator<state_type>::create(state);
3232

33-
// Compute derivative: dydt = f(t, y)
34-
this->sys_(this->current_time_, state, dydt);
33+
// Compute derivative: f(t, y)
34+
this->sys_(this->current_time_, state, derivative);
3535

36-
// Update state: y_new = y + dt * dydt
36+
// Update state: y_{n+1} = y_n + dt * f(t_n, y_n)
3737
for (std::size_t i = 0; i < state.size(); ++i) {
3838
auto state_it = state.begin();
39-
auto dydt_it = dydt.begin();
40-
41-
state_it[i] = state_it[i] + dt * dydt_it[i];
39+
auto deriv_it = derivative.begin();
40+
state_it[i] += dt * deriv_it[i];
4241
}
4342

4443
this->advance_time(dt);

include/integrators/ode/improved_euler.hpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,18 @@
66
namespace diffeq::integrators::ode {
77

88
/**
9-
* @brief Improved Euler (Heun's method): y_{n+1} = y_n + h/2 * (k1 + k2)
9+
* @brief Improved Euler (Heun's method) integrator
1010
*
11-
* Second-order explicit method where:
12-
* - k1 = f(t_n, y_n)
13-
* - k2 = f(t_n + h, y_n + h*k1)
14-
*
15-
* Also known as Heun's method or the explicit trapezoidal method.
11+
* Second-order explicit method. Better than basic Euler.
1612
*
1713
* Order: 2
18-
* Stability: Better than Euler for most problems
14+
* Stability: Better than Euler, but still poor for stiff problems
15+
* Usage: Simple problems where RK4 is overkill
1916
*/
20-
template<system_state S, can_be_time T = double>
21-
class ImprovedEulerIntegrator : public AbstractIntegrator<S, T> {
17+
template<system_state S>
18+
class ImprovedEulerIntegrator : public AbstractIntegrator<S> {
2219
public:
23-
using base_type = AbstractIntegrator<S, T>;
20+
using base_type = AbstractIntegrator<S>;
2421
using state_type = typename base_type::state_type;
2522
using time_type = typename base_type::time_type;
2623
using value_type = typename base_type::value_type;
@@ -34,29 +31,25 @@ class ImprovedEulerIntegrator : public AbstractIntegrator<S, T> {
3431
state_type k1 = StateCreator<state_type>::create(state);
3532
state_type k2 = StateCreator<state_type>::create(state);
3633
state_type temp_state = StateCreator<state_type>::create(state);
37-
34+
3835
// k1 = f(t, y)
3936
this->sys_(this->current_time_, state, k1);
4037

41-
// temp_state = y + dt * k1
38+
// k2 = f(t + dt, y + dt*k1)
4239
for (std::size_t i = 0; i < state.size(); ++i) {
4340
auto state_it = state.begin();
4441
auto k1_it = k1.begin();
4542
auto temp_it = temp_state.begin();
46-
4743
temp_it[i] = state_it[i] + dt * k1_it[i];
4844
}
49-
50-
// k2 = f(t + dt, temp_state)
5145
this->sys_(this->current_time_ + dt, temp_state, k2);
5246

5347
// y_new = y + dt/2 * (k1 + k2)
5448
for (std::size_t i = 0; i < state.size(); ++i) {
5549
auto state_it = state.begin();
5650
auto k1_it = k1.begin();
5751
auto k2_it = k2.begin();
58-
59-
state_it[i] = state_it[i] + dt * (k1_it[i] + k2_it[i]) / static_cast<time_type>(2);
52+
state_it[i] += dt * (k1_it[i] + k2_it[i]) / static_cast<time_type>(2);
6053
}
6154

6255
this->advance_time(dt);

0 commit comments

Comments
 (0)