Skip to content

Commit 98307a0

Browse files
committed
Reintroduce output scale into ibm transitions
1 parent 610c3e9 commit 98307a0

3 files changed

Lines changed: 16 additions & 10 deletions

File tree

probdiffeq/ivpsolvers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def prior_wiener_integrated(
2222
"""Construct an adaptive(/continuous-time), multiply-integrated Wiener process."""
2323
ssm = impl.choose(ssm_fact, tcoeffs_like=tcoeffs)
2424
if output_scale is None:
25-
output_scale = ssm.prototypes.output_scale()
25+
output_scale = np.ones_like(ssm.prototypes.output_scale())
2626
discretize = ssm.conditional.ibm_transitions(base_scale=output_scale)
2727
init = ssm.normal.from_tcoeffs(tcoeffs)
2828
return init, discretize, ssm
@@ -35,7 +35,10 @@ def prior_wiener_integrated_discrete(
3535
init, discretize, ssm = prior_wiener_integrated(
3636
tcoeffs_like, output_scale=output_scale, ssm_fact=ssm_fact
3737
)
38-
transitions, (p, p_inv) = functools.vmap(discretize)(np.diff(ts))
38+
39+
scales = np.ones_like(ssm.prototypes.output_scale())
40+
discretize_vmap = functools.vmap(discretize, in_axes=(0, None))
41+
transitions, (p, p_inv) = discretize_vmap(np.diff(ts), scales)
3942

4043
preconditioner_apply_vmap = functools.vmap(ssm.conditional.preconditioner_apply)
4144
conditionals = preconditioner_apply_vmap(transitions, p, p_inv)
@@ -798,7 +801,7 @@ def _calibration_running_mean(*, ssm) -> _Calibration:
798801
# In this case, the _calibration_most_recent() stuff becomes void.
799802

800803
def init():
801-
prior = ssm.prototypes.output_scale()
804+
prior = np.ones_like(ssm.prototypes.output_scale())
802805
return prior, prior, 0.0
803806

804807
def update(state, /, observed):
@@ -820,7 +823,7 @@ def solver_dynamic(strategy, *, correction, prior, ssm):
820823

821824
def step_dynamic(state, /, *, dt, calibration):
822825
# Estimate error and calibrate the output scale
823-
ones = ssm.prototypes.output_scale()
826+
ones = np.ones_like(ssm.prototypes.output_scale())
824827
transition = prior(dt, ones)
825828
hidden = strategy.extrapolate_mean(state.rv, transition=transition)
826829
t = state.t + dt
@@ -855,7 +858,7 @@ def step_dynamic(state, /, *, dt, calibration):
855858

856859
def _calibration_most_recent(*, ssm) -> _Calibration:
857860
def init():
858-
return ssm.prototypes.output_scale()
861+
return np.ones_like(ssm.prototypes.output_scale())
859862

860863
def update(_state, /, observed):
861864
return ssm.stats.mahalanobis_norm_relative(0.0, observed)
@@ -906,7 +909,7 @@ def step(state: _State, *, dt, calibration):
906909

907910
def _calibration_none(*, ssm) -> _Calibration:
908911
def init():
909-
return ssm.prototypes.output_scale()
912+
return np.ones_like(ssm.prototypes.output_scale())
910913

911914
def update(_state, /, observed):
912915
raise NotImplementedError

probdiffeq/taylor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ def starter(vf, initial_values, /, num: int, t):
4242
init = (rv_t0, conditional_t0)
4343

4444
# Discretised prior
45-
discretise = ssm.conditional.ibm_transitions(output_scale=1.0)
46-
ibm_transitions = functools.vmap(discretise)(np.diff(ts))
45+
scale = ssm.prototypes.output_scale()
46+
discretise = ssm.conditional.ibm_transitions(scale)
47+
ibm_transitions = functools.vmap(discretise, in_axes=(0, None))(
48+
np.diff(ts), scale
49+
)
4750

4851
# Generate an observation-model for the QOI
4952
# (1e-7 observation noise for nuggets and for reusing existing code)

tests/test_impl/test_logpdfs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ def random_variable(fact):
3535
tcoeffs = [np.ones((3,))] * 5 # values irrelevant
3636
ssm = impl.choose(fact, tcoeffs_like=tcoeffs)
3737
output_scale = np.ones_like(ssm.prototypes.output_scale())
38-
discretize = ssm.conditional.ibm_transitions(output_scale=output_scale)
39-
rv = discretize(0.1)
38+
discretize = ssm.conditional.ibm_transitions(output_scale)
39+
rv = discretize(0.1, output_scale)
4040
return rv[0].noise, ssm

0 commit comments

Comments
 (0)