Skip to content

Commit d211e86

Browse files
authored
Merge pull request #417 from meta-inf/collapsed-elbo
Fix typo in CollapsedELBO
2 parents 2d2f451 + 13294a4 commit d211e86

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

gpjax/objectives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def step(
446446

447447
m = variational_family.num_inducing
448448

449-
noise = variational_family.posterior.likelihood.obs_stddev
449+
noise = variational_family.posterior.likelihood.obs_stddev**2
450450
z = variational_family.inducing_inputs
451451
Kzz = kernel.gram(z)
452452
Kzz += cola.ops.I_like(Kzz) * variational_family.jitter

tests/test_objectives.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,12 @@ def test_collapsed_elbo(
175175
assert isinstance(evaluation, jax.Array)
176176
assert evaluation.shape == ()
177177

178-
# with pytest.raises(TypeError):
178+
# Data on the full dataset should be the same as the marginal likelihood
179+
q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=D.X)
180+
mll = ConjugateMLL(negative=negative)
181+
expected_value = mll(p * likelihood, D)
182+
actual_value = negative_elbo(q, D)
183+
assert jnp.abs(actual_value - expected_value) / expected_value < 1e-6
179184

180185

181186
@pytest.mark.parametrize("num_datapoints", [1, 2, 10])

0 commit comments

Comments
 (0)