Skip to content

Commit e103cd3

Browse files
authored
Transition to Numpyro backend (#506)
* Add Numpyro dist * Switch TFP for Numpyro * Drop unused imports * Update docs * Update examples * Bump version
1 parent c185787 commit e103cd3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+837
-631
lines changed

README.md

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -107,65 +107,6 @@ jupytext --to notebook example.py
107107
jupytext --to py:percent example.ipynb
108108
```
109109

110-
# Simple example
111-
112-
Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
113-
114-
```python
115-
from jax import config
116-
117-
config.update("jax_enable_x64", True)
118-
119-
import gpjax as gpx
120-
from jax import grad, jit
121-
import jax.numpy as jnp
122-
import jax.random as jr
123-
import optax as ox
124-
125-
key = jr.key(123)
126-
127-
f = lambda x: 10 * jnp.sin(x)
128-
129-
n = 50
130-
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
131-
y = f(x) + jr.normal(key, shape=(n,1))
132-
D = gpx.Dataset(X=x, y=y)
133-
134-
# Construct the prior
135-
meanf = gpx.mean_functions.Zero()
136-
kernel = gpx.kernels.RBF()
137-
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
138-
139-
# Define a likelihood
140-
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
141-
142-
# Construct the posterior
143-
posterior = prior * likelihood
144-
145-
# Define an optimiser
146-
optimiser = ox.adam(learning_rate=1e-2)
147-
148-
# Obtain Type 2 MLEs of the hyperparameters
149-
opt_posterior, history = gpx.fit(
150-
model=posterior,
151-
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
152-
train_data=D,
153-
optim=optimiser,
154-
num_iters=500,
155-
safe=True,
156-
key=key,
157-
)
158-
159-
# Infer the predictive posterior distribution
160-
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
161-
latent_dist = opt_posterior(xtest, D)
162-
predictive_dist = opt_posterior.likelihood(latent_dist)
163-
164-
# Obtain the predictive mean and standard deviation
165-
pred_mean = predictive_dist.mean()
166-
pred_std = predictive_dist.stddev()
167-
```
168-
169110
# Installation
170111

171112
## Stable version

docs/scripts/sharp_bits_figure.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@
6969
plt.savefig("../_static/step_size_figure.png", bbox_inches="tight")
7070

7171
# %%
72-
import tensorflow_probability.substrates.jax.bijectors as tfb
72+
import numpyro.distributions.transforms as npt
7373

74-
bij = tfb.Exp()
74+
bij = npt.ExpTransform()
7575

7676
x = np.linspace(0.05, 3.0, 6)
77-
y = np.asarray(bij.inverse(x))
77+
y = np.asarray(bij.inv(x))
7878
lval = 0.5
7979
rval = 0.52
8080

docs/sharp_bits.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ this value that we apply gradient updates to. When we wish to recover the constr
8080
value, we apply the inverse of the bijector, which is the exponential function in this
8181
case. This gives us back the blue cross.
8282

83-
In GPJax, we supply bijective functions using [Tensorflow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors).
83+
In GPJax, we supply bijective functions using [Numpyro](https://num.pyro.ai/en/stable/distributions.html#transforms).
8484

8585

8686
## Positive-definiteness

examples/barycentres.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# extension: .py
99
# format_name: percent
1010
# format_version: '1.3'
11-
# jupytext_version: 1.16.6
11+
# jupytext_version: 1.16.7
1212
# kernelspec:
1313
# display_name: gpjax
1414
# language: python
@@ -41,7 +41,7 @@
4141
import jax.scipy.linalg as jsl
4242
from jaxtyping import install_import_hook
4343
import matplotlib.pyplot as plt
44-
import tensorflow_probability.substrates.jax.distributions as tfd
44+
import numpyro.distributions as npd
4545

4646
from examples.utils import use_mpl_style
4747

@@ -161,7 +161,7 @@
161161

162162

163163
# %%
164-
def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
164+
def fit_gp(x: jax.Array, y: jax.Array) -> npd.MultivariateNormal:
165165
if y.ndim == 1:
166166
y = y.reshape(-1, 1)
167167
D = gpx.Dataset(X=x, y=y)
@@ -204,9 +204,9 @@ def sqrtm(A: jax.Array):
204204

205205

206206
def wasserstein_barycentres(
207-
distributions: tp.List[tfd.MultivariateNormalFullCovariance], weights: jax.Array
207+
distributions: tp.List[npd.MultivariateNormal], weights: jax.Array
208208
):
209-
covariances = [d.covariance() for d in distributions]
209+
covariances = [d.covariance_matrix for d in distributions]
210210
cov_stack = jnp.stack(covariances)
211211
stack_sqrt = jax.vmap(sqrtm)(cov_stack)
212212

@@ -231,7 +231,7 @@ def step(covariance_candidate: jax.Array, idx: None):
231231
# %%
232232
weights = jnp.ones((n_datasets,)) / n_datasets
233233

234-
means = jnp.stack([d.mean() for d in posterior_preds])
234+
means = jnp.stack([d.mean for d in posterior_preds])
235235
barycentre_mean = jnp.tensordot(weights, means, axes=1)
236236

237237
step_fn = jax.jit(wasserstein_barycentres(posterior_preds, weights))
@@ -242,7 +242,7 @@ def step(covariance_candidate: jax.Array, idx: None):
242242
)
243243
L = jnp.linalg.cholesky(barycentre_covariance)
244244

245-
barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L)
245+
barycentre_process = npd.MultivariateNormal(barycentre_mean, scale_tril=L)
246246

247247
# %% [markdown]
248248
# ## Plotting the result
@@ -254,16 +254,16 @@ def step(covariance_candidate: jax.Array, idx: None):
254254

255255
# %%
256256
def plot(
257-
dist: tfd.MultivariateNormalTriL,
257+
dist: npd.MultivariateNormal,
258258
ax,
259259
color: str,
260260
label: str = None,
261261
ci_alpha: float = 0.2,
262262
linewidth: float = 1.0,
263263
zorder: int = 0,
264264
):
265-
mu = dist.mean()
266-
sigma = dist.stddev()
265+
mu = dist.mean
266+
sigma = jnp.sqrt(dist.variance)
267267
ax.plot(xtest, mu, linewidth=linewidth, color=color, label=label, zorder=zorder)
268268
ax.fill_between(
269269
xtest.squeeze(),

examples/classification.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# extension: .py
99
# format_name: percent
1010
# format_version: '1.3'
11-
# jupytext_version: 1.16.6
11+
# jupytext_version: 1.16.7
1212
# kernelspec:
1313
# display_name: gpjax
1414
# language: python
@@ -37,8 +37,8 @@
3737
install_import_hook,
3838
)
3939
import matplotlib.pyplot as plt
40+
import numpyro.distributions as npd
4041
import optax as ox
41-
import tensorflow_probability.substrates.jax as tfp
4242

4343
from examples.utils import use_mpl_style
4444
from gpjax.lower_cholesky import lower_cholesky
@@ -50,7 +50,6 @@
5050
import gpjax as gpx
5151

5252

53-
tfd = tfp.distributions
5453
identity_matrix = jnp.eye
5554

5655
# set the default style for plotting
@@ -120,7 +119,6 @@
120119
# Optax's optimisers.
121120

122121
# %%
123-
124122
optimiser = ox.adam(learning_rate=0.01)
125123

126124
opt_posterior, history = gpx.fit(
@@ -140,8 +138,8 @@
140138
map_latent_dist = opt_posterior.predict(xtest, train_data=D)
141139
predictive_dist = opt_posterior.likelihood(map_latent_dist)
142140

143-
predictive_mean = predictive_dist.mean()
144-
predictive_std = predictive_dist.stddev()
141+
predictive_mean = predictive_dist.mean
142+
predictive_std = jnp.sqrt(predictive_dist.variance)
145143

146144
fig, ax = plt.subplots()
147145
ax.scatter(x, y, label="Observations", color=cols[0])
@@ -215,8 +213,6 @@
215213
# datapoints below.
216214

217215
# %%
218-
219-
220216
gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
221217
jitter = 1e-6
222218

@@ -246,7 +242,7 @@ def loss(params, D):
246242
L_inv = jsp.linalg.solve_triangular(L, identity_matrix(D.n), lower=True)
247243
H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)
248244
LH = jnp.linalg.cholesky(H_inv)
249-
laplace_approximation = tfd.MultivariateNormalTriL(f_hat.squeeze(), LH)
245+
laplace_approximation = npd.MultivariateNormal(f_hat.squeeze(), scale_tril=LH)
250246

251247

252248
# %% [markdown]
@@ -265,7 +261,7 @@ def loss(params, D):
265261

266262

267263
# %%
268-
def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL:
264+
def construct_laplace(test_inputs: Float[Array, "N D"]) -> npd.MultivariateNormal:
269265
map_latent_dist = opt_posterior.predict(xtest, train_data=D)
270266

271267
Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
@@ -279,10 +275,10 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
279275
# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
280276
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)
281277

282-
mean = map_latent_dist.mean()
283-
covariance = map_latent_dist.covariance() + laplace_cov_term
278+
mean = map_latent_dist.mean
279+
covariance = map_latent_dist.covariance_matrix + laplace_cov_term
284280
L = jnp.linalg.cholesky(covariance)
285-
return tfd.MultivariateNormalTriL(jnp.atleast_1d(mean.squeeze()), L)
281+
return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), scale_tril=L)
286282

287283

288284
# %% [markdown]
@@ -291,8 +287,8 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
291287
laplace_latent_dist = construct_laplace(xtest)
292288
predictive_dist = opt_posterior.likelihood(laplace_latent_dist)
293289

294-
predictive_mean = predictive_dist.mean()
295-
predictive_std = predictive_dist.stddev()
290+
predictive_mean = predictive_dist.mean
291+
predictive_std = jnp.sqrt(predictive_dist.variance)
296292

297293
fig, ax = plt.subplots()
298294
ax.scatter(x, y, label="Observations", color=cols[0])

examples/collapsed_vi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# extension: .py
88
# format_name: percent
99
# format_version: '1.3'
10-
# jupytext_version: 1.16.6
10+
# jupytext_version: 1.16.7
1111
# kernelspec:
1212
# display_name: gpjax_beartype
1313
# language: python
@@ -161,10 +161,10 @@
161161

162162
inducing_points = opt_posterior.inducing_inputs.value
163163

164-
samples = latent_dist.sample(seed=key, sample_shape=(20,))
164+
samples = latent_dist.sample(key=key, sample_shape=(20,))
165165

166-
predictive_mean = predictive_dist.mean()
167-
predictive_std = predictive_dist.stddev()
166+
predictive_mean = predictive_dist.mean
167+
predictive_std = jnp.sqrt(predictive_dist.variance)
168168

169169
fig, ax = plt.subplots()
170170

examples/constructing_new_kernels.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# extension: .py
99
# format_name: percent
1010
# format_version: '1.3'
11-
# jupytext_version: 1.16.6
11+
# jupytext_version: 1.16.7
1212
# kernelspec:
1313
# display_name: gpjax
1414
# language: python
@@ -24,6 +24,7 @@
2424
# %%
2525
# Enable Float64 for more stable matrix inversions.
2626
from jax import config
27+
from jax.nn import softplus
2728
import jax.numpy as jnp
2829
import jax.random as jr
2930
from jaxtyping import (
@@ -32,7 +33,9 @@
3233
install_import_hook,
3334
)
3435
import matplotlib.pyplot as plt
35-
import tensorflow_probability.substrates.jax as tfp
36+
import numpyro.distributions as npd
37+
from numpyro.distributions import constraints
38+
import numpyro.distributions.transforms as npt
3639

3740
from examples.utils import use_mpl_style
3841
from gpjax.kernels.computations import DenseKernelComputation
@@ -225,9 +228,27 @@ def angular_distance(x, y, c):
225228
return jnp.abs((x - y + c) % (c * 2) - c)
226229

227230

228-
bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64))
231+
class ShiftedSoftplusTransform(npt.ParameterFreeTransform):
232+
r"""
233+
Transform from unconstrained space to the domain [4, infinity) via
234+
:math:`y = 4 + \log(1 + \exp(x))`. The inverse is computed as
235+
:math:`x = \log(\exp(y - 4) - 1)`.
236+
"""
229237

230-
DEFAULT_BIJECTION["polar"] = bij
238+
domain = constraints.real
239+
codomain = constraints.interval(4.0, jnp.inf) # updated codomain
240+
241+
def __call__(self, x):
242+
return 4.0 + softplus(x) # shift the softplus output by 4
243+
244+
def _inverse(self, y):
245+
return npt._softplus_inv(y - 4.0) # subtract the shift in the inverse
246+
247+
def log_abs_det_jacobian(self, x, y, intermediates=None):
248+
return -softplus(-x)
249+
250+
251+
DEFAULT_BIJECTION["polar"] = ShiftedSoftplusTransform()
231252

232253

233254
class Polar(gpx.kernels.AbstractKernel):
@@ -307,7 +328,7 @@ def __call__(
307328

308329
# %%
309330
posterior_rv = opt_posterior.likelihood(opt_posterior.predict(angles, train_data=D))
310-
mu = posterior_rv.mean()
331+
mu = posterior_rv.mean
311332
one_sigma = posterior_rv.stddev()
312333

313334
# %%

examples/deep_kernels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# extension: .py
99
# format_name: percent
1010
# format_version: '1.3'
11-
# jupytext_version: 1.16.6
11+
# jupytext_version: 1.16.7
1212
# kernelspec:
1313
# display_name: gpjax
1414
# language: python
@@ -238,8 +238,8 @@ def __call__(self, x: jax.Array) -> jax.Array:
238238
latent_dist = opt_posterior(xtest, train_data=D)
239239
predictive_dist = opt_posterior.likelihood(latent_dist)
240240

241-
predictive_mean = predictive_dist.mean()
242-
predictive_std = predictive_dist.stddev()
241+
predictive_mean = predictive_dist.mean
242+
predictive_std = jnp.sqrt(predictive_dist.variance)
243243

244244
fig, ax = plt.subplots()
245245
ax.plot(x, y, "o", label="Observations", color=cols[0])

0 commit comments

Comments
 (0)