Skip to content

Commit c50d34d

Browse files
authored
Merge pull request #148 from JaxGaussianProcesses/More_Kernels
This PR add additional kernels, and provides the notion of a "compute_engine" to perform kernel operations, that in future will build the foundation for alternative matrix solving algorithms, e.g., conjugate gradients.
2 parents ebd6cb7 + e5c339a commit c50d34d

19 files changed

+897
-466
lines changed

.pre-commit-config.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,12 @@ repos:
1919
- id: nbqa-pyupgrade
2020
args: [--py37-plus]
2121
- id: nbqa-flake8
22-
args: ['--ignore=E501,E203,E302,E402,E731,W503']
22+
args: ['--ignore=E501,E203,E302,E402,E731,W503']
23+
- repo: https://github.com/PyCQA/autoflake
24+
rev: v2.0.0
25+
hooks:
26+
- id: autoflake
27+
args: ["--in-place", "--remove-unused-variables", "--remove-all-unused-imports", "--recursive"]
28+
name: AutoFlake
29+
description: "Format with AutoFlake"
30+
stages: [commit]

examples/classification.pct.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# format_version: '1.3'
1010
# jupytext_version: 1.11.2
1111
# kernelspec:
12-
# display_name: base
12+
# display_name: Python 3.9.7 ('gpjax')
1313
# language: python
1414
# name: python3
1515
# ---
@@ -19,7 +19,7 @@
1919
#
2020
# In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling.
2121

22-
# %% vscode={"languageId": "python"}
22+
# %%
2323
import blackjax
2424
import distrax as dx
2525
import jax
@@ -47,7 +47,7 @@
4747
#
4848
# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later.
4949

50-
# %% vscode={"languageId": "python"}
50+
# %%
5151
x = jnp.sort(jr.uniform(key, shape=(100, 1), minval=-1.0, maxval=1.0), axis=0)
5252
y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5
5353

@@ -61,15 +61,15 @@
6161
#
6262
# We begin by defining a Gaussian process prior with a radial basis function (RBF) kernel, chosen for the purpose of exposition. Since our observations are binary, we choose a Bernoulli likelihood with a probit link function.
6363

64-
# %% vscode={"languageId": "python"}
64+
# %%
6565
kernel = gpx.RBF()
6666
prior = gpx.Prior(kernel=kernel)
6767
likelihood = gpx.Bernoulli(num_datapoints=D.n)
6868

6969
# %% [markdown]
7070
# We construct the posterior through the product of our prior and likelihood.
7171

72-
# %% vscode={"languageId": "python"}
72+
# %%
7373
posterior = prior * likelihood
7474
print(type(posterior))
7575

@@ -79,7 +79,7 @@
7979
# %% [markdown]
8080
# To begin we obtain an initial parameter state through the `initialise` callable (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We can obtain a MAP estimate by optimising the marginal log-likelihood with Optax's optimisers.
8181

82-
# %% vscode={"languageId": "python"}
82+
# %%
8383
parameter_state = gpx.initialise(posterior)
8484
negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))
8585

@@ -97,7 +97,7 @@
9797
# %% [markdown]
9898
# From which we can make predictions at novel inputs, as illustrated below.
9999

100-
# %% vscode={"languageId": "python"}
100+
# %%
101101
map_latent_dist = posterior(map_estimate, D)(xtest)
102102

103103
predictive_dist = likelihood(map_estimate, map_latent_dist)
@@ -158,15 +158,15 @@
158158
#
159159
# that we identify as a Gaussian distribution, $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below.
160160

161-
# %% vscode={"languageId": "python"}
161+
# %%
162162
gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
163163
jitter = 1e-6
164164

165165
# Compute (latent) function value map estimates at training points:
166-
Kxx = gram(kernel, map_estimate["kernel"], x)
166+
Kxx = gram(map_estimate["kernel"], x)
167167
Kxx += I(D.n) * jitter
168-
Lx = Kxx.triangular_lower()
169-
f_hat = jnp.matmul(Lx, map_estimate["latent"])
168+
Lx = Kxx.to_root()
169+
f_hat = Lx @ map_estimate["latent"]
170170

