Skip to content

Commit 4f06883

Browse files
authored
Linearisations return conditionals (#830)
* Linearisations return conditionals * Fix the tests
1 parent d83f097 commit 4f06883

5 files changed

Lines changed: 92 additions & 68 deletions

File tree

probdiffeq/backend/tree_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ def tree_all(tree, /):
2626

2727
def ravel_pytree(tree, /):
2828
return jax.flatten_util.ravel_pytree(tree)
29+
30+
31+
def register_dataclass(datacls):
32+
return jax.tree_util.register_dataclass(datacls)

probdiffeq/impl/_conditional.py

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
tree_util,
1010
)
1111
from probdiffeq.backend import numpy as np
12-
from probdiffeq.backend.typing import Any, Array
13-
from probdiffeq.impl import _normal
12+
from probdiffeq.backend.typing import Array
13+
from probdiffeq.impl import _normal, _stats
1414
from probdiffeq.util import cholesky_util
1515

1616

17-
class Conditional(containers.NamedTuple):
17+
@tree_util.register_dataclass
18+
@containers.dataclass
19+
class Conditional:
1820
"""Conditional distributions."""
1921

20-
matmul: Array
21-
noise: Any # Usually a random-variable type
22+
A: Array
23+
noise: _normal.Normal
2224

2325

2426
class ConditionalBackend(abc.ABC):
@@ -57,6 +59,10 @@ def preconditioner_apply(self, cond, p, p_inv, /):
5759
def to_derivative(self, i, standard_deviation):
5860
raise NotImplementedError
5961

62+
@abc.abstractmethod
63+
def rescale_noise(self, cond, scale):
64+
raise NotImplementedError
65+
6066

6167
class DenseConditional(ConditionalBackend):
6268
def __init__(self, ode_shape, num_derivatives, unravel, flat_shape):
@@ -65,19 +71,17 @@ def __init__(self, ode_shape, num_derivatives, unravel, flat_shape):
6571
self.unravel = unravel
6672
self.flat_shape = flat_shape
6773

68-
def apply(self, x, conditional, /):
69-
matrix, noise = conditional
70-
return _normal.Normal(matrix @ x + noise.mean, noise.cholesky)
74+
def apply(self, x, cond, /):
75+
return _normal.Normal(cond.A @ x + cond.noise.mean, cond.noise.cholesky)
7176

72-
def marginalise(self, rv, conditional, /):
73-
matmul, noise = conditional
74-
R_stack = ((matmul @ rv.cholesky).T, noise.cholesky.T)
77+
def marginalise(self, rv, cond, /):
78+
R_stack = ((cond.A @ rv.cholesky).T, cond.noise.cholesky.T)
7579
cholesky_new = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T
76-
return _normal.Normal(matmul @ rv.mean + noise.mean, cholesky_new)
80+
return _normal.Normal(cond.A @ rv.mean + cond.noise.mean, cholesky_new)
7781

7882
def merge(self, cond1, cond2, /):
79-
A, b = cond1
80-
C, d = cond2
83+
A, b = cond1.A, cond1.noise
84+
C, d = cond2.A, cond2.noise
8185

8286
g = A @ C
8387
xi = A @ d.mean + b.mean
@@ -86,8 +90,8 @@ def merge(self, cond1, cond2, /):
8690
)
8791
return Conditional(g, _normal.Normal(xi, Xi.T))
8892

89-
def revert(self, rv, conditional, /):
90-
matrix, noise = conditional
93+
def revert(self, rv, cond, /):
94+
matrix, noise = cond.A, cond.noise
9195
mean, cholesky = rv.mean, rv.cholesky
9296

9397
# QR-decomposition
@@ -133,10 +137,9 @@ def discretise(dt):
133137
return discretise
134138

135139
def preconditioner_apply(self, cond, p, p_inv, /):
136-
A, noise = cond
137140
normal = _normal.DenseNormal(ode_shape=self.ode_shape)
138-
noise = normal.preconditioner_apply(noise, p)
139-
A = p[:, None] * A * p_inv[None, :]
141+
noise = normal.preconditioner_apply(cond.noise, p)
142+
A = p[:, None] * cond.A * p_inv[None, :]
140143
return Conditional(A, noise)
141144

142145
def to_derivative(self, i, standard_deviation):
@@ -153,23 +156,30 @@ def select(a):
153156
noise = _normal.Normal(bias, standard_deviation * eye)
154157
return Conditional(linop, noise)
155158

