diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 03174d0f..456b4d24 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -255,6 +255,22 @@ def _predict(prior_mean: Float[Array, "state_dim"], return mu_pred, Sigma_pred +def _mask_emission_weights_and_covar(H_t, R_t, is_missing): + """ + Handle partial + full missing emissions by redefining the weights and covariance. + """ + P_obs = jnp.diag(1 - is_missing.astype(int)) # Projects onto observed emissions. + P_mis = jnp.diag(is_missing.astype(int)) # Projects onto missing emissions. + + H = P_obs @ H_t + + epsilon = 1e-8 # Prevent matrix from becoming singular. + if R_t.ndim == 2: + R = P_obs @ R_t @ P_obs + epsilon * P_mis + else: + R = P_obs @ R_t + epsilon * is_missing.astype(int) + return H, R + def _condition_on(prior_mean: Float[Array, "state_dim"], prior_cov: Float[Array, "state_dim state_dim"], emission_matrix: Float[Array, "emission_dim state_dim"], @@ -279,6 +295,13 @@ def _condition_on(prior_mean: Float[Array, "state_dim"], mu_pred (D_hid,): predicted mean. Sigma_pred (D_hid,D_hid): predicted covariance. """ + is_missing = jnp.isnan(emission) + dummy_value = 0.0 # Any finite value will do, as long as 0 * dummy != nan. + emission = jnp.where(~is_missing, emission, dummy_value) + emission_matrix, emission_cov = _mask_emission_weights_and_covar( + emission_matrix, emission_cov, is_missing, + ) + if emission_cov.ndim == 2: S = emission_cov + emission_matrix @ prior_cov @ emission_matrix.T K = psd_solve(S, emission_matrix @ prior_cov).T @@ -465,7 +488,7 @@ def lgssm_filter(params: ParamsLGSSM, Args: params: model parameters - emissions: array of observations. + emissions: array of observations. Values set to NaN are considered missing. inputs: optional array of inputs. Returns: @@ -477,13 +500,27 @@ def lgssm_filter(params: ParamsLGSSM, def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): """Compute the log likelihood of an observation under a linear Gaussian model.""" + is_missing = jnp.isnan(y) + missing_mask = is_missing.astype(int) + n_missing = sum(is_missing) + H, R = _mask_emission_weights_and_covar(H, R, is_missing) + m = H @ pred_mean + D @ u + d + + # Fill missing values with mean, so that the (y - m) term in the exponent is 0. + y_filled = jnp.where(is_missing, m, y) + # We set variance of missing values to one, so that the normalization constant + # equals sqrt[2pi]. + variance_mis = 1.0 + # We than substract the normalization constants from the missing (marginalized) + # observations. + correction = -n_missing / 2 * jnp.log(2 * jnp.pi) if R.ndim==2: - S = R + H @ pred_cov @ H.T - return MVN(m, S).log_prob(y) + S = R + H @ pred_cov @ H.T + variance_mis * jnp.diag(missing_mask) + return MVN(m, S).log_prob(y_filled) - correction else: L = H @ jnp.linalg.cholesky(pred_cov) - return MVNLowRank(m, R, L).log_prob(y) + return MVNLowRank(m, R + variance_mis * missing_mask, L).log_prob(y_filled) - correction def _step(carry, t): diff --git a/dynamax/linear_gaussian_ssm/inference_test.py b/dynamax/linear_gaussian_ssm/inference_test.py index a1c89e06..31225129 100644 --- a/dynamax/linear_gaussian_ssm/inference_test.py +++ b/dynamax/linear_gaussian_ssm/inference_test.py @@ -7,10 +7,14 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.distributions as tfd +from itertools import count from functools import partial -from dynamax.linear_gaussian_ssm import LinearGaussianSSM +from dynamax.linear_gaussian_ssm.inference import _get_params, _predict +from dynamax.linear_gaussian_ssm import ( + LinearGaussianSSM, ParamsLGSSM, ParamsLGSSMEmissions, lgssm_filter, lgssm_joint_sample, +) from dynamax.utils.utils import has_tpu -from jax import vmap +from jax import tree, vmap from jax import random as jr # Use different tolerance threshold for TPU @@ -205,8 +209,172 @@ def test_kalman_vs_joint(self): assert allclose(self.ssm_posterior_diag.smoothed_means, joint_means) assert allclose(self.ssm_posterior_diag.smoothed_covariances, joint_covs) - def test_posterior_samples(self): """Test that posterior samples match the mean of the smoother""" monte_carlo_var = vmap(jnp.diag)(self.posterior.smoothed_covariances) / self.num_samples assert jnp.all(abs(jnp.mean(self.samples, axis=0) - self.posterior.smoothed_means) < 6 * jnp.sqrt(monte_carlo_var)) + +def _random_positive_definite_matrix(key, n): + """Generate a matrix eligibly to use as a covariance matrix.""" + Q0 = jr.normal(key, shape=[n, n]) + Q_sym = (Q0 + Q0.T)/2 + I = jnp.eye(n) + return Q_sym + n * I + +def make_dynamic_lgssm_params(num_timesteps, latent_dim=2, observation_dim=4, seed=0): + """Create a time-varying LGSSM with time-varying parameters.""" + key_seq = map(jr.key, count(seed)) + + F = jr.normal(next(key_seq), shape=[num_timesteps - 1, latent_dim, latent_dim]) + keys = jr.split(next(key_seq), num=num_timesteps - 1) + Q = vmap(partial(_random_positive_definite_matrix, n=latent_dim))(keys) + + H = jr.normal(next(key_seq), shape=[num_timesteps, observation_dim, latent_dim]) + keys = jr.split(next(key_seq), num=num_timesteps) + R = vmap(partial(_random_positive_definite_matrix, n=observation_dim))(keys) + b = jnp.zeros([num_timesteps - 1, latent_dim]) + d = jnp.zeros([num_timesteps, observation_dim]) + D = jnp.zeros([num_timesteps, observation_dim, 0]) + + μ0 = jnp.zeros(latent_dim) + Σ0 = jnp.eye(latent_dim) + + lgssm = LinearGaussianSSM(latent_dim, observation_dim) + params, _ = lgssm.initialize(next(key_seq), + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_bias=b, + dynamics_covariance=Q, + emission_weights=H, + emission_bias=d, + emission_input_weights=D, + emission_covariance=R) + return params, lgssm + +class TestFilterMissingness: + """ + Test filtering with partial and full missing emissions. + """ + + num_timesteps = 6 + key = jr.PRNGKey(1) + + params, lgssm = make_dynamic_lgssm_params(num_timesteps, latent_dim=2, observation_dim=4) + _, emissions = lgssm_joint_sample(params, key, num_timesteps) + + def _make_emission_covar_params_diagonal(self, params): + emissions_covar = jnp.diagonal(params.emissions.cov, axis1=1, axis2=2) + params_emissions = ParamsLGSSMEmissions( + params.emissions.weights, + params.emissions.bias, + params.emissions.input_weights, + emissions_covar, + ) + return ParamsLGSSM( + params.initial, params.dynamics, params_emissions, + ) + + @pytest.mark.parametrize("use_diagonal_emissions_covar", [True, False]) + def test_partial_missing_observations(self, use_diagonal_emissions_covar): + """ + Test missing subvector of emissions. + + The following two cases should be equivalent. + i) The same subvector of emissions is missing in all time points. + ii) A measurement model corresponding to the (observed) subvector, with all + emissions completely observed. + """ + # Index 1 and 3 are missing. Represent by nan. + is_observed = jnp.array([True, False, True, False]) + + + # Method i) + y_partial_observed = jnp.where(is_observed[jnp.newaxis,:], self.emissions, jnp.nan) + params = self.params + if use_diagonal_emissions_covar: + params = self._make_emission_covar_params_diagonal(self.params) + posterior_method_i = lgssm_filter(params, y_partial_observed) + + # Method ii) + y_subvector = self.emissions[:, is_observed] + params_emissions = self.params.emissions + sub_cov = params_emissions.cov[:, is_observed][...,is_observed] + params_emissions_subvector = ParamsLGSSMEmissions( + weights=params_emissions.weights[:, is_observed], + bias=params_emissions.bias[:, is_observed], + input_weights=params_emissions.input_weights[:, is_observed], + cov=sub_cov, + ) + params_subvector = ParamsLGSSM( + initial=self.params.initial, + dynamics=self.params.dynamics, + emissions=params_emissions_subvector + ) + if use_diagonal_emissions_covar: + params_subvector = self._make_emission_covar_params_diagonal(params_subvector) + posterior_method_ii = lgssm_filter(params_subvector, y_subvector) + + # Both methods must yield identical results. + is_close = tree.map(allclose, posterior_method_i, posterior_method_ii) + assert tree.all(is_close) + + @pytest.mark.parametrize("use_diagonal_emissions_covar", [True, False]) + def test_full_missing(self, use_diagonal_emissions_covar): + """ + Test that a full missing emission skips the update step. + + The forwards filtering step consists of two steps: + 1) Predict the next state. + 2) Update the state using the observation. + When an emission is completely missing, the Bayesian update corresponds to + skipping 2). + """ + t_missing = 2 + y_mid_missing = self.emissions.at[t_missing].set(jnp.nan) + params = self.params + if use_diagonal_emissions_covar: + params = self._make_emission_covar_params_diagonal(self.params) + posterior = lgssm_filter(params, y_mid_missing) + + not_na = tree.map(lambda x: jnp.all(~jnp.isnan(x)), posterior) + assert tree.all(not_na) + + # Predict the next state. + filtered_mean_tm1 = posterior.filtered_means[t_missing - 1] + filtered_cov_tm1 = posterior.filtered_covariances[t_missing - 1] + F, B, b, Q, *_ = _get_params(params, self.num_timesteps, t_missing - 1) + u = jnp.zeros(B.shape[1:]) + pred_mean_t, pred_cov_t = _predict(filtered_mean_tm1, filtered_cov_tm1, F, B, b, Q, u) + + # Check that the filtered mean is corresponds to skipping the update step. + filtered_mean_t = posterior.filtered_means[t_missing] + filtered_cov_t = posterior.filtered_covariances[t_missing] + assert allclose(filtered_mean_t, pred_mean_t) + assert allclose(filtered_cov_t, pred_cov_t) + + def test_log_likelihood_last_missing(self): + """Test the log-likelihood with a missing emission. + + When emission y[t] is completely missing, p(y[1:t]) = p(y[1:t-1]). + """ + def _trim_params(params): + return ParamsLGSSM( + params.initial, + tree.map(lambda x: x[:-1], params.dynamics), + tree.map(lambda x: x[:-1], params.emissions), + ) + + # Method i): + # Compute the log-likelihood of a fully observed sequence not including the last + # emission. + t_missing = -1 + params_Tmin1 = _trim_params(self.params) + posterior_Tmin1 = lgssm_filter(params_Tmin1, self.emissions[:-1]) + + # Method ii): + # Compute the log-likelihood of a sequence with the last emission missing. + y_last_missing = self.emissions.at[t_missing].set(jnp.nan) + posterior_last_missing = lgssm_filter(self.params, y_last_missing) + + allclose(posterior_last_missing.marginal_loglik, posterior_Tmin1.marginal_loglik)