Skip to content

Commit d9cd2aa

Browse files
edenoclaude
andcommitted
fix: symmetrize covariance matrices before Cholesky decomposition
Added covariance symmetrization in GMM estimation for numerical stability. This ensures exact symmetry before Cholesky decomposition, preventing potential numerical issues from floating-point rounding errors. Changes: - In _estimate_gaussian_covariances_full(): symmetrize with (cov + cov.T) * 0.5 - In _estimate_gaussian_covariances_tied(): symmetrize with (cov + cov.T) * 0.5 - Applied before regularization to ensure exact symmetry Benefits: - Guarantees numerically exact symmetry (error = 0.00e+00) - Prevents rare Cholesky decomposition failures - Best practice for numerical hygiene in covariance estimation Validation: - Tested with synthetic data: all covariances perfectly symmetric - Cholesky decomposition works correctly for all components - No changes to convergence or model quality 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 6367144 commit d9cd2aa

File tree

1 file changed

+4
-0
lines changed
  • src/non_local_detector/likelihoods

1 file changed

+4
-0
lines changed

src/non_local_detector/likelihoods/gmm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def _estimate_gaussian_covariances_full(
6262
_, n_features = means.shape
6363
diff = X[jnp.newaxis, :, :] - means[:, jnp.newaxis, :] # (K, N, D)
6464
covariances = jax.vmap(lambda r, d, n: ((d.T * r) @ d) / n)(resp.T, diff, nk)
65+
# Symmetrize covariances for numerical stability (ensure exact symmetry)
66+
covariances = (covariances + jnp.swapaxes(covariances, -2, -1)) * 0.5
6567
covariances += jnp.eye(n_features, dtype=X.dtype) * reg_covar
6668
return covariances
6769

@@ -106,6 +108,8 @@ def _estimate_gaussian_covariances_tied(
106108
avg_means2 = (nk * means.T) @ means # (D, D)
107109

108110
covariance = (XTX_w - avg_means2) / jnp.maximum(sum_w, 1.0)
111+
# Symmetrize covariance for numerical stability (ensure exact symmetry)
112+
covariance = (covariance + covariance.T) * 0.5
109113
covariance += jnp.eye(n_features, dtype=X.dtype) * reg_covar
110114
return covariance
111115

0 commit comments

Comments
 (0)