1+ from gpjax .parameters import log_density
2+ from gpjax .parameters .priors import evaluate_prior , prior_checks
3+ from gpjax .gps import Prior
4+ from gpjax .kernels import RBF
5+ from gpjax .likelihoods import Bernoulli
6+ from tensorflow_probability .substrates .jax import distributions as tfd
7+ import pytest
8+ import jax .numpy as jnp
9+
10+
11+ @pytest .mark .parametrize ('x' , [- 1. , 0. , 1. ])
12+ def test_lpd (x ):
13+ val = jnp .array (x )
14+ dist = tfd .Normal (loc = 0. , scale = 1. )
15+ lpd = log_density (val , dist )
16+ assert lpd is not None
17+
18+
19+ def test_prior_evaluation ():
20+ """
21+ Test the regular setup that every parameter has a corresponding prior distribution attached to its unconstrained
22+ value.
23+ """
24+ params = {
25+ "lengthscale" : jnp .array ([1. ]),
26+ "variance" : jnp .array ([1. ]),
27+ "obs_noise" : jnp .array ([1. ]),
28+ }
29+ priors = {
30+ "lengthscale" : tfd .Gamma (1.0 , 1.0 ),
31+ "variance" : tfd .Gamma (2.0 , 2.0 ),
32+ "obs_noise" : tfd .Gamma (3.0 , 3.0 ),
33+ }
34+ lpd = evaluate_prior (params , priors )
35+ assert pytest .approx (lpd ) == - 2.0110168
36+
37+
38+ def test_none_prior ():
39+ """
40+ Test that multiple dispatch is working in the case of no priors.
41+ """
42+ params = {
43+ "lengthscale" : jnp .array ([1. ]),
44+ "variance" : jnp .array ([1. ]),
45+ "obs_noise" : jnp .array ([1. ]),
46+ }
47+ lpd = evaluate_prior (params , None )
48+ assert lpd == 0.
49+
50+
51+ def test_incomplete_priors ():
52+ """
53+ Test the case where a user specifies priors for some, but not all, parameters.
54+ """
55+ params = {
56+ "lengthscale" : jnp .array ([1. ]),
57+ "variance" : jnp .array ([1. ]),
58+ "obs_noise" : jnp .array ([1. ]),
59+ }
60+ priors = {
61+ "lengthscale" : tfd .Gamma (1.0 , 1.0 ),
62+ "variance" : tfd .Gamma (2.0 , 2.0 ),
63+ }
64+ lpd = evaluate_prior (params , priors )
65+ assert pytest .approx (lpd ) == - 1.6137061
66+
67+
68+ def test_checks ():
69+ incomplete_priors = {'lengthscale' : jnp .array ([1. ])}
70+ posterior = Prior (kernel = RBF ()) * Bernoulli ()
71+ priors = prior_checks (posterior , incomplete_priors )
72+ assert 'latent' in priors .keys ()
73+ assert 'variance' not in priors .keys ()
74+
75+
76+ def test_check_needless ():
77+ complete_prior = {
78+ "lengthscale" : tfd .Gamma (1.0 , 1.0 ),
79+ "variance" : tfd .Gamma (2.0 , 2.0 ),
80+ "obs_noise" : tfd .Gamma (3.0 , 3.0 ),
81+ "latent" : tfd .Normal (loc = 0. , scale = 1. )
82+ }
83+ posterior = Prior (kernel = RBF ()) * Bernoulli ()
84+ priors = prior_checks (posterior , complete_prior )
85+ assert priors == complete_prior
0 commit comments