@@ -22,7 +22,7 @@ def prior_wiener_integrated(
2222 """Construct an adaptive(/continuous-time), multiply-integrated Wiener process."""
2323 ssm = impl .choose (ssm_fact , tcoeffs_like = tcoeffs )
2424 if output_scale is None :
25- output_scale = ssm .prototypes .output_scale ()
25+ output_scale = np . ones_like ( ssm .prototypes .output_scale () )
2626 discretize = ssm .conditional .ibm_transitions (base_scale = output_scale )
2727 init = ssm .normal .from_tcoeffs (tcoeffs )
2828 return init , discretize , ssm
@@ -35,7 +35,10 @@ def prior_wiener_integrated_discrete(
3535 init , discretize , ssm = prior_wiener_integrated (
3636 tcoeffs_like , output_scale = output_scale , ssm_fact = ssm_fact
3737 )
38- transitions , (p , p_inv ) = functools .vmap (discretize )(np .diff (ts ))
38+
39+ scales = np .ones_like (ssm .prototypes .output_scale ())
40+ discretize_vmap = functools .vmap (discretize , in_axes = (0 , None ))
41+ transitions , (p , p_inv ) = discretize_vmap (np .diff (ts ), scales )
3942
4043 preconditioner_apply_vmap = functools .vmap (ssm .conditional .preconditioner_apply )
4144 conditionals = preconditioner_apply_vmap (transitions , p , p_inv )
@@ -798,7 +801,7 @@ def _calibration_running_mean(*, ssm) -> _Calibration:
798801 # In this case, the _calibration_most_recent() stuff becomes void.
799802
800803 def init ():
801- prior = ssm .prototypes .output_scale ()
804+ prior = np . ones_like ( ssm .prototypes .output_scale () )
802805 return prior , prior , 0.0
803806
804807 def update (state , / , observed ):
@@ -820,7 +823,7 @@ def solver_dynamic(strategy, *, correction, prior, ssm):
820823
821824 def step_dynamic (state , / , * , dt , calibration ):
822825 # Estimate error and calibrate the output scale
823- ones = ssm .prototypes .output_scale ()
826+ ones = np . ones_like ( ssm .prototypes .output_scale () )
824827 transition = prior (dt , ones )
825828 hidden = strategy .extrapolate_mean (state .rv , transition = transition )
826829 t = state .t + dt
@@ -855,7 +858,7 @@ def step_dynamic(state, /, *, dt, calibration):
855858
856859def _calibration_most_recent (* , ssm ) -> _Calibration :
857860 def init ():
858- return ssm .prototypes .output_scale ()
861+ return np . ones_like ( ssm .prototypes .output_scale () )
859862
860863 def update (_state , / , observed ):
861864 return ssm .stats .mahalanobis_norm_relative (0.0 , observed )
@@ -906,7 +909,7 @@ def step(state: _State, *, dt, calibration):
906909
907910def _calibration_none (* , ssm ) -> _Calibration :
908911 def init ():
909- return ssm .prototypes .output_scale ()
912+ return np . ones_like ( ssm .prototypes .output_scale () )
910913
911914 def update (_state , / , observed ):
912915 raise NotImplementedError
0 commit comments