|
70 | 70 | "\n", |
71 | 71 | "for k, ax in zip(kernels, axes.ravel()):\n", |
72 | 72 | " prior = gpx.Prior(kernel=k)\n", |
73 | | - " params, _, _, _ = gpx.initialise(prior)\n", |
| 73 | + " params, _, _, _ = gpx.initialise(prior, key).unpack()\n", |
74 | 74 | " rv = prior(params)(x)\n", |
75 | 75 | " y = rv.sample(sample_shape=10, seed=key)\n", |
76 | 76 | "\n", |
|
117 | 117 | "outputs": [], |
118 | 118 | "source": [ |
119 | 119 | "print(f\"ARD: {slice_kernel.ard}\")\n", |
120 | | - "print(f\"Lengthscales: {slice_kernel.params['lengthscale']}\")" |
| 120 | + "print(f\"Lengthscales: {slice_kernel._initialise_params(key)['lengthscale']}\")" |
121 | 121 | ] |
122 | 122 | }, |
123 | 123 | { |
|
136 | 136 | "outputs": [], |
137 | 137 | "source": [ |
138 | 138 | "x_matrix = jr.normal(key, shape=(50, 5))\n", |
139 | | - "K = gpx.kernels.gram(slice_kernel, x_matrix, slice_kernel.params)\n", |
| 139 | + "K = gpx.kernels.gram(slice_kernel, x_matrix, slice_kernel._initialise_params(key))\n", |
140 | 140 | "print(K.shape)" |
141 | 141 | ] |
142 | 142 | }, |
|
160 | 160 | "outputs": [], |
161 | 161 | "source": [ |
162 | 162 | "k1 = gpx.RBF()\n", |
163 | | - "k1._params = {\"lengthscale\": jnp.array(1.0), \"variance\": jnp.array(10.0)}\n", |
164 | 163 | "k2 = gpx.Polynomial()\n", |
165 | 164 | "sum_k = k1 + k2\n", |
166 | 165 | "\n", |
167 | 166 | "fig, ax = plt.subplots(ncols=3, figsize=(20, 5))\n", |
168 | | - "im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1.params))\n", |
169 | | - "im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2.params))\n", |
170 | | - "im2 = ax[2].matshow(gpx.kernels.gram(sum_k, x, sum_k.params))\n", |
| 167 | + "im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1._initialise_params(key)))\n", |
| 168 | + "im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2._initialise_params(key)))\n", |
| 169 | + "im2 = ax[2].matshow(gpx.kernels.gram(sum_k, x, sum_k._initialise_params(key)))\n", |
171 | 170 | "\n", |
172 | 171 | "fig.colorbar(im0, ax=ax[0])\n", |
173 | 172 | "fig.colorbar(im1, ax=ax[1])\n", |
|
194 | 193 | "prod_k = k1 * k2 * k3\n", |
195 | 194 | "\n", |
196 | 195 | "fig, ax = plt.subplots(ncols=4, figsize=(20, 5))\n", |
197 | | - "im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1.params))\n", |
198 | | - "im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2.params))\n", |
199 | | - "im2 = ax[2].matshow(gpx.kernels.gram(k3, x, k3.params))\n", |
200 | | - "im3 = ax[3].matshow(gpx.kernels.gram(prod_k, x, prod_k.params))\n", |
| 196 | + "im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1._initialise_params(key)))\n", |
| 197 | + "im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2._initialise_params(key)))\n", |
| 198 | + "im2 = ax[2].matshow(gpx.kernels.gram(k3, x, k3._initialise_params(key)))\n", |
| 199 | + "im3 = ax[3].matshow(gpx.kernels.gram(prod_k, x, prod_k._initialise_params(key)))\n", |
201 | 200 | "\n", |
202 | 201 | "fig.colorbar(im0, ax=ax[0])\n", |
203 | 202 | "fig.colorbar(im1, ax=ax[1])\n", |
|
362 | 361 | "circlular_posterior = gpx.Prior(kernel=PKern) * likelihood\n", |
363 | 362 | "\n", |
364 | 363 | "# Initialise parameters and corresponding transformations\n", |
365 | | - "params, trainable, constrainer, unconstrainer = gpx.initialise(circlular_posterior)\n", |
| 364 | + "params, trainable, constrainer, unconstrainer = gpx.initialise(circlular_posterior, key).unpack()\n", |
366 | 365 | "\n", |
367 | 366 | "# Optimise GP's marginal log-likelihood using Adam\n", |
368 | 367 | "mll = jit(circlular_posterior.marginal_log_likelihood(D, constrainer, negative=True))\n", |
369 | | - "learned_params = gpx.fit(\n", |
| 368 | + "learned_params, training_history = gpx.fit(\n", |
370 | 369 | " mll,\n", |
371 | 370 | " params,\n", |
372 | 371 | " trainable,\n", |
373 | 372 | " adam(learning_rate=0.05),\n", |
374 | 373 | " n_iters=1000,\n", |
375 | | - ")\n", |
| 374 | + ").unpack()\n", |
376 | 375 | "\n", |
377 | 376 | "# Untransform learned parameters\n", |
378 | 377 | "final_params = gpx.transform(learned_params, constrainer)" |
|
0 commit comments