Skip to content

Commit 168a106

Browse files
authored
Make zero-mean fn. constant a Static class (#500)
* Make zero-mean fn. constant a Static class * Remove stray comment * Add missing imports * Fix zero-mean bug * Fix zero-mean bug * Bump version
1 parent 7df38bf commit 168a106

File tree

4 files changed

+34
-37
lines changed

4 files changed

+34
-37
lines changed

gpjax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
__description__ = "Didactic Gaussian processes in JAX"
4040
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
4141
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42-
__version__ = "0.10.0"
42+
__version__ = "0.10.1"
4343

4444
__all__ = [
4545
"base",

gpjax/kernels/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from gpjax.parameters import (
3333
Parameter,
3434
Real,
35+
Static,
3536
)
3637
from gpjax.typing import (
3738
Array,
@@ -220,7 +221,9 @@ class Constant(AbstractKernel):
220221
def __init__(
221222
self,
222223
active_dims: tp.Union[list[int], slice, None] = None,
223-
constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
224+
constant: tp.Union[
225+
ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
226+
] = jnp.array(0.0),
224227
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
225228
):
226229
if isinstance(constant, Parameter):

gpjax/mean_functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from gpjax.parameters import (
2929
Parameter,
3030
Real,
31+
Static
3132
)
3233
from gpjax.typing import (
3334
Array,
@@ -130,9 +131,9 @@ class Constant(AbstractMeanFunction):
130131
"""
131132

132133
def __init__(
133-
self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0
134+
self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0
134135
):
135-
if isinstance(constant, Parameter):
136+
if isinstance(constant, Parameter) or isinstance(constant, Static):
136137
self.constant = constant
137138
else:
138139
self.constant = Real(jnp.array(constant))
@@ -158,7 +159,7 @@ class Zero(Constant):
158159
"""
159160

160161
def __init__(self):
161-
super().__init__(constant=jnp.array(0.0))
162+
super().__init__(constant=Static(jnp.array(0.0)))
162163

163164

164165
class CombinationMeanFunction(AbstractMeanFunction):

tests/test_mean_functions.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55

66

77
import jax.numpy as jnp
8+
import jax.random as jr
89
from jaxtyping import (
910
Array,
1011
Float,
1112
)
1213
import pytest
1314

15+
import gpjax as gpx
1416
from gpjax.mean_functions import (
1517
AbstractMeanFunction,
1618
Constant,
1719
Zero,
1820
)
21+
from gpjax.parameters import Static
1922

2023

2124
def test_abstract() -> None:
@@ -49,38 +52,28 @@ def test_constant(constant: Float[Array, " Q"]) -> None:
4952
).all()
5053

5154

52-
# TODO: rewrite this test after work on fit
53-
# def test_zero_mean_remains_zero() -> None:
54-
# key = jr.PRNGKey(123)
55-
56-
# x = jr.uniform(key=key, minval=0, maxval=1, shape=(20, 1))
57-
# y = jnp.full((20, 1), 50, dtype=jnp.float64) # Dataset with non-zero mean
58-
# D = gpx.Dataset(X=x, y=y)
59-
60-
# kernel = gpx.kernels.Constant(constant=jnp.array(0.0))
61-
# kernel = kernel.replace_trainable(
62-
# constant=False
63-
# ) # Prevent kernel from modelling non-zero mean
64-
# meanf = Zero()
65-
# prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
66-
# likelihood = gpx.likelihoods.Gaussian(
67-
# num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
68-
# )
69-
# likelihood = likelihood.replace_trainable(obs_stddev=False)
70-
# posterior = prior * likelihood
71-
72-
# negative_mll = gpx.objectives.ConjugateMLL(negative=True)
73-
# opt_posterior, _ = gpx.fit(
74-
# model=posterior,
75-
# objective=negative_mll,
76-
# train_data=D,
77-
# optim=ox.adam(learning_rate=0.5),
78-
# num_iters=1000,
79-
# safe=True,
80-
# key=key,
81-
# )
82-
83-
# assert opt_posterior.prior.mean_function.constant == 0.0
55+
def test_zero_mean_remains_zero() -> None:
56+
key = jr.PRNGKey(123)
57+
58+
x = jr.uniform(key=key, minval=0, maxval=1, shape=(20, 1))
59+
y = jnp.full((20, 1), 50, dtype=jnp.float64) # Dataset with non-zero mean
60+
D = gpx.Dataset(X=x, y=y)
61+
62+
constant = Static(jnp.array(0.0))
63+
kernel = gpx.kernels.Constant(constant=constant)
64+
meanf = Zero()
65+
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
66+
likelihood = gpx.likelihoods.Gaussian(
67+
num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
68+
)
69+
posterior = prior * likelihood
70+
71+
opt_posterior, _ = gpx.fit_scipy(
72+
model=posterior,
73+
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
74+
train_data=D,
75+
)
76+
assert opt_posterior.prior.mean_function.constant.value == 0.0
8477

8578

8679
def test_initialising_zero_mean_with_constant_raises_error():

0 commit comments

Comments
 (0)