Skip to content

Commit e6b297c

Browse files
committed
Support multi-output y in obs_variance computation
Expand leverage shape (n,) -> (n, 1) for broadcasting with (n, p) residuals in all three conditional classes. The triangular solves and _process_sigma already handle matrix RHS naturally.
1 parent c3c787d commit e6b297c

2 files changed

Lines changed: 27 additions & 0 deletions

File tree

mellon/conditional.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def _compute_obs_variance(self, x, y, mu, cov_func, sigma, jitter, weights):
237237

238238
# Corrected squared residuals (HC3 estimator)
239239
residual = y - prediction
240+
if residual.ndim > h.ndim:
241+
h = h[..., None]
240242
corrected_r2 = residual**2 / (1 - h) ** 2
241243

242244
# Fit second GP to corrected_r2 with noise regularization sigma.
@@ -456,6 +458,8 @@ def _compute_obs_variance(self, x, y, xu, mu, cov_func, sigma, jitter, weights):
456458

457459
# Corrected squared residuals (HC3 estimator)
458460
residual = y - prediction
461+
if residual.ndim > h.ndim:
462+
h = h[..., None]
459463
corrected_r2 = residual**2 / (1 - h) ** 2
460464

461465
# Fit second GP on landmarks to corrected_r2 with noise sigma.
@@ -675,6 +679,8 @@ def _compute_obs_variance(self, x, y, xu, mu, cov_func, sigma, jitter, weights):
675679

676680
# Corrected squared residuals (HC3 estimator)
677681
residual = y - prediction
682+
if residual.ndim > h.ndim:
683+
h = h[..., None]
678684
corrected_r2 = residual**2 / (1 - h) ** 2
679685

680686
# Fit second GP on landmarks to corrected_r2 with noise sigma.

tests/test_leverage.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,27 @@ def test_estimator_get_obs_variance(setup_data):
247247
)
248248

249249

250+
def test_obs_variance_multi_output():
251+
"""obs_variance should work with multi-output y of shape (n, p)."""
252+
n, d, p = 80, 3, 4
253+
key = jax.random.PRNGKey(42)
254+
k1, k2 = jax.random.split(key)
255+
X = jax.random.normal(k1, (n, d))
256+
y = jax.random.normal(k2, (n, p))
257+
258+
# Full GP
259+
est = mellon.FunctionEstimator(sigma=1.0, n_landmarks=0, obs_variance=True)
260+
est.fit(X, y)
261+
var = est.predict.obs_variance(X)
262+
assert var.shape == (n, p), f"Full GP: expected ({n}, {p}), got {var.shape}"
263+
264+
# Sparse GP
265+
est_s = mellon.FunctionEstimator(sigma=1.0, n_landmarks=20, obs_variance=True)
266+
est_s.fit(X, y)
267+
var_s = est_s.predict.obs_variance(X)
268+
assert var_s.shape == (n, p), f"Sparse GP: expected ({n}, {p}), got {var_s.shape}"
269+
270+
250271
def test_fit_obs_variance_override(setup_data):
251272
"""fit(x, y, obs_variance=True) should override constructor default."""
252273
X, y = setup_data

0 commit comments

Comments
 (0)