Skip to content

Commit 56cf51f

Browse files
authored
Merge pull request #95 from thomaspinder/Fix-conjugate-regression-bug-(single-datapoint)-
Fix conjugate regression bug for a single datapoint.
2 parents 7816533 + 7dda2b7 commit 56cf51f

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

gpjax/gps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def mll(
234234

235235
constant = jnp.array(-1.0) if negative else jnp.array(1.0)
236236
return constant * (
237-
marginal_likelihood.log_prob(y.squeeze()).squeeze() + log_prior_density
237+
marginal_likelihood.log_prob(jnp.atleast_1d(y.squeeze())).squeeze()
238+
+ log_prior_density
238239
)
239240

240241
return mll

tests/test_abstractions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
tfd = tf.data
1212

1313

14-
@pytest.mark.parametrize("n", [20])
14+
@pytest.mark.parametrize("n", [1, 20])
1515
def test_fit(n):
1616
key = jr.PRNGKey(123)
1717
x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n, 1)), axis=0)
@@ -27,6 +27,7 @@ def test_fit(n):
2727
assert isinstance(optimised_params, dict)
2828
assert mll(optimised_params) < pre_mll_val
2929

30+
3031
def test_stop_grads():
3132
params = {"x": jnp.array(3.0), "y": jnp.array(4.0)}
3233
trainables = {"x": True, "y": False}

tests/test_gp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax.random as jr
66
import pytest
77

8-
from gpjax import Dataset, initialise, likelihoods, transform
8+
from gpjax import Dataset, initialise, transform
99
from gpjax.gps import (
1010
AbstractGP,
1111
ConjugatePosterior,
@@ -36,7 +36,7 @@ def test_prior(num_datapoints):
3636
assert sigma.shape == (num_datapoints, num_datapoints)
3737

3838

39-
@pytest.mark.parametrize("num_datapoints", [2, 10])
39+
@pytest.mark.parametrize("num_datapoints", [1, 2, 10])
4040
def test_conjugate_posterior(num_datapoints):
4141
key = jr.PRNGKey(123)
4242
x = jnp.sort(
@@ -80,7 +80,7 @@ def test_conjugate_posterior(num_datapoints):
8080
assert sigma.shape == (num_datapoints, num_datapoints)
8181

8282

83-
@pytest.mark.parametrize("num_datapoints", [2, 10])
83+
@pytest.mark.parametrize("num_datapoints", [1, 2, 10])
8484
@pytest.mark.parametrize("likel", NonConjugateLikelihoods)
8585
def test_nonconjugate_posterior(num_datapoints, likel):
8686
key = jr.PRNGKey(123)

0 commit comments

Comments
 (0)