Skip to content

Commit 7171fbd

Browse files
pnkraemerNicholas Kraemer
andauthored
Implement damping (#820)
* Implement damping for TS0 * Implement damping for the other corrections * Delete unused impl.transform| * TS* linearise around a random variable now * Merge Correction implementations * Leave a TODO * Reorder some functions --------- Co-authored-by: Nicholas Kraemer <nicholaskraemer@Libbys-MacBook-Air-2.local>
1 parent 17aae8e commit 7171fbd

5 files changed

Lines changed: 158 additions & 243 deletions

File tree

probdiffeq/impl/_conditional.py

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -14,101 +14,6 @@
1414
from probdiffeq.util import cholesky_util
1515

1616

17-
class TransformBackend(abc.ABC):
18-
@abc.abstractmethod
19-
def marginalise(self, rv, transformation, /):
20-
raise NotImplementedError
21-
22-
@abc.abstractmethod
23-
def revert(self, rv, transformation, /):
24-
raise NotImplementedError
25-
26-
27-
class DenseTransform(TransformBackend):
28-
def marginalise(self, rv, transformation, /):
29-
A, b = transformation
30-
cholesky_new = cholesky_util.triu_via_qr((A @ rv.cholesky).T).T
31-
return _normal.Normal(A @ rv.mean + b, cholesky_new)
32-
33-
def revert(self, rv, transformation, /):
34-
A, b = transformation
35-
mean, cholesky = rv.mean, rv.cholesky
36-
37-
# QR-decomposition
38-
# (todo: rename revert_conditional_noisefree to
39-
# revert_transformation_cov_sqrt())
40-
r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree(
41-
R_X_F=(A @ cholesky).T, R_X=cholesky.T
42-
)
43-
44-
# Gather terms and return
45-
m_cor = mean - gain @ (A @ mean + b)
46-
corrected = _normal.Normal(m_cor, r_cor.T)
47-
observed = _normal.Normal(A @ mean + b, r_obs.T)
48-
return observed, Conditional(gain, corrected)
49-
50-
51-
class IsotropicTransform(TransformBackend):
52-
def marginalise(self, rv, transformation, /):
53-
A, b = transformation
54-
mean, cholesky = rv.mean, rv.cholesky
55-
cholesky_new = cholesky_util.triu_via_qr((A @ cholesky).T)
56-
cholesky_squeezed = np.reshape(cholesky_new, ())
57-
return _normal.Normal((A @ mean) + b, cholesky_squeezed)
58-
59-
def revert(self, rv, transformation, /):
60-
A, b = transformation
61-
mean, cholesky = rv.mean, rv.cholesky
62-
63-
# QR-decomposition
64-
# (todo: rename revert_conditional_noisefree
65-
# to revert_transformation_cov_sqrt())
66-
r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree(
67-
R_X_F=(A @ cholesky).T, R_X=cholesky.T
68-
)
69-
cholesky_obs = np.reshape(r_obs, ())
70-
cholesky_cor = r_cor.T
71-
72-
# Gather terms and return
73-
mean_observed = A @ mean + b
74-
m_cor = mean - gain * mean_observed
75-
corrected = _normal.Normal(m_cor, cholesky_cor)
76-
observed = _normal.Normal(mean_observed, cholesky_obs)
77-
return observed, Conditional(gain, corrected)
78-
79-
80-
class BlockDiagTransform(TransformBackend):
81-
def __init__(self, ode_shape):
82-
self.ode_shape = ode_shape
83-
84-
def marginalise(self, rv, transformation, /):
85-
A, b = transformation
86-
mean, cholesky = rv.mean, rv.cholesky
87-
88-
A_cholesky = A @ cholesky
89-
cholesky = functools.vmap(cholesky_util.triu_via_qr)(_transpose(A_cholesky))
90-
91-
mean = functools.vmap(lambda x, y, z: x @ y + z)(A, mean, b)
92-
return _normal.Normal(mean, cholesky)
93-
94-
def revert(self, rv, transformation, /):
95-
A, bias = transformation
96-
cholesky_upper = np.transpose(rv.cholesky, axes=(0, -1, -2))
97-
A_cholesky_upper = _transpose(A @ rv.cholesky)
98-
99-
revert_fun = functools.vmap(cholesky_util.revert_conditional_noisefree)
100-
r_obs, (r_cor, gain) = revert_fun(A_cholesky_upper, cholesky_upper)
101-
cholesky_obs = _transpose(r_obs)
102-
cholesky_cor = _transpose(r_cor)
103-
104-
# Gather terms and return
105-
mean_observed = functools.vmap(lambda x, y, z: x @ y + z)(A, rv.mean, bias)
106-
m_cor = rv.mean - (gain * (mean_observed[..., None]))[..., 0]
107-
corrected = _normal.Normal(m_cor, cholesky_cor)
108-
observed = _normal.Normal(mean_observed, cholesky_obs)
109-
return observed, Conditional(gain, corrected)
110-
111-
11217
class Conditional(containers.NamedTuple):
11318
"""Conditional distributions."""
11419

@@ -550,10 +455,6 @@ def preconditioner_prepare(*, num_derivatives):
550455
return tree_util.Partial(preconditioner_diagonal, scales=scales, powers=powers)
551456

552457

553-
def _hilbert(a):
554-
return 1 / (a[:, None] + a[None, :] + 1)
555-
556-
557458
def _pascal(a, /):
558459
return _batch_gram(_binom)(a[:, None], a[None, :])
559460

probdiffeq/impl/_linearise.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
from probdiffeq.backend import abc, functools
22
from probdiffeq.backend import numpy as np
3+
from probdiffeq.backend.typing import Callable
34
from probdiffeq.impl import _normal
45
from probdiffeq.util import cholesky_util
56

67

78
class LinearisationBackend(abc.ABC):
89
@abc.abstractmethod
9-
def ode_taylor_0th(self, ode_order):
10+
def ode_taylor_0th(self, ode_order: int, damp: float) -> _normal.Normal:
1011
raise NotImplementedError
1112

1213
@abc.abstractmethod
13-
def ode_taylor_1st(self, ode_order):
14+
def ode_taylor_1st(self, ode_order: int, damp: float) -> _normal.Normal:
1415
raise NotImplementedError
1516

1617
@abc.abstractmethod
17-
def ode_statistical_1st(self, cubature_fun): # ode_order > 1 not supported
18+
def ode_statistical_1st(
19+
self, cubature_fun: Callable, damp: float
20+
) -> _normal.Normal:
1821
raise NotImplementedError
1922

2023
@abc.abstractmethod
21-
def ode_statistical_0th(self, cubature_fun): # ode_order > 1 not supported
24+
def ode_statistical_0th(
25+
self, cubature_fun: Callable, damp: float
26+
) -> _normal.Normal:
2227
raise NotImplementedError
2328

2429

@@ -27,8 +32,9 @@ def __init__(self, ode_shape, unravel):
2732
self.ode_shape = ode_shape
2833
self.unravel = unravel
2934

30-
def ode_taylor_0th(self, ode_order):
31-
def linearise_fun_wrapped(fun, mean):
35+
def ode_taylor_0th(self, ode_order, damp: float):
36+
def linearise_fun_wrapped(fun, rv):
37+
mean = rv.mean
3238
a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order))
3339
a1 = functools.partial(self._select_dy, idx_or_slice=ode_order)
3440

