Skip to content

Commit 5202215

Browse files
committed
Add support for inhomogeneous parameters
1 parent 35e1217 commit 5202215

File tree

2 files changed

+159
-34
lines changed

2 files changed

+159
-34
lines changed

dynamax/linear_gaussian_ssm/models.py

Lines changed: 105 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
from fastprogress.fastprogress import progress_bar
99
from functools import partial
10-
from jax import jit
10+
from jax import jit, tree, vmap
1111
from jax.tree_util import tree_map
1212
from jaxtyping import Array, Float
1313
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
1414
from typing import Any, Optional, Tuple, Union, runtime_checkable
15-
from typing_extensions import Protocol
15+
from typing_extensions import Protocol
1616

1717
from dynamax.ssm import SSM
1818
from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
@@ -206,7 +206,7 @@ def sample(self,
206206
key: PRNGKeyT,
207207
num_timesteps: int,
208208
inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None) \
209-
-> Tuple[Float[Array, "num_timesteps state_dim"],
209+
-> Tuple[Float[Array, "num_timesteps state_dim"],
210210
Float[Array, "num_timesteps emission_dim"]]:
211211
"""Sample from the model.
212212
@@ -588,6 +588,47 @@ def m_step(self,
588588
)
589589
return params, m_step_state
590590

591+
def _check_params(self, params: ParamsLGSSM, num_timesteps: int) -> ParamsLGSSM:
592+
"""Replace None parameters with zeros."""
593+
dynamics, emissions = params.dynamics, params.emissions
594+
is_inhomogeneous = dynamics.weights.ndim == 3
595+
596+
def _zeros_if_none(x, shape):
597+
if x is None:
598+
return jnp.zeros(shape)
599+
return x
600+
601+
shape_prefix = ()
602+
if is_inhomogeneous:
603+
shape_prefix = (num_timesteps - 1,)
604+
605+
clean_dynamics = ParamsLGSSMDynamics(
606+
weights=dynamics.weights,
607+
bias=_zeros_if_none(dynamics.bias, shape=shape_prefix + (self.state_dim,)),
608+
input_weights=_zeros_if_none(
609+
dynamics.input_weights, shape=shape_prefix + (self.state_dim, self.input_dim)
610+
),
611+
cov=dynamics.cov
612+
)
613+
shape_prefix = ()
614+
if is_inhomogeneous:
615+
shape_prefix = (num_timesteps,)
616+
617+
clean_emissions = ParamsLGSSMEmissions(
618+
weights=emissions.weights,
619+
bias=_zeros_if_none(emissions.bias, shape=shape_prefix + (self.emission_dim,)),
620+
input_weights=_zeros_if_none(
621+
emissions.input_weights, shape=shape_prefix + (self.emission_dim, self.input_dim)
622+
),
623+
cov=emissions.cov
624+
)
625+
return ParamsLGSSM(
626+
initial=params.initial,
627+
dynamics=clean_dynamics,
628+
emissions=clean_emissions,
629+
)
630+
631+
591632
def fit_blocked_gibbs(self,
592633
key: PRNGKeyT,
593634
initial_params: ParamsLGSSM,
@@ -599,7 +640,8 @@ def fit_blocked_gibbs(self,
599640
600641
Args:
601642
key: random number key.
602-
initial_params: starting parameters.
643+
initial_params: starting parameters. Include a leading time axis for
644+
the dynamics and emissions parameters in inhomogeneous models.
603645
sample_size: how many samples to draw.
604646
emissions: set of observation sequences.
605647
inputs: optional set of input sequences.
@@ -609,67 +651,97 @@ def fit_blocked_gibbs(self,
609651
"""
610652
num_timesteps = len(emissions)
611653

654+
# Inhomogeneous models have a leading time dimension.
655+
is_inhomogeneous = initial_params.dynamics.weights.ndim == 3
656+
612657
if inputs is None:
613658
inputs = jnp.zeros((num_timesteps, 0))
614659

660+
initial_params = self._check_params(initial_params, num_timesteps)
661+
615662
def sufficient_stats_from_sample(states):
616663
"""Convert samples of states to sufficient statistics."""
617664
inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
618665
# Let xn[t] = x[t+1] for t = 0...T-2
619-
x, xp, xn = states, states[:-1], states[1:]
620-
u, up = inputs_joint, inputs_joint[:-1]
666+
x, xn = states, states[1:]
667+
u = inputs_joint
668+
# Let z[t] = [x[t], u[t]] for t = 0...T-1
669+
z = jnp.concatenate([x, u], axis=-1)
670+
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
671+
zp = z[:-1]
621672
y = emissions
622673

623674
init_stats = (x[0], jnp.outer(x[0], x[0]), 1)
624675

625676
# Quantities for the dynamics distribution
626-
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
627-
sum_zpzpT = jnp.block([[xp.T @ xp, xp.T @ up], [up.T @ xp, up.T @ up]])
628-
sum_zpxnT = jnp.block([[xp.T @ xn], [up.T @ xn]])
629-
sum_xnxnT = xn.T @ xn
630-
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1)
677+
sum_zpzpT = jnp.einsum('ti,tj->tij', zp, zp)
678+
sum_zpxnT = jnp.einsum('ti,tj->tij', zp, xn)
679+
sum_xnxnT = jnp.einsum('ti,tj->tij', xn, xn)
680+
z_is_observed = jnp.ones(num_timesteps - 1)
681+
# The dynamics stats have a leading time dimension.
682+
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, z_is_observed)
631683
if not self.has_dynamics_bias:
632-
dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT,
633-
num_timesteps - 1)
684+
dynamics_stats = (sum_zpzpT[:, :-1, :-1], sum_zpxnT[:, :-1, :], sum_xnxnT,
685+
z_is_observed)
634686

