Skip to content

Add partial and complete missing emission support in LGSSM filtering #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,22 @@ def _predict(prior_mean: Float[Array, "state_dim"],
return mu_pred, Sigma_pred


def _mask_emission_weights_and_covar(H_t, R_t, is_missing):
"""
Handle partial + full missing emissions by redefining the weights and covariance.
"""
P_obs = jnp.diag(1 - is_missing.astype(int)) # Projects onto observed emissions.
P_mis = jnp.diag(is_missing.astype(int)) # Projects onto missing emissions.

H = P_obs @ H_t

epsilon = 1e-8 # Prevent matrix from becoming singular.
if R_t.ndim == 2:
R = P_obs @ R_t @ P_obs + epsilon * P_mis
else:
R = P_obs @ R_t + epsilon * is_missing.astype(int)
return H, R

def _condition_on(prior_mean: Float[Array, "state_dim"],
prior_cov: Float[Array, "state_dim state_dim"],
emission_matrix: Float[Array, "emission_dim state_dim"],
Expand All @@ -279,6 +295,13 @@ def _condition_on(prior_mean: Float[Array, "state_dim"],
mu_pred (D_hid,): predicted mean.
Sigma_pred (D_hid,D_hid): predicted covariance.
"""
is_missing = jnp.isnan(emission)
dummy_value = 0.0 # Any finite value will do, as long as 0 * dummy != nan.
emission = jnp.where(~is_missing, emission, dummy_value)
emission_matrix, emission_cov = _mask_emission_weights_and_covar(
emission_matrix, emission_cov, is_missing,
)

if emission_cov.ndim == 2:
S = emission_cov + emission_matrix @ prior_cov @ emission_matrix.T
K = psd_solve(S, emission_matrix @ prior_cov).T
Expand Down Expand Up @@ -465,7 +488,7 @@ def lgssm_filter(params: ParamsLGSSM,

Args:
params: model parameters
emissions: array of observations.
emissions: array of observations. Values set to NaN are considered missing.
inputs: optional array of inputs.

Returns:
Expand All @@ -477,13 +500,27 @@ def lgssm_filter(params: ParamsLGSSM,

def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y):
"""Compute the log likelihood of an observation under a linear Gaussian model."""
is_missing = jnp.isnan(y)
missing_mask = is_missing.astype(int)
n_missing = sum(is_missing)
H, R = _mask_emission_weights_and_covar(H, R, is_missing)

m = H @ pred_mean + D @ u + d

# Fill missing values with mean, so that the (y - m) term in the exponent is 0.
y_filled = jnp.where(is_missing, m, y)
# We set variance of missing values to one, so that the normalization constant
# equals sqrt[2pi].
variance_mis = 1.0
# We than substract the normalization constants from the missing (marginalized)
# observations.
correction = -n_missing / 2 * jnp.log(2 * jnp.pi)
if R.ndim==2:
S = R + H @ pred_cov @ H.T
return MVN(m, S).log_prob(y)
S = R + H @ pred_cov @ H.T + variance_mis * jnp.diag(missing_mask)
return MVN(m, S).log_prob(y_filled) - correction
else:
L = H @ jnp.linalg.cholesky(pred_cov)
return MVNLowRank(m, R, L).log_prob(y)
return MVNLowRank(m, R + variance_mis * missing_mask, L).log_prob(y_filled) - correction


def _step(carry, t):
Expand Down
174 changes: 171 additions & 3 deletions dynamax/linear_gaussian_ssm/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd

from itertools import count
from functools import partial
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm.inference import _get_params, _predict
from dynamax.linear_gaussian_ssm import (
LinearGaussianSSM, ParamsLGSSM, ParamsLGSSMEmissions, lgssm_filter, lgssm_joint_sample,
)
from dynamax.utils.utils import has_tpu
from jax import vmap
from jax import tree, vmap
from jax import random as jr

# Use different tolerance threshold for TPU
Expand Down Expand Up @@ -205,8 +209,172 @@ def test_kalman_vs_joint(self):
assert allclose(self.ssm_posterior_diag.smoothed_means, joint_means)
assert allclose(self.ssm_posterior_diag.smoothed_covariances, joint_covs)


def test_posterior_samples(self):
"""Test that posterior samples match the mean of the smoother"""
monte_carlo_var = vmap(jnp.diag)(self.posterior.smoothed_covariances) / self.num_samples
assert jnp.all(abs(jnp.mean(self.samples, axis=0) - self.posterior.smoothed_means) < 6 * jnp.sqrt(monte_carlo_var))

def _random_positive_definite_matrix(key, n):
"""Generate a matrix eligibly to use as a covariance matrix."""
Q0 = jr.normal(key, shape=[n, n])
Q_sym = (Q0 + Q0.T)/2
I = jnp.eye(n)
return Q_sym + n * I

def make_dynamic_lgssm_params(num_timesteps, latent_dim=2, observation_dim=4, seed=0):
"""Create a time-varying LGSSM with time-varying parameters."""
key_seq = map(jr.key, count(seed))

F = jr.normal(next(key_seq), shape=[num_timesteps - 1, latent_dim, latent_dim])
keys = jr.split(next(key_seq), num=num_timesteps - 1)
Q = vmap(partial(_random_positive_definite_matrix, n=latent_dim))(keys)

H = jr.normal(next(key_seq), shape=[num_timesteps, observation_dim, latent_dim])
keys = jr.split(next(key_seq), num=num_timesteps)
R = vmap(partial(_random_positive_definite_matrix, n=observation_dim))(keys)
b = jnp.zeros([num_timesteps - 1, latent_dim])
d = jnp.zeros([num_timesteps, observation_dim])
D = jnp.zeros([num_timesteps, observation_dim, 0])

μ0 = jnp.zeros(latent_dim)
Σ0 = jnp.eye(latent_dim)

lgssm = LinearGaussianSSM(latent_dim, observation_dim)
params, _ = lgssm.initialize(next(key_seq),
initial_mean=μ0,
initial_covariance=Σ0,
dynamics_weights=F,
dynamics_bias=b,
dynamics_covariance=Q,
emission_weights=H,
emission_bias=d,
emission_input_weights=D,
emission_covariance=R)
return params, lgssm

class TestFilterMissingness:
"""
Test filtering with partial and full missing emissions.
"""

num_timesteps = 6
key = jr.PRNGKey(1)

params, lgssm = make_dynamic_lgssm_params(num_timesteps, latent_dim=2, observation_dim=4)
_, emissions = lgssm_joint_sample(params, key, num_timesteps)

def _make_emission_covar_params_diagonal(self, params):
emissions_covar = jnp.diagonal(params.emissions.cov, axis1=1, axis2=2)
params_emissions = ParamsLGSSMEmissions(
params.emissions.weights,
params.emissions.bias,
params.emissions.input_weights,
emissions_covar,
)
return ParamsLGSSM(
params.initial, params.dynamics, params_emissions,
)

@pytest.mark.parametrize("use_diagonal_emissions_covar", [True, False])
def test_partial_missing_observations(self, use_diagonal_emissions_covar):
"""
Test missing subvector of emissions.

The following two cases should be equivalent.
i) The same subvector of emissions is missing in all time points.
ii) A measurement model corresponding to the (observed) subvector, with all
emissions completely observed.
"""
# Index 1 and 3 are missing. Represent by nan.
is_observed = jnp.array([True, False, True, False])


# Method i)
y_partial_observed = jnp.where(is_observed[jnp.newaxis,:], self.emissions, jnp.nan)
params = self.params
if use_diagonal_emissions_covar:
params = self._make_emission_covar_params_diagonal(self.params)
posterior_method_i = lgssm_filter(params, y_partial_observed)

# Method ii)
y_subvector = self.emissions[:, is_observed]
params_emissions = self.params.emissions
sub_cov = params_emissions.cov[:, is_observed][...,is_observed]
params_emissions_subvector = ParamsLGSSMEmissions(
weights=params_emissions.weights[:, is_observed],
bias=params_emissions.bias[:, is_observed],
input_weights=params_emissions.input_weights[:, is_observed],
cov=sub_cov,
)
params_subvector = ParamsLGSSM(
initial=self.params.initial,
dynamics=self.params.dynamics,
emissions=params_emissions_subvector
)
if use_diagonal_emissions_covar:
params_subvector = self._make_emission_covar_params_diagonal(params_subvector)
posterior_method_ii = lgssm_filter(params_subvector, y_subvector)

# Both methods must yield identical results.
is_close = tree.map(allclose, posterior_method_i, posterior_method_ii)
assert tree.all(is_close)

@pytest.mark.parametrize("use_diagonal_emissions_covar", [True, False])
def test_full_missing(self, use_diagonal_emissions_covar):
"""
Test that a full missing emission skips the update step.