171171
# Negative Hessian, H = -∇²p_tilde(y|f):
172172
H = jax.jacfwd(jax.jacrev(negative_mll))(map_estimate)["latent"]["latent"][:, 0, :, 0]
@@ -190,21 +190,21 @@
190190
#
191191
# This is the same approximate distribution $q_{map}(f(\cdot))$, but we have pertubed the covariance by a curvature term of $\mathbf{K}_{\boldsymbol{(\cdot)\boldsymbol{x}}} \mathbf{K}_{\boldsymbol{xx}}^{-1} [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} \mathbf{K}_{\boldsymbol{xx}}^{-1} \mathbf{K}_{\boldsymbol{\boldsymbol{x}(\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\cdot))$.
192192

193-
# %% vscode={"languageId": "python"}
193+
# %%
194194
def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri:
195195

196196
map_latent_dist = posterior(map_estimate, D)(test_inputs)
197197

198-
Kxt = cross_covariance(kernel, map_estimate["kernel"], x, test_inputs)
199-
Kxx = gram(kernel, map_estimate["kernel"], x)
198+
Kxt = cross_covariance(map_estimate["kernel"], x, test_inputs)
199+
Kxx = gram(map_estimate["kernel"], x)
200200
Kxx += I(D.n) * jitter
201-
Lx = Kxx.triangular_lower()
201+
Lx = Kxx.to_root()
202202

203203
# Lx⁻¹ Kxt
204-
Lx_inv_Ktx = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)
204+
Lx_inv_Ktx = Lx.solve(Kxt)
205205

206206
# Kxx⁻¹ Kxt
207-
Kxx_inv_Ktx = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Ktx, lower=False)
207+
Kxx_inv_Ktx = Lx.T.solve(Lx_inv_Ktx)
208208

209209
# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
210210
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Ktx.T, H_inv), Kxx_inv_Ktx)
@@ -217,7 +217,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal
217217

218218
# %% [markdown]
219219
# From this we can construct the predictive distribution at the test points.
220-
# %% vscode={"languageId": "python"}
220+
# %%
221221
laplace_latent_dist = construct_laplace(xtest)
222222
predictive_dist = likelihood(map_estimate, laplace_latent_dist)
223223

@@ -267,7 +267,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal
267267
#
268268
# We begin by generating _sensible_ initial positions for our sampler before defining an inference loop and sampling 500 values from our Markov chain. In practice, drawing more samples will be necessary.
269269

270-
# %% vscode={"languageId": "python"}
270+
# %%
271271
# Adapted from BlackJax's introduction notebook.
272272
num_adapt = 500
273273
num_samples = 500
@@ -304,14 +304,14 @@ def one_step(state, rng_key):
304304
#
305305
# BlackJax gives us easy access to our sampler's efficiency through metrics such as the sampler's _acceptance probability_ (the number of times that our chain accepted a proposed sample, divided by the total number of steps run by the chain). For NUTS and Hamiltonian Monte Carlo sampling, we typically seek an acceptance rate of 60-70% to strike the right balance between having a chain which is _stuck_ and rarely moves versus a chain that is too jumpy with frequent small steps.
306306

307-
# %% vscode={"languageId": "python"}
307+
# %%
308308
acceptance_rate = jnp.mean(infos.acceptance_probability)
309309
print(f"Acceptance rate: {acceptance_rate:.2f}")
310310

311311
# %% [markdown]
312312
# Our acceptance rate is slightly too large, prompting an examination of the chain's trace plots. A well-mixing chain will have very few (if any) flat spots in its trace plot whilst also not having too many steps in the same direction. In addition to the model's hyperparameters, there will be 500 samples for each of the 100 latent function values in the `states.position` dictionary. We depict the chains that correspond to the model hyperparameters and the first value of the latent function for brevity.
313313

314-
# %% vscode={"languageId": "python"}
314+
# %%
315315
fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(15, 5), tight_layout=True)
316316
ax0.plot(states.position["kernel"]["lengthscale"])
317317
ax1.plot(states.position["kernel"]["variance"])
@@ -327,7 +327,7 @@ def one_step(state, rng_key):
327327
#
328328
# An ideal Markov chain would have samples completely uncorrelated with their neighbours after a single lag. However, in practice, correlations often exist within our chain's sample set. A commonly used technique to try and reduce this correlation is _thinning_ whereby we select every $n$th sample where $n$ is the minimum lag length at which we believe the samples are uncorrelated. Although further analysis of the chain's autocorrelation is required to find appropriate thinning factors, we employ a thin factor of 10 for demonstration purposes.
329329