635687
# Quantities for the emissions
636-
# Let z[t] = [x[t], u[t]] for t = 0...T-1
637-
sum_zzT = jnp.block([[x.T @ x, x.T @ u], [u.T @ x, u.T @ u]])
638-
sum_zyT = jnp.block([[x.T @ y], [u.T @ y]])
639-
sum_yyT = y.T @ y
640-
emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps)
688+
sum_zzT = jnp.einsum('ti,tj->tij', z, z)
689+
sum_zyT = jnp.einsum('ti,tj->tij', z, y)
690+
sum_yyT = jnp.einsum('ti,tj->tij', y, y)
691+
y_is_observed = jnp.ones(num_timesteps)
692+
# The emissions stats have a leading time dimension.
693+
emission_stats = (sum_zzT, sum_zyT, sum_yyT, y_is_observed)
641694
if not self.has_emissions_bias:
642-
emission_stats = (sum_zzT[:-1, :-1], sum_zyT[:-1, :], sum_yyT, num_timesteps)
695+
emission_stats = (sum_zzT[:, :-1, :-1], sum_zyT[:, :-1, :], sum_yyT, y_is_observed)
643696

644697
return init_stats, dynamics_stats, emission_stats
645698

646-
def lgssm_params_sample(rng, stats):
647-
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
648-
init_stats, dynamics_stats, emission_stats = stats
649-
rngs = iter(jr.split(rng, 3))
650-
651-
# Sample the initial params
699+
def _sample_initial_params(rng, init_stats):
652700
initial_posterior = niw_posterior_update(self.initial_prior, init_stats)
653-
S, m = initial_posterior.sample(seed=next(rngs))
701+
S, m = initial_posterior.sample(seed=rng)
702+
return ParamsLGSSMInitial(mean=m, cov=S)
654703

655-
# Sample the dynamics params
704+
def _sample_dynamics_params(rng, dynamics_stats):
656705
dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats)
657-
Q, FB = dynamics_posterior.sample(seed=next(rngs))
706+
Q, FB = dynamics_posterior.sample(seed=rng)
658707
F = FB[:, :self.state_dim]
659708
B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \
660709
else (FB[:, self.state_dim:], jnp.zeros(self.state_dim))
710+
return ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q)
661711

