Skip to content

Commit 3595d61

Browse files
committed
Refactor BDF and RK45 integrators for improved readability and performance
- Simplified residual and error calculations by removing unnecessary iterator usage. - Enhanced clarity in the integration steps by directly accessing vector elements. - Improved code maintainability by reducing complexity in the update processes for new states and errors.
1 parent 1836412 commit 3595d61

File tree

2 files changed

+21
-83
lines changed

2 files changed

+21
-83
lines changed

include/integrators/ode/bdf.hpp

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -187,27 +187,21 @@ class BDFIntegrator : public AdaptiveIntegrator<S> {
187187

188188
// Calculate residual: R = alpha[0]*y_{n+1} + sum(alpha[j]*y_{n+1-j}) - beta*h*f(t_{n+1}, y_{n+1})
189189
for (std::size_t i = 0; i < y_new.size(); ++i) {
190-
auto residual_it = residual.begin();
191-
auto y_new_it = y_new.begin();
192-
auto f_new_it = f_new.begin();
193-
194-
residual_it[i] = alpha_coeffs_[current_order_][0] * y_new_it[i];
190+
residual[i] = alpha_coeffs_[current_order_][0] * y_new[i];
195191

196192
// Add history terms
197193
for (int j = 1; j <= current_order_ && j < static_cast<int>(y_history_.size()); ++j) {
198-
auto y_hist_it = y_history_[j].begin();
199-
residual_it[i] += alpha_coeffs_[current_order_][j] * y_hist_it[i];
194+
residual[i] += alpha_coeffs_[current_order_][j] * y_history_[j][i];
200195
}
201196

202197
// Subtract the f term
203-
residual_it[i] -= beta_coeffs_[current_order_] * dt * f_new_it[i];
198+
residual[i] -= beta_coeffs_[current_order_] * dt * f_new[i];
204199
}
205200

206201
// Check convergence
207202
time_type residual_norm = 0;
208203
for (std::size_t i = 0; i < residual.size(); ++i) {
209-
auto residual_it = residual.begin();
210-
residual_norm += residual_it[i] * residual_it[i];
204+
residual_norm += residual[i] * residual[i];
211205
}
212206
residual_norm = std::sqrt(residual_norm);
213207

@@ -220,17 +214,13 @@ class BDFIntegrator : public AdaptiveIntegrator<S> {
220214
// Update y_new using simplified Newton step
221215
// For simplicity, we use a diagonal approximation of the Jacobian
222216
for (std::size_t i = 0; i < y_new.size(); ++i) {
223-
auto y_new_it = y_new.begin();
224-
auto residual_it = residual.begin();
225-
auto f_new_it = f_new.begin();
226-
227217
// Simplified Newton update: y_new -= residual / (alpha[0] - beta*h*df/dy)
228218
// We approximate df/dy using finite differences
229219
time_type df_dy = estimate_jacobian_diagonal(i, y_new, dt);
230220
time_type denominator = alpha_coeffs_[current_order_][0] - beta_coeffs_[current_order_] * dt * df_dy;
231221

232222
if (std::abs(denominator) > newton_tolerance_) {
233-
y_new_it[i] -= residual_it[i] / denominator;
223+
y_new[i] -= residual[i] / denominator;
234224
}
235225
}
236226
}
@@ -251,14 +241,11 @@ class BDFIntegrator : public AdaptiveIntegrator<S> {
251241

252242
// Perturb y[i] and evaluate f
253243
y_pert = y;
254-
auto y_pert_it = y_pert.begin();
255-
y_pert_it[i] += epsilon;
244+
y_pert[i] += epsilon;
256245
this->sys_(this->current_time_ + dt, y_pert, f_pert);
257246

258247
// Estimate ∂f_i/∂y_i
259-
auto f_orig_it = f_orig.begin();
260-
auto f_pert_it = f_pert.begin();
261-
return (f_pert_it[i] - f_orig_it[i]) / epsilon;
248+
return (f_pert[i] - f_orig[i]) / epsilon;
262249
}
263250

264251
void calculate_error_estimate(const state_type& y_new, state_type& error, time_type dt) {
@@ -269,26 +256,20 @@ class BDFIntegrator : public AdaptiveIntegrator<S> {
269256

270257
// Reconstruct solution using previous order
271258
for (std::size_t i = 0; i < y_new.size(); ++i) {
272-
auto y_prev_order_it = y_prev_order.begin();
273-
auto y_new_it = y_new.begin();
274-
275-
y_prev_order_it[i] = alpha_coeffs_[current_order_ - 1][0] * y_new_it[i];
259+
y_prev_order[i] = alpha_coeffs_[current_order_ - 1][0] * y_new[i];
276260

277261
// Add history terms for previous order
278262
for (int j = 1; j < current_order_ && j < static_cast<int>(y_history_.size()); ++j) {
279-
auto y_hist_it = y_history_[j].begin();
280-
y_prev_order_it[i] += alpha_coeffs_[current_order_ - 1][j] * y_hist_it[i];
263+
y_prev_order[i] += alpha_coeffs_[current_order_ - 1][j] * y_history_[j][i];
281264
}
282265

283266
// Calculate error as difference
284-
auto error_it = error.begin();
285-
error_it[i] = std::abs(y_new_it[i] - y_prev_order_it[i]);
267+
error[i] = std::abs(y_new[i] - y_prev_order[i]);
286268
}
287269
} else {
288270
// Fallback error estimate
289271
for (std::size_t i = 0; i < y_new.size(); ++i) {
290-
auto error_it = error.begin();
291-
error_it[i] = static_cast<time_type>(1e-6);
272+
error[i] = static_cast<time_type>(1e-6);
292273
}
293274
}
294275
}
@@ -306,11 +287,7 @@ class BDFIntegrator : public AdaptiveIntegrator<S> {
306287
this->sys_(this->current_time_ + small_dt, y_new, f_new);
307288

308289
for (std::size_t i = 0; i < state.size(); ++i) {
309-
auto y_new_it = y_new.begin();
310-
auto f_new_it = f_new.begin();
311-
auto state_it = state.begin();
312-
313-
y_new_it[i] = state_it[i] + small_dt * f_new_it[i];
290+
y_new[i] = state[i] + small_dt * f_new[i];
314291
}
315292
}
316293

include/integrators/ode/rk45.hpp

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -81,89 +81,50 @@ class RK45Integrator : public AdaptiveIntegrator<S> {
8181
constexpr time_type c4_4 = static_cast<time_type>(393.0/640.0);
8282
constexpr time_type c5_4 = static_cast<time_type>(-92097.0/339200.0);
8383
constexpr time_type c6_4 = static_cast<time_type>(187.0/2100.0);
84-
constexpr time_type c7_4 = static_cast<time_type>(1.0/40.0);
84+
// Note: c7_4 = 1.0/40.0 is not used in RK45 (only in RK45 with FSAL)
8585

8686
// k1 = f(t, y)
8787
this->sys_(this->current_time_, state, k1);
8888

8989
// k2 = f(t + a2*dt, y + dt*(b21*k1))
9090
for (std::size_t i = 0; i < state.size(); ++i) {
91-
auto state_it = state.begin();
92-
auto k1_it = k1.begin();
93-
auto temp_it = temp_state.begin();
94-
temp_it[i] = state_it[i] + dt * b21 * k1_it[i];
91+
temp_state[i] = state[i] + dt * b21 * k1[i];
9592
}
9693
this->sys_(this->current_time_ + a2 * dt, temp_state, k2);
9794

9895
// k3 = f(t + a3*dt, y + dt*(b31*k1 + b32*k2))
9996
for (std::size_t i = 0; i < state.size(); ++i) {
100-
auto state_it = state.begin();
101-
auto k1_it = k1.begin();
102-
auto k2_it = k2.begin();
103-
auto temp_it = temp_state.begin();
104-
temp_it[i] = state_it[i] + dt * (b31 * k1_it[i] + b32 * k2_it[i]);
97+
temp_state[i] = state[i] + dt * (b31 * k1[i] + b32 * k2[i]);
10598
}
10699
this->sys_(this->current_time_ + a3 * dt, temp_state, k3);
107100

108101
// k4 = f(t + a4*dt, y + dt*(b41*k1 + b42*k2 + b43*k3))
109102
for (std::size_t i = 0; i < state.size(); ++i) {
110-
auto state_it = state.begin();
111-
auto k1_it = k1.begin();
112-
auto k2_it = k2.begin();
113-
auto k3_it = k3.begin();
114-
auto temp_it = temp_state.begin();
115-
temp_it[i] = state_it[i] + dt * (b41 * k1_it[i] + b42 * k2_it[i] + b43 * k3_it[i]);
103+
temp_state[i] = state[i] + dt * (b41 * k1[i] + b42 * k2[i] + b43 * k3[i]);
116104
}
117105
this->sys_(this->current_time_ + a4 * dt, temp_state, k4);
118106

119107
// k5 = f(t + a5*dt, y + dt*(b51*k1 + b52*k2 + b53*k3 + b54*k4))
120108
for (std::size_t i = 0; i < state.size(); ++i) {
121-
auto state_it = state.begin();
122-
auto k1_it = k1.begin();
123-
auto k2_it = k2.begin();
124-
auto k3_it = k3.begin();
125-
auto k4_it = k4.begin();
126-
auto temp_it = temp_state.begin();
127-
temp_it[i] = state_it[i] + dt * (b51 * k1_it[i] + b52 * k2_it[i] + b53 * k3_it[i] + b54 * k4_it[i]);
109+
temp_state[i] = state[i] + dt * (b51 * k1[i] + b52 * k2[i] + b53 * k3[i] + b54 * k4[i]);
128110
}
129111
this->sys_(this->current_time_ + a5 * dt, temp_state, k5);
130112

131113
// k6 = f(t + dt, y + dt*(b61*k1 + b62*k2 + b63*k3 + b64*k4 + b65*k5))
132114
for (std::size_t i = 0; i < state.size(); ++i) {
133-
auto state_it = state.begin();
134-
auto k1_it = k1.begin();
135-
auto k2_it = k2.begin();
136-
auto k3_it = k3.begin();
137-
auto k4_it = k4.begin();
138-
auto k5_it = k5.begin();
139-
auto temp_it = temp_state.begin();
140-
temp_it[i] = state_it[i] + dt * (b61 * k1_it[i] + b62 * k2_it[i] + b63 * k3_it[i] + b64 * k4_it[i] + b65 * k5_it[i]);
115+
temp_state[i] = state[i] + dt * (b61 * k1[i] + b62 * k2[i] + b63 * k3[i] + b64 * k4[i] + b65 * k5[i]);
141116
}
142117
this->sys_(this->current_time_ + dt, temp_state, k6);
143118

144119
// 5th order solution: y_new = y + dt*(c1*k1 + c3*k3 + c4*k4 + c5*k5 + c6*k6)
145120
for (std::size_t i = 0; i < state.size(); ++i) {
146-
auto state_it = state.begin();
147-
auto k1_it = k1.begin();
148-
auto k3_it = k3.begin();
149-
auto k4_it = k4.begin();
150-
auto k5_it = k5.begin();
151-
auto k6_it = k6.begin();
152-
auto y_new_it = y_new.begin();
153-
y_new_it[i] = state_it[i] + dt * (c1 * k1_it[i] + c3 * k3_it[i] + c4 * k4_it[i] + c5 * k5_it[i] + c6 * k6_it[i]);
121+
y_new[i] = state[i] + dt * (c1 * k1[i] + c3 * k3[i] + c4 * k4[i] + c5 * k5[i] + c6 * k6[i]);
154122
}
155123

156124
// 4th order solution for error estimation
157125
for (std::size_t i = 0; i < state.size(); ++i) {
158-
auto state_it = state.begin();
159-
auto k1_it = k1.begin();
160-
auto k3_it = k3.begin();
161-
auto k4_it = k4.begin();
162-
auto k5_it = k5.begin();
163-
auto k6_it = k6.begin();
164-
auto error_it = error.begin();
165-
error_it[i] = dt * ((c1 - c1_4) * k1_it[i] + (c3 - c3_4) * k3_it[i] + (c4 - c4_4) * k4_it[i] +
166-
(c5 - c5_4) * k5_it[i] + (c6 - c6_4) * k6_it[i]);
126+
error[i] = dt * ((c1 - c1_4) * k1[i] + (c3 - c3_4) * k3[i] + (c4 - c4_4) * k4[i] +
127+
(c5 - c5_4) * k5[i] + (c6 - c6_4) * k6[i]);
167128
}
168129

169130
// Calculate error norm

0 commit comments

Comments
 (0)