The forwards filtering step consists of two steps:
1) Predict the next state.
2) Update the state using the observation.
When an emission is completely missing, the Bayesian update corresponds to
skipping 2).
"""
t_missing = 2
y_mid_missing = self.emissions.at[t_missing].set(jnp.nan)
params = self.params
if use_diagonal_emissions_covar:
params = self._make_emission_covar_params_diagonal(self.params)
posterior = lgssm_filter(params, y_mid_missing)

not_na = tree.map(lambda x: jnp.all(~jnp.isnan(x)), posterior)
assert tree.all(not_na)

# Predict the next state.
filtered_mean_tm1 = posterior.filtered_means[t_missing - 1]
filtered_cov_tm1 = posterior.filtered_covariances[t_missing - 1]
F, B, b, Q, *_ = _get_params(params, self.num_timesteps, t_missing - 1)
u = jnp.zeros(B.shape[1:])
pred_mean_t, pred_cov_t = _predict(filtered_mean_tm1, filtered_cov_tm1, F, B, b, Q, u)

# Check that the filtered mean is corresponds to skipping the update step.
filtered_mean_t = posterior.filtered_means[t_missing]
filtered_cov_t = posterior.filtered_covariances[t_missing]
assert allclose(filtered_mean_t, pred_mean_t)
assert allclose(filtered_cov_t, pred_cov_t)

def test_log_likelihood_last_missing(self):
"""Test the log-likelihood with a missing emission.

When emission y[t] is completely missing, p(y[1:t]) = p(y[1:t-1]).
"""
def _trim_params(params):
return ParamsLGSSM(
params.initial,
tree.map(lambda x: x[:-1], params.dynamics),
tree.map(lambda x: x[:-1], params.emissions),
)

# Method i):
# Compute the log-likelihood of a fully observed sequence not including the last
# emission.
t_missing = -1
params_Tmin1 = _trim_params(self.params)
posterior_Tmin1 = lgssm_filter(params_Tmin1, self.emissions[:-1])

# Method ii):
# Compute the log-likelihood of a sequence with the last emission missing.
y_last_missing = self.emissions.at[t_missing].set(jnp.nan)
posterior_last_missing = lgssm_filter(self.params, y_last_missing)

allclose(posterior_last_missing.marginal_loglik, posterior_Tmin1.marginal_loglik)