Skip to content

Commit 9d90f32

Browse files
committed
Add explicit prior specification
1 parent 10b2848 commit 9d90f32

File tree

9 files changed

+501
-321
lines changed

9 files changed

+501
-321
lines changed

.github/workflows/test_docs.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ name: Test documentation
33
on:
44
pull_request:
55

6+
permissions:
7+
contents: read
8+
pages: write
9+
id-token: write
10+
611
jobs:
712
test-docs:
813
# Functionality for testing documentation builds on multiple OSes and Python versions
@@ -42,3 +47,26 @@ jobs:
4247
uv sync --extra docs
4348
uv run python docs/scripts/gen_examples.py --execute
4449
uv run mkdocs build
50+
51+
- name: Upload built docs artifact
52+
uses: actions/upload-artifact@v4
53+
with:
54+
name: docs-site-html
55+
path: site
56+
57+
- name: Upload Pages artifact
58+
uses: actions/upload-pages-artifact@v3
59+
with:
60+
path: site
61+
62+
deploy-docs-preview:
63+
needs: test-docs
64+
if: github.event_name == 'pull_request'
65+
runs-on: ubuntu-latest
66+
environment:
67+
name: github-pages
68+
url: ${{ steps.deployment.outputs.page_url }}
69+
steps:
70+
- name: Deploy MkDocs preview
71+
id: deployment
72+
uses: actions/deploy-pages@v4

examples/lgcp_numpyro.py

Lines changed: 0 additions & 119 deletions
This file was deleted.

examples/numpyro_integration.py

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from numpyro.infer import (
3131
MCMC,
3232
NUTS,
33+
Predictive,
3334
)
3435

3536
import gpjax as gpx
@@ -67,20 +68,38 @@
6768
plt.scatter(x, y, label="Data", alpha=0.6)
6869
plt.plot(x, y_clean, "k--", label="True Signal")
6970
plt.legend()
70-
plt.show()
71+
# plt.show()
7172

7273
# %% [markdown]
7374
# ## Model Definition
7475
#
7576
# We define a GP model with a generic mean function (zero for now, as we will handle the linear trend explicitly in the Numpyro model) and a kernel that is the product of a periodic kernel and an RBF kernel. This choice reflects our prior knowledge that the signal is locally periodic.
7677

7778
# %%
78-
kernel = gpx.kernels.RBF() * gpx.kernels.Periodic()
79+
# Define priors
80+
lengthscale_prior = dist.LogNormal(0.0, 1.0)
81+
variance_prior = dist.LogNormal(0.0, 1.0)
82+
period_prior = dist.LogNormal(0.0, 0.5)
83+
noise_prior = dist.LogNormal(0.0, 1.0)
84+
85+
# Define Kernel with priors
86+
# We can explicitly attach priors to the parameters
87+
kernel = gpx.kernels.RBF(
88+
lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior),
89+
variance=gpx.parameters.PositiveReal(1.0, prior=variance_prior),
90+
) * gpx.kernels.Periodic(
91+
lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior),
92+
period=gpx.parameters.PositiveReal(1.0, prior=period_prior),
93+
)
94+
7995
meanf = gpx.mean_functions.Zero()
8096
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
8197

8298
# We will use a ConjugatePosterior since we assume Gaussian noise
83-
likelihood = gpx.likelihoods.Gaussian(num_datapoints=N)
99+
likelihood = gpx.likelihoods.Gaussian(
100+
num_datapoints=N,
101+
obs_stddev=gpx.parameters.NonNegativeReal(1.0, prior=noise_prior),
102+
)
84103
posterior = prior * likelihood
85104

86105
# We initialise the model parameters.
@@ -111,7 +130,8 @@ def model(X, Y):
111130
# 2. Register GP parameters
112131
# This automatically samples parameters from the GPJax model
113132
# and returns a model with updated values.
114-
# We can specify custom priors if needed, but we'll rely on defaults here.
133+
# We attached priors to the parameters during model definition,
134+
# so register_parameters will use those.
115135
# register_parameters modifies the model in-place (and returns it).
116136
# Since Numpyro re-runs this function, we are overwriting the parameters
117137
# of the same object repeatedly, which is fine as they are completely determined
@@ -150,67 +170,59 @@ def model(X, Y):
150170
samples = mcmc.get_samples()
151171

152172

153-
# Helper to get predictions
154-
def predict(rng_key, sample_idx):
155-
# Reconstruct model with sampled values
156-
157-
# Linear part
158-
slope = samples["slope"][sample_idx]
159-
intercept = samples["intercept"][sample_idx]
160-
trend = slope * x + intercept
173+
def predict_fn(X_new, Y_train):
174+
# 1. Sample linear model parameters
175+
slope = numpyro.sample("slope", dist.Normal(0.0, 2.0))
176+
intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0))
161177

162-
# GP part
163-
# We use numpyro.handlers.substitute to inject the sampled values into register_parameters
164-
# to reconstruct the GP model state for this sample.
165-
sample_dict = {k: v[sample_idx] for k, v in samples.items()}
178+
# Calculate residuals
179+
trend_train = slope * x + intercept
180+
residuals = Y_train - trend_train
166181

167-
with numpyro.handlers.substitute(data=sample_dict):
168-
# We call register_parameters again to update the posterior object with this sample's values
169-
p_posterior = register_parameters(posterior)
182+
# 2. Register GP parameters
183+
p_posterior = register_parameters(posterior)
170184

