11#pragma once
2+ #include < core/concepts.hpp>
23#include < core/adaptive_integrator.hpp>
34#include < core/state_creator.hpp>
45#include < vector>
@@ -20,10 +21,10 @@ namespace diffeq::integrators::ode {
2021 * Adaptive: Yes
2122 * Stiff: Excellent
2223 */
23- template <system_state S, can_be_time T = double >
24- class BDFIntegrator : public AdaptiveIntegrator <S, T > {
24+ template <system_state S>
25+ class BDFIntegrator : public AdaptiveIntegrator <S> {
2526public:
26- using base_type = AdaptiveIntegrator<S, T >;
27+ using base_type = AdaptiveIntegrator<S>;
2728 using state_type = typename base_type::state_type;
2829 using time_type = typename base_type::time_type;
2930 using value_type = typename base_type::value_type;
@@ -198,93 +199,125 @@ class BDFIntegrator : public AdaptiveIntegrator<S, T> {
198199 residual_it[i] += alpha_coeffs_[current_order_][j] * y_hist_it[i];
199200 }
200201
201- // Subtract beta*h* f term
202+ // Subtract the f term
202203 residual_it[i] -= beta_coeffs_[current_order_] * dt * f_new_it[i];
203204 }
204205
205206 // Check convergence
206- time_type residual_norm = static_cast <time_type>( 0 ) ;
207+ time_type residual_norm = 0 ;
207208 for (std::size_t i = 0 ; i < residual.size (); ++i) {
208209 auto residual_it = residual.begin ();
209210 residual_norm += residual_it[i] * residual_it[i];
210211 }
211212 residual_norm = std::sqrt (residual_norm);
212213
213214 if (residual_norm < newton_tolerance_) {
214- // Converged - calculate error estimate using lower order method
215+ // Newton iteration converged
215216 calculate_error_estimate (y_new, error, dt);
216217 return true ;
217218 }
218219
219- // Update y_new using simplified Newton update
220- // This is a simplified approach - a full implementation would compute the Jacobian
220+ // Update y_new using simplified Newton step
221+ // For simplicity, we use a diagonal approximation of the Jacobian
221222 for (std::size_t i = 0 ; i < y_new.size (); ++i) {
222223 auto y_new_it = y_new.begin ();
223224 auto residual_it = residual.begin ();
225+ auto f_new_it = f_new.begin ();
226+
227+ // Simplified Newton update: y_new -= residual / (alpha[0] - beta*h*df/dy)
228+ // We approximate df/dy using finite differences
229+ time_type df_dy = estimate_jacobian_diagonal (i, y_new, dt);
230+ time_type denominator = alpha_coeffs_[current_order_][0 ] - beta_coeffs_[current_order_] * dt * df_dy;
224231
225- y_new_it[i] = y_new_it[i] - residual_it[i] / alpha_coeffs_[current_order_][0 ];
232+ if (std::abs (denominator) > newton_tolerance_) {
233+ y_new_it[i] -= residual_it[i] / denominator;
234+ }
226235 }
227236 }
228237
229- return false ; // Newton iteration failed to converge
238+ // Newton iteration failed to converge
239+ return false ;
240+ }
241+
242+ time_type estimate_jacobian_diagonal (std::size_t i, const state_type& y, time_type dt) {
243+ // Estimate diagonal element of Jacobian using finite differences
244+ time_type epsilon = static_cast <time_type>(1e-8 );
245+ state_type y_pert = StateCreator<state_type>::create (y);
246+ state_type f_orig = StateCreator<state_type>::create (y);
247+ state_type f_pert = StateCreator<state_type>::create (y);
248+
249+ // Evaluate f at original point
250+ this ->sys_ (this ->current_time_ + dt, y, f_orig);
251+
252+ // Perturb y[i] and evaluate f
253+ y_pert = y;
254+ auto y_pert_it = y_pert.begin ();
255+ y_pert_it[i] += epsilon;
256+ this ->sys_ (this ->current_time_ + dt, y_pert, f_pert);
257+
258+ // Estimate ∂f_i/∂y_i
259+ auto f_orig_it = f_orig.begin ();
260+ auto f_pert_it = f_pert.begin ();
261+ return (f_pert_it[i] - f_orig_it[i]) / epsilon;
230262 }
231263
232264 void calculate_error_estimate (const state_type& y_new, state_type& error, time_type dt) {
233- // Improved error estimate using difference between current and lower order methods
234- if (current_order_ > 1 && y_history_.size () >= static_cast < size_t >( current_order_) ) {
235- // Use difference between current order and order-1 solution
236- state_type y_lower = StateCreator<state_type>::create (y_new);
265+ // Simple error estimate based on the difference between current and previous order solutions
266+ if (y_history_.size () >= 2 && current_order_ > 1 ) {
267+ // Use the difference between current order and previous order as error estimate
268+ state_type y_prev_order = StateCreator<state_type>::create (y_new);
237269
238- // Calculate order-1 solution (backward Euler)
270+ // Reconstruct solution using previous order
239271 for (std::size_t i = 0 ; i < y_new.size (); ++i) {
240- auto y_lower_it = y_lower.begin ();
241- auto y_hist_it = y_history_[0 ].begin ();
242- y_lower_it[i] = y_hist_it[i];
243- }
244-
245- // Simple backward Euler step
246- state_type f_lower = StateCreator<state_type>::create (y_lower);
247- this ->sys_ (this ->current_time_ + dt, y_lower, f_lower);
248-
249- for (std::size_t i = 0 ; i < y_new.size (); ++i) {
250- auto y_lower_it = y_lower.begin ();
251- auto f_lower_it = f_lower.begin ();
252- y_lower_it[i] += dt * f_lower_it[i];
253- }
254-
255- // Error estimate is the difference
256- for (std::size_t i = 0 ; i < error.size (); ++i) {
257- auto error_it = error.begin ();
272+ auto y_prev_order_it = y_prev_order.begin ();
258273 auto y_new_it = y_new.begin ();
259- auto y_lower_it = y_lower.begin ();
260- error_it[i] = y_new_it[i] - y_lower_it[i];
274+
275+ y_prev_order_it[i] = alpha_coeffs_[current_order_ - 1 ][0 ] * y_new_it[i];
276+
277+ // Add history terms for previous order
278+ for (int j = 1 ; j < current_order_ && j < static_cast <int >(y_history_.size ()); ++j) {
279+ auto y_hist_it = y_history_[j].begin ();
280+ y_prev_order_it[i] += alpha_coeffs_[current_order_ - 1 ][j] * y_hist_it[i];
281+ }
282+
283+ // Calculate error as difference
284+ auto error_it = error.begin ();
285+ error_it[i] = std::abs (y_new_it[i] - y_prev_order_it[i]);
261286 }
262287 } else {
263288 // Fallback error estimate
264- for (std::size_t i = 0 ; i < error .size (); ++i) {
289+ for (std::size_t i = 0 ; i < y_new .size (); ++i) {
265290 auto error_it = error.begin ();
266- error_it[i] = dt * static_cast <time_type>(1e-6 );
291+ error_it[i] = static_cast <time_type>(1e-6 );
267292 }
268293 }
269294 }
270295
271296 time_type fallback_step (state_type& state, time_type dt) {
272- // Fallback to simple backward Euler when BDF fails
273- // This ensures the integrator doesn't crash on very stiff problems
297+ // Fallback to backward Euler with very small step
298+ time_type small_dt = std::min (dt, static_cast <time_type>( 1e-6 ));
274299
275- time_type actual_dt = std::min (dt, static_cast <time_type>(1e-6 )); // Very small step
300+ state_type y_new = StateCreator<state_type>::create (state);
301+ state_type f_new = StateCreator<state_type>::create (state);
276302
277- state_type f = StateCreator<state_type>::create (state);
278- this ->sys_ (this ->current_time_ + actual_dt, state, f);
279-
280- for (std::size_t i = 0 ; i < state.size (); ++i) {
281- auto state_it = state.begin ();
282- auto f_it = f.begin ();
283- state_it[i] += actual_dt * f_it[i];
303+ // Simple backward Euler iteration
304+ y_new = state;
305+ for (int iter = 0 ; iter < 5 ; ++iter) {
306+ this ->sys_ (this ->current_time_ + small_dt, y_new, f_new);
307+
308+ for (std::size_t i = 0 ; i < state.size (); ++i) {
309+ auto y_new_it = y_new.begin ();
310+ auto f_new_it = f_new.begin ();
311+ auto state_it = state.begin ();
312+
313+ y_new_it[i] = state_it[i] + small_dt * f_new_it[i];
314+ }
284315 }
285316
286- this ->advance_time (actual_dt);
287- return actual_dt * static_cast <time_type>(2.0 ); // Suggest doubling for next step
317+ state = y_new;
318+ this ->advance_time (small_dt);
319+
320+ return small_dt;
288321 }
289322};
290323
0 commit comments