@@ -441,24 +441,22 @@ class _Correction:
441441 linearize : Callable
442442 vector_field : Callable
443443
444- can_handle_higher_order : bool
445-
446444 def init (self , x , / ):
447445 """Initialise the state from the solution."""
448446 y = self .ssm .prototypes .observed ()
449447 return x , y
450448
451449 def correct (self , rv , / , t ):
452450 """Perform the correction step."""
453- f_wrapped = self . _parametrize_vector_field ( t = t )
451+ f_wrapped = functools . partial ( self . vector_field , t = t )
454452 cond = self .linearize (f_wrapped , rv )
455453 observed , reverted = self .ssm .conditional .revert (rv , cond )
456454 corrected = reverted .noise
457455 return corrected , observed
458456
459457 def estimate_error (self , rv , / , t ):
460458 """Estimate the error."""
461- f_wrapped = self . _parametrize_vector_field ( t = t )
459+ f_wrapped = functools . partial ( self . vector_field , t = t )
462460 cond = self .linearize (f_wrapped , rv )
463461 observed = self .ssm .conditional .marginalise (rv , cond )
464462
@@ -469,9 +467,6 @@ def estimate_error(self, rv, /, t):
469467 error_estimate = output_scale * error_estimate_unscaled
470468 return error_estimate , observed
471469
472- def _parametrize_vector_field (self , * , t ):
473- return functools .partial (self .vector_field , t = t )
474-
475470
476471def correction_ts0 (vector_field , * , ssm , ode_order = 1 , damp : float = 0.0 ) -> _Correction :
477472 """Zeroth-order Taylor linearisation."""
@@ -482,7 +477,6 @@ def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
482477 ode_order = ode_order ,
483478 ssm = ssm ,
484479 linearize = linearize ,
485- can_handle_higher_order = True ,
486480 )
487481
488482
@@ -495,7 +489,6 @@ def correction_ts1(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
495489 ode_order = ode_order ,
496490 ssm = ssm ,
497491 linearize = linearize ,
498- can_handle_higher_order = True ,
499492 )
500493
501494
@@ -510,7 +503,6 @@ def correction_slr0(
510503 ode_order = 1 ,
511504 linearize = linearize ,
512505 name = "SLR0" ,
513- can_handle_higher_order = False , # TODO: implement this
514506 )
515507
516508
@@ -525,7 +517,6 @@ def correction_slr1(
525517 ode_order = 1 ,
526518 linearize = linearize ,
527519 name = "SLR1" ,
528- can_handle_higher_order = False , # TODO: implement this
529520 )
530521
531522
0 commit comments