Skip to content

Commit 17c6825

Browse files
authored
Merge pull request #400 from JaxGaussianProcesses/fix_noise
Fix obs_noise confusion
2 parents c9836fc + 8076f74 commit 17c6825

27 files changed

+93
-58
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,4 @@ package.json
151151
package-lock.json
152152
node_modules/
153153

154-
docs/api
154+
docs/api

docs/GOVERNANCE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,4 @@ maintainers or reach out over
9393
-----
9494

9595
This file was adapted from
96-
[BlackJAX](https://github.com/blackjax-devs/blackjax/blob/main/GOVERNANCE.md).
96+
[BlackJAX](https://github.com/blackjax-devs/blackjax/blob/main/GOVERNANCE.md).

docs/examples/bayesian_optimisation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ def return_optimised_posterior(
204204
data: gpx.Dataset, prior: gpx.Module, key: Array
205205
) -> gpx.Module:
206206
likelihood = gpx.Gaussian(
207-
num_datapoints=data.n, obs_noise=jnp.array(1e-6)
208-
) # Our function is noise-free, so we set the observation noise to a very small value
209-
likelihood = likelihood.replace_trainable(obs_noise=False)
207+
num_datapoints=data.n, obs_stddev=jnp.array(1e-3)
208+
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
209+
likelihood = likelihood.replace_trainable(obs_stddev=False)
210210

211211
posterior = prior * likelihood
212212

docs/examples/decision_making.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
153153
# with the correct number of datapoints:
154154

155155
# %%
156-
likelihood_builder = lambda n: gpx.Gaussian(num_datapoints=n, obs_noise=jnp.array(1e-6))
156+
likelihood_builder = lambda n: gpx.Gaussian(
157+
num_datapoints=n, obs_stddev=jnp.array(1e-3)
158+
)
157159

158160
# %% [markdown]
159161
# Now we have all the components required for constructing our GP posterior. Since we'll

docs/examples/intro_to_kernels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,9 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
223223
prior = gpx.Prior(mean_function=mean, kernel=kernel)
224224

225225
likelihood = gpx.Gaussian(
226-
num_datapoints=D.n, obs_noise=jnp.array(1e-6)
227-
) # Our function is noise-free, so we set the observation noise to a very small value
228-
likelihood = likelihood.replace_trainable(obs_noise=False)
226+
num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
227+
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
228+
likelihood = likelihood.replace_trainable(obs_stddev=False)
229229

230230
no_opt_posterior = prior * likelihood
231231

docs/examples/likelihoods_guide.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,13 @@
9999
# Some likelihoods, such as the Gaussian likelihood, contain parameters that we seek
100100
# to infer. In the case of the Gaussian likelihood, we have a single parameter
101101
# $\sigma^2$ that determines the observation noise. In GPJax, we can specify the value
102-
# of this parameter when instantiating the likelihood object. If we do not specify a
102+
# of $\sigma$ when instantiating the likelihood object. If we do not specify a
103103
# value, then the likelihood will be initialised with a default value. In the case of
104104
# the Gaussian likelihood, the default value is $1.0$. If we instead wanted to
105-
# initialise the likelihood with a value of $0.5$, then we would do this as follows:
105+
# initialise the likelihood standard deviation with a value of $0.5$, then we would do
106+
# this as follows:
106107

107-
gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_noise=0.5)
108+
gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.5)
108109

109110
# To control other properties of the observation noise such as trainability and value
110111
# constraints, see our [PyTree guide](pytrees.md).
@@ -127,7 +128,7 @@
127128
meanf = gpx.Zero()
128129
prior = gpx.Prior(kernel=kernel, mean_function=meanf)
129130

130-
likelihood = gpx.Gaussian(num_datapoints=D.n, obs_noise=0.1)
131+
likelihood = gpx.Gaussian(num_datapoints=D.n, obs_stddev=0.1)
131132

132133
posterior = prior * likelihood
133134

