Skip to content

Commit 24fbde8

Browse files
committed
Support priors
1 parent 97cc8c8 commit 24fbde8

File tree

10 files changed

+183
-35
lines changed

10 files changed

+183
-35
lines changed

gpjax/objectives/mlls.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..gps import ConjugatePosterior, NonConjugatePosterior
88
from ..kernels import gram
99
from ..likelihoods import link_function
10-
from ..parameters.prior_densities import log_density
10+
from ..parameters.priors import evaluate_prior, prior_checks
1111
from ..parameters.transforms import (SoftplusTransformation, Transformation,
1212
untransform)
1313
from ..types import Array
@@ -28,17 +28,17 @@ def marginal_ll(
2828
Returns: A multivariate normal distribution
2929
"""
3030

31-
def mll(params: dict, x: Array, y: Array):
31+
def mll(params: dict, x: Array, y: Array, priors: dict = None):
3232
params = untransform(params, transformation)
3333
mu = gp.prior.mean_function(x)
3434
gram_matrix = params["variance"] * gram(gp.prior.kernel, x / params["lengthscale"])
3535
gram_matrix += params["obs_noise"] * I(x.shape[0])
3636
L = jnp.linalg.cholesky(gram_matrix)
3737
random_variable = tfd.MultivariateNormalTriL(mu, L)
38-
# TODO: Attach log-prior density sum here
39-
constant = jnp.array(-1.0) if negative else jnp.array(1.0)
40-
return constant * random_variable.log_prob(y.squeeze()).mean()
4138

39+
log_prior_density = evaluate_prior(params, priors)
40+
constant = jnp.array(-1.0) if negative else jnp.array(1.0)
41+
return constant * (random_variable.log_prob(y.squeeze()).mean() + log_prior_density)
4242
return mll
4343

4444

@@ -49,7 +49,7 @@ def marginal_ll(
4949
negative: bool = False,
5050
jitter: float = 1e-6,
5151
) -> Callable:
52-
def mll(params: dict, x: Array, y: Array):
52+
def mll(params: dict, x: Array, y: Array, priors: dict = {'latent': tfd.Normal(loc=0., scale=1.)}):
5353
params = untransform(params, transformation)
5454
n = x.shape[0]
5555
link = link_function(gp.likelihood)
@@ -59,9 +59,10 @@ def mll(params: dict, x: Array, y: Array):
5959
F = jnp.matmul(L, params["latent"])
6060
rv = link(F)
6161
ll = jnp.sum(rv.log_prob(y))
62-
# TODO: Attach full log-prior density sum here
63-
latent_prior = jnp.sum(log_density(params["latent"], tfd.Normal(loc=0.0, scale=1.0)))
62+
63+
priors = prior_checks(gp, priors)
64+
log_prior_density = evaluate_prior(params, priors)
6465
constant = jnp.array(-1.0) if negative else jnp.array(1.0)
65-
return constant * (ll + latent_prior)
66+
return constant * (ll + log_prior_density)
6667

6768
return mll

gpjax/parameters/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .base import complete, initialise
2-
from .prior_densities import log_density
2+
from .priors import log_density
33
from .transforms import (IdentityTransformation, SoftplusTransformation,
44
transform, untransform)

gpjax/parameters/prior_densities.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

gpjax/parameters/priors.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import jax.numpy as jnp
2+
from jax.interpreters.ad import JVPTracer
3+
from jax.interpreters.partial_eval import DynamicJaxprTracer
4+
from multipledispatch import dispatch
5+
from tensorflow_probability.substrates.jax import distributions as tfd
6+
from ..gps import NonConjugatePosterior
7+
import warnings
8+
9+
from ..types import Array, NoneType
10+
11+
12+
@dispatch((jnp.DeviceArray, JVPTracer, DynamicJaxprTracer), tfd.Distribution)
13+
def log_density(param: jnp.DeviceArray, density: tfd.Distribution) -> Array:
14+
return density.log_prob(param)
15+
16+
17+
@dispatch(dict, NoneType)
18+
def evaluate_prior(params: dict, priors: dict) -> Array:
19+
return jnp.array(0.)
20+
21+
22+
@dispatch(dict, dict)
23+
def evaluate_prior(params: dict, priors: dict) -> Array:
24+
lpd = jnp.array(0)
25+
for param, val in priors.items():
26+
lpd+=jnp.sum(log_density(params[param], priors[param]))
27+
return lpd
28+
29+
30+
@dispatch(NonConjugatePosterior, dict)
31+
def prior_checks(gp: NonConjugatePosterior, priors: dict) -> dict:
32+
if 'latent' in priors.keys():
33+
latent_prior = priors['latent']
34+
if latent_prior.name != 'Normal':
35+
warnings.warn(f'A {latent_prior.name} distribution prior has been placed on the latent function. It is strongly afvised that a unit-Gaussian prior is used.')
36+
return priors
37+
else:
38+
priors['latent'] = tfd.Normal(loc=0., scale=1.)
39+
return priors

gpjax/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66

77
# Array = Union[jnp.ndarray, ShardedDeviceArray, jnp.DeviceArray] # Cannot currently dispatch on a Union type
88
# Data = Tuple[Array, Array]
9+
NoneType = type(None)

tests/objectives/test_mlls.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from gpjax.kernels import RBF
55
from gpjax.parameters import transform, SoftplusTransformation, initialise
66
import jax.numpy as jnp
7+
import jax.random as jr
78
import pytest
89
from typing import Callable
10+
from tensorflow_probability.substrates.jax import distributions as tfd
911

1012

1113
def test_conjugate():
@@ -29,3 +31,27 @@ def test_non_conjugate():
2931
y = jnp.sin(x)
3032
params = transform(params=initialise(posterior, n), transformation=SoftplusTransformation)
3133
assert neg_mll(params, x, y) == jnp.array(-1.)*mll(params, x, y)
34+
35+
36+
def test_prior_mll():
37+
"""
38+
Test that the MLL evaluation works with priors attached to the parameter values.
39+
"""
40+
key = jr.PRNGKey(123)
41+
x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)), axis=0)
42+
f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x)
43+
y = f(x) + jr.normal(key, shape=x.shape) * 0.1
44+
posterior = Prior(kernel=RBF()) * Gaussian()
45+
mll = marginal_ll(posterior)
46+
47+
params = initialise(posterior)
48+
priors = {
49+
"lengthscale": tfd.Gamma(1.0, 1.0),
50+
"variance": tfd.Gamma(2.0, 2.0),
51+
"obs_noise": tfd.Gamma(2.0, 2.0),
52+
}
53+
mll_eval = mll(params, x, y)
54+
mll_eval_priors = mll(params, x, y, priors)
55+
56+
assert pytest.approx(mll_eval) == jnp.array(-115.72332969)
57+
assert pytest.approx(mll_eval_priors) == jnp.array(-118.97202259)

tests/parameters/test_prior_densities.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

tests/parameters/test_priors.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from gpjax.parameters import log_density
2+
from gpjax.parameters.priors import evaluate_prior, prior_checks
3+
from gpjax.gps import Prior
4+
from gpjax.kernels import RBF
5+
from gpjax.likelihoods import Bernoulli
6+
from tensorflow_probability.substrates.jax import distributions as tfd
7+
import pytest
8+
import jax.numpy as jnp
9+
10+
11+
@pytest.mark.parametrize('x', [-1., 0., 1.])
12+
def test_lpd(x):
13+
val = jnp.array(x)
14+
dist = tfd.Normal(loc=0., scale=1.)
15+
lpd = log_density(val, dist)
16+
assert lpd is not None
17+
18+
19+
def test_prior_evaluation():
20+
"""
21+
Test the regular setup that every parameter has a corresponding prior distribution attached to its unconstrained
22+
value.
23+
"""
24+
params = {
25+
"lengthscale": jnp.array([1.]),
26+
"variance": jnp.array([1.]),
27+
"obs_noise": jnp.array([1.]),
28+
}
29+
priors = {
30+
"lengthscale": tfd.Gamma(1.0, 1.0),
31+
"variance": tfd.Gamma(2.0, 2.0),
32+
"obs_noise": tfd.Gamma(3.0, 3.0),
33+
}
34+
lpd = evaluate_prior(params, priors)
35+
assert pytest.approx(lpd) == -2.0110168
36+
37+
38+
def test_none_prior():
39+
"""
40+
Test that multiple dispatch is working in the case of no priors.
41+
"""
42+
params = {
43+
"lengthscale": jnp.array([1.]),
44+
"variance": jnp.array([1.]),
45+
"obs_noise": jnp.array([1.]),
46+
}
47+
lpd = evaluate_prior(params, None)
48+
assert lpd == 0.
49+
50+
51+
def test_incomplete_priors():
52+
"""
53+
Test the case where a user specifies priors for some, but not all, parameters.
54+
"""
55+
params = {
56+
"lengthscale": jnp.array([1.]),
57+
"variance": jnp.array([1.]),
58+
"obs_noise": jnp.array([1.]),
59+
}
60+
priors = {
61+
"lengthscale": tfd.Gamma(1.0, 1.0),
62+
"variance": tfd.Gamma(2.0, 2.0),
63+
}
64+
lpd = evaluate_prior(params, priors)
65+
assert pytest.approx(lpd) == -1.6137061
66+
67+
68+
def test_checks():
69+
incomplete_priors = {'lengthscale': jnp.array([1.])}
70+
posterior = Prior(kernel=RBF()) * Bernoulli()
71+
priors = prior_checks(posterior, incomplete_priors)
72+
assert 'latent' in priors.keys()
73+
assert 'variance' not in priors.keys()
74+
75+
76+
def test_check_needless():
77+
complete_prior = {
78+
"lengthscale": tfd.Gamma(1.0, 1.0),
79+
"variance": tfd.Gamma(2.0, 2.0),
80+
"obs_noise": tfd.Gamma(3.0, 3.0),
81+
"latent": tfd.Normal(loc=0., scale=1.)
82+
}
83+
posterior = Prior(kernel=RBF()) * Bernoulli()
84+
priors = prior_checks(posterior, complete_prior)
85+
assert priors == complete_prior

tests/test_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from gpjax.types import NoneType
2+
3+
4+
def test_nonetype():
5+
assert isinstance(None, NoneType)

tests/test_utilities.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from gpjax.utils import concat_dictionaries, I, merge_dictionaries
1+
from gpjax.utils import concat_dictionaries, I, merge_dictionaries, standardise, unstandardise
22
import jax.numpy as jnp
3+
import jax.random as jr
34
import pytest
45

56

@@ -24,3 +25,17 @@ def test_merge_dicts():
2425
d = merge_dictionaries(d1, d2)
2526
assert list(d.keys()) == ['a', 'b']
2627
assert list(d.values()) == [1, 3]
28+
29+
30+
def test_standardise():
31+
key = jr.PRNGKey(123)
32+
x = jr.uniform(key, shape=(100, 1))
33+
xtr, xmean, xstd = standardise(x)
34+
assert pytest.approx(jnp.mean(xtr), rel=4) == 0.
35+
36+
xtr2 = standardise(x, xmean, xstd)
37+
assert pytest.approx(jnp.mean(xtr2), rel=4) == 0.
38+
39+
xuntr = unstandardise(xtr, xmean, xstd)
40+
diff = jnp.sum(jnp.abs(xuntr - x))
41+
assert pytest.approx(diff) == 0

0 commit comments

Comments
 (0)