11import jax .numpy as jnp
2- from jax import scipy as jsc
32from jax import lax , vmap , value_and_grad
43from jax .scipy .linalg import solve_triangular
54from jaxtyping import Array , Float
65from typing import NamedTuple , Optional
76
8- from dynamax .utils .utils import linear_solve
7+ from dynamax .utils .utils import psd_solve
98
109
1110class ParamsLGSSMInfo (NamedTuple ):
@@ -60,7 +59,7 @@ def info_to_moment_form(etas, Lambdas):
6059 means (N,D)
6160 covs (N,D,D)
6261 """
63- means = vmap (lambda A , b :linear_solve (A , b ))(Lambdas , etas )
62+ means = vmap (lambda A , b :psd_solve (A , b ))(Lambdas , etas )
6463 covs = jnp .linalg .inv (Lambdas )
6564 return means , covs
6665
@@ -82,7 +81,7 @@ def _mvn_info_log_prob(eta, Lambda, x):
8281 """
8382 D = len (Lambda )
8483 lp = x .T @ eta - 0.5 * x .T @ Lambda @ x
85- lp += - 0.5 * eta .T @ linear_solve (Lambda , eta )
84+ lp += - 0.5 * eta .T @ psd_solve (Lambda , eta )
8685 sign , logdet = jnp .linalg .slogdet (Lambda )
8786 lp += - 0.5 * (D * jnp .log (2 * jnp .pi ) - sign * logdet )
8887 return lp
@@ -121,7 +120,7 @@ def _info_predict(eta, Lambda, F, Q_prec, B, u, b):
121120 eta_pred (D_hid,): predicted precision weighted mean.
122121 Lambda_pred (D_hid,D_hid): predicted precision.
123122 """
124- K = linear_solve (Lambda + F .T @ Q_prec @ F , F .T @ Q_prec ).T
123+ K = psd_solve (Lambda + F .T @ Q_prec @ F , F .T @ Q_prec ).T
125124 I = jnp .eye (F .shape [0 ])
126125 ## This version should be more stable than:
127126 # Lambda_pred = (I - K @ F.T) @ Q_prec
@@ -263,7 +262,7 @@ def _smooth_step(carry, args):
263262
264263 # This is the information form version of the 'reverse' Kalman gain
265264 # See Eq 8.11 of Saarka's "Bayesian Filtering and Smoothing"
266- G = linear_solve (Q_prec + smoothed_prec_next - pred_prec , Q_prec @ F )
265+ G = psd_solve (Q_prec + smoothed_prec_next - pred_prec , Q_prec @ F )
267266
268267 # Compute the smoothed parameter estimates
269268 smoothed_prec = filtered_prec + F .T @ Q_prec @ (F - G )
@@ -398,18 +397,18 @@ def lds_to_block_tridiag(lds, data, inputs):
398397 T = len (data )
399398
400399 # diagonal blocks of precision matrix
401- J_diag = jnp .array ([jnp .dot (C (t ).T , linear_solve (R (t ), C (t ))) for t in range (T )])
400+ J_diag = jnp .array ([jnp .dot (C (t ).T , psd_solve (R (t ), C (t ))) for t in range (T )])
402401 J_diag = J_diag .at [0 ].add (jnp .linalg .inv (Q0 ))
403- J_diag = J_diag .at [:- 1 ].add (jnp .array ([jnp .dot (A (t ).T , linear_solve (Q (t ), A (t ))) for t in range (T - 1 )]))
402+ J_diag = J_diag .at [:- 1 ].add (jnp .array ([jnp .dot (A (t ).T , psd_solve (Q (t ), A (t ))) for t in range (T - 1 )]))
404403 J_diag = J_diag .at [1 :].add (jnp .array ([jnp .linalg .inv (Q (t )) for t in range (0 , T - 1 )]))
405404
406405 # lower diagonal blocks of precision matrix
407- J_lower_diag = jnp .array ([- linear_solve (Q (t ), A (t )) for t in range (T - 1 )])
406+ J_lower_diag = jnp .array ([- psd_solve (Q (t ), A (t )) for t in range (T - 1 )])
408407
409408 # linear potential
410- h = jnp .array ([jnp .dot (data [t ] - D (t ) @ inputs [t ], linear_solve (R (t ), C (t ))) for t in range (T )])
411- h = h .at [0 ].add (linear_solve (Q0 , m0 ))
412- h = h .at [:- 1 ].add (jnp .array ([- jnp .dot (A (t ).T , linear_solve (Q (t ), B (t ) @ inputs [t ])) for t in range (T - 1 )]))
413- h = h .at [1 :].add (jnp .array ([linear_solve (Q (t ), B (t ) @ inputs [t ]) for t in range (T - 1 )]))
409+ h = jnp .array ([jnp .dot (data [t ] - D (t ) @ inputs [t ], psd_solve (R (t ), C (t ))) for t in range (T )])
410+ h = h .at [0 ].add (psd_solve (Q0 , m0 ))
411+ h = h .at [:- 1 ].add (jnp .array ([- jnp .dot (A (t ).T , psd_solve (Q (t ), B (t ) @ inputs [t ])) for t in range (T - 1 )]))
412+ h = h .at [1 :].add (jnp .array ([psd_solve (Q (t ), B (t ) @ inputs [t ]) for t in range (T - 1 )]))
414413
415414 return J_diag , J_lower_diag , h
0 commit comments