Skip to content

Commit 99ef16f

Browse files
authored
Merge pull request #100 from thomaspinder/revamped_init
Revamped init
2 parents b1142ab + 48b90d9 commit 99ef16f

32 files changed

+525
-466
lines changed

.github/workflows/workflow-master.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ on:
99
jobs:
1010
codecov:
1111
name: Codecov Workflow
12-
runs-on: ubuntu-18.04
12+
runs-on: ubuntu-22.04
1313

1414
steps:
1515
- uses: actions/checkout@v1

docs/nbs/barycentres.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,19 @@
120120
" D = gpx.Dataset(X=x, y=y)\n",
121121
" likelihood = gpx.Gaussian(num_datapoints=n)\n",
122122
" posterior = gpx.Prior(kernel=gpx.RBF()) * likelihood\n",
123-
" params, trainables, constrainers, unconstrainers = gpx.initialise(posterior)\n",
123+
" params, trainables, constrainers, unconstrainers = gpx.initialise(posterior, key).unpack()\n",
124124
" params = gpx.transform(params, unconstrainers)\n",
125125
"\n",
126126
" objective = jax.jit(posterior.marginal_log_likelihood(D, constrainers, negative=True))\n",
127127
"\n",
128128
" opt = ox.adam(learning_rate=0.01)\n",
129-
" learned_params = gpx.fit(\n",
129+
" learned_params, training_history = gpx.fit(\n",
130130
" objective=objective,\n",
131131
" trainables=trainables,\n",
132132
" params=params,\n",
133133
" optax_optim=opt,\n",
134134
" n_iters=1000,\n",
135-
" )\n",
135+
" ).unpack()\n",
136136
" learned_params = gpx.transform(learned_params, constrainers)\n",
137137
" return likelihood(posterior(D, learned_params)(xtest), learned_params)\n",
138138
"\n",

docs/nbs/classification.ipynb

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@
140140
"metadata": {},
141141
"outputs": [],
142142
"source": [
143-
"params, trainable, constrainer, unconstrainer = gpx.initialise(posterior)\n",
143+
"parameter_state = gpx.initialise(posterior)\n",
144+
"params, trainable, constrainer, unconstrainer = parameter_state.unpack()\n",
144145
"params = gpx.transform(params, unconstrainer)\n",
145146
"\n",
146147
"mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=True))"
@@ -166,13 +167,13 @@
166167
"outputs": [],
167168
"source": [
168169
"opt = ox.adam(learning_rate=0.01)\n",
169-
"unconstrained_params = gpx.fit(\n",
170+
"unconstrained_params, training_history = gpx.fit(\n",
170171
" mll,\n",
171172
" params,\n",
172173
" trainable,\n",
173174
" opt,\n",
174175
" n_iters=500,\n",
175-
")\n",
176+
").unpack()\n",
176177
"\n",
177178
"negative_Hessian = jax.jacfwd(jax.jacrev(mll))(unconstrained_params)[\"latent\"][\"latent\"][:,0,:,0]\n",
178179
"\n",
@@ -585,7 +586,7 @@
585586
"custom_cell_magics": "kql"
586587
},
587588
"kernelspec": {
588-
"display_name": "Python 3.8.9 64-bit",
589+
"display_name": "Python 3.9.7 ('gpjax')",
589590
"language": "python",
590591
"name": "python3"
591592
},
@@ -599,11 +600,11 @@
599600
"name": "python",
600601
"nbconvert_exporter": "python",
601602
"pygments_lexer": "ipython3",
602-
"version": "3.8.9"
603+
"version": "3.9.7"
603604
},
604605
"vscode": {
605606
"interpreter": {
606-
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
607+
"hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf"
607608
}
608609
}
609610
},