171-
# Now predict on residuals
172-
residuals = y - trend
185+
# Create dataset for residuals
173186
D_resid = gpx.Dataset(X=x, y=residuals)
174187

175-
latent_dist = p_posterior.predict(x, train_data=D_resid)
176-
predictive_mean = latent_dist.mean
177-
predictive_std = latent_dist.stddev()
188+
# 3. Compute latent GP distribution
189+
latent_dist = p_posterior.predict(X_new, train_data=D_resid)
178190

179-
return trend + predictive_mean, predictive_std
191+
# 4. Sample latent function values
192+
f = numpyro.sample("f", latent_dist)
193+
f = f.reshape((-1, 1))
180194

195+
# 5. Compute and return total prediction
196+
total_prediction = slope * X_new + intercept + f
197+
numpyro.deterministic("y_pred", total_prediction)
198+
return total_prediction
181199

182-
# Plot
183-
plt.figure(figsize=(12, 6))
184-
plt.scatter(x, y, alpha=0.5, label="Data", color="gray")
185-
plt.plot(x, y_clean, "k--", label="True Signal")
186200

187-
# Compute mean prediction (using mean of samples for efficiency)
188-
mean_slope = jnp.mean(samples["slope"])
189-
mean_intercept = jnp.mean(samples["intercept"])
190-
mean_trend = mean_slope * x + mean_intercept
201+
# Create predictive utility
202+
predictive = Predictive(predict_fn, posterior_samples=samples)
191203

192-
mean_samples = {k: jnp.mean(v, axis=0) for k, v in samples.items()}
193-
with numpyro.handlers.substitute(data=mean_samples):
194-
p_posterior_mean = register_parameters(posterior)
204+
# Generate predictions
205+
predictions = predictive(jr.key(1), X_new=x, Y_train=y)
206+
y_pred = predictions["y_pred"]
195207

196-
residuals_mean = y - mean_trend
197-
D_resid_mean = gpx.Dataset(X=x, y=residuals_mean)
198-
latent_dist = p_posterior_mean.predict(x, train_data=D_resid_mean)
199-
pred_mean = latent_dist.mean
200-
pred_std = latent_dist.stddev()
208+
# Compute statistics
209+
mean_prediction = jnp.mean(y_pred, axis=0)
210+
std_prediction = jnp.std(y_pred, axis=0)
201211

202-
total_mean = mean_trend.flatten() + pred_mean.flatten()
203-
std_flat = pred_std.flatten()
212+
# Plot
213+
plt.figure(figsize=(12, 6))
214+
plt.scatter(x, y, alpha=0.5, label="Data", color="gray")
215+
plt.plot(x, y_clean, "k--", label="True Signal")
204216

205-
plt.plot(x, total_mean, "b-", label="Posterior Mean")
217+
plt.plot(x, mean_prediction, "b-", label="Posterior Mean")
206218
plt.fill_between(
207219
x.flatten(),
208-
total_mean - 2 * std_flat,
209-
total_mean + 2 * std_flat,
220+
mean_prediction.flatten() - 2 * std_prediction.flatten(),
221+
mean_prediction.flatten() + 2 * std_prediction.flatten(),
210222
color="b",
211223
alpha=0.2,
212224
label="95% CI (GP Uncertainty)",
213225
)
214226

215227
plt.legend()
216-
plt.show()
228+
# plt.show()

gpjax/distributions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def sample(self, key, sample_shape=()):
6868
def affine_transformation(_x):
6969
return self.loc + covariance_root @ _x
7070

71+
if not sample_shape:
72+
return affine_transformation(white_noise)
73+
7174
return vmap(affine_transformation)(white_noise)
7275

7376
@property

gpjax/numpyro_extras.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,10 @@
66
import numpyro.distributions as dist
77

88
from gpjax.parameters import (
9-
FillTriangularTransform,
109
Parameter,
1110
)
1211

1312

14-
def _get_default_prior(tag, shape, ndim):
15-
if tag in ("positive", "non_negative"):
16-
return dist.LogNormal(0.0, 1.0).expand(shape).to_event(ndim)
17-
if tag == "real":
18-
return dist.Normal(0.0, 1.0).expand(shape).to_event(ndim)
19-
if tag == "sigmoid":
20-
return dist.Uniform(0.0, 1.0).expand(shape).to_event(ndim)
21-
if tag == "lower_triangular":
22-
N = shape[-1]
23-
K = N * (N + 1) // 2
24-
batch_shape = shape[:-2]
25-
base_shape = batch_shape + (K,)
26-
base_dist = dist.Normal(0.0, 1.0).expand(base_shape).to_event(1)
27-
td = dist.TransformedDistribution(base_dist, FillTriangularTransform())
28-
return td.to_event(len(batch_shape))
29-
return dist.Normal(0.0, 1.0).expand(shape).to_event(ndim)
30-
31-
3213
def register_parameters(
3314
model: nnx.Module,
3415
priors: tp.Dict[str, dist.Distribution] | None = None,
@@ -71,7 +52,12 @@ def _param_callback(path, param):
7152
# Determine prior
7253
prior = priors.get(name)
7354
if prior is None:
74-
prior = _get_default_prior(param.tag, param.value.shape, param.value.ndim)
55+
# Check for attached prior
56+
numpyro_props = getattr(param, "numpyro_properties", {})
57+
prior = numpyro_props.get("prior")
58+
59+
if prior is None:
60+
return param
7561

7662
# Sample
7763
value = numpyro.sample(name, prior)

0 commit comments

Comments
 (0)