11from probdiffeq .backend import abc , functools
22from probdiffeq .backend import numpy as np
3+ from probdiffeq .backend .typing import Callable
34from probdiffeq .impl import _normal
45from probdiffeq .util import cholesky_util
56
67
78class LinearisationBackend (abc .ABC ):
89 @abc .abstractmethod
9- def ode_taylor_0th (self , ode_order ) :
10+ def ode_taylor_0th (self , ode_order : int , damp : float ) -> _normal . Normal :
1011 raise NotImplementedError
1112
1213 @abc .abstractmethod
13- def ode_taylor_1st (self , ode_order ) :
14+ def ode_taylor_1st (self , ode_order : int , damp : float ) -> _normal . Normal :
1415 raise NotImplementedError
1516
1617 @abc .abstractmethod
17- def ode_statistical_1st (self , cubature_fun ): # ode_order > 1 not supported
18+ def ode_statistical_1st (
19+ self , cubature_fun : Callable , damp : float
20+ ) -> _normal .Normal :
1821 raise NotImplementedError
1922
2023 @abc .abstractmethod
21- def ode_statistical_0th (self , cubature_fun ): # ode_order > 1 not supported
24+ def ode_statistical_0th (
25+ self , cubature_fun : Callable , damp : float
26+ ) -> _normal .Normal :
2227 raise NotImplementedError
2328
2429
@@ -27,8 +32,9 @@ def __init__(self, ode_shape, unravel):
2732 self .ode_shape = ode_shape
2833 self .unravel = unravel
2934
30- def ode_taylor_0th (self , ode_order ):
31- def linearise_fun_wrapped (fun , mean ):
35+ def ode_taylor_0th (self , ode_order , damp : float ):
36+ def linearise_fun_wrapped (fun , rv ):
37+ mean = rv .mean
3238 a0 = functools .partial (self ._select_dy , idx_or_slice = slice (0 , ode_order ))
3339 a1 = functools .partial (self ._select_dy , idx_or_slice = ode_order )
3440
@@ -40,12 +46,15 @@ def linearise_fun_wrapped(fun, mean):
4046 linop = _jac_materialize (
4147 lambda v , _p : self ._autobatch_linop (a1 )(v ), inputs = mean
4248 )
43- return linop , - fx
49+ cov_lower = damp * np .eye (len (fx ))
50+ bias = _normal .Normal (- fx , cov_lower )
51+ return linop , bias
4452
4553 return linearise_fun_wrapped
4654
47- def ode_taylor_1st (self , ode_order ):
48- def new (fun , mean , / ):
55+ def ode_taylor_1st (self , ode_order , damp ):
56+ def new (fun , rv , / ):
57+ mean = rv .mean
4958 a0 = functools .partial (self ._select_dy , idx_or_slice = slice (0 , ode_order ))
5059 a1 = functools .partial (self ._select_dy , idx_or_slice = ode_order )
5160
@@ -62,11 +71,13 @@ def A(x):
6271 return x1 - jvp (x0 )
6372
6473 linop = _jac_materialize (lambda v , _p : A (v ), inputs = mean )
65- return linop , - fx
74+ cov_lower = damp * np .eye (len (fx ))
75+ bias = _normal .Normal (- fx , cov_lower )
76+ return linop , bias
6677
6778 return new
6879
69- def ode_statistical_1st (self , cubature_fun ):
80+ def ode_statistical_1st (self , cubature_fun , damp : float ):
7081 cubature_rule = cubature_fun (input_shape = self .ode_shape )
7182 linearise_fun = functools .partial (self .slr1 , cubature_rule = cubature_rule )
7283
@@ -91,14 +102,18 @@ def A(x):
91102 return a1 (x ) - J @ a0 (x )
92103
93104 linop = _jac_materialize (lambda v , _p : A (v ), inputs = rv .mean )
94-
95105 mean , cov_lower = noise .mean , noise .cholesky
106+
107+ # Include the damping term. (TODO: use a single qr?)
108+ damping = damp * np .eye (len (cov_lower ))
109+ stack = np .concatenate ((cov_lower .T , damping .T ))
110+ cov_lower = cholesky_util .triu_via_qr (stack ).T
96111 bias = _normal .Normal (- mean , cov_lower )
97112 return linop , bias
98113
99114 return new
100115
101- def ode_statistical_0th (self , cubature_fun ):
116+ def ode_statistical_0th (self , cubature_fun , damp : float ):
102117 cubature_rule = cubature_fun (input_shape = self .ode_shape )
103118 linearise_fun = functools .partial (self .slr0 , cubature_rule = cubature_rule )
104119
@@ -119,6 +134,12 @@ def new(fun, rv, /):
119134 # Gather the variables and return
120135 noise = linearise_fun (fun , linearisation_pt )
121136 mean , cov_lower = noise .mean , noise .cholesky
137+
138+ # Include the damping term. (TODO: use a single qr?)
139+ damping = damp * np .eye (len (cov_lower ))
140+ stack = np .concatenate ((cov_lower .T , damping .T ))
141+ cov_lower = cholesky_util .triu_via_qr (stack ).T
142+
122143 bias = _normal .Normal (- mean , cov_lower )
123144 linop = _jac_materialize (lambda v , _p : a1 (v ), inputs = rv .mean )
124145 return linop , bias
@@ -195,23 +216,26 @@ def slr0(fn, x, *, cubature_rule):
195216
196217
197218class IsotropicLinearisation (LinearisationBackend ):
198- def ode_taylor_1st (self , ode_order ):
219+ def ode_taylor_1st (self , ode_order , damp : float ):
199220 raise NotImplementedError
200221
201- def ode_taylor_0th (self , ode_order ):
202- def linearise_fun_wrapped (fun , mean ):
222+ def ode_taylor_0th (self , ode_order , damp : float ):
223+ def linearise_fun_wrapped (fun , rv ):
224+ mean = rv .mean
203225 fx = self .ts0 (fun , mean [:ode_order , ...])
204226 linop = _jac_materialize (
205227 lambda s , _p : s [[ode_order ], ...], inputs = mean [:, 0 ]
206228 )
207- return linop , - fx
229+ cov_lower = damp * np .eye (1 )
230+ bias = _normal .Normal (- fx , cov_lower )
231+ return linop , bias
208232
209233 return linearise_fun_wrapped
210234
211- def ode_statistical_0th (self , cubature_fun ):
235+ def ode_statistical_0th (self , cubature_fun , damp : float ):
212236 raise NotImplementedError
213237
214- def ode_statistical_1st (self , cubature_fun ):
238+ def ode_statistical_1st (self , cubature_fun , damp : float ):
215239 raise NotImplementedError
216240
217241 @staticmethod
@@ -220,8 +244,9 @@ def ts0(fn, m):
220244
221245
222246class BlockDiagLinearisation (LinearisationBackend ):
223- def ode_taylor_0th (self , ode_order ):
224- def linearise_fun_wrapped (fun , mean ):
247+ def ode_taylor_0th (self , ode_order , damp : float ):
248+ def linearise_fun_wrapped (fun , rv ):
249+ mean = rv .mean
225250 m0 = mean [:, :ode_order ]
226251 fx = self .ts0 (fun , m0 .T )
227252
@@ -233,17 +258,20 @@ def lo(s):
233258 return _jac_materialize (lambda v , _p : a1 (v ), inputs = s )
234259
235260 linop = lo (mean )
236- return linop , - fx [:, None ]
261+ d , * _ = linop .shape
262+ cov_lower = damp * np .ones ((d , 1 , 1 ))
263+ bias = _normal .Normal (- fx [:, None ], cov_lower )
264+ return linop , bias
237265
238266 return linearise_fun_wrapped
239267
240- def ode_taylor_1st (self , ode_order ):
268+ def ode_taylor_1st (self , ode_order , damp : float ):
241269 raise NotImplementedError
242270
243- def ode_statistical_0th (self , cubature_fun ):
271+ def ode_statistical_0th (self , cubature_fun , damp : float ):
244272 raise NotImplementedError
245273
246- def ode_statistical_1st (self , cubature_fun ):
274+ def ode_statistical_1st (self , cubature_fun , damp : float ):
247275 raise NotImplementedError
248276
249277 @staticmethod
0 commit comments