Skip to content

Commit 48706eb

Browse files
Merge pull request #358 from Thomas-Christie/zero-mean-fix
Fix bug in zero mean function and add test
2 parents c5eb47b + fe8bde4 commit 48706eb

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

gpjax/kernels/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from gpjax.kernels.approximations import RFF
1818
from gpjax.kernels.base import (
1919
AbstractKernel,
20+
Constant,
2021
ProductKernel,
2122
SumKernel,
2223
)
@@ -27,7 +28,10 @@
2728
DiagonalKernelComputation,
2829
EigenKernelComputation,
2930
)
30-
from gpjax.kernels.non_euclidean import GraphKernel, CatKernel
31+
from gpjax.kernels.non_euclidean import (
32+
CatKernel,
33+
GraphKernel,
34+
)
3135
from gpjax.kernels.nonstationary import (
3236
ArcCosine,
3337
Linear,
@@ -47,6 +51,7 @@
4751
__all__ = [
4852
"AbstractKernel",
4953
"ArcCosine",
54+
"Constant",
5055
"RBF",
5156
"GraphKernel",
5257
"CatKernel",

gpjax/mean_functions.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
150150
return jnp.ones((x.shape[0], 1)) * self.constant
151151

152152

153+
@dataclasses.dataclass
154+
class Zero(Constant):
155+
r"""Zero mean function.
156+
157+
The zero mean function. This function returns a zero scalar value for all
158+
inputs. Unlike the Constant mean function, the constant scalar zero is fixed, and
159+
cannot be treated as a model hyperparameter and learned during training.
160+
"""
161+
constant: Float[Array, "1"] = static_field(jnp.array([0.0]), init=False)
162+
163+
153164
@dataclasses.dataclass
154165
class CombinationMeanFunction(AbstractMeanFunction):
155166
r"""A base class for products or sums of AbstractMeanFunctions."""
@@ -199,4 +210,3 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
199210
ProductMeanFunction = partial(
200211
CombinationMeanFunction, operator=partial(jnp.sum, axis=0)
201212
)
202-
Zero = partial(Constant, constant=jnp.array([0.0]))

tests/test_mean_functions.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1+
# Enable Float64 for more stable matrix inversions.
2+
from jax import config
3+
4+
config.update("jax_enable_x64", True)
5+
6+
7+
import jax
18
import jax.numpy as jnp
9+
import jax.random as jr
210
from jaxtyping import (
311
Array,
412
Float,
513
)
14+
import optax as ox
615
import pytest
716

17+
import gpjax as gpx
818
from gpjax.mean_functions import (
919
AbstractMeanFunction,
1020
Constant,
21+
Zero,
1122
)
1223

1324

@@ -40,3 +51,45 @@ def test_constant(constant: Float[Array, " Q"]) -> None:
4051
assert (
4152
mf(jnp.array([[1.0, 2.0], [3.0, 4.0]])) == jnp.array([constant, constant])
4253
).all()
54+
55+
56+
def test_zero_mean_remains_zero() -> None:
57+
key = jr.PRNGKey(123)
58+
59+
x = jr.uniform(key=key, minval=0, maxval=1, shape=(20, 1))
60+
y = jnp.full((20, 1), 50, dtype=jnp.float64) # Dataset with non-zero mean
61+
D = gpx.Dataset(X=x, y=y)
62+
63+
kernel = gpx.kernels.Constant(constant=jnp.array(0.0))
64+
kernel = kernel.replace_trainable(
65+
constant=False
66+
) # Prevent kernel from modelling non-zero mean
67+
meanf = Zero()
68+
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
69+
likelihood = gpx.Gaussian(num_datapoints=D.n, obs_noise=jnp.array(1e-6))
70+
likelihood = likelihood.replace_trainable(obs_noise=False)
71+
posterior = prior * likelihood
72+
73+
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
74+
opt_posterior, _ = gpx.fit(
75+
model=posterior,
76+
objective=negative_mll,
77+
train_data=D,
78+
optim=ox.adam(learning_rate=0.5),
79+
num_iters=1000,
80+
safe=True,
81+
key=key,
82+
)
83+
84+
assert opt_posterior.prior.mean_function.constant == 0.0
85+
86+
87+
def test_zero_mean_pytree_no_leaves():
88+
zero_mean = Zero()
89+
leaves = jax.tree_util.tree_leaves(zero_mean)
90+
assert len(leaves) == 0
91+
92+
93+
def test_initialising_zero_mean_with_constant_raises_error():
94+
with pytest.raises(TypeError):
95+
Zero(constant=jnp.array([1.0]))

0 commit comments

Comments
 (0)