Skip to content

Commit 07d99db

Browse files
authored
Merge pull request #408 from JaxGaussianProcesses/namespace_cleanup
Namespace cleanup
2 parents ac47576 + 86490b1 commit 07d99db

32 files changed

+198
-247
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ D = gpx.Dataset(X=x, y=y)
135135
# Construct the prior
136136
meanf = gpx.mean_functions.Zero()
137137
kernel = gpx.kernels.RBF()
138-
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
138+
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
139139

140140
# Define a likelihood
141-
likelihood = gpx.Gaussian(num_datapoints=n)
141+
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
142142

143143
# Construct the posterior
144144
posterior = prior * likelihood

benchmarks/objectives.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def setup(self, n_datapoints: int, n_dims: int):
2222
self.data = gpx.Dataset(X=self.X, y=self.y)
2323
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
2424
meanf = gpx.mean_functions.Constant()
25-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
25+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
2626
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
2727
self.objective = gpx.ConjugateMLL()
2828
self.posterior = self.prior * self.likelihood
@@ -48,7 +48,7 @@ def setup(self, n_datapoints: int, n_dims: int):
4848
self.data = gpx.Dataset(X=self.X, y=self.y)
4949
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
5050
meanf = gpx.mean_functions.Constant()
51-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
51+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
5252
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
5353
self.objective = gpx.LogPosteriorDensity()
5454
self.posterior = self.prior * self.likelihood
@@ -75,7 +75,7 @@ def setup(self, n_datapoints: int, n_dims: int):
7575
self.data = gpx.Dataset(X=self.X, y=self.y)
7676
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
7777
meanf = gpx.mean_functions.Constant()
78-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
78+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
7979
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
8080
self.objective = gpx.LogPosteriorDensity()
8181
self.posterior = self.prior * self.likelihood

benchmarks/predictions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def setup(self, n_test: int, n_dims: int):
2121
self.data = gpx.Dataset(X=self.X, y=self.y)
2222
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
2323
meanf = gpx.mean_functions.Constant()
24-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
24+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
2525
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
2626
self.posterior = self.prior * self.likelihood
2727
key, subkey = jr.split(key)
@@ -46,7 +46,7 @@ def setup(self, n_test: int, n_dims: int):
4646
self.data = gpx.Dataset(X=self.X, y=self.y)
4747
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
4848
meanf = gpx.mean_functions.Constant()
49-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
49+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
5050
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
5151
self.posterior = self.prior * self.likelihood
5252
key, subkey = jr.split(key)
@@ -71,7 +71,7 @@ def setup(self, n_test: int, n_dims: int):
7171
self.data = gpx.Dataset(X=self.X, y=self.y)
7272
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
7373
meanf = gpx.mean_functions.Constant()
74-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
74+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
7575
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
7676
self.posterior = self.prior * self.likelihood
7777
key, subkey = jr.split(key)

benchmarks/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def setup(self, n_datapoints: int, n_inducing: int):
1919
self.data = gpx.Dataset(X=self.X, y=self.y)
2020
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
2121
meanf = gpx.mean_functions.Constant()
22-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
22+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
2323
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
2424
self.posterior = self.prior * self.likelihood
2525

benchmarks/stochastic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
2020
self.data = gpx.Dataset(X=self.X, y=self.y)
2121
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
2222
meanf = gpx.mean_functions.Constant()
23-
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
23+
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
2424
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
2525
self.posterior = self.prior * self.likelihood
2626

docs/examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Prior(AbstractPrior):
6767
>>>
6868
>>> meanf = gpx.mean_functions.Zero()
6969
>>> kernel = gpx.kernels.RBF()
70-
>>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
70+
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
7171
7272
Attributes:
7373
kernel (Kernel): The kernel function used to parameterise the prior.

docs/examples/barycentres.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
134134
y = y.reshape(-1, 1)
135135
D = gpx.Dataset(X=x, y=y)
136136

137-
likelihood = gpx.Gaussian(num_datapoints=n)
138-
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood
137+
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
138+
posterior = (
139+
gpx.gps.Prior(
140+
mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF()
141+
)
142+
* likelihood
143+
)
139144

