Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions pymc_extras/statespace/filters/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,20 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
return filter_results

def handle_missing_values(
self, y, Z, H
) -> tuple[TensorVariable, TensorVariable, TensorVariable, float]:
self, y, Z, H, d
) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, float]:
"""
Handle missing values in the observation data `y`
Handle missing values in the observation data ``y``.

Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns
associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent
propagating NaNs through the computation.
Adjust the design matrix ``Z``, the observation noise covariance matrix ``H``, and the observation
intercept ``d`` by zeroing the rows associated with observations that are missing at this iteration.
The missing entries of ``y`` are replaced with zeros to prevent propagating NaNs through the
computation. With ``y``, ``Z @ a``, and ``d`` all zero on the missing rows, the innovation
:math:`v = y - (Z a + d)` is exactly zero there, so missing observations contribute nothing to the
state update.

Return a binary flag tensor `all_nan_flag`,indicating if all values in the observation data are missing. This
flag is used for numerical adjustments in the update method.
Return a binary flag tensor ``all_nan_flag`` indicating whether every component of the observation
is missing. This flag is used for numerical adjustments in the update method.

Parameters
----------
Expand All @@ -318,21 +321,28 @@ def handle_missing_values(
The design matrix.
H : TensorVariable
The observation noise covariance matrix.
d : TensorVariable
The observation intercept.

Returns
-------
y_masked : TensorVariable
Observation vector with missing values replaced by zeros.

Z_masked: TensorVariable
Design matrix adjusted to exclude the missing states from the information set of observed variables in the
update step
Z_masked : TensorVariable
Design matrix with the rows corresponding to missing observations zeroed out.

H_masked : TensorVariable
Observation noise covariance matrix with the rows *and columns* corresponding to missing
observations zeroed out, so the result remains symmetric.

H_masked: TensorVariable
Noise covariance matrix, adjusted to exclude the missing states
d_masked : TensorVariable
Observation intercept with the entries corresponding to missing observations zeroed out. Without
this masking, missing rows of the innovation become :math:`-d`, injecting a fake observation
into the state update and inflating the log-likelihood by :math:`d^2 / \\text{jitter}`.

all_nan_flag: float
1 if the entire state vector is missing
all_nan_flag : float
1 if every component of the observation is missing.

References
----------
Expand All @@ -344,10 +354,11 @@ def handle_missing_values(
W = pt.diag(pt.bitwise_not(nan_mask).astype(pytensor.config.floatX))

Z_masked = W.dot(Z)
H_masked = W.dot(H)
H_masked = W.dot(H).dot(W.mT)
d_masked = W.dot(d)
y_masked = pt.set_subtensor(y[nan_mask], 0.0)

return y_masked, Z_masked, H_masked, all_nan_flag
return y_masked, Z_masked, H_masked, d_masked, all_nan_flag

@staticmethod
def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
Expand Down Expand Up @@ -517,10 +528,12 @@ def kalman_step(self, *args) -> tuple:
2nd ed, Oxford University Press, 2012.
"""
y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
y_masked, Z_masked, H_masked, d_masked, all_nan_flag = self.handle_missing_values(
y, Z, H, d
)

a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update(
y=y_masked, a=a, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
y=y_masked, a=a, d=d_masked, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
)

P_filtered = stabilize(P_filtered, self.cov_jitter)
Expand Down Expand Up @@ -652,10 +665,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):

y_hat = Z.dot(a) + d
v = y - y_hat

H_chol = pytensor.ifelse(
pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True, on_error="nan")
)
H_chol = pt.linalg.cholesky(stabilize(H, self.cov_jitter), lower=True)

# The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
Expand Down Expand Up @@ -787,11 +797,13 @@ def kalman_step(self, *args):
y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)

nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value))
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
y_masked, Z_masked, H_masked, d_masked, all_nan_flag = self.handle_missing_values(
y, Z, H, d
)

result = pytensor.scan(
self._univariate_inner_filter_step,
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
sequences=[y_masked, Z_masked, d_masked, pt.diag(H_masked), nan_mask],
outputs_info=[a, P, None, None, None],
name="univariate_inner_scan",
return_updates=False,
Expand Down
56 changes: 56 additions & 0 deletions tests/statespace/filters/test_kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,62 @@ def test_missing_data(filter_name, p, rng):
), f"Shape of {name} does not match expected"


@pytest.mark.parametrize("filter_name", filter_names)
def test_missing_value_with_nonzero_obs_intercept(filter_name, rng):
"""
With non-zero observation intercept ``d``, masking must zero ``d`` at missing rows so the
innovation does not become ``-d`` and contaminate the log-likelihood. Verify by comparing
against the equivalent ``(y - d, 0)`` parameterization, under which the filter is invariant.
"""
p, m, r, n = 3, 5, 1, 10
data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng, missing_data=2)

d_nonzero = np.array([1.5, -0.7, 2.1], dtype=floatX)

# Reference: absorb d into the data (NaN entries stay NaN under subtraction).
data_absorbed = data - d_nonzero
out_ref = get_filter_function(filter_name)(
data_absorbed, a0, P0, c, np.zeros_like(d_nonzero), T, Z, R, H, Q
)
out_d = get_filter_function(filter_name)(data, a0, P0, c, d_nonzero, T, Z, R, H, Q)

for idx, name in enumerate(output_names):
assert_allclose(
out_d[idx],
out_ref[idx],
atol=ATOL,
rtol=RTOL,
err_msg=f"{name} differs between (d, y) and (0, y - d) with missing observations",
)


@pytest.mark.parametrize("filter_name", filter_names)
def test_missing_value_with_nondiagonal_obs_cov(filter_name, rng):
"""
With non-diagonal ``H`` and a missing observation at position ``j``, the cross-covariances
``H[:, j]`` and ``H[j, :]`` cannot influence any observed quantity. Verify by comparing
against a run where those rows and columns have been zeroed by hand — the two must agree.
"""
p, m, r, n = 2, 5, 1, 10
data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng)
data[:, 1] = np.nan

H_full = np.array([[1.0, 0.4], [0.4, 1.0]], dtype=floatX)
H_zeroed = np.array([[1.0, 0.0], [0.0, 0.0]], dtype=floatX)

out_full = get_filter_function(filter_name)(data, a0, P0, c, d, T, Z, R, H_full, Q)
out_zeroed = get_filter_function(filter_name)(data, a0, P0, c, d, T, Z, R, H_zeroed, Q)

for idx, name in enumerate(output_names):
assert_allclose(
out_full[idx],
out_zeroed[idx],
atol=ATOL,
rtol=RTOL,
err_msg=f"{name} depends on H entries at masked positions",
)


@pytest.mark.parametrize("filter_name", filter_names)
@pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"])
def test_last_smoother_is_last_filtered(filter_name, output_idx, rng):
Expand Down
Loading