159+
def rescale_noise(self, cond, scale):
160+
A = cond.A
161+
noise = cond.noise
162+
stats = _stats.DenseStats(ode_shape=self.ode_shape, unravel=self.unravel)
163+
noise_new = stats.rescale_cholesky(noise, scale)
164+
return Conditional(A, noise_new)
165+
156166

157167
class IsotropicConditional(ConditionalBackend):
158168
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
159169
self.ode_shape = ode_shape
160170
self.num_derivatives = num_derivatives
161171
self.unravel_tree = unravel_tree
162172

163-
def apply(self, x, conditional, /):
164-
A, noise = conditional
173+
def apply(self, x, cond, /):
174+
A, noise = cond.A, cond.noise
165175
# if the gain is qoi-to-hidden, the data is a (d,) array.
166176
# this is problematic for the isotropic model unless we explicitly broadcast.
167177
if np.ndim(x) == 1:
168178
x = x[None, :]
169179
return _normal.Normal(A @ x + noise.mean, noise.cholesky)
170180

171-
def marginalise(self, rv, conditional, /):
172-
matrix, noise = conditional
181+
def marginalise(self, rv, cond, /):
182+
matrix, noise = cond.A, cond.noise
173183

174184
mean = matrix @ rv.mean + noise.mean
175185

@@ -178,8 +188,8 @@ def marginalise(self, rv, conditional, /):
178188
return _normal.Normal(mean, cholesky)
179189

180190
def merge(self, cond1, cond2, /):
181-
A, b = cond1
182-
C, d = cond2
191+
A, b = cond1.A, cond1.noise
192+
C, d = cond2.A, cond2.noise
183193

184194
g = A @ C
185195
xi = A @ d.mean + b.mean
@@ -189,8 +199,8 @@ def merge(self, cond1, cond2, /):
189199
noise = _normal.Normal(xi, Xi)
190200
return Conditional(g, noise)
191201

192-
def revert(self, rv, conditional, /):
193-
matrix, noise = conditional
202+
def revert(self, rv, cond, /):
203+
matrix, noise = cond.A, cond.noise
194204