docs/nbs/collapsed_vi.ipynb

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
"from jax import jit\n",
3232
"\n",
3333
"import gpjax as gpx\n",
34-
"import tensorflow as tf\n",
3534
"\n",
36-
"tf.random.set_seed(42)\n",
3735
"key = jr.PRNGKey(123)"
3836
]
3937
},
@@ -188,21 +186,21 @@
188186
"metadata": {},
189187
"outputs": [],
190188
"source": [
191-
"params, trainables, constrainers, unconstrainers = gpx.initialise(sgpr)\n",
189+
"params, trainables, constrainers, unconstrainers = gpx.initialise(sgpr, key).unpack()\n",
192190
"\n",
193191
"loss_fn = jit(sgpr.elbo(D, constrainers, negative=True))\n",
194192
"\n",
195193
"optimiser = ox.adam(learning_rate=0.005)\n",
196194
"\n",
197195
"params = gpx.transform(params, unconstrainers)\n",
198196
"\n",
199-
"learned_params = gpx.fit(\n",
197+
"learned_params, training_history = gpx.fit(\n",
200198
" objective = loss_fn,\n",
201199
" params = params,\n",
202200
" trainables = trainables,\n",
203201
" optax_optim = optimiser,\n",
204202
" n_iters=2000,\n",
205-
")\n",
203+
").unpack()\n",
206204
"learned_params = gpx.transform(learned_params, constrainers)"
207205
]
208206
},
@@ -270,7 +268,7 @@
270268
"outputs": [],
271269
"source": [
272270
"full_rank_model = gpx.Prior(kernel = gpx.RBF()) * gpx.Gaussian(num_datapoints=D.n)\n",
273-
"fr_params, fr_trainables, fr_constrainers, fr_unconstrainers = gpx.initialise(full_rank_model)\n",
271+
"fr_params, fr_trainables, fr_constrainers, fr_unconstrainers = gpx.initialise(full_rank_model, key).unpack()\n",
274272
"fr_params = gpx.transform(fr_params, fr_unconstrainers)\n",
275273
"mll = jit(full_rank_model.marginal_log_likelihood(D, fr_constrainers, negative=True))\n",
276274
"%timeit mll(fr_params).block_until_ready()"

docs/nbs/graph_kernels.ipynb

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
"kernel = gpx.GraphKernel(laplacian=L)\n",
114114
"f = gpx.Prior(kernel=kernel)\n",
115115
"\n",
116-
"true_params = f.params\n",
116+
"true_params = f._initialise_params(key)\n",
117117
"true_params[\"kernel\"] = {\n",
118118
" \"lengthscale\": jnp.array(2.3),\n",
119119
" \"variance\": jnp.array(3.2),\n",
@@ -174,21 +174,21 @@
174174
"source": [
175175
"likelihood = gpx.Gaussian(num_datapoints=y.shape[0])\n",
176176
"posterior = f * likelihood\n",
177-
"params, trainable, constrainer, unconstrainer = gpx.initialise(posterior)\n",
177+
"params, trainable, constrainer, unconstrainer = gpx.initialise(posterior, key).unpack()\n",
178178
"params = gpx.transform(params, unconstrainer)\n",
179179
"\n",
180180
"mll = jit(\n",
181181
" posterior.marginal_log_likelihood(train_data=D, transformations=constrainer, negative=True)\n",
182182
")\n",
183183
"\n",
184184
"opt = ox.adam(learning_rate=0.01)\n",
185-
"learned_params = gpx.fit(\n",
185+
"learned_params, training_history = gpx.fit(\n",
186186
" objective=mll,\n",
187187
" params=params,\n",
188188
" trainables=trainable,\n",
189189
" optax_optim=opt,\n",
190190
" n_iters=1000,\n",
191-
")\n",
191+
").unpack()\n",
192192
"learned_params = gpx.transform(learned_params, constrainer)"
193193
]
194194
},
@@ -297,6 +297,23 @@
297297
"display_name": "Python 3.9.7 ('gpjax')",
298298
"language": "python",
299299
"name": "python3"
300+
},
301+
"language_info": {
302+
"codemirror_mode": {
303+
"name": "ipython",
304+
"version": 3
305+
},
306+
"file_extension": ".py",
307+
"mimetype": "text/x-python",
308+
"name": "python",
309+
"nbconvert_exporter": "python",
310+
"pygments_lexer": "ipython3",
311+
"version": "3.9.7"
312+
},
313+
"vscode": {
314+
"interpreter": {
315+
"hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf"
316+
}
300317
}
301318
},
302319
"nbformat": 4,

docs/nbs/haiku.ipynb

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,10 @@
109109
"\n",
110110
" def initialise(self, dummy_x, key):\n",
111111
" nn_params = self.network.init(rng=key, x=dummy_x)\n",
112-
" base_kernel_params = self.base_kernel.params\n",
112+
" base_kernel_params = self.base_kernel._initialise_params(key)\n",
113113
" self._params = {**nn_params, **base_kernel_params}\n",
114114
"\n",
115-
" @property\n",
116-
" def params(self):\n",
115+
" def _initialise_params(self, key):\n",
117116
" return self._params"
118117
]
119118
},
@@ -176,7 +175,7 @@
176175
"likelihood = gpx.Gaussian(num_datapoints=D.n)\n",
177176
"posterior = prior * likelihood\n",
178177
"\n",
179-
"params, trainables, constrainers, unconstrainers = gpx.initialise(posterior)\n",
178+
"params, trainables, constrainers, unconstrainers = gpx.initialise(posterior, key).unpack()\n",
180179
"params = gpx.transform(params, unconstrainers)"
181180
]
182181
},
@@ -217,13 +216,13 @@
217216
" ox.adamw(learning_rate=schedule),\n",
218217
")\n",
219218
"\n",
220-
"final_params = gpx.fit(\n",
219+
"final_params, training_history = gpx.fit(\n",
221220
" mll,\n",
222221
" params,\n",
223222
" trainables,\n",
224223
" opt,\n",
225224
" n_iters=5000,\n",
226-
")\n",
225+
").unpack()\n",
227226
"final_params = gpx.transform(final_params, constrainers)"
228227
]
229228
},
@@ -294,6 +293,23 @@
294293
"display_name": "Python 3.9.7 ('gpjax')",
295294
"language": "python",
296295
"name": "python3"
296+
},
297+
"language_info": {
298+
"codemirror_mode": {
299+
"name": "ipython",
300+
"version": 3
301+
},
302+
"file_extension": ".py",
303+
"mimetype": "text/x-python",
304+
"name": "python",
305+
"nbconvert_exporter": "python",
306+
"pygments_lexer": "ipython3",
307+
"version": "3.9.7"
308+
},
309+
"vscode": {
310+
"interpreter": {
311+
"hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf"
312+
}
297313
}
298314
},
299315
"nbformat": 4,

