|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Technical Details\n", |
| 8 | + "\n", |
| 9 | + "As with all packages, there are numerous technical details that are abstracted away from the user. Now in order to ensure a clean interface, this abstraction is entirely necessary. However, it can sometimes be confusing when navigating a package's source code to pin down what's going on when there's so many _under the hood_ operations taking place. In this notebook I'll aim to shed some light on all of tricks that we do in GPJax in order to help elucidate the code to anyone wishing to extend GPJax for their own uses.\n" |
| 10 | + ] |
| 11 | + }, |
| 12 | + { |
| 13 | + "cell_type": "markdown", |
| 14 | + "metadata": {}, |
| 15 | + "source": [ |
| 16 | + "## Parameter Transformations" |
| 17 | + ] |
| 18 | + }, |
| 19 | + { |
| 20 | + "cell_type": "markdown", |
| 21 | + "metadata": {}, |
| 22 | + "source": [ |
| 23 | + "### Motivations\n", |
| 24 | + "\n", |
| 25 | + "Many parameters in a Gaussian process are what we call a _constrained parameter_. By this, we mean that the parameters value is only defined on a subset of $\\mathbb{R}$. One example of this is the lengthscale parameter in any of the stationary kernels. It would not make sense to have a negative lengthscale, and as such the parameter's value is constrained to exist only on the positive real line. \n", |
| 26 | + "\n", |
| 27 | + "Whilst mathematically correct, constrained parameters can become a pain when optimising as many optimisers are designed to operate on an unconstrained space. Further, it can often be computationally inefficient to restrict the search space of an optimiser. For these reasons, we instead transform the constrained parameter to exist in an unconstrained space. Optimisation is then done on this unconstrained parameter before we transform it back when we need to evaluate its value. \n", |
| 28 | + "\n", |
| 29 | + "Only bijective transformations are valid as we cannot afford to lose our original parameter value when transforming. As such, we have to be careful about which transformations we use. Some common choices include the log-exponential bijection and the softplus transform. We, by default, opt for the softplus transformation in GPJax as it less prone to overflowing in comparison to log-exp transformations.\n" |
| 30 | + ] |
| 31 | + }, |
| 32 | + { |
| 33 | + "cell_type": "markdown", |
| 34 | + "metadata": {}, |
| 35 | + "source": [ |
| 36 | + "### Implementation\n", |
| 37 | + "\n", |
| 38 | + "When it comes to implementations, we attach the transformation directly to the `Parameter` class. It is an optional argument that one can specify when instantiating their parameter. To see this, simply consider the following example" |
| 39 | + ] |
| 40 | + }, |
| 41 | + { |
| 42 | + "cell_type": "code", |
| 43 | + "execution_count": 6, |
| 44 | + "metadata": {}, |
| 45 | + "outputs": [], |
| 46 | + "source": [ |
| 47 | + "from gpjax.parameters import Parameter\n", |
| 48 | + "from gpjax.transforms import Softplus\n", |
| 49 | + "import jax.numpy as jnp\n", |
| 50 | + "\n", |
| 51 | + "x = Parameter(jnp.array(1.0), transform=Softplus())" |
| 52 | + ] |
| 53 | + }, |
| 54 | + { |
| 55 | + "cell_type": "markdown", |
| 56 | + "metadata": {}, |
| 57 | + "source": [ |
| 58 | + "Now we know that the softplus transformation operation on an input $x \\in \\mathbb{R}_{>0}$ can be written as \n", |
| 59 | + "$$\\alpha(x) = \\log(\\exp(x)-1)$$\n", |
| 60 | + "where $\\alpha(x) \\in \\mathbb{R}$. In this instance, it can be seen that $\\alpha(1)=0.54$. Now this unconstrained value is stored within the parameter's `value` property" |
| 61 | + ] |
| 62 | + }, |
| 63 | + { |
| 64 | + "cell_type": "code", |
| 65 | + "execution_count": 7, |
| 66 | + "metadata": {}, |
| 67 | + "outputs": [ |
| 68 | + { |
| 69 | + "name": "stdout", |
| 70 | + "output_type": "stream", |
| 71 | + "text": [ |
| 72 | + "0.541324854612918\n" |
| 73 | + ] |
| 74 | + } |
| 75 | + ], |
| 76 | + "source": [ |
| 77 | + "print(x.value)" |
| 78 | + ] |
| 79 | + }, |
| 80 | + { |
| 81 | + "cell_type": "markdown", |
| 82 | + "metadata": {}, |
| 83 | + "source": [ |
| 84 | + "whilst the original constrained value can be computed by accesing the parameter's `untransform` property" |
| 85 | + ] |
| 86 | + }, |
| 87 | + { |
| 88 | + "cell_type": "code", |
| 89 | + "execution_count": 8, |
| 90 | + "metadata": {}, |
| 91 | + "outputs": [ |
| 92 | + { |
| 93 | + "name": "stdout", |
| 94 | + "output_type": "stream", |
| 95 | + "text": [ |
| 96 | + "1.0\n" |
| 97 | + ] |
| 98 | + } |
| 99 | + ], |
| 100 | + "source": [ |
| 101 | + "print(x.untransform)" |
| 102 | + ] |
| 103 | + }, |
| 104 | + { |
| 105 | + "cell_type": "markdown", |
| 106 | + "metadata": {}, |
| 107 | + "source": [ |
| 108 | + "### Custom transformation\n", |
| 109 | + "\n", |
| 110 | + "Should you wish to define your own custom transformation, then this can very easily be done by simply extending the `Transform` class within `gpjax.transforms` and defining a forward transformation and a backward transformation." |
| 111 | + ] |
| 112 | + }, |
| 113 | + { |
| 114 | + "cell_type": "code", |
| 115 | + "execution_count": 9, |
| 116 | + "metadata": {}, |
| 117 | + "outputs": [], |
| 118 | + "source": [ |
| 119 | + "class Transform:\n", |
| 120 | + " def __init__(self, name=\"Transformation\"):\n", |
| 121 | + " self.name = name\n", |
| 122 | + "\n", |
| 123 | + " @staticmethod\n", |
| 124 | + " def forward(x: jnp.ndarray) -> jnp.ndarray:\n", |
| 125 | + " raise NotImplementedError\n", |
| 126 | + "\n", |
| 127 | + " @staticmethod\n", |
| 128 | + " def backward(x: jnp.ndarray) -> jnp.ndarray:\n", |
| 129 | + " raise NotImplementedError" |
| 130 | + ] |
| 131 | + }, |
| 132 | + { |
| 133 | + "cell_type": "markdown", |
| 134 | + "metadata": {}, |
| 135 | + "source": [ |
| 136 | + "The `forward` method is the transformation that maps from a constrained space to an unconstrained space, whilst the `backward` method is the transformation that reverses this. A nice example of this can be seen for the earlier used softplus transformation" |
| 137 | + ] |
| 138 | + }, |
| 139 | + { |
| 140 | + "cell_type": "code", |
| 141 | + "execution_count": 10, |
| 142 | + "metadata": {}, |
| 143 | + "outputs": [], |
| 144 | + "source": [ |
| 145 | + "from jax.nn import softplus\n", |
| 146 | + "\n", |
| 147 | + "class Softplus(Transform):\n", |
| 148 | + " def __init__(self):\n", |
| 149 | + " super().__init__(name='Softplus')\n", |
| 150 | + "\n", |
| 151 | + " @staticmethod\n", |
| 152 | + " def forward(x: jnp.ndarray) -> jnp.ndarray:\n", |
| 153 | + " return jnp.log(jnp.exp(x) - 1.)\n", |
| 154 | + "\n", |
| 155 | + " @staticmethod\n", |
| 156 | + " def backward(x: jnp.ndarray) -> jnp.ndarray:\n", |
| 157 | + " return softplus(x)" |
| 158 | + ] |
| 159 | + }, |
| 160 | + { |
| 161 | + "cell_type": "markdown", |
| 162 | + "metadata": {}, |
| 163 | + "source": [ |
| 164 | + "## Prior distributions" |
| 165 | + ] |
| 166 | + }, |
| 167 | + { |
| 168 | + "cell_type": "markdown", |
| 169 | + "metadata": {}, |
| 170 | + "source": [ |
| 171 | + "### Motivations\n", |
| 172 | + "\n", |
| 173 | + "Often when we use Gaussian processes, we do so as they facilitate easily incorporation of prior information into the model. Implicitly, by the very use of a Gaussian process we are incorporating our prior inforation around the functional behaviour of the latent function that we are seeking to recover. However, we can take this one step further by placing priors on the hyperparameters of the Gaussian process. Going into the details of which priors are recommended and how to go about selecting them goes beyond the scope of this article, but it's suffice to say that doing so can greatly enhance the utility of a Gaussian process. \n", |
| 174 | + "\n", |
| 175 | + "At least in my own experience, when priors are placed on the hyperparameters of a Gaussian process they are specified with respect to the constrained parameter value. As an example of this, consider the lengthscale parameter $\\ell \\in \\mathbb{R}_{>0}$. When specifying a prior distribution $p_{0}(\\ell)$, I would typically select a distribution that has support on the positive real line, such as the Gamma distribution. An opposing approach would be to transform the parameter so that it is defined on the entire real line (as discussed in §1) and then specify a prior distribution such as a Gaussian that has an unconstrained support. Deciding which of these two approaches to adopt in GPJax is somewhat a moot point to me, so I've opted for priors to be defined on the constrained parameter. That being said, I'd be more than open to altering this opinion is people felt strongly that priors should be defined on the unconstrained parameter value." |
| 176 | + ] |
| 177 | + }, |
| 178 | + { |
| 179 | + "cell_type": "markdown", |
| 180 | + "metadata": {}, |
| 181 | + "source": [ |
| 182 | + "### Implementation\n", |
| 183 | + "\n", |
| 184 | + "Regarding the implementational details of enabling prior specification, this is hopefully a more lucid concept upon code inspection. As with the earlier discussed parameter transformations, the notion of a prior distribution is acknolwedged in the definition of a parameter. To exactly specify a prior distribution, one should simply call in the relevant distribution from TensorFlow probability's distributions module. For an example of this, consider the parameter `x` that was earlier defined." |
| 185 | + ] |
| 186 | + }, |
| 187 | + { |
| 188 | + "cell_type": "code", |
| 189 | + "execution_count": 22, |
| 190 | + "metadata": {}, |
| 191 | + "outputs": [], |
| 192 | + "source": [ |
| 193 | + "from tensorflow_probability.substrates.jax import distributions as tfd\n", |
| 194 | + "\n", |
| 195 | + "x.prior = tfd.Gamma(concentration = 3., rate = 2.)" |
| 196 | + ] |
| 197 | + }, |
| 198 | + { |
| 199 | + "cell_type": "markdown", |
| 200 | + "metadata": {}, |
| 201 | + "source": [ |
| 202 | + "If we momentarily pause to consider the state of this parameter now, then we have a constrained parameter value with corresponding prior distribution. When it comes to deriving our posterior distribution, then we know that it is proportional to the product of the likelihood and the prior density function. As addition is less prone to numerical overflow than multiplication, we take the log of this produce. The log of a product is just a sum of logs, meaning that our log-posterior is then proportional to the sum of our log-likelihood and the log-prior density. Therefore, to connect the value of our parameter and its respective prior distribution, the only implementational point left to cover is how to evaluate the parameters log-prior density. This can be done through the following `@property`" |
| 203 | + ] |
| 204 | + }, |
| 205 | + { |
| 206 | + "cell_type": "code", |
| 207 | + "execution_count": 24, |
| 208 | + "metadata": {}, |
| 209 | + "outputs": [ |
| 210 | + { |
| 211 | + "name": "stdout", |
| 212 | + "output_type": "stream", |
| 213 | + "text": [ |
| 214 | + "-0.613706111907959\n" |
| 215 | + ] |
| 216 | + } |
| 217 | + ], |
| 218 | + "source": [ |
| 219 | + "print(x.log_density)" |
| 220 | + ] |
| 221 | + }, |
| 222 | + { |
| 223 | + "cell_type": "markdown", |
| 224 | + "metadata": {}, |
| 225 | + "source": [ |
| 226 | + "Naturally, should one wish to evaluate the prior density of the parameter, then the exponent can be taken" |
| 227 | + ] |
| 228 | + }, |
| 229 | + { |
| 230 | + "cell_type": "code", |
| 231 | + "execution_count": 26, |
| 232 | + "metadata": {}, |
| 233 | + "outputs": [ |
| 234 | + { |
| 235 | + "name": "stdout", |
| 236 | + "output_type": "stream", |
| 237 | + "text": [ |
| 238 | + "0.5413408768770793\n" |
| 239 | + ] |
| 240 | + } |
| 241 | + ], |
| 242 | + "source": [ |
| 243 | + "print(jnp.exp(x.log_density))" |
| 244 | + ] |
| 245 | + }, |
| 246 | + { |
| 247 | + "cell_type": "markdown", |
| 248 | + "metadata": {}, |
| 249 | + "source": [ |
| 250 | + "## Cholesky decomposition\n", |
| 251 | + "\n" |
| 252 | + ] |
| 253 | + }, |
| 254 | + { |
| 255 | + "cell_type": "code", |
| 256 | + "execution_count": null, |
| 257 | + "metadata": {}, |
| 258 | + "outputs": [], |
| 259 | + "source": [] |
| 260 | + } |
| 261 | + ], |
| 262 | + "metadata": { |
| 263 | + "kernelspec": { |
| 264 | + "display_name": "gpblocks", |
| 265 | + "language": "python", |
| 266 | + "name": "gpblocks" |
| 267 | + }, |
| 268 | + "language_info": { |
| 269 | + "codemirror_mode": { |
| 270 | + "name": "ipython", |
| 271 | + "version": 3 |
| 272 | + }, |
| 273 | + "file_extension": ".py", |
| 274 | + "mimetype": "text/x-python", |
| 275 | + "name": "python", |
| 276 | + "nbconvert_exporter": "python", |
| 277 | + "pygments_lexer": "ipython3", |
| 278 | + "version": "3.8.5" |
| 279 | + }, |
| 280 | + "toc-autonumbering": false, |
| 281 | + "toc-showcode": false, |
| 282 | + "toc-showmarkdowntxt": false, |
| 283 | + "toc-showtags": false |
| 284 | + }, |
| 285 | + "nbformat": 4, |
| 286 | + "nbformat_minor": 4 |
| 287 | +} |
0 commit comments