195205
r_ext_p, (r_bw_p, gain) = cholesky_util.revert_conditional(
196206
R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T
@@ -225,7 +235,7 @@ def discretise(dt):
225235
return discretise
226236

227237
def preconditioner_apply(self, cond, p, p_inv, /):
228-
A, noise = cond
238+
A, noise = cond.A, cond.noise
229239

230240
A_new = p[:, None] * A * p_inv[None, :]
231241

@@ -245,25 +255,34 @@ def select(a):
245255

246256
return Conditional(linop, noise)
247257

258+
def rescale_noise(self, cond, scale):
259+
A = cond.A
260+
noise = cond.noise
261+
stats = _stats.IsotropicStats(
262+
ode_shape=self.ode_shape, unravel=self.unravel_tree
263+
)
264+
noise_new = stats.rescale_cholesky(noise, scale)
265+
return Conditional(A, noise_new)
266+
248267

249268
class BlockDiagConditional(ConditionalBackend):
250269
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
251270
self.ode_shape = ode_shape
252271
self.num_derivatives = num_derivatives
253272
self.unravel_tree = unravel_tree
254273

255-
def apply(self, x, conditional, /):
274+
def apply(self, x, cond, /):
256275
if np.ndim(x) == 1:
257276
x = x[..., None]
258277

259278
def apply_unbatch(m, s, n):
260279
return _normal.Normal(m @ s + n.mean, n.cholesky)
261280

262-
matrix, noise = conditional
281+
matrix, noise = cond.A, cond.noise
263282
return functools.vmap(apply_unbatch)(matrix, x, noise)
264283

265-
def marginalise(self, rv, conditional, /):
266-
matrix, noise = conditional
284+
def marginalise(self, rv, cond, /):
285+
matrix, noise = cond.A, cond.noise
267286
assert matrix.ndim == 3
268287

269288
mean = np.einsum("ijk,ik->ij", matrix, rv.mean) + noise.mean
@@ -275,8 +294,8 @@ def marginalise(self, rv, conditional, /):
275294
return _normal.Normal(mean, _transpose(cholesky))
276295

277296
def merge(self, cond1, cond2, /):
278-
A, b = cond1
279-
C, d = cond2
297+
A, b = cond1.A, cond1.noise
298+
C, d = cond2.A, cond2.noise
280299

281300
g = A @ C
282301
xi = (A @ d.mean[..., None])[..., 0] + b.mean
@@ -286,8 +305,8 @@ def merge(self, cond1, cond2, /):
286305
noise = _normal.Normal(xi, Xi)
287306
return Conditional(g, noise)
288307

289-
def revert(self, rv, conditional, /):
290-
A, noise = conditional
308+
def revert(self, rv, cond, /):
309+
A, noise = cond.A, cond.noise
291310
rv_chol_upper = np.transpose(rv.cholesky, axes=(0, 2, 1))
292311
noise_chol_upper = np.transpose(noise.cholesky, axes=(0, 2, 1))
293312
A_rv_chol_upper = np.transpose(A @ rv.cholesky, axes=(0, 2, 1))
@@ -329,11 +348,10 @@ def discretise(dt):
329348
return discretise
330349

331350
def preconditioner_apply(self, cond, p, p_inv, /):
332-
A, noise = cond
333-
A_new = p[None, :, None] * A * p_inv[None, None, :]
351+
A_new = p[None, :, None] * cond.A * p_inv[None, None, :]
334352

335353
normal = _normal.BlockDiagNormal(ode_shape=self.ode_shape)
336-
noise = normal.preconditioner_apply(noise, p)
354+
noise = normal.preconditioner_apply(cond.noise, p)
337355
return Conditional(A_new, noise)
338356

339357
def to_derivative(self, i, standard_deviation):
@@ -349,6 +367,15 @@ def select(a):
349367

350368
return Conditional(linop, noise)
351369

370+
def rescale_noise(self, cond, scale):
371+
A = cond.A
372+
noise = cond.noise
373+
stats = _stats.BlockDiagStats(
374+
ode_shape=self.ode_shape, unravel=self.unravel_tree
375+
)
376+
noise_new = stats.rescale_cholesky(noise, scale)
377+
return Conditional(A, noise_new)
378+
352379

353380
def _transpose(matrix):
354381
return np.transpose(matrix, axes=(0, 2, 1))

probdiffeq/impl/_linearise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from probdiffeq.backend import abc, functools
22
from probdiffeq.backend import numpy as np
33
from probdiffeq.backend.typing import Callable
4-
from probdiffeq.impl import _normal
4+
from probdiffeq.impl import _conditional, _normal
55
from probdiffeq.util import cholesky_util
66

77

@@ -48,7 +48,7 @@ def linearise_fun_wrapped(fun, rv):
4848
)
4949
cov_lower = damp * np.eye(len(fx))
5050
bias = _normal.Normal(-fx, cov_lower)
51-
return linop, bias
51+
return _conditional.Conditional(linop, bias)
5252

5353
return linearise_fun_wrapped
5454

@@ -73,7 +73,7 @@ def A(x):
7373
linop = _jac_materialize(lambda v, _p: A(v), inputs=mean)
7474
cov_lower = damp * np.eye(len(fx))
7575
bias = _normal.Normal(-fx, cov_lower)
76-
return linop, bias
76+
return _conditional.Conditional(linop, bias)
7777

7878
return new
7979

@@ -109,7 +109,7 @@ def A(x):
109109
stack = np.concatenate((cov_lower.T, damping.T))
110110
cov_lower = cholesky_util.triu_via_qr(stack).T
111111
bias = _normal.Normal(-mean, cov_lower)
112-
return linop, bias
112+
return _conditional.Conditional(linop, bias)
113113

114114
return new
115115

@@ -142,7 +142,7 @@ def new(fun, rv, /):
142142

143143
bias = _normal.Normal(-mean, cov_lower)
144144
linop = _jac_materialize(lambda v, _p: a1(v), inputs=rv.mean)
145-
return linop, bias
145+
return _conditional.Conditional(linop, bias)
146146

147147
return new
148148

@@ -228,7 +228,7 @@ def linearise_fun_wrapped(fun, rv):
228228
)
229229
cov_lower = damp * np.eye(1)
230230
bias = _normal.Normal(-fx, cov_lower)
231-
return linop, bias
231+
return _conditional.Conditional(linop, bias)
232232

233233
return linearise_fun_wrapped
234234

