Skip to content

Commit 26014b1

Browse files
committed
Implement _Correction.correct
1 parent 98307a0 commit 26014b1

1 file changed

Lines changed: 14 additions & 34 deletions

File tree

probdiffeq/ivpsolvers.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -502,30 +502,32 @@ class _Correction:
502502
linearize: Callable
503503
vector_field: Callable
504504

505-
use_re_linearize: bool
506505
can_handle_higher_order: bool
507506

508507
def init(self, x, /):
509508
"""Initialise the state from the solution."""
510509
y = self.ssm.prototypes.observed()
511510
return x, y
512511

512+
def correct(self, rv, /, t):
513+
f_wrapped = self._parametrize_vector_field(t=t)
514+
cond = self.linearize(f_wrapped, rv)
515+
observed, reverted = self.ssm.conditional.revert(rv, cond)
516+
corrected = reverted.noise
517+
return corrected, observed
518+
513519
def estimate_error(self, rv, /, t):
514520
"""Perform all elements of the correction until the error estimate."""
515521
f_wrapped = self._parametrize_vector_field(t=t)
516522
cond = self.linearize(f_wrapped, rv)
517523
observed = self.ssm.conditional.marginalise(rv, cond)
518524

519-
# TODO: the functions involved in error estimation are still a bit patchy.
520-
# for instance, they assume that they are called
521-
# in exactly this error estimation
522-
# context. Same for prototype_qoi etc.
523525
zero_data = np.zeros(())
524526
output_scale = self.ssm.stats.mahalanobis_norm_relative(zero_data, rv=observed)
525527
stdev = self.ssm.stats.standard_deviation(observed)
526528
error_estimate_unscaled = np.squeeze(stdev)
527529
error_estimate = output_scale * error_estimate_unscaled
528-
return error_estimate, observed, (cond, f_wrapped)
530+
return error_estimate, observed
529531

530532
def _parametrize_vector_field(self, *, t):
531533
if self.can_handle_higher_order:
@@ -537,15 +539,6 @@ def f_wrapped(s):
537539

538540
return functools.partial(self.vector_field, t=t)
539541

540-
def complete(self, rv, cache, /):
541-
"""Complete what has been left out by `estimate_error`."""
542-
cond, f_wrapped = cache
543-
if self.use_re_linearize:
544-
cond = self.linearize(f_wrapped, rv)
545-
observed, reverted = self.ssm.conditional.revert(rv, cond)
546-
corrected = reverted.noise
547-
return corrected, observed
548-
549542

550543
def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction:
551544
"""Zeroth-order Taylor linearisation."""
@@ -556,7 +549,6 @@ def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
556549
ode_order=ode_order,
557550
ssm=ssm,
558551
linearize=linearize,
559-
use_re_linearize=False,
560552
can_handle_higher_order=True,
561553
)
562554

@@ -570,7 +562,6 @@ def correction_ts1(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
570562
ode_order=ode_order,
571563
ssm=ssm,
572564
linearize=linearize,
573-
use_re_linearize=False,
574565
can_handle_higher_order=True,
575566
)
576567

@@ -586,7 +577,6 @@ def correction_slr0(
586577
ode_order=1,
587578
linearize=linearize,
588579
name="SLR0",
589-
use_re_linearize=True,
590580
can_handle_higher_order=False, # TODO: implement this
591581
)
592582

@@ -602,7 +592,6 @@ def correction_slr1(
602592
ode_order=1,
603593
linearize=linearize,
604594
name="SLR1",
605-
use_re_linearize=True,
606595
can_handle_higher_order=False, # TODO: implement this
607596
)
608597

@@ -761,22 +750,15 @@ def step_mle(state, /, *, dt, calibration):
761750
transition = prior(dt, output_scale_prior)
762751
mean_extra = strategy.extrapolate_mean(state.rv, transition=transition)
763752
t = state.t + dt
764-
error, *_ = correction.estimate_error(mean_extra, t=t)
765-
766-
# TODO (next):
767-
# - Give correction a step() function that does both.
768-
# - Then, see whether begin and complete are called separately anywhere else.
769-
# - Then, move output scale arguments to prior evaluation.
770-
# - Then, done?
753+
error, _ = correction.estimate_error(mean_extra, t=t)
771754

772755
# Do the full prediction step (reuse previous discretisation)
773756
hidden, extra = strategy.extrapolate(
774757
state.rv, state.strategy_state, transition=transition
775758
)
776759

777760
# Do the full correction step
778-
*_, corr = correction.estimate_error(hidden, t=t)
779-
hidden, observed = correction.complete(hidden, corr)
761+
hidden, observed = correction.correct(hidden, t=t)
780762

781763
# Calibrate the output scale
782764
output_scale = calibration.update(state.output_scale, observed=observed)
@@ -827,7 +809,7 @@ def step_dynamic(state, /, *, dt, calibration):
827809
transition = prior(dt, ones)
828810
hidden = strategy.extrapolate_mean(state.rv, transition=transition)
829811
t = state.t + dt
830-
error, observed, _ = correction.estimate_error(hidden, t=t)
812+
error, observed = correction.estimate_error(hidden, t=t)
831813
output_scale = calibration.update(state.output_scale, observed=observed)
832814

833815
# Do the full extrapolation with the calibrated output scale
@@ -838,8 +820,7 @@ def step_dynamic(state, /, *, dt, calibration):
838820
)
839821

840822
# Do the full correction step
841-
*_, corr = correction.estimate_error(hidden, t=t)
842-
hidden, corr = correction.complete(hidden, corr)
823+
hidden, corr = correction.correct(hidden, t=t)
843824

844825
# Return solution
845826
state = _State(t=t, rv=hidden, strategy_state=extra, output_scale=output_scale)
@@ -879,16 +860,15 @@ def step(state: _State, *, dt, calibration):
879860
transition = prior(dt, state.output_scale)
880861
hidden = strategy.extrapolate_mean(state.rv, transition=transition)
881862
t = state.t + dt
882-
error, *_ = correction.estimate_error(hidden, t=t)
863+
error, _ = correction.estimate_error(hidden, t=t)
883864

884865
# Do the full extrapolation step (reuse the transition)
885866
hidden, extra = strategy.extrapolate(
886867
state.rv, state.strategy_state, transition=transition
887868
)
888869

889870
# Do the full correction step
890-
*_, corr = correction.estimate_error(hidden, t=t)
891-
hidden, corr = correction.complete(hidden, corr)
871+
hidden, corr = correction.correct(hidden, t=t)
892872

893873
# Extract and return solution
894874
state = _State(

0 commit comments

Comments
 (0)