140145
opt_posterior, _ = gpx.fit_scipy(
141146
model=posterior,

docs/examples/bayesian_optimisation.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
201201

202202
# %%
203203
def return_optimised_posterior(
204-
data: gpx.Dataset, prior: gpx.Module, key: Array
205-
) -> gpx.Module:
206-
likelihood = gpx.Gaussian(
204+
data: gpx.Dataset, prior: gpx.base.Module, key: Array
205+
) -> gpx.base.Module:
206+
likelihood = gpx.likelihoods.Gaussian(
207207
num_datapoints=data.n, obs_stddev=jnp.array(1e-3)
208208
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
209209
likelihood = likelihood.replace_trainable(obs_stddev=False)
@@ -230,7 +230,7 @@ def return_optimised_posterior(
230230

231231
mean = gpx.mean_functions.Zero()
232232
kernel = gpx.kernels.Matern52()
233-
prior = gpx.Prior(mean_function=mean, kernel=kernel)
233+
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
234234
opt_posterior = return_optimised_posterior(D, prior, key)
235235

236236
# %% [markdown]
@@ -315,7 +315,7 @@ def optimise_sample(
315315

316316
# %%
317317
def plot_bayes_opt(
318-
posterior: gpx.Module,
318+
posterior: gpx.base.Module,
319319
sample: FunctionalSample,
320320
dataset: gpx.Dataset,
321321
queried_x: ScalarFloat,
@@ -401,7 +401,7 @@ def plot_bayes_opt(
401401
# Generate optimised posterior using previously observed data
402402
mean = gpx.mean_functions.Zero()
403403
kernel = gpx.kernels.Matern52()
404-
prior = gpx.Prior(mean_function=mean, kernel=kernel)
404+
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
405405
opt_posterior = return_optimised_posterior(D, prior, subkey)
406406

407407
# Draw a sample from the posterior, and find the minimiser of it
@@ -543,7 +543,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
543543
kernel = gpx.kernels.Matern52(
544544
active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0
545545
)
546-
prior = gpx.Prior(mean_function=mean, kernel=kernel)
546+
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
547547
opt_posterior = return_optimised_posterior(D, prior, subkey)
548548

549549
# Draw a sample from the posterior, and find the minimiser of it
@@ -561,7 +561,8 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
561561
# Evaluate the black-box function at the best point observed so far, and add it to the dataset
562562
y_star = six_hump_camel(x_star)
563563
print(
564-
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value: {y_star}"
564+
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:"
565+
f" {y_star}"
565566
)
566567
D = D + gpx.Dataset(X=x_star, y=y_star)
567568
bo_experiment_results.append(D)

docs/examples/classification.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@
8989
# choose a Bernoulli likelihood with a probit link function.
9090

9191
# %%
92-
kernel = gpx.RBF()
93-
meanf = gpx.Constant()
94-
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
95-
likelihood = gpx.Bernoulli(num_datapoints=D.n)
92+
kernel = gpx.kernels.RBF()
93+
meanf = gpx.mean_functions.Constant()
94+
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
95+
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)
9696

9797
# %% [markdown]
9898
# We construct the posterior through the product of our prior and likelihood.
@@ -116,7 +116,7 @@
116116
# Optax's optimisers.
117117

118118
# %%
119-
negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True))
119+
negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True))
120120

121121
optimiser = ox.adam(learning_rate=0.01)
122122

@@ -345,7 +345,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
345345
num_adapt = 500
346346
num_samples = 500
347347

348-
lpd = jax.jit(gpx.LogPosteriorDensity(negative=False))
348+
lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False))
349349
unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D))
350350

351351
adapt = blackjax.window_adaptation(

docs/examples/collapsed_vi.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@
106106
# this, it is intractable to evaluate.
107107

108108
# %%
109-
meanf = gpx.Constant()
110-
kernel = gpx.RBF()
111-
likelihood = gpx.Gaussian(num_datapoints=D.n)
112-
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
109+
meanf = gpx.mean_functions.Constant()
110+
kernel = gpx.kernels.RBF()
111+
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
112+
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
113113
posterior = prior * likelihood
114114

115115
# %% [markdown]
@@ -119,15 +119,17 @@
119119
# inducing points into the constructor as arguments.
120120

121121
# %%
122-
q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z)
122+
q = gpx.variational_families.CollapsedVariationalGaussian(
123+
posterior=posterior, inducing_inputs=z
124+
)
123125

124126
# %% [markdown]
125127
# We define our variational inference algorithm through `CollapsedVI`. This defines
126128
# the collapsed variational free energy bound considered in
127129
# <strong data-cite="titsias2009">Titsias (2009)</strong>.
128130

129131
# %%
130-
elbo = gpx.CollapsedELBO(negative=True)
132+
elbo = gpx.objectives.CollapsedELBO(negative=True)
131133

132134
# %% [markdown]
133135
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
@@ -241,14 +243,14 @@
241243
# full model.
242244

243245
# %%
244-
full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(
245-
num_datapoints=D.n
246-
)
247-
negative_mll = jit(gpx.ConjugateMLL(negative=True).step)
246+
full_rank_model = gpx.gps.Prior(
247+
mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
248+
) * gpx.likelihoods.Gaussian(num_datapoints=D.n)
249+
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step)
248250
# %timeit negative_mll(full_rank_model, D).block_until_ready()
249251

250252
# %%
251-
negative_elbo = jit(gpx.CollapsedELBO(negative=True).step)
253+
negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step)
252254
# %timeit negative_elbo(q, D).block_until_ready()
253255

254256
# %% [markdown]

0 commit comments

Comments
 (0)