77
88from fastprogress .fastprogress import progress_bar
99from functools import partial
10- from jax import jit
10+ from jax import jit , tree , vmap
1111from jax .tree_util import tree_map
1212from jaxtyping import Array , Float
1313from tensorflow_probability .substrates .jax .distributions import MultivariateNormalFullCovariance as MVN
1414from typing import Any , Optional , Tuple , Union , runtime_checkable
15- from typing_extensions import Protocol
15+ from typing_extensions import Protocol
1616
1717from dynamax .ssm import SSM
1818from dynamax .linear_gaussian_ssm .inference import lgssm_joint_sample , lgssm_filter , lgssm_smoother , lgssm_posterior_sample
@@ -206,7 +206,7 @@ def sample(self,
206206 key : PRNGKeyT ,
207207 num_timesteps : int ,
208208 inputs : Optional [Float [Array , "num_timesteps input_dim" ]] = None ) \
209- -> Tuple [Float [Array , "num_timesteps state_dim" ],
209+ -> Tuple [Float [Array , "num_timesteps state_dim" ],
210210 Float [Array , "num_timesteps emission_dim" ]]:
211211 """Sample from the model.
212212
@@ -588,6 +588,47 @@ def m_step(self,
588588 )
589589 return params , m_step_state
590590
591+ def _check_params (self , params : ParamsLGSSM , num_timesteps : int ) -> ParamsLGSSM :
592+ """Replace None parameters with zeros."""
593+ dynamics , emissions = params .dynamics , params .emissions
594+ is_inhomogeneous = dynamics .weights .ndim == 3
595+
596+ def _zeros_if_none (x , shape ):
597+ if x is None :
598+ return jnp .zeros (shape )
599+ return x
600+
601+ shape_prefix = ()
602+ if is_inhomogeneous :
603+ shape_prefix = (num_timesteps - 1 ,)
604+
605+ clean_dynamics = ParamsLGSSMDynamics (
606+ weights = dynamics .weights ,
607+ bias = _zeros_if_none (dynamics .bias , shape = shape_prefix + (self .state_dim ,)),
608+ input_weights = _zeros_if_none (
609+ dynamics .input_weights , shape = shape_prefix + (self .state_dim , self .input_dim )
610+ ),
611+ cov = dynamics .cov
612+ )
613+ shape_prefix = ()
614+ if is_inhomogeneous :
615+ shape_prefix = (num_timesteps ,)
616+
617+ clean_emissions = ParamsLGSSMEmissions (
618+ weights = emissions .weights ,
619+ bias = _zeros_if_none (emissions .bias , shape = shape_prefix + (self .emission_dim ,)),
620+ input_weights = _zeros_if_none (
621+ emissions .input_weights , shape = shape_prefix + (self .emission_dim , self .input_dim )
622+ ),
623+ cov = emissions .cov
624+ )
625+ return ParamsLGSSM (
626+ initial = params .initial ,
627+ dynamics = clean_dynamics ,
628+ emissions = clean_emissions ,
629+ )
630+
631+
591632 def fit_blocked_gibbs (self ,
592633 key : PRNGKeyT ,
593634 initial_params : ParamsLGSSM ,
@@ -599,7 +640,8 @@ def fit_blocked_gibbs(self,
599640
600641 Args:
601642 key: random number key.
602- initial_params: starting parameters.
643+ initial_params: starting parameters. Include a leading time axis for
644+ the dynamics and emissions parameters in inhomogeneous models.
603645 sample_size: how many samples to draw.
604646 emissions: set of observation sequences.
605647 inputs: optional set of input sequences.
@@ -609,67 +651,97 @@ def fit_blocked_gibbs(self,
609651 """
610652 num_timesteps = len (emissions )
611653
654+ # Inhomogeneous models have a leading time dimension.
655+ is_inhomogeneous = initial_params .dynamics .weights .ndim == 3
656+
612657 if inputs is None :
613658 inputs = jnp .zeros ((num_timesteps , 0 ))
614659
660+ initial_params = self ._check_params (initial_params , num_timesteps )
661+
615662 def sufficient_stats_from_sample (states ):
616663 """Convert samples of states to sufficient statistics."""
617664 inputs_joint = jnp .concatenate ((inputs , jnp .ones ((num_timesteps , 1 ))), axis = 1 )
618665 # Let xn[t] = x[t+1] for t = 0...T-2
619- x , xp , xn = states , states [:- 1 ], states [1 :]
620- u , up = inputs_joint , inputs_joint [:- 1 ]
666+ x , xn = states , states [1 :]
667+ u = inputs_joint
668+ # Let z[t] = [x[t], u[t]] for t = 0...T-1
669+ z = jnp .concatenate ([x , u ], axis = - 1 )
670+ # Let zp[t] = [x[t], u[t]] for t = 0...T-2
671+ zp = z [:- 1 ]
621672 y = emissions
622673
623674 init_stats = (x [0 ], jnp .outer (x [0 ], x [0 ]), 1 )
624675
625676 # Quantities for the dynamics distribution
626- # Let zp[t] = [x[t], u[t]] for t = 0...T-2
627- sum_zpzpT = jnp .block ([[xp .T @ xp , xp .T @ up ], [up .T @ xp , up .T @ up ]])
628- sum_zpxnT = jnp .block ([[xp .T @ xn ], [up .T @ xn ]])
629- sum_xnxnT = xn .T @ xn
630- dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , num_timesteps - 1 )
677+ sum_zpzpT = jnp .einsum ('ti,tj->tij' , zp , zp )
678+ sum_zpxnT = jnp .einsum ('ti,tj->tij' , zp , xn )
679+ sum_xnxnT = jnp .einsum ('ti,tj->tij' , xn , xn )
680+ z_is_observed = jnp .ones (num_timesteps - 1 )
681+ # The dynamics stats have a leading time dimension.
682+ dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , z_is_observed )
631683 if not self .has_dynamics_bias :
632- dynamics_stats = (sum_zpzpT [:- 1 , :- 1 ], sum_zpxnT [:- 1 , :], sum_xnxnT ,
633- num_timesteps - 1 )
684+ dynamics_stats = (sum_zpzpT [:, : - 1 , :- 1 ], sum_zpxnT [:, :- 1 , :], sum_xnxnT ,
685+ z_is_observed )
634686
635687 # Quantities for the emissions
636- # Let z[t] = [x[t], u[t]] for t = 0...T-1
637- sum_zzT = jnp .block ([[x .T @ x , x .T @ u ], [u .T @ x , u .T @ u ]])
638- sum_zyT = jnp .block ([[x .T @ y ], [u .T @ y ]])
639- sum_yyT = y .T @ y
640- emission_stats = (sum_zzT , sum_zyT , sum_yyT , num_timesteps )
688+ sum_zzT = jnp .einsum ('ti,tj->tij' , z , z )
689+ sum_zyT = jnp .einsum ('ti,tj->tij' , z , y )
690+ sum_yyT = jnp .einsum ('ti,tj->tij' , y , y )
691+ y_is_observed = jnp .ones (num_timesteps )
692+ # The emissions stats have a leading time dimension.
693+ emission_stats = (sum_zzT , sum_zyT , sum_yyT , y_is_observed )
641694 if not self .has_emissions_bias :
642- emission_stats = (sum_zzT [:- 1 , :- 1 ], sum_zyT [:- 1 , :], sum_yyT , num_timesteps )
695+ emission_stats = (sum_zzT [:, : - 1 , :- 1 ], sum_zyT [:, : - 1 , :], sum_yyT , y_is_observed )
643696
644697 return init_stats , dynamics_stats , emission_stats
645698
646- def lgssm_params_sample (rng , stats ):
647- """Sample parameters of the model given sufficient statistics from observed states and emissions."""
648- init_stats , dynamics_stats , emission_stats = stats
649- rngs = iter (jr .split (rng , 3 ))
650-
651- # Sample the initial params
699+ def _sample_initial_params (rng , init_stats ):
652700 initial_posterior = niw_posterior_update (self .initial_prior , init_stats )
653- S , m = initial_posterior .sample (seed = next (rngs ))
701+ S , m = initial_posterior .sample (seed = rng )
702+ return ParamsLGSSMInitial (mean = m , cov = S )
654703
655- # Sample the dynamics params
704+ def _sample_dynamics_params ( rng , dynamics_stats ):
656705 dynamics_posterior = mniw_posterior_update (self .dynamics_prior , dynamics_stats )
657- Q , FB = dynamics_posterior .sample (seed = next ( rngs ) )
706+ Q , FB = dynamics_posterior .sample (seed = rng )
658707 F = FB [:, :self .state_dim ]
659708 B , b = (FB [:, self .state_dim :- 1 ], FB [:, - 1 ]) if self .has_dynamics_bias \
660709 else (FB [:, self .state_dim :], jnp .zeros (self .state_dim ))
710+ return ParamsLGSSMDynamics (weights = F , bias = b , input_weights = B , cov = Q )
661711
662- # Sample the emission params
712+ def _sample_emission_params ( rng , emission_stats ):
663713 emission_posterior = mniw_posterior_update (self .emission_prior , emission_stats )
664- R , HD = emission_posterior .sample (seed = next ( rngs ) )
714+ R , HD = emission_posterior .sample (seed = rng )
665715 H = HD [:, :self .state_dim ]
666716 D , d = (HD [:, self .state_dim :- 1 ], HD [:, - 1 ]) if self .has_emissions_bias \
667717 else (HD [:, self .state_dim :], jnp .zeros (self .emission_dim ))
718+ return ParamsLGSSMEmissions (weights = H , bias = d , input_weights = D , cov = R )
719+
720+ def lgssm_params_sample (rng , stats ):
721+ """Sample parameters of the model given sufficient statistics from observed states and emissions."""
722+ init_stats , dynamics_stats , emission_stats = stats
723+ rngs = iter (jr .split (rng , 3 ))
724+
725+ # Sample the initial params
726+ initial_params = _sample_initial_params (next (rngs ), init_stats )
727+
728+ # Sample the dynamics and emission params.
729+ if not is_inhomogeneous :
730+ # Aggregate summary statistics across time for homogeneous model.
731+ dynamics_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), dynamics_stats )
732+ emission_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), emission_stats )
733+ dynamics_params = _sample_dynamics_params (next (rngs ), dynamics_stats )
734+ emission_params = _sample_emission_params (next (rngs ), emission_stats )
735+ else :
736+ keys_dynamics = jr .split (next (rngs ), num_timesteps - 1 )
737+ keys_emission = jr .split (next (rngs ), num_timesteps )
738+ dynamics_params = vmap (_sample_dynamics_params )(keys_dynamics , dynamics_stats )
739+ emission_params = vmap (_sample_emission_params )(keys_emission , emission_stats )
668740
669741 params = ParamsLGSSM (
670- initial = ParamsLGSSMInitial ( mean = m , cov = S ) ,
671- dynamics = ParamsLGSSMDynamics ( weights = F , bias = b , input_weights = B , cov = Q ) ,
672- emissions = ParamsLGSSMEmissions ( weights = H , bias = d , input_weights = D , cov = R )
742+ initial = initial_params ,
743+ dynamics = dynamics_params ,
744+ emissions = emission_params ,
673745 )
674746 return params
675747
0 commit comments