1- # Parallel filtering and smoothing for a lgssm.
2- # This implementation is adapted from the work of Adrien Correnflos in,
3- # https://github.com/EEA-sensors/sequential-parallelization-examples/
1+ '''
2+ Parallel filtering and smoothing for a lgssm.
3+
4+ This implementation is adapted from the work of Adrien Correnflos:
5+ https://github.com/EEA-sensors/sequential-parallelization-examples/
6+
7+ Note that in the original implementation, the initial state distribution
8+ applies to t=0, and the first emission occurs at time `t=1` (i.e. after
9+ the initial state has been transformed by the dynamics), whereas here,
10+ the first emission occurs at time `t=0` and is produced directly by the
11+ untransformed initial state (see below).
12+
13+ Sarkka et al.
14+
15+ F₀,Q₀ F₁,Q₁ F₂,Q₂
16+ Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
17+ | | |
18+ | H₁,R₁ | H₂,R₂ | H₃,R₃
19+ | | |
20+ Y₁ Y₂ Y₃
21+
22+ Dynamax
23+
24+ F₀,Q₀ F₁,Q₁ F₂,Q₂
25+ Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
26+ | | | |
27+ | H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃
28+ | | | |
29+ Y₀ Y₁ Y₂ Y₃
30+
31+ '''
32+
433import jax .numpy as jnp
5- import jax .scipy as jsc
634from jax import vmap , lax
735from tensorflow_probability .substrates .jax .distributions import MultivariateNormalFullCovariance as MVN
836from jaxtyping import Array , Float
37+ from typing import NamedTuple
38+ from dynamax .types import PRNGKey
39+ from functools import partial
940
10- from dynamax .utils .utils import psd_solve
41+ from jax .scipy .linalg import cho_solve , cho_factor
42+ from dynamax .utils .utils import symmetrize
1143from dynamax .linear_gaussian_ssm import PosteriorGSSMFiltered , PosteriorGSSMSmoothed , ParamsLGSSM
1244
45+
1346def _get_params (x , dim , t ):
1447 if callable (x ):
1548 return x (t )
1649 elif x .ndim == dim + 1 :
1750 return x [t ]
1851 else :
1952 return x
53+
54+ #---------------------------------------------------------------------------#
55+ # Filtering #
56+ #---------------------------------------------------------------------------#
2057
21- def _make_associative_filtering_elements (params , emissions ):
58+ class FilterMessage (NamedTuple ):
59+ """
60+ Filtering associative scan elements.
61+
62+ Attributes:
63+ A: P(z_j | y_{i:j}, z_{i-1}) weights.
64+ b: P(z_j | y_{i:j}, z_{i-1}) bias.
65+ C: P(z_j | y_{i:j}, z_{i-1}) covariance.
66+ J: P(z_{i-1} | y_{i:j}) covariance.
67+ eta: P(z_{i-1} | y_{i:j}) mean.
68+ """
69+ A : Float [Array , "ntime state_dim state_dim" ]
70+ b : Float [Array , "ntime state_dim" ]
71+ C : Float [Array , "ntime state_dim state_dim" ]
72+ J : Float [Array , "ntime state_dim state_dim" ]
73+ eta : Float [Array , "ntime state_dim" ]
74+ logZ : Float [Array , "ntime" ]
75+
76+
77+ def _initialize_filtering_messages (params , emissions ):
2278 """Preprocess observations to construct input for filtering assocative scan."""
2379
24- def _first_filtering_element (params , y ):
25- F = _get_params (params .dynamics .weights , 2 , 0 )
80+ def _first_message (params , y ):
2681 H = _get_params (params .emissions .weights , 2 , 0 )
27- Q = _get_params (params .dynamics .cov , 2 , 0 )
2882 R = _get_params (params .emissions .cov , 2 , 0 )
83+ d = _get_params (params .emissions .bias , 1 , 0 )
84+ m = params .initial .mean
85+ P = params .initial .cov
2986
30- S = H @ Q @ H .T + R
31- CF , low = jsc .linalg .cho_factor (S )
87+ S = H @ P @ H .T + R
88+ CF , low = cho_factor (S )
89+ K = cho_solve ((CF , low ), H @ P ).T
3290
33- m1 = params .initial .mean
34- P1 = params .initial .cov
35- S1 = H @ P1 @ H .T + R
36- K1 = psd_solve (S1 , H @ P1 ).T
37-
38- A = jnp .zeros_like (F )
39- b = m1 + K1 @ (y - H @ m1 )
40- C = P1 - K1 @ S1 @ K1 .T
41-
42- eta = F .T @ H .T @ jsc .linalg .cho_solve ((CF , low ), y )
43- J = F .T @ H .T @ jsc .linalg .cho_solve ((CF , low ), H @ F )
44-
45- logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = H @ P1 @ H .T + R ).log_prob (y )
91+ A = jnp .zeros_like (P )
92+ b = m + K @ (y - H @ m - d )
93+ C = symmetrize (P - K @ S @ K .T )
94+ eta = jnp .zeros_like (b )
95+ J = jnp .eye (len (b ))
4696
97+ logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = H @ P @ H .T + R ).log_prob (y )
4798 return A , b , C , J , eta , logZ
4899
49100
50- def _generic_filtering_element (params , y , t ):
101+ @partial (vmap , in_axes = (None , 0 , 0 ))
102+ def _generic_message (params , y , t ):
51103 F = _get_params (params .dynamics .weights , 2 , t )
52- H = _get_params (params .emissions .weights , 2 , t + 1 )
53104 Q = _get_params (params .dynamics .cov , 2 , t )
105+ b = _get_params (params .dynamics .bias , 1 , t )
106+ H = _get_params (params .emissions .weights , 2 , t + 1 )
54107 R = _get_params (params .emissions .cov , 2 , t + 1 )
108+ d = _get_params (params .emissions .bias , 1 , t + 1 )
55109
56110 S = H @ Q @ H .T + R
57- CF , low = jsc .linalg .cho_factor (S )
58- K = jsc .linalg .cho_solve ((CF , low ), H @ Q ).T
59- A = F - K @ H @ F
60- b = K @ y
61- C = Q - K @ H @ Q
111+ CF , low = cho_factor (S )
112+ K = cho_solve ((CF , low ), H @ Q ).T
62113
63- eta = F .T @ H .T @ jsc . linalg . cho_solve ((CF , low ), y )
64- J = F .T @ H .T @ jsc . linalg . cho_solve ((CF , low ), H @ F )
114+ eta = F .T @ H .T @ cho_solve ((CF , low ), y - H @ b - d )
115+ J = symmetrize ( F .T @ H .T @ cho_solve ((CF , low ), H @ F ) )
65116
66- logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = S ).log_prob (y )
117+ A = F - K @ H @ F
118+ b = b + K @ (y - H @ b - d )
119+ C = symmetrize (Q - K @ H @ Q )
67120
121+ logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = S ).log_prob (y )
68122 return A , b , C , J , eta , logZ
69123
70- first_elems = _first_filtering_element (params , emissions [0 ])
71- generic_elems = vmap (_generic_filtering_element , (None , 0 , 0 ))(params , emissions [1 :], jnp .arange (len (emissions )- 1 ))
72- combined_elems = tuple (jnp .concatenate ((first_elm [None ,...], gen_elm ))
73- for first_elm , gen_elm in zip (first_elems , generic_elems ))
74- return combined_elems
124+
125+ A0 , b0 , C0 , J0 , eta0 , logZ0 = _first_message (params , emissions [0 ])
126+ At , bt , Ct , Jt , etat , logZt = _generic_message (params , emissions [1 :], jnp .arange (len (emissions )- 1 ))
127+
128+ return FilterMessage (
129+ A = jnp .concatenate ([A0 [None ], At ]),
130+ b = jnp .concatenate ([b0 [None ], bt ]),
131+ C = jnp .concatenate ([C0 [None ], Ct ]),
132+ J = jnp .concatenate ([J0 [None ], Jt ]),
133+ eta = jnp .concatenate ([eta0 [None ], etat ]),
134+ logZ = jnp .concatenate ([logZ0 [None ], logZt ])
135+ )
136+
137+
75138
76139def lgssm_filter (
77140 params : ParamsLGSSM ,
@@ -83,71 +146,81 @@ def lgssm_filter(
83146
84147 Note: This function does not yet handle `inputs` to the system.
85148 """
86- #TODO: Add input handling.
87- initial_elements = _make_associative_filtering_elements (params , emissions )
88-
89149 @vmap
90- def filtering_operator (elem1 , elem2 ):
150+ def _operator (elem1 , elem2 ):
91151 A1 , b1 , C1 , J1 , eta1 , logZ1 = elem1
92152 A2 , b2 , C2 , J2 , eta2 , logZ2 = elem2
93- dim = A1 .shape [0 ]
94- I = jnp .eye (dim )
153+ I = jnp .eye (A1 .shape [0 ])
95154
96155 I_C1J2 = I + C1 @ J2
97- temp = jsc .linalg .solve (I_C1J2 .T , A2 .T ).T
156+ temp = jnp .linalg .solve (I_C1J2 .T , A2 .T ).T
98157 A = temp @ A1
99158 b = temp @ (b1 + C1 @ eta2 ) + b2
100- C = temp @ C1 @ A2 .T + C2
159+ C = symmetrize ( temp @ C1 @ A2 .T + C2 )
101160
102161 I_J2C1 = I + J2 @ C1
103- temp = jsc .linalg .solve (I_J2C1 .T , A1 ).T
104-
162+ temp = jnp .linalg .solve (I_J2C1 .T , A1 ).T
105163 eta = temp @ (eta2 - J2 @ b1 ) + eta1
106- J = temp @ J2 @ A1 + J1
107-
108- # mu = jsc.linalg.solve(J2, eta2)
109- # t2 = - eta2 @ mu + (b1 - mu) @ jsc.linalg.solve(I_J2C1, (J2 @ b1 - eta2))
164+ J = symmetrize (temp @ J2 @ A1 + J1 )
110165
111166 mu = jnp .linalg .solve (C1 , b1 )
112167 t1 = (b1 @ mu - (eta2 + mu ) @ jnp .linalg .solve (I_C1J2 , C1 @ eta2 + b1 ))
113-
114168 logZ = (logZ1 + logZ2 + 0.5 * jnp .linalg .slogdet (I_C1J2 )[1 ] + 0.5 * t1 )
169+ return FilterMessage (A , b , C , J , eta , logZ )
170+
171+ initial_messages = _initialize_filtering_messages (params , emissions )
172+ final_messages = lax .associative_scan (_operator , initial_messages )
173+
174+ return PosteriorGSSMFiltered (
175+ filtered_means = final_messages .b ,
176+ filtered_covariances = final_messages .C ,
177+ marginal_loglik = - final_messages .logZ [- 1 ])
115178
116- return A , b , C , J , eta , logZ
117179
118- _ , filtered_means , filtered_covs , _ , _ , logZ = lax . associative_scan (
119- filtering_operator , initial_elements
120- )
180+ #---------------------------------------------------------------------------#
181+ # Smoothing #
182+ #---------------------------------------------------------------------------#
121183
122- return PosteriorGSSMFiltered (marginal_loglik = - logZ [- 1 ],
123- filtered_means = filtered_means , filtered_covariances = filtered_covs )
184+ class SmoothMessage (NamedTuple ):
185+ """
186+ Smoothing associative scan elements.
124187
188+ Attributes:
189+ E: P(z_i | y_{1:j}, z_{j+1}) weights.
190+ g: P(z_i | y_{1:j}, z_{j+1}) bias.
191+ L: P(z_i | y_{1:j}, z_{j+1}) covariance.
192+ """
193+ E : Float [Array , "ntime state_dim state_dim" ]
194+ g : Float [Array , "ntime state_dim" ]
195+ L : Float [Array , "ntime state_dim state_dim" ]
125196
126197
127- def _make_associative_smoothing_elements (params , filtered_means , filtered_covariances ):
198+ def _initialize_smoothing_messages (params , filtered_means , filtered_covariances ):
128199 """Preprocess filtering output to construct input for smoothing assocative scan."""
129200
130- def _last_smoothing_element (m , P ):
201+ def _last_message (m , P ):
131202 return jnp .zeros_like (P ), m , P
132203
133- def _generic_smoothing_element (params , m , P , t ):
204+ @partial (vmap , in_axes = (None , 0 , 0 , 0 ))
205+ def _generic_message (params , m , P , t ):
134206 F = _get_params (params .dynamics .weights , 2 , t )
135207 Q = _get_params (params .dynamics .cov , 2 , t )
208+ b = _get_params (params .dynamics .bias , 1 , t )
136209
137- Pp = F @ P @ F .T + Q
138-
139- E = psd_solve (Pp , F @ P ).T
140- g = m - E @ F @ m
141- L = P - E @ Pp @ E .T
210+ CF , low = cho_factor (F @ P @ F .T + Q )
211+ E = cho_solve ((CF , low ), F @ P ).T
212+ g = m - E @ (F @ m + b )
213+ L = symmetrize (P - E @ F @ P )
142214 return E , g , L
143-
144- last_elems = _last_smoothing_element (filtered_means [- 1 ], filtered_covariances [- 1 ])
145- generic_elems = vmap (_generic_smoothing_element , (None , 0 , 0 , 0 ))(
146- params , filtered_means [:- 1 ], filtered_covariances [:- 1 ], jnp .arange (len (filtered_covariances )- 1 )
147- )
148- combined_elems = tuple (jnp .append (gen_elm , last_elm [None ,:], axis = 0 )
149- for gen_elm , last_elm in zip (generic_elems , last_elems ))
150- return combined_elems
215+
216+ En , gn , Ln = _last_message (filtered_means [- 1 ], filtered_covariances [- 1 ])
217+ Et , gt , Lt = _generic_message (params , filtered_means [:- 1 ], filtered_covariances [:- 1 ], jnp .arange (len (filtered_means )- 1 ))
218+
219+ return SmoothMessage (
220+ E = jnp .concatenate ([Et , En [None ]]),
221+ g = jnp .concatenate ([gt , gn [None ]]),
222+ L = jnp .concatenate ([Lt , Ln [None ]])
223+ )
151224
152225
153226def lgssm_smoother (
@@ -163,26 +236,78 @@ def lgssm_smoother(
163236 filtered_posterior = lgssm_filter (params , emissions )
164237 filtered_means = filtered_posterior .filtered_means
165238 filtered_covs = filtered_posterior .filtered_covariances
166- initial_elements = _make_associative_smoothing_elements (params , filtered_means , filtered_covs )
167-
239+
168240 @vmap
169- def smoothing_operator (elem1 , elem2 ):
241+ def _operator (elem1 , elem2 ):
170242 E1 , g1 , L1 = elem1
171243 E2 , g2 , L2 = elem2
172-
173244 E = E2 @ E1
174245 g = E2 @ g1 + g2
175- L = E2 @ L1 @ E2 .T + L2
176-
246+ L = symmetrize (E2 @ L1 @ E2 .T + L2 )
177247 return E , g , L
178248
179- _ , smoothed_means , smoothed_covs , * _ = lax . associative_scan (
180- smoothing_operator , initial_elements , reverse = True
181- )
249+ initial_messages = _initialize_smoothing_messages ( params , filtered_means , filtered_covs )
250+ final_messages = lax . associative_scan ( _operator , initial_messages , reverse = True )
251+
182252 return PosteriorGSSMSmoothed (
183253 marginal_loglik = filtered_posterior .marginal_loglik ,
184254 filtered_means = filtered_means ,
185255 filtered_covariances = filtered_covs ,
186- smoothed_means = smoothed_means ,
187- smoothed_covariances = smoothed_covs
256+ smoothed_means = final_messages . g ,
257+ smoothed_covariances = final_messages . L
188258 )
259+
260+
261+ #---------------------------------------------------------------------------#
262+ # Sampling #
263+ #---------------------------------------------------------------------------#
264+
265+ class SampleMessage (NamedTuple ):
266+ """
267+ Sampling associative scan elements.
268+
269+ Attributes:
270+ E: z_i ~ z_{j+1} weights.
271+ h: z_i ~ z_{j+1} bias.
272+ """
273+ E : Float [Array , "ntime state_dim state_dim" ]
274+ h : Float [Array , "ntime state_dim" ]
275+
276+
277+ def _initialize_sampling_messages (key , params , filtered_means , filtered_covariances ):
278+ """A parallel version of the lgssm sampling algorithm.
279+
280+ Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`,
281+ the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`.
282+ """
283+ E , g , L = _initialize_smoothing_messages (params , filtered_means , filtered_covariances )
284+ return SampleMessage (E = E , h = MVN (g , L ).sample (seed = key ))
285+
286+
287+ def lgssm_posterior_sample (
288+ key : PRNGKey ,
289+ params : ParamsLGSSM ,
290+ emissions : Float [Array , "ntime emission_dim" ]
291+ ) -> Float [Array , "ntime state_dim" ]:
292+ """A parallel version of the lgssm sampling algorithm.
293+
294+ See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.
295+
296+ Note: This function does not yet handle `inputs` to the system.
297+ """
298+ filtered_posterior = lgssm_filter (params , emissions )
299+ filtered_means = filtered_posterior .filtered_means
300+ filtered_covs = filtered_posterior .filtered_covariances
301+
302+ @vmap
303+ def _operator (elem1 , elem2 ):
304+ E1 , h1 = elem1
305+ E2 , h2 = elem2
306+
307+ E = E2 @ E1
308+ h = E2 @ h1 + h2
309+ return E , h
310+
311+ initial_messages = _initialize_sampling_messages (key , params , filtered_means , filtered_covs )
312+ _ , samples = lax .associative_scan (_operator , initial_messages , reverse = True )
313+ return samples
0 commit comments