@@ -502,30 +502,32 @@ class _Correction:
502502 linearize : Callable
503503 vector_field : Callable
504504
505- use_re_linearize : bool
506505 can_handle_higher_order : bool
507506
508507 def init (self , x , / ):
509508 """Initialise the state from the solution."""
510509 y = self .ssm .prototypes .observed ()
511510 return x , y
512511
512+ def correct (self , rv , / , t ):
513+ f_wrapped = self ._parametrize_vector_field (t = t )
514+ cond = self .linearize (f_wrapped , rv )
515+ observed , reverted = self .ssm .conditional .revert (rv , cond )
516+ corrected = reverted .noise
517+ return corrected , observed
518+
513519 def estimate_error (self , rv , / , t ):
514520 """Perform all elements of the correction until the error estimate."""
515521 f_wrapped = self ._parametrize_vector_field (t = t )
516522 cond = self .linearize (f_wrapped , rv )
517523 observed = self .ssm .conditional .marginalise (rv , cond )
518524
519- # TODO: the functions involved in error estimation are still a bit patchy.
520- # for instance, they assume that they are called
521- # in exactly this error estimation
522- # context. Same for prototype_qoi etc.
523525 zero_data = np .zeros (())
524526 output_scale = self .ssm .stats .mahalanobis_norm_relative (zero_data , rv = observed )
525527 stdev = self .ssm .stats .standard_deviation (observed )
526528 error_estimate_unscaled = np .squeeze (stdev )
527529 error_estimate = output_scale * error_estimate_unscaled
528- return error_estimate , observed , ( cond , f_wrapped )
530+ return error_estimate , observed
529531
530532 def _parametrize_vector_field (self , * , t ):
531533 if self .can_handle_higher_order :
@@ -537,15 +539,6 @@ def f_wrapped(s):
537539
538540 return functools .partial (self .vector_field , t = t )
539541
540- def complete (self , rv , cache , / ):
541- """Complete what has been left out by `estimate_error`."""
542- cond , f_wrapped = cache
543- if self .use_re_linearize :
544- cond = self .linearize (f_wrapped , rv )
545- observed , reverted = self .ssm .conditional .revert (rv , cond )
546- corrected = reverted .noise
547- return corrected , observed
548-
549542
550543def correction_ts0 (vector_field , * , ssm , ode_order = 1 , damp : float = 0.0 ) -> _Correction :
551544 """Zeroth-order Taylor linearisation."""
@@ -556,7 +549,6 @@ def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
556549 ode_order = ode_order ,
557550 ssm = ssm ,
558551 linearize = linearize ,
559- use_re_linearize = False ,
560552 can_handle_higher_order = True ,
561553 )
562554
@@ -570,7 +562,6 @@ def correction_ts1(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
570562 ode_order = ode_order ,
571563 ssm = ssm ,
572564 linearize = linearize ,
573- use_re_linearize = False ,
574565 can_handle_higher_order = True ,
575566 )
576567
@@ -586,7 +577,6 @@ def correction_slr0(
586577 ode_order = 1 ,
587578 linearize = linearize ,
588579 name = "SLR0" ,
589- use_re_linearize = True ,
590580 can_handle_higher_order = False , # TODO: implement this
591581 )
592582
@@ -602,7 +592,6 @@ def correction_slr1(
602592 ode_order = 1 ,
603593 linearize = linearize ,
604594 name = "SLR1" ,
605- use_re_linearize = True ,
606595 can_handle_higher_order = False , # TODO: implement this
607596 )
608597
@@ -761,22 +750,15 @@ def step_mle(state, /, *, dt, calibration):
761750 transition = prior (dt , output_scale_prior )
762751 mean_extra = strategy .extrapolate_mean (state .rv , transition = transition )
763752 t = state .t + dt
764- error , * _ = correction .estimate_error (mean_extra , t = t )
765-
766- # TODO (next):
767- # - Give correction a step() function that does both.
768- # - Then, see whether begin and complete are called separately anywhere else.
769- # - Then, move output scale arguments to prior evaluation.
770- # - Then, done?
753+ error , _ = correction .estimate_error (mean_extra , t = t )
771754
772755 # Do the full prediction step (reuse previous discretisation)
773756 hidden , extra = strategy .extrapolate (
774757 state .rv , state .strategy_state , transition = transition
775758 )
776759
777760 # Do the full correction step
778- * _ , corr = correction .estimate_error (hidden , t = t )
779- hidden , observed = correction .complete (hidden , corr )
761+ hidden , observed = correction .correct (hidden , t = t )
780762
781763 # Calibrate the output scale
782764 output_scale = calibration .update (state .output_scale , observed = observed )
@@ -827,7 +809,7 @@ def step_dynamic(state, /, *, dt, calibration):
827809 transition = prior (dt , ones )
828810 hidden = strategy .extrapolate_mean (state .rv , transition = transition )
829811 t = state .t + dt
830- error , observed , _ = correction .estimate_error (hidden , t = t )
812+ error , observed = correction .estimate_error (hidden , t = t )
831813 output_scale = calibration .update (state .output_scale , observed = observed )
832814
833815 # Do the full extrapolation with the calibrated output scale
@@ -838,8 +820,7 @@ def step_dynamic(state, /, *, dt, calibration):
838820 )
839821
840822 # Do the full correction step
841- * _ , corr = correction .estimate_error (hidden , t = t )
842- hidden , corr = correction .complete (hidden , corr )
823+ hidden , corr = correction .correct (hidden , t = t )
843824
844825 # Return solution
845826 state = _State (t = t , rv = hidden , strategy_state = extra , output_scale = output_scale )
@@ -879,16 +860,15 @@ def step(state: _State, *, dt, calibration):
879860 transition = prior (dt , state .output_scale )
880861 hidden = strategy .extrapolate_mean (state .rv , transition = transition )
881862 t = state .t + dt
882- error , * _ = correction .estimate_error (hidden , t = t )
863+ error , _ = correction .estimate_error (hidden , t = t )
883864
884865 # Do the full extrapolation step (reuse the transition)
885866 hidden , extra = strategy .extrapolate (
886867 state .rv , state .strategy_state , transition = transition
887868 )
888869
889870 # Do the full correction step
890- * _ , corr = correction .estimate_error (hidden , t = t )
891- hidden , corr = correction .complete (hidden , corr )
871+ hidden , corr = correction .correct (hidden , t = t )
892872
893873 # Extract and return solution
894874 state = _State (
0 commit comments