Skip to content

Commit 4c7a0d1

Browse files
committed
Implement a block-diagonal TS1
1 parent c827bcb commit 4c7a0d1

3 files changed

Lines changed: 62 additions & 15 deletions

File tree

docs/examples_basic/posterior_uncertainties.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def vf(y, *, t): # noqa: ARG001
4242
# Set up a solver
4343
# To all users: Try replacing the fixedpoint-smoother with a filter!
4444
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3)
45-
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
45+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag")
4646
ts = ivpsolvers.correction_ts1(vf, ssm=ssm)
4747
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
4848
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts, ssm=ssm)

probdiffeq/impl/_linearise.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ def a1(m):
6565
return _Linearization(init, step)
6666

6767
def ode_taylor_1st(
68-
self, ode_order, damp, jvp_samples: int, jvp_samples_seed: int
68+
self, ode_order, damp, jvp_probes: int, jvp_probes_seed: int
6969
) -> _Linearization:
70-
del jvp_samples
71-
del jvp_samples_seed
70+
del jvp_probes
71+
del jvp_probes_seed
7272

7373
def init():
7474
return None
@@ -253,13 +253,13 @@ def __init__(self, unravel):
253253
self.unravel = unravel
254254

255255
def ode_taylor_1st(
256-
self, ode_order, damp: float, jvp_samples: int, jvp_samples_seed: int
256+
self, ode_order, damp: float, jvp_probes: int, jvp_probes_seed: int
257257
):
258258
if ode_order > 1:
259259
raise ValueError
260260

261261
def init():
262-
return random.prng_key(seed=jvp_samples_seed)
262+
return random.prng_key(seed=jvp_probes_seed)
263263

264264
def step(fun, rv, key):
265265
mean = rv.mean
@@ -282,7 +282,7 @@ def select_0(s):
282282
# Estimate the trace using Hutchinson's estimator
283283
# J_trace, jacobian_state = jacobian(Jvp, m0, jacobian_state)
284284
key, subkey = random.split(key, num=2)
285-
sample_shape = (jvp_samples, *m0.shape)
285+
sample_shape = (jvp_probes, *m0.shape)
286286
v = random.rademacher(subkey, shape=sample_shape, dtype=m0.dtype)
287287
J_trace = functools.vmap(lambda s: linalg.vector_dot(s, Jvp(s)))(v)
288288
J_trace = J_trace.mean(axis=0)
@@ -368,9 +368,57 @@ def a1(s):
368368
return _Linearization(init, step)
369369

370370
def ode_taylor_1st(
371-
self, ode_order, damp: float, jvp_samples: int, jvp_samples_seed: int
371+
self, ode_order, damp: float, jvp_probes: int, jvp_probes_seed: int
372372
):
373-
raise NotImplementedError
373+
if ode_order > 1:
374+
raise ValueError
375+
376+
def init():
377+
return random.prng_key(seed=jvp_probes_seed)
378+
379+
def step(fun, rv, key):
380+
mean = rv.mean
381+
382+
def a1(s):
383+
return s[[ode_order], ...]
384+
385+
linop = functools.vmap(functools.jacrev(a1))(mean)
386+
387+
def vf_flat(u):
388+
return tree_util.ravel_pytree(fun(unravel(u)))[0]
389+
390+
def select_0(s):
391+
return tree_util.ravel_pytree(self.unravel(s)[0])
392+
393+
# Evaluate the linearisation
394+
m0, unravel = select_0(rv.mean)
395+
fx, Jvp = functools.linearize(vf_flat, m0)
396+
397+
key, subkey = random.split(key, num=2)
398+
sample_shape = (jvp_probes, *m0.shape)
399+
v = random.rademacher(subkey, shape=sample_shape, dtype=m0.dtype)
400+
J_diag = functools.vmap(lambda s: s * Jvp(s))(v)
401+
J_diag = J_diag.mean(axis=0)
402+
E1 = functools.jacrev(lambda s: s[0])(rv.mean[0])
403+
linop = linop - J_diag[:, None, None] * E1[None, None, :]
404+
405+
fx = rv.mean[:, 1] - fx
406+
fx = fx[..., None]
407+
diff = functools.vmap(lambda a, b: a @ b)(linop, rv.mean)
408+
fx = fx - diff
409+
410+
d, *_ = linop.shape
411+
cov_lower = damp * np.ones((d, 1, 1))
412+
bias = _normal.Normal(fx, cov_lower)
413+
414+
to_latent = np.ones((linop.shape[2],))
415+
to_observed = np.ones((linop.shape[1],))
416+
cond = _conditional.LatentCond(
417+
linop, bias, to_latent=to_latent, to_observed=to_observed
418+
)
419+
return cond, key
420+
421+
return _Linearization(init, step)
374422

375423
def ode_statistical_0th(self, cubature_fun, damp: float):
376424
raise NotImplementedError

probdiffeq/ivpsolvers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ def init(self, x, /):
458458
def correct(self, rv, correction_state, /, t):
459459
"""Perform the correction step."""
460460
f_wrapped = functools.partial(self.vector_field, t=t)
461-
462461
cond, correction_state = self.linearize.update(f_wrapped, rv, correction_state)
463462
observed, reverted = self.ssm.conditional.revert(rv, cond)
464463
corrected = reverted.noise
@@ -496,16 +495,16 @@ def correction_ts1(
496495
ssm,
497496
ode_order=1,
498497
damp: float = 0.0,
499-
jvp_samples=10,
500-
jvp_samples_seed=1,
498+
jvp_probes=10,
499+
jvp_probes_seed=1,
501500
) -> _Correction:
502501
"""First-order Taylor linearisation."""
503-
assert jvp_samples > 0
502+
assert jvp_probes > 0
504503
linearize = ssm.linearise.ode_taylor_1st(
505504
ode_order=ode_order,
506505
damp=damp,
507-
jvp_samples=jvp_samples,
508-
jvp_samples_seed=jvp_samples_seed,
506+
jvp_probes=jvp_probes,
507+
jvp_probes_seed=jvp_probes_seed,
509508
)
510509
return _Correction(
511510
name="TS1",

0 commit comments

Comments
 (0)