@@ -261,7 +261,7 @@ def lo(s):
261261
d, *_ = linop.shape
262262
cov_lower = damp * np.ones((d, 1, 1))
263263
bias = _normal.Normal(-fx[:, None], cov_lower)
264-
return linop, bias
264+
return _conditional.Conditional(linop, bias)
265265

266266
return linearise_fun_wrapped
267267

probdiffeq/ivpsolvers.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,8 @@ def complete(self, _ssv, extra, /, output_scale):
234234
cond, (p, p_inv), rv_p = extra
235235

236236
# Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity)
237-
A, noise = cond
238-
noise = self.ssm.stats.rescale_cholesky(noise, output_scale)
239-
extrapolated_p, cond_p = self.ssm.conditional.revert(rv_p, (A, noise))
237+
cond = self.ssm.conditional.rescale_noise(cond, output_scale)
238+
extrapolated_p, cond_p = self.ssm.conditional.revert(rv_p, cond)
240239
extrapolated = self.ssm.normal.preconditioner_apply(extrapolated_p, p)
241240
cond = self.ssm.conditional.preconditioner_apply(cond_p, p, p_inv)
242241

@@ -331,9 +330,8 @@ def complete(self, _ssv, extra, /, output_scale):
331330
cond, (p, p_inv), rv_p = extra
332331

333332
# Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity)
334-
A, noise = cond
335-
noise = self.ssm.stats.rescale_cholesky(noise, output_scale)
336-
extrapolated_p = self.ssm.conditional.marginalise(rv_p, (A, noise))
333+
cond = self.ssm.conditional.rescale_noise(cond, output_scale)
334+
extrapolated_p = self.ssm.conditional.marginalise(rv_p, cond)
337335
extrapolated = self.ssm.normal.preconditioner_apply(extrapolated_p, p)
338336

339337
# Gather and return
@@ -397,9 +395,8 @@ def complete(self, _rv, extra, /, output_scale):
397395
cond, (p, p_inv), rv_p, bw0 = extra
398396

399397
# Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity)
400-
A, noise = cond
401-
noise = self.ssm.stats.rescale_cholesky(noise, output_scale)
402-
extrapolated_p, cond_p = self.ssm.conditional.revert(rv_p, (A, noise))
398+
cond = self.ssm.conditional.rescale_noise(cond, output_scale)
399+
extrapolated_p, cond_p = self.ssm.conditional.revert(rv_p, cond)
403400
extrapolated = self.ssm.normal.preconditioner_apply(extrapolated_p, p)
404401
cond = self.ssm.conditional.preconditioner_apply(cond_p, p, p_inv)
405402

@@ -513,8 +510,8 @@ def init(self, x, /):
513510
def estimate_error(self, rv, /, t):
514511
"""Perform all elements of the correction until the error estimate."""
515512
f_wrapped = self._parametrize_vector_field(t=t)
516-
A, b = self.linearize(f_wrapped, rv)
517-
observed = self.ssm.conditional.marginalise(rv, (A, b))
513+
cond = self.linearize(f_wrapped, rv)
514+
observed = self.ssm.conditional.marginalise(rv, cond)
518515

519516
# TODO: the functions involved in error estimation are still a bit patchy.
520517
# for instance, they assume that they are called
@@ -525,7 +522,7 @@ def estimate_error(self, rv, /, t):
525522
stdev = self.ssm.stats.standard_deviation(observed)
526523
error_estimate_unscaled = np.squeeze(stdev)
527524
error_estimate = output_scale * error_estimate_unscaled
528-
return error_estimate, observed, (A, b, f_wrapped)
525+
return error_estimate, observed, (cond, f_wrapped)
529526

530527
def _parametrize_vector_field(self, *, t):
531528
if self.can_handle_higher_order:
@@ -539,10 +536,11 @@ def f_wrapped(s):
539536

540537
def complete(self, rv, cache, /):
541538
"""Complete what has been left out by `estimate_error`."""
542-
A, b, f_wrapped = cache
539+
cond, f_wrapped = cache
543540
if self.use_re_linearize:
544-
A, b = self.linearize(f_wrapped, rv)
545-
observed, (_gain, corrected) = self.ssm.conditional.revert(rv, (A, b))
541+
cond = self.linearize(f_wrapped, rv)
542+
observed, reverted = self.ssm.conditional.revert(rv, cond)
543+
corrected = reverted.noise
546544
return corrected, observed
547545

548546

0 commit comments

Comments
 (0)