Skip to content

Commit 41b27ea

Browse files
committed
Make error normalisation an actual class method
1 parent 76e7c50 commit 41b27ea

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

probdiffeq/ivpsolvers.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ def body_fn(state: _RejectionState) -> _RejectionState:
995995
)
996996

997997
# Propose a new step
998-
error_power = _error_scale_and_normalize(error_estimate)
998+
error_power = self._error_scale_and_normalize(error_estimate)
999999
dt, state_control = self.control.apply(
10001000
dt, state.control, error_power=error_power
10011001
)
@@ -1007,16 +1007,6 @@ def body_fn(state: _RejectionState) -> _RejectionState:
10071007
step_from=state.step_from,
10081008
)
10091009

1010-
def _error_scale_and_normalize(error: _ErrorEstimate):
1011-
assert isinstance(error, _ErrorEstimate)
1012-
normalize = self.atol + self.rtol * np.abs(error.reference)
1013-
error_relative = error.estimate / normalize
1014-
1015-
dim = np.atleast_1d(error.reference).size
1016-
error_norm = linalg.vector_norm(error_relative, order=self.norm_ord)
1017-
error_norm_rel = error_norm / np.sqrt(dim)
1018-
return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate)
1019-
10201010
def extract(s: _RejectionState) -> _AdaState:
10211011
num_steps = state0.stats + 1.0 # TODO: track step attempts as well
10221012
return _AdaState(
@@ -1031,6 +1021,16 @@ def extract(s: _RejectionState) -> _AdaState:
10311021
state_new = control_flow.while_loop(cond_fn, body_fn, init_val)
10321022
return extract(state_new)
10331023

1024+
def _error_scale_and_normalize(self, error: _ErrorEstimate):
1025+
assert isinstance(error, _ErrorEstimate)
1026+
normalize = self.atol + self.rtol * np.abs(error.reference)
1027+
error_relative = error.estimate / normalize
1028+
1029+
dim = np.atleast_1d(error.reference).size
1030+
error_norm = linalg.vector_norm(error_relative, order=self.norm_ord)
1031+
error_norm_rel = error_norm / np.sqrt(dim)
1032+
return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate)
1033+
10341034
def extract_before_t1(self, state: _AdaState, t):
10351035
del t
10361036
solution_solver = self.solver.extract(state.step_from)

0 commit comments

Comments
 (0)