330-
# %% vscode={"languageId": "python"}
330+
# %%
331331
thin_factor = 10
332332
samples = []
333333

@@ -351,7 +351,7 @@ def one_step(state, rng_key):
351351
#
352352
# Finally, we end this tutorial by plotting the predictions obtained from our model against the observed data.
353353

354-
# %% vscode={"languageId": "python"}
354+
# %%
355355
fig, ax = plt.subplots(figsize=(16, 5), tight_layout=True)
356356
ax.plot(
357357
x, y, "o", markersize=5, color="tab:red", label="Observations", zorder=2, alpha=0.7
@@ -371,6 +371,6 @@ def one_step(state, rng_key):
371371
# %% [markdown]
372372
# ## System configuration
373373

374-
# %% vscode={"languageId": "python"}
374+
# %%
375375
# %load_ext watermark
376376
# %watermark -n -u -v -iv -w -a "Thomas Pinder & Daniel Dodd"

examples/graph_kernels.pct.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# format_version: '1.3'
1010
# jupytext_version: 1.11.2
1111
# kernelspec:
12-
# display_name: base
12+
# display_name: Python 3.9.7 ('gpjax')
1313
# language: python
1414
# name: python3
1515
# ---
@@ -55,7 +55,7 @@
5555

5656
pos = nx.spring_layout(G, seed=123) # positions for all nodes
5757

58-
nx.draw(G, pos, node_color="tab:blue", with_labels=False, alpha=0.5)
58+
nx.draw(G) # , pos, node_color="tab:blue", with_labels=False, alpha=0.5)
5959

6060
# %% [markdown]
6161
#
@@ -95,6 +95,12 @@
9595

9696
D = gpx.Dataset(X=x, y=y)
9797

98+
# %%
99+
kernel.compute_engine.gram
100+
101+
# %%
102+
kernel.gram(params=kernel._initialise_params(key), inputs=x)
103+
98104
# %% [markdown]
99105
#
100106
# We can visualise this signal in the following cell.

examples/haiku.pct.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# format_version: '1.3'
1010
# jupytext_version: 1.11.2
1111
# kernelspec:
12-
# display_name: base
12+
# display_name: Python 3.9.7 ('gpjax')
1313
# language: python
1414
# name: python3
1515
# ---
@@ -28,15 +28,18 @@
2828
import jax.random as jr
2929
import matplotlib.pyplot as plt
3030
import optax as ox
31-
from chex import dataclass
3231
from jax.config import config
3332
from scipy.signal import sawtooth
3433
from jaxtyping import Float, Array
3534
from typing import Dict
3635

3736

3837
import gpjax as gpx
39-
from gpjax.kernels import DenseKernelComputation, AbstractKernel
38+
from gpjax.kernels import (
39+
DenseKernelComputation,
40+
AbstractKernelComputation,
41+
AbstractKernel,
42+
)
4043
from gpjax.types import PRNGKeyType
4144

4245
# Enable Float64 for more stable matrix inversions.
@@ -79,16 +82,23 @@
7982
# Although deep kernels are not currently supported natively in GPJax, defining one is straightforward as we now demonstrate. Using the base `AbstractKernel` object given in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the user supplying the neural network and base kernel of their choice. Kernel matrices are then computed using the regular `gram` and `cross_covariance` functions.
8083

8184
# %%
82-
@dataclass
83-
class _DeepKernelFunction:
84-
network: hk.Module
85-
base_kernel: AbstractKernel
86-
85+
class DeepKernelFunction(AbstractKernel):
86+
def __init__(
87+
self,
88+
network: hk.Module,
89+
base_kernel: AbstractKernel,
90+
compute_engine: AbstractKernelComputation = DenseKernelComputation,
91+
active_dims: tp.Optional[tp.List[int]] = None,
92+
) -> None:
93+
super().__init__(compute_engine, active_dims, True, False, "Deep Kernel")
94+
self.network = network
95+
self.base_kernel = base_kernel
8796

88-
@dataclass
89-
class DeepKernelFunction(AbstractKernel, DenseKernelComputation, _DeepKernelFunction):
9097
def __call__(
91-
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"],
98+
self,
99+
params: Dict,
100+
x: Float[Array, "1 D"],
101+
y: Float[Array, "1 D"],
92102
) -> Float[Array, "1"]:
93103
xt = self.network.apply(params=params, x=x)
94104
yt = self.network.apply(params=params, x=y)

0 commit comments

Comments
 (0)