docs/nbs/kernels.ipynb

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"\n",
7171
"for k, ax in zip(kernels, axes.ravel()):\n",
7272
" prior = gpx.Prior(kernel=k)\n",
73-
" params, _, _, _ = gpx.initialise(prior)\n",
73+
" params, _, _, _ = gpx.initialise(prior, key).unpack()\n",
7474
" rv = prior(params)(x)\n",
7575
" y = rv.sample(sample_shape=10, seed=key)\n",
7676
"\n",
@@ -117,7 +117,7 @@
117117
"outputs": [],
118118
"source": [
119119
"print(f\"ARD: {slice_kernel.ard}\")\n",
120-
"print(f\"Lengthscales: {slice_kernel.params['lengthscale']}\")"
120+
"print(f\"Lengthscales: {slice_kernel._initialise_params(key)['lengthscale']}\")"
121121
]
122122
},
123123
{
@@ -136,7 +136,7 @@
136136
"outputs": [],
137137
"source": [
138138
"x_matrix = jr.normal(key, shape=(50, 5))\n",
139-
"K = gpx.kernels.gram(slice_kernel, x_matrix, slice_kernel.params)\n",
139+
"K = gpx.kernels.gram(slice_kernel, x_matrix, slice_kernel._initialise_params(key))\n",
140140
"print(K.shape)"
141141
]
142142
},
@@ -160,14 +160,13 @@
160160
"outputs": [],
161161
"source": [
162162
"k1 = gpx.RBF()\n",
163-
"k1._params = {\"lengthscale\": jnp.array(1.0), \"variance\": jnp.array(10.0)}\n",
164163
"k2 = gpx.Polynomial()\n",
165164
"sum_k = k1 + k2\n",
166165
"\n",
167166
"fig, ax = plt.subplots(ncols=3, figsize=(20, 5))\n",
168-
"im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1.params))\n",
169-
"im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2.params))\n",
170-
"im2 = ax[2].matshow(gpx.kernels.gram(sum_k, x, sum_k.params))\n",
167+
"im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1._initialise_params(key)))\n",
168+
"im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2._initialise_params(key)))\n",
169+
"im2 = ax[2].matshow(gpx.kernels.gram(sum_k, x, sum_k._initialise_params(key)))\n",
171170
"\n",
172171
"fig.colorbar(im0, ax=ax[0])\n",
173172
"fig.colorbar(im1, ax=ax[1])\n",
@@ -194,10 +193,10 @@
194193
"prod_k = k1 * k2 * k3\n",
195194
"\n",
196195
"fig, ax = plt.subplots(ncols=4, figsize=(20, 5))\n",
197-
"im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1.params))\n",
198-
"im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2.params))\n",
199-
"im2 = ax[2].matshow(gpx.kernels.gram(k3, x, k3.params))\n",
200-
"im3 = ax[3].matshow(gpx.kernels.gram(prod_k, x, prod_k.params))\n",
196+
"im0 = ax[0].matshow(gpx.kernels.gram(k1, x, k1._initialise_params(key)))\n",
197+
"im1 = ax[1].matshow(gpx.kernels.gram(k2, x, k2._initialise_params(key)))\n",
198+
"im2 = ax[2].matshow(gpx.kernels.gram(k3, x, k3._initialise_params(key)))\n",
199+
"im3 = ax[3].matshow(gpx.kernels.gram(prod_k, x, prod_k._initialise_params(key)))\n",
201200
"\n",
202201
"fig.colorbar(im0, ax=ax[0])\n",
203202
"fig.colorbar(im1, ax=ax[1])\n",
@@ -362,17 +361,17 @@
362361
"circlular_posterior = gpx.Prior(kernel=PKern) * likelihood\n",
363362
"\n",
364363
"# Initialise parameters and corresponding transformations\n",
365-
"params, trainable, constrainer, unconstrainer = gpx.initialise(circlular_posterior)\n",
364+
"params, trainable, constrainer, unconstrainer = gpx.initialise(circlular_posterior, key).unpack()\n",
366365
"\n",
367366
"# Optimise GP's marginal log-likelihood using Adam\n",
368367
"mll = jit(circlular_posterior.marginal_log_likelihood(D, constrainer, negative=True))\n",
369-
"learned_params = gpx.fit(\n",
368+
"learned_params, training_history = gpx.fit(\n",
370369
" mll,\n",
371370
" params,\n",
372371
" trainable,\n",
373372
" adam(learning_rate=0.05),\n",
374373
" n_iters=1000,\n",
375-
")\n",
374+
").unpack()\n",
376375
"\n",
377376
"# Untransform learned parameters\n",
378377
"final_params = gpx.transform(learned_params, constrainer)"

0 commit comments

Comments
 (0)