|
5 | 5 |
|
6 | 6 |
|
7 | 7 | import jax.numpy as jnp |
| 8 | +import jax.random as jr |
8 | 9 | from jaxtyping import ( |
9 | 10 | Array, |
10 | 11 | Float, |
11 | 12 | ) |
12 | 13 | import pytest |
13 | 14 |
|
| 15 | +import gpjax as gpx |
14 | 16 | from gpjax.mean_functions import ( |
15 | 17 | AbstractMeanFunction, |
16 | 18 | Constant, |
17 | 19 | Zero, |
18 | 20 | ) |
| 21 | +from gpjax.parameters import Static |
19 | 22 |
|
20 | 23 |
|
21 | 24 | def test_abstract() -> None: |
@@ -49,38 +52,28 @@ def test_constant(constant: Float[Array, " Q"]) -> None: |
49 | 52 | ).all() |
50 | 53 |
|
51 | 54 |
|
52 | | -# TODO: rewrite this test after work on fit |
53 | | -# def test_zero_mean_remains_zero() -> None: |
54 | | -# key = jr.PRNGKey(123) |
55 | | - |
56 | | -# x = jr.uniform(key=key, minval=0, maxval=1, shape=(20, 1)) |
57 | | -# y = jnp.full((20, 1), 50, dtype=jnp.float64) # Dataset with non-zero mean |
58 | | -# D = gpx.Dataset(X=x, y=y) |
59 | | - |
60 | | -# kernel = gpx.kernels.Constant(constant=jnp.array(0.0)) |
61 | | -# kernel = kernel.replace_trainable( |
62 | | -# constant=False |
63 | | -# ) # Prevent kernel from modelling non-zero mean |
64 | | -# meanf = Zero() |
65 | | -# prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) |
66 | | -# likelihood = gpx.likelihoods.Gaussian( |
67 | | -# num_datapoints=D.n, obs_stddev=jnp.array(1e-3) |
68 | | -# ) |
69 | | -# likelihood = likelihood.replace_trainable(obs_stddev=False) |
70 | | -# posterior = prior * likelihood |
71 | | - |
72 | | -# negative_mll = gpx.objectives.ConjugateMLL(negative=True) |
73 | | -# opt_posterior, _ = gpx.fit( |
74 | | -# model=posterior, |
75 | | -# objective=negative_mll, |
76 | | -# train_data=D, |
77 | | -# optim=ox.adam(learning_rate=0.5), |
78 | | -# num_iters=1000, |
79 | | -# safe=True, |
80 | | -# key=key, |
81 | | -# ) |
82 | | - |
83 | | -# assert opt_posterior.prior.mean_function.constant == 0.0 |
| 55 | +def test_zero_mean_remains_zero() -> None: |
| 56 | + key = jr.PRNGKey(123) |
| 57 | + |
| 58 | + x = jr.uniform(key=key, minval=0, maxval=1, shape=(20, 1)) |
| 59 | + y = jnp.full((20, 1), 50, dtype=jnp.float64) # Dataset with non-zero mean |
| 60 | + D = gpx.Dataset(X=x, y=y) |
| 61 | + |
| 62 | + constant = Static(jnp.array(0.0)) |
| 63 | + kernel = gpx.kernels.Constant(constant=constant) |
| 64 | + meanf = Zero() |
| 65 | + prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) |
| 66 | + likelihood = gpx.likelihoods.Gaussian( |
| 67 | + num_datapoints=D.n, obs_stddev=jnp.array(1e-3) |
| 68 | + ) |
| 69 | + posterior = prior * likelihood |
| 70 | + |
| 71 | + opt_posterior, _ = gpx.fit_scipy( |
| 72 | + model=posterior, |
| 73 | + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), |
| 74 | + train_data=D, |
| 75 | + ) |
| 76 | + assert opt_posterior.prior.mean_function.constant.value == 0.0 |
84 | 77 |
|
85 | 78 |
|
86 | 79 | def test_initialising_zero_mean_with_constant_raises_error(): |
|
0 commit comments