662-
# Sample the emission params
712+
def _sample_emission_params(rng, emission_stats):
663713
emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats)
664-
R, HD = emission_posterior.sample(seed=next(rngs))
714+
R, HD = emission_posterior.sample(seed=rng)
665715
H = HD[:, :self.state_dim]
666716
D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \
667717
else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim))
718+
return ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
719+
720+
def lgssm_params_sample(rng, stats):
721+
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
722+
init_stats, dynamics_stats, emission_stats = stats
723+
rngs = iter(jr.split(rng, 3))
724+
725+
# Sample the initial params
726+
initial_params = _sample_initial_params(next(rngs), init_stats)
727+
728+
# Sample the dynamics and emission params.
729+
if not is_inhomogeneous:
730+
# Aggregate summary statistics across time for homogeneous model.
731+
dynamics_stats = tree.map(lambda x: jnp.sum(x, axis=0), dynamics_stats)
732+
emission_stats = tree.map(lambda x: jnp.sum(x, axis=0), emission_stats)
733+
dynamics_params = _sample_dynamics_params(next(rngs), dynamics_stats)
734+
emission_params = _sample_emission_params(next(rngs), emission_stats)
735+
else:
736+
keys_dynamics = jr.split(next(rngs), num_timesteps - 1)
737+
keys_emission = jr.split(next(rngs), num_timesteps)
738+
dynamics_params = vmap(_sample_dynamics_params)(keys_dynamics, dynamics_stats)
739+
emission_params = vmap(_sample_emission_params)(keys_emission, emission_stats)
668740

669741
params = ParamsLGSSM(
670-
initial=ParamsLGSSMInitial(mean=m, cov=S),
671-
dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q),
672-
emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
742+
initial=initial_params,
743+
dynamics=dynamics_params,
744+
emissions=emission_params,
673745
)
674746
return params
675747

dynamax/linear_gaussian_ssm/models_test.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""
22
Tests for the linear Gaussian SSM models.
33
"""
4+
from itertools import count, product
45

5-
import pytest
6+
import jax.numpy as jnp
67
import jax.random as jr
8+
from jax import tree
9+
import pytest
710

811
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
912
from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
13+
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM
1014
from dynamax.utils.utils import monotonically_increasing
1115

1216
NUM_TIMESTEPS = 100
@@ -29,3 +33,52 @@ def test_sample_and_fit(cls, kwargs, inputs):
2933
fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3)
3034
assert monotonically_increasing(lps)
3135
fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3)
36+
37+
@pytest.mark.parametrize(["has_dynamics_bias", "has_emissions_bias"], product([True, False], repeat=2))
38+
def test_inhomogeneous_lgcssm(has_dynamics_bias, has_emissions_bias):
39+
"""
40+
Test a LinearGaussianConjugateSSM with time-varying dynamics and emission model.
41+
"""
42+
state_dim = 2
43+
emission_dim = 3
44+
num_timesteps = 4
45+
keys = map(jr.PRNGKey, count())
46+
kwargs = {
47+
"state_dim": state_dim,
48+
"emission_dim": emission_dim,
49+
"has_dynamics_bias": has_dynamics_bias,
50+
"has_emissions_bias": has_emissions_bias,
51+
}
52+
model = LinearGaussianConjugateSSM(**kwargs)
53+
params, param_props = model.initialize(jr.PRNGKey(0))
54+
# Repeat the parameters for each timestep.
55+
inhomogeneous_dynamics = tree.map(
56+
lambda x: jnp.repeat(x[None], num_timesteps - 1, axis=0), params.dynamics,
57+
)
58+
inhomogeneous_emissions = tree.map(
59+
lambda x: jnp.repeat(x[None], num_timesteps, axis=0), params.emissions,
60+
)
61+
62+
_, emissions = model.sample(params, next(keys), num_timesteps=num_timesteps)
63+
inhomogeneous_params = ParamsLGSSM(
64+
initial=params.initial,
65+
dynamics=inhomogeneous_dynamics,
66+
emissions=inhomogeneous_emissions,
67+
)
68+
params_trace = model.fit_blocked_gibbs(
69+
next(keys),
70+
inhomogeneous_params,
71+
sample_size=5,
72+
emissions=emissions,
73+
)
74+
75+
# Arbitrarily check the last set of parameters from the Markov chain.
76+
last_params = tree.map(lambda x: x[-1], params_trace)
77+
assert last_params.initial.mean.shape == (state_dim,)
78+
assert last_params.initial.cov.shape == (state_dim, state_dim)
79+
assert last_params.dynamics.weights.shape == (num_timesteps - 1, state_dim, state_dim)
80+
assert last_params.emissions.weights.shape == (num_timesteps, emission_dim, state_dim)
81+
assert last_params.dynamics.bias.shape == (num_timesteps - 1, state_dim)
82+
assert last_params.emissions.bias.shape == (num_timesteps, emission_dim)
83+
assert last_params.dynamics.cov.shape == (num_timesteps - 1, state_dim, state_dim)
84+
assert last_params.emissions.cov.shape == (num_timesteps, emission_dim, emission_dim)

0 commit comments

Comments
 (0)