99 tree_util ,
1010)
1111from probdiffeq .backend import numpy as np
12- from probdiffeq .backend .typing import Any , Array
13- from probdiffeq .impl import _normal
12+ from probdiffeq .backend .typing import Array
13+ from probdiffeq .impl import _normal , _stats
1414from probdiffeq .util import cholesky_util
1515
1616
17- class Conditional (containers .NamedTuple ):
17+ @tree_util .register_dataclass
18+ @containers .dataclass
19+ class Conditional :
1820 """Conditional distributions."""
1921
20- matmul : Array
21- noise : Any # Usually a random-variable type
22+ A : Array
23+ noise : _normal . Normal
2224
2325
2426class ConditionalBackend (abc .ABC ):
@@ -57,6 +59,10 @@ def preconditioner_apply(self, cond, p, p_inv, /):
5759 def to_derivative (self , i , standard_deviation ):
5860 raise NotImplementedError
5961
62+ @abc .abstractmethod
63+ def rescale_noise (self , cond , scale ):
64+ raise NotImplementedError
65+
6066
6167class DenseConditional (ConditionalBackend ):
6268 def __init__ (self , ode_shape , num_derivatives , unravel , flat_shape ):
@@ -65,19 +71,17 @@ def __init__(self, ode_shape, num_derivatives, unravel, flat_shape):
6571 self .unravel = unravel
6672 self .flat_shape = flat_shape
6773
68- def apply (self , x , conditional , / ):
69- matrix , noise = conditional
70- return _normal .Normal (matrix @ x + noise .mean , noise .cholesky )
74+ def apply (self , x , cond , / ):
75+ return _normal .Normal (cond .A @ x + cond .noise .mean , cond .noise .cholesky )
7176
72- def marginalise (self , rv , conditional , / ):
73- matmul , noise = conditional
74- R_stack = ((matmul @ rv .cholesky ).T , noise .cholesky .T )
77+ def marginalise (self , rv , cond , / ):
78+ R_stack = ((cond .A @ rv .cholesky ).T , cond .noise .cholesky .T )
7579 cholesky_new = cholesky_util .sum_of_sqrtm_factors (R_stack = R_stack ).T
76- return _normal .Normal (matmul @ rv .mean + noise .mean , cholesky_new )
80+ return _normal .Normal (cond . A @ rv .mean + cond . noise .mean , cholesky_new )
7781
7882 def merge (self , cond1 , cond2 , / ):
79- A , b = cond1
80- C , d = cond2
83+ A , b = cond1 . A , cond1 . noise
84+ C , d = cond2 . A , cond2 . noise
8185
8286 g = A @ C
8387 xi = A @ d .mean + b .mean
@@ -86,8 +90,8 @@ def merge(self, cond1, cond2, /):
8690 )
8791 return Conditional (g , _normal .Normal (xi , Xi .T ))
8892
89- def revert (self , rv , conditional , / ):
90- matrix , noise = conditional
93+ def revert (self , rv , cond , / ):
94+ matrix , noise = cond . A , cond . noise
9195 mean , cholesky = rv .mean , rv .cholesky
9296
9397 # QR-decomposition
@@ -133,10 +137,9 @@ def discretise(dt):
133137 return discretise
134138
135139 def preconditioner_apply (self , cond , p , p_inv , / ):
136- A , noise = cond
137140 normal = _normal .DenseNormal (ode_shape = self .ode_shape )
138- noise = normal .preconditioner_apply (noise , p )
139- A = p [:, None ] * A * p_inv [None , :]
141+ noise = normal .preconditioner_apply (cond . noise , p )
142+ A = p [:, None ] * cond . A * p_inv [None , :]
140143 return Conditional (A , noise )
141144
142145 def to_derivative (self , i , standard_deviation ):
@@ -153,23 +156,30 @@ def select(a):
153156 noise = _normal .Normal (bias , standard_deviation * eye )
154157 return Conditional (linop , noise )
155158
159+ def rescale_noise (self , cond , scale ):
160+ A = cond .A
161+ noise = cond .noise
162+ stats = _stats .DenseStats (ode_shape = self .ode_shape , unravel = self .unravel )
163+ noise_new = stats .rescale_cholesky (noise , scale )
164+ return Conditional (A , noise_new )
165+
156166
157167class IsotropicConditional (ConditionalBackend ):
158168 def __init__ (self , * , ode_shape , num_derivatives , unravel_tree ):
159169 self .ode_shape = ode_shape
160170 self .num_derivatives = num_derivatives
161171 self .unravel_tree = unravel_tree
162172
163- def apply (self , x , conditional , / ):
164- A , noise = conditional
173+ def apply (self , x , cond , / ):
174+ A , noise = cond . A , cond . noise
165175 # if the gain is qoi-to-hidden, the data is a (d,) array.
166176 # this is problematic for the isotropic model unless we explicitly broadcast.
167177 if np .ndim (x ) == 1 :
168178 x = x [None , :]
169179 return _normal .Normal (A @ x + noise .mean , noise .cholesky )
170180
171- def marginalise (self , rv , conditional , / ):
172- matrix , noise = conditional
181+ def marginalise (self , rv , cond , / ):
182+ matrix , noise = cond . A , cond . noise
173183
174184 mean = matrix @ rv .mean + noise .mean
175185
@@ -178,8 +188,8 @@ def marginalise(self, rv, conditional, /):
178188 return _normal .Normal (mean , cholesky )
179189
180190 def merge (self , cond1 , cond2 , / ):
181- A , b = cond1
182- C , d = cond2
191+ A , b = cond1 . A , cond1 . noise
192+ C , d = cond2 . A , cond2 . noise
183193
184194 g = A @ C
185195 xi = A @ d .mean + b .mean
@@ -189,8 +199,8 @@ def merge(self, cond1, cond2, /):
189199 noise = _normal .Normal (xi , Xi )
190200 return Conditional (g , noise )
191201
192- def revert (self , rv , conditional , / ):
193- matrix , noise = conditional
202+ def revert (self , rv , cond , / ):
203+ matrix , noise = cond . A , cond . noise
194204
195205 r_ext_p , (r_bw_p , gain ) = cholesky_util .revert_conditional (
196206 R_X_F = (matrix @ rv .cholesky ).T , R_X = rv .cholesky .T , R_YX = noise .cholesky .T
@@ -225,7 +235,7 @@ def discretise(dt):
225235 return discretise
226236
227237 def preconditioner_apply (self , cond , p , p_inv , / ):
228- A , noise = cond
238+ A , noise = cond . A , cond . noise
229239
230240 A_new = p [:, None ] * A * p_inv [None , :]
231241
@@ -245,25 +255,34 @@ def select(a):
245255
246256 return Conditional (linop , noise )
247257
258+ def rescale_noise (self , cond , scale ):
259+ A = cond .A
260+ noise = cond .noise
261+ stats = _stats .IsotropicStats (
262+ ode_shape = self .ode_shape , unravel = self .unravel_tree
263+ )
264+ noise_new = stats .rescale_cholesky (noise , scale )
265+ return Conditional (A , noise_new )
266+
248267
249268class BlockDiagConditional (ConditionalBackend ):
250269 def __init__ (self , * , ode_shape , num_derivatives , unravel_tree ):
251270 self .ode_shape = ode_shape
252271 self .num_derivatives = num_derivatives
253272 self .unravel_tree = unravel_tree
254273
255- def apply (self , x , conditional , / ):
274+ def apply (self , x , cond , / ):
256275 if np .ndim (x ) == 1 :
257276 x = x [..., None ]
258277
259278 def apply_unbatch (m , s , n ):
260279 return _normal .Normal (m @ s + n .mean , n .cholesky )
261280
262- matrix , noise = conditional
281+ matrix , noise = cond . A , cond . noise
263282 return functools .vmap (apply_unbatch )(matrix , x , noise )
264283
265- def marginalise (self , rv , conditional , / ):
266- matrix , noise = conditional
284+ def marginalise (self , rv , cond , / ):
285+ matrix , noise = cond . A , cond . noise
267286 assert matrix .ndim == 3
268287
269288 mean = np .einsum ("ijk,ik->ij" , matrix , rv .mean ) + noise .mean
@@ -275,8 +294,8 @@ def marginalise(self, rv, conditional, /):
275294 return _normal .Normal (mean , _transpose (cholesky ))
276295
277296 def merge (self , cond1 , cond2 , / ):
278- A , b = cond1
279- C , d = cond2
297+ A , b = cond1 . A , cond1 . noise
298+ C , d = cond2 . A , cond2 . noise
280299
281300 g = A @ C
282301 xi = (A @ d .mean [..., None ])[..., 0 ] + b .mean
@@ -286,8 +305,8 @@ def merge(self, cond1, cond2, /):
286305 noise = _normal .Normal (xi , Xi )
287306 return Conditional (g , noise )
288307
289- def revert (self , rv , conditional , / ):
290- A , noise = conditional
308+ def revert (self , rv , cond , / ):
309+ A , noise = cond . A , cond . noise
291310 rv_chol_upper = np .transpose (rv .cholesky , axes = (0 , 2 , 1 ))
292311 noise_chol_upper = np .transpose (noise .cholesky , axes = (0 , 2 , 1 ))
293312 A_rv_chol_upper = np .transpose (A @ rv .cholesky , axes = (0 , 2 , 1 ))
@@ -329,11 +348,10 @@ def discretise(dt):
329348 return discretise
330349
331350 def preconditioner_apply (self , cond , p , p_inv , / ):
332- A , noise = cond
333- A_new = p [None , :, None ] * A * p_inv [None , None , :]
351+ A_new = p [None , :, None ] * cond .A * p_inv [None , None , :]
334352
335353 normal = _normal .BlockDiagNormal (ode_shape = self .ode_shape )
336- noise = normal .preconditioner_apply (noise , p )
354+ noise = normal .preconditioner_apply (cond . noise , p )
337355 return Conditional (A_new , noise )
338356
339357 def to_derivative (self , i , standard_deviation ):
@@ -349,6 +367,15 @@ def select(a):
349367
350368 return Conditional (linop , noise )
351369
370+ def rescale_noise (self , cond , scale ):
371+ A = cond .A
372+ noise = cond .noise
373+ stats = _stats .BlockDiagStats (
374+ ode_shape = self .ode_shape , unravel = self .unravel_tree
375+ )
376+ noise_new = stats .rescale_cholesky (noise , scale )
377+ return Conditional (A , noise_new )
378+
352379
353380def _transpose (matrix ):
354381 return np .transpose (matrix , axes = (0 , 2 , 1 ))
0 commit comments