@@ -40,12 +46,15 @@ def linearise_fun_wrapped(fun, mean):
4046
linop = _jac_materialize(
4147
lambda v, _p: self._autobatch_linop(a1)(v), inputs=mean
4248
)
43-
return linop, -fx
49+
cov_lower = damp * np.eye(len(fx))
50+
bias = _normal.Normal(-fx, cov_lower)
51+
return linop, bias
4452

4553
return linearise_fun_wrapped
4654

47-
def ode_taylor_1st(self, ode_order):
48-
def new(fun, mean, /):
55+
def ode_taylor_1st(self, ode_order, damp):
56+
def new(fun, rv, /):
57+
mean = rv.mean
4958
a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order))
5059
a1 = functools.partial(self._select_dy, idx_or_slice=ode_order)
5160

@@ -62,11 +71,13 @@ def A(x):
6271
return x1 - jvp(x0)
6372

6473
linop = _jac_materialize(lambda v, _p: A(v), inputs=mean)
65-
return linop, -fx
74+
cov_lower = damp * np.eye(len(fx))
75+
bias = _normal.Normal(-fx, cov_lower)
76+
return linop, bias
6677

6778
return new
6879

69-
def ode_statistical_1st(self, cubature_fun):
80+
def ode_statistical_1st(self, cubature_fun, damp: float):
7081
cubature_rule = cubature_fun(input_shape=self.ode_shape)
7182
linearise_fun = functools.partial(self.slr1, cubature_rule=cubature_rule)
7283

@@ -91,14 +102,18 @@ def A(x):
91102
return a1(x) - J @ a0(x)
92103

93104
linop = _jac_materialize(lambda v, _p: A(v), inputs=rv.mean)
94-
95105
mean, cov_lower = noise.mean, noise.cholesky
106+
107+
# Include the damping term. (TODO: use a single qr?)
108+
damping = damp * np.eye(len(cov_lower))
109+
stack = np.concatenate((cov_lower.T, damping.T))
110+
cov_lower = cholesky_util.triu_via_qr(stack).T
96111
bias = _normal.Normal(-mean, cov_lower)
97112
return linop, bias
98113

99114
return new
100115

101-
def ode_statistical_0th(self, cubature_fun):
116+
def ode_statistical_0th(self, cubature_fun, damp: float):
102117
cubature_rule = cubature_fun(input_shape=self.ode_shape)
103118
linearise_fun = functools.partial(self.slr0, cubature_rule=cubature_rule)
104119