@@ -252,7 +253,7 @@ def q_moments(x):
252253

253254
lquad = gpx.Gaussian(
254255
num_datapoints=D.n,
255-
obs_noise=jnp.array([0.1]),
256+
obs_stddev=jnp.array([0.1]),
256257
integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20),
257258
)
258259

docs/examples/oceanmodelling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __call__(
226226
def initialise_gp(kernel, mean, dataset):
227227
prior = gpx.Prior(mean_function=mean, kernel=kernel)
228228
likelihood = gpx.Gaussian(
229-
num_datapoints=dataset.n, obs_noise=jnp.array([1.0e-6], dtype=jnp.float64)
229+
num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64)
230230
)
231231
posterior = prior * likelihood
232232
return posterior

gpjax/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import warnings
1818

1919
from beartype.typing import (
20+
Literal,
2021
Optional,
2122
Union,
22-
Literal,
2323
)
2424
import jax.numpy as jnp
2525
from jaxtyping import (

gpjax/distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from beartype.typing import (
1818
Any,
19+
Generic,
1920
Optional,
2021
Tuple,
21-
Generic,
2222
TypeVar,
2323
Union,
2424
)

gpjax/gps.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515

1616
# from __future__ import annotations
1717
from abc import abstractmethod
18-
from dataclasses import dataclass, field
18+
from dataclasses import (
19+
dataclass,
20+
field,
21+
)
1922
from typing import overload
2023

2124
from beartype.typing import (
@@ -25,7 +28,6 @@
2528
)
2629
import cola
2730
from cola.ops import Dense
28-
2931
import jax.numpy as jnp
3032
from jax.random import (
3133
PRNGKey,
@@ -47,7 +49,10 @@
4749
ReshapedDistribution,
4850
ReshapedGaussianDistribution,
4951
)
50-
from gpjax.kernels import RFF, White
52+
from gpjax.kernels import (
53+
RFF,
54+
White,
55+
)
5156
from gpjax.kernels.base import AbstractKernel
5257
from gpjax.likelihoods import (
5358
AbstractLikelihood,
@@ -503,7 +508,7 @@ def predict(
503508
n_test = len(test_inputs)
504509

505510
# Observation noise o²
506-
obs_noise = self.likelihood.obs_noise
511+
obs_var = self.likelihood.obs_stddev**2
507512
mx = self.prior.mean_function(x)
508513

509514
# Precompute Gram matrix, Kxx, at training inputs, x
@@ -512,7 +517,7 @@ def predict(
512517

513518
# Σ = Kxx + Io²
514519
Sigma = cola.ops.Kronecker(Kxx, Kyy)
515-
Sigma += cola.ops.I_like(Sigma) * (obs_noise + self.jitter)
520+
Sigma += cola.ops.I_like(Sigma) * (obs_var + self.jitter)
516521
Sigma = cola.PSD(Sigma)
517522

518523
if mask is not None:
@@ -606,13 +611,10 @@ def sample_approx(
606611

607612
# sample weights v for canonical features
608613
# v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
614+
obs_var = self.likelihood.obs_stddev**2
609615
Kxx = self.prior.kernel.gram(train_data.X) # [N, N]
610-
Sigma = Kxx + cola.ops.I_like(Kxx) * (
611-
self.likelihood.obs_noise + self.jitter
612-
) # [N, N]
613-
eps = jnp.sqrt(self.likelihood.obs_noise) * normal(
614-
key, [train_data.n, num_samples]
615-
) # [N, B]
616+
Sigma = Kxx + cola.ops.I_like(Kxx) * (obs_var + self.jitter) # [N, N]
617+
eps = jnp.sqrt(obs_var) * normal(key, [train_data.n, num_samples]) # [N, B]
616618
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
617619
Phi = fourier_feature_fn(train_data.X)
618620
canonical_weights = cola.solve(

0 commit comments

Comments
 (0)