You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Merge pull request #148 from JaxGaussianProcesses/More_Kernels
This PR add additional kernels, and provides the notion of a "compute_engine" to perform kernel operations, that in future will build the foundation for alternative matrix solving algorithms, e.g., conjugate gradients.
Copy file name to clipboardExpand all lines: examples/classification.pct.py
+24-24Lines changed: 24 additions & 24 deletions
Original file line number
Diff line number
Diff line change
@@ -9,7 +9,7 @@
9
9
# format_version: '1.3'
10
10
# jupytext_version: 1.11.2
11
11
# kernelspec:
12
-
# display_name: base
12
+
# display_name: Python 3.9.7 ('gpjax')
13
13
# language: python
14
14
# name: python3
15
15
# ---
@@ -19,7 +19,7 @@
19
19
#
20
20
# In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling.
21
21
22
-
# %% vscode={"languageId": "python"}
22
+
# %%
23
23
importblackjax
24
24
importdistraxasdx
25
25
importjax
@@ -47,7 +47,7 @@
47
47
#
48
48
# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later.
# We begin by defining a Gaussian process prior with a radial basis function (RBF) kernel, chosen for the purpose of exposition. Since our observations are binary, we choose a Bernoulli likelihood with a probit link function.
63
63
64
-
# %% vscode={"languageId": "python"}
64
+
# %%
65
65
kernel=gpx.RBF()
66
66
prior=gpx.Prior(kernel=kernel)
67
67
likelihood=gpx.Bernoulli(num_datapoints=D.n)
68
68
69
69
# %% [markdown]
70
70
# We construct the posterior through the product of our prior and likelihood.
71
71
72
-
# %% vscode={"languageId": "python"}
72
+
# %%
73
73
posterior=prior*likelihood
74
74
print(type(posterior))
75
75
@@ -79,7 +79,7 @@
79
79
# %% [markdown]
80
80
# To begin we obtain an initial parameter state through the `initialise` callable (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We can obtain a MAP estimate by optimising the marginal log-likelihood with Optax's optimisers.
# that we identify as a Gaussian distribution, $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below.
# This is the same approximate distribution $q_{map}(f(\cdot))$, but we have pertubed the covariance by a curvature term of $\mathbf{K}_{\boldsymbol{(\cdot)\boldsymbol{x}}} \mathbf{K}_{\boldsymbol{xx}}^{-1} [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} \mathbf{K}_{\boldsymbol{xx}}^{-1} \mathbf{K}_{\boldsymbol{\boldsymbol{x}(\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\cdot))$.
# We begin by generating _sensible_ initial positions for our sampler before defining an inference loop and sampling 500 values from our Markov chain. In practice, drawing more samples will be necessary.
# BlackJax gives us easy access to our sampler's efficiency through metrics such as the sampler's _acceptance probability_ (the number of times that our chain accepted a proposed sample, divided by the total number of steps run by the chain). For NUTS and Hamiltonian Monte Carlo sampling, we typically seek an acceptance rate of 60-70% to strike the right balance between having a chain which is _stuck_ and rarely moves versus a chain that is too jumpy with frequent small steps.
# Our acceptance rate is slightly too large, prompting an examination of the chain's trace plots. A well-mixing chain will have very few (if any) flat spots in its trace plot whilst also not having too many steps in the same direction. In addition to the model's hyperparameters, there will be 500 samples for each of the 100 latent function values in the `states.position` dictionary. We depict the chains that correspond to the model hyperparameters and the first value of the latent function for brevity.
# An ideal Markov chain would have samples completely uncorrelated with their neighbours after a single lag. However, in practice, correlations often exist within our chain's sample set. A commonly used technique to try and reduce this correlation is _thinning_ whereby we select every $n$th sample where $n$ is the minimum lag length at which we believe the samples are uncorrelated. Although further analysis of the chain's autocorrelation is required to find appropriate thinning factors, we employ a thin factor of 10 for demonstration purposes.
329
329
330
-
# %% vscode={"languageId": "python"}
330
+
# %%
331
331
thin_factor=10
332
332
samples= []
333
333
@@ -351,7 +351,7 @@ def one_step(state, rng_key):
351
351
#
352
352
# Finally, we end this tutorial by plotting the predictions obtained from our model against the observed data.
# Enable Float64 for more stable matrix inversions.
@@ -79,16 +82,23 @@
79
82
# Although deep kernels are not currently supported natively in GPJax, defining one is straightforward as we now demonstrate. Using the base `AbstractKernel` object given in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the user supplying the neural network and base kernel of their choice. Kernel matrices are then computed using the regular `gram` and `cross_covariance` functions.
0 commit comments