@@ -119,6 +134,12 @@ def new(fun, rv, /):
119134
# Gather the variables and return
120135
noise = linearise_fun(fun, linearisation_pt)
121136
mean, cov_lower = noise.mean, noise.cholesky
137+
138+
# Include the damping term. (TODO: use a single qr?)
139+
damping = damp * np.eye(len(cov_lower))
140+
stack = np.concatenate((cov_lower.T, damping.T))
141+
cov_lower = cholesky_util.triu_via_qr(stack).T
142+
122143
bias = _normal.Normal(-mean, cov_lower)
123144
linop = _jac_materialize(lambda v, _p: a1(v), inputs=rv.mean)
124145
return linop, bias
@@ -195,23 +216,26 @@ def slr0(fn, x, *, cubature_rule):
195216

196217

197218
class IsotropicLinearisation(LinearisationBackend):
198-
def ode_taylor_1st(self, ode_order):
219+
def ode_taylor_1st(self, ode_order, damp: float):
199220
raise NotImplementedError
200221

201-
def ode_taylor_0th(self, ode_order):
202-
def linearise_fun_wrapped(fun, mean):
222+
def ode_taylor_0th(self, ode_order, damp: float):
223+
def linearise_fun_wrapped(fun, rv):
224+
mean = rv.mean
203225
fx = self.ts0(fun, mean[:ode_order, ...])
204226
linop = _jac_materialize(
205227
lambda s, _p: s[[ode_order], ...], inputs=mean[:, 0]
206228
)
207-
return linop, -fx
229+
cov_lower = damp * np.eye(1)
230+
bias = _normal.Normal(-fx, cov_lower)
231+
return linop, bias
208232

209233
return linearise_fun_wrapped
210234

211-
def ode_statistical_0th(self, cubature_fun):
235+
def ode_statistical_0th(self, cubature_fun, damp: float):
212236
raise NotImplementedError
213237

214-
def ode_statistical_1st(self, cubature_fun):
238+
def ode_statistical_1st(self, cubature_fun, damp: float):
215239
raise NotImplementedError
216240

217241
@staticmethod
@@ -220,8 +244,9 @@ def ts0(fn, m):
220244

221245

222246
class BlockDiagLinearisation(LinearisationBackend):
223-
def ode_taylor_0th(self, ode_order):
224-
def linearise_fun_wrapped(fun, mean):
247+
def ode_taylor_0th(self, ode_order, damp: float):
248+
def linearise_fun_wrapped(fun, rv):
249+
mean = rv.mean
225250
m0 = mean[:, :ode_order]
226251
fx = self.ts0(fun, m0.T)
227252

@@ -233,17 +258,20 @@ def lo(s):
233258
return _jac_materialize(lambda v, _p: a1(v), inputs=s)
234259

235260
linop = lo(mean)
236-
return linop, -fx[:, None]
261+
d, *_ = linop.shape
262+
cov_lower = damp * np.ones((d, 1, 1))
263+
bias = _normal.Normal(-fx[:, None], cov_lower)
264+
return linop, bias
237265

238266
return linearise_fun_wrapped
239267

240-
def ode_taylor_1st(self, ode_order):
268+
def ode_taylor_1st(self, ode_order, damp: float):
241269
raise NotImplementedError
242270

243-
def ode_statistical_0th(self, cubature_fun):
271+
def ode_statistical_0th(self, cubature_fun, damp: float):
244272
raise NotImplementedError
245273

246-
def ode_statistical_1st(self, cubature_fun):
274+
def ode_statistical_1st(self, cubature_fun, damp: float):
247275
raise NotImplementedError
248276

249277
@staticmethod

probdiffeq/impl/impl.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ class FactImpl:
1515
stats: _stats.StatsBackend
1616
linearise: _linearise.LinearisationBackend
1717
conditional: _conditional.ConditionalBackend
18-
transform: _conditional.TransformBackend
1918

2019
# To assert a valid tree_equal of solutions, the factorisations
2120
# must be comparable.
@@ -61,11 +60,9 @@ def _select_dense(*, tcoeffs_like) -> FactImpl:
6160
unravel=unravel,
6261
flat_shape=flat.shape,
6362
)
64-
transform = _conditional.DenseTransform()
6563
return FactImpl(
6664
name="dense",
6765
linearise=linearise,
68-
transform=transform,
6966
conditional=conditional,
7067
normal=normal,
7168
prototypes=prototypes,
@@ -88,15 +85,13 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl:
8885
conditional = _conditional.IsotropicConditional(
8986
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
9087
)
91-
transform = _conditional.IsotropicTransform()
9288
return FactImpl(
9389
name="isotropic",
9490
prototypes=prototypes,
9591
normal=normal,
9692
stats=stats,
9793
linearise=linearise,
9894
conditional=conditional,
99-
transform=transform,
10095
)
10196

10297

@@ -115,13 +110,11 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl:
115110
conditional = _conditional.BlockDiagConditional(
116111
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
117112
)
118-
transform = _conditional.BlockDiagTransform(ode_shape=ode_shape)
119113
return FactImpl(
120114
name="blockdiag",
121115
prototypes=prototypes,
122116
normal=normal,
123117
stats=stats,
124118
linearise=linearise,
125119
conditional=conditional,
126-
transform=transform,
127120
)

0 commit comments

Comments
 (0)