Skip to content

Commit 10b2848

Browse files
committed
Add Numpyro MWE
1 parent fcc55e5 commit 10b2848

File tree

7 files changed

+655
-311
lines changed

7 files changed

+655
-311
lines changed

examples/heteroscedastic_inference.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,22 @@
357357
alpha=0.3,
358358
label="One std. dev.",
359359
)
360-
ax.plot(xtest.squeeze(), predictive_mean - predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
361-
ax.plot(xtest.squeeze(), predictive_mean + predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
360+
ax.plot(
361+
xtest.squeeze(),
362+
predictive_mean - predictive_std,
363+
"--",
364+
color=cols[1],
365+
alpha=0.5,
366+
linewidth=0.75,
367+
)
368+
ax.plot(
369+
xtest.squeeze(),
370+
predictive_mean + predictive_std,
371+
"--",
372+
color=cols[1],
373+
alpha=0.5,
374+
linewidth=0.75,
375+
)
362376
ax.fill_between(
363377
xtest.squeeze(),
364378
predictive_mean - 2 * predictive_std,
@@ -367,8 +381,22 @@
367381
alpha=0.1,
368382
label="Two std. dev.",
369383
)
370-
ax.plot(xtest.squeeze(), predictive_mean - 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
371-
ax.plot(xtest.squeeze(), predictive_mean + 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
384+
ax.plot(
385+
xtest.squeeze(),
386+
predictive_mean - 2 * predictive_std,
387+
"--",
388+
color=cols[1],
389+
alpha=0.5,
390+
linewidth=0.75,
391+
)
392+
ax.plot(
393+
xtest.squeeze(),
394+
predictive_mean + 2 * predictive_std,
395+
"--",
396+
color=cols[1],
397+
alpha=0.5,
398+
linewidth=0.75,
399+
)
372400

373401
ax.set_title("Sparse Heteroscedastic Regression")
374402
ax.legend(loc="best", fontsize="small")

examples/lgcp_numpyro.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# %%
2+
import jax.numpy as jnp
3+
from jax import random
4+
from jax import config
5+
import numpy as np
6+
7+
import gpjax as gpx
8+
from gpjax import numpyro_extras
9+
import numpyro
10+
import numpyro.distributions as dist
11+
from numpyro.infer import MCMC, NUTS
12+
import arviz as az
13+
14+
import matplotlib.pyplot as plt
15+
16+
# Enable x64 support for JAX
17+
config.update("jax_enable_x64", True)
18+
19+
# Set random seed
20+
key = random.PRNGKey(42)
21+
22+
# Configure MCMC
23+
num_warmup = 1000
24+
num_samples = 1000
25+
num_chains = 4
26+
27+
# Set device count for numpyro for parallel chains
28+
numpyro.set_host_device_count(num_chains)
29+
30+
# %%
31+
# 1. Data: Coal Mining Disasters (1851-1962)
32+
# Counts of disasters per year
33+
counts = jnp.array([
34+
4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, 2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0, 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, 3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1
35+
], dtype=jnp.float64)
36+
37+
years = jnp.arange(1851, 1851 + len(counts), dtype=jnp.float64).reshape(-1, 1)
38+
# Normalize years for better numerical stability in GP
39+
years_norm = (years - years.min()) / (years.max() - years.min())
40+
41+
# %%
42+
# 2. Model Definition
43+
# We model the log-intensity log(lambda(t)) as a Gaussian Process.
44+
# lambda(t) = exp(f(t))
45+
# y_i ~ Poisson(lambda(t_i))
46+
47+
# Mean function: Constant mean
48+
mean_f = gpx.mean_functions.Constant(constant=jnp.array([0.0]))
49+
50+
# Kernel: Matern52
51+
# We expect changes over decades, so lengthscale should be non-trivial.
52+
# Since x is normalized to [0, 1], a lengthscale of 0.1 corresponds to ~11 years.
53+
kernel = gpx.kernels.Matern52(lengthscale=0.2, variance=0.5)
54+
55+
prior = gpx.gps.Prior(mean_function=mean_f, kernel=kernel)
56+
57+
def model(x, y):
58+
# Register GPJax parameters (lengthscale, variance, mean_constant) with Numpyro
59+
gp = numpyro_extras.register_parameters(prior)
60+
61+
# Sample the latent function f at the input locations x
62+
f = numpyro.sample("f", gp(x))
63+
64+
# The intensity is exp(f)
65+
rate = jnp.exp(f)
66+
67+
# Observation model: Poisson
68+
numpyro.sample("y", dist.Poisson(rate), obs=y)
69+
70+
# %%
71+
# 3. Inference
72+
rng_key, rng_key_ = random.split(key)
73+
74+
kernel_nuts = NUTS(model, target_accept_prob=0.9)
75+
mcmc = MCMC(
76+
kernel_nuts,
77+
num_warmup=num_warmup,
78+
num_samples=num_samples,
79+
num_chains=num_chains,
80+
progress_bar=True,
81+
jit_model_args=True,
82+
)
83+
84+
# Run MCMC
85+
# Note: We pass years_norm for stability, but we'll plot against original years
86+
mcmc.run(rng_key_, x=years_norm, y=counts)
87+
88+
# %%
89+
# 4. Analysis & Plotting
90+
mcmc.print_summary()
91+
92+
# Extract samples
93+
samples = mcmc.get_samples()
94+
f_samples = samples["f"]
95+
intensity_samples = jnp.exp(f_samples)
96+
97+
# Compute statistics
98+
mean_intensity = jnp.mean(intensity_samples, axis=0)
99+
lower_ci = jnp.percentile(intensity_samples, 2.5, axis=0)
100+
upper_ci = jnp.percentile(intensity_samples, 97.5, axis=0)
101+
102+
# Plot
103+
plt.figure(figsize=(12, 6))
104+
plt.bar(years.flatten(), counts, color="gray", alpha=0.5, label="Observed Counts", width=1.0)
105+
plt.plot(years.flatten(), mean_intensity, color="C0", label="Posterior Mean Intensity", linewidth=2)
106+
plt.fill_between(years.flatten(), lower_ci, upper_ci, color="C0", alpha=0.3, label="95% CI")
107+
108+
plt.xlabel("Year")
109+
plt.ylabel("Number of Disasters")
110+
plt.title("Coal Mining Disasters: Log-Gaussian Cox Process (GPJax + Numpyro)")
111+
plt.legend()
112+
plt.grid(True, alpha=0.3)
113+
plt.tight_layout()
114+
plt.savefig("lgcp_coal_mining.png")
115+
# plt.show()
116+
117+
# Trace plot for diagnostics
118+
az.plot_trace(mcmc, var_names=["kernel.lengthscale", "kernel.variance"])
119+
plt.tight_layout()

examples/numpyro_integration.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# cell_metadata_filter: -all
5+
# custom_cell_magics: kql
6+
# text_representation:
7+
# extension: .py
8+
# format_name: percent
9+
# format_version: '1.3'
10+
# jupytext_version: 1.17.3
11+
# kernelspec:
12+
# display_name: python3
13+
# language: python
14+
# name: python3
15+
# ---
16+
17+
# %% [markdown]
18+
# # Joint Inference with Numpyro
19+
#
20+
# In this notebook, we demonstrate how to use [Numpyro](https://num.pyro.ai/) to perform fully Bayesian inference over the hyperparameters of a Gaussian process model.
21+
# We will look at a scenario where we have a structured mean function (a linear model) and a GP capturing the residuals. We will infer the parameters of both the linear model and the GP jointly.
22+
23+
# %%
24+
from jax import config
25+
import jax.numpy as jnp
26+
import jax.random as jr
27+
import matplotlib.pyplot as plt
28+
import numpyro
29+
import numpyro.distributions as dist
30+
from numpyro.infer import (
31+
MCMC,
32+
NUTS,
33+
)
34+
35+
import gpjax as gpx
36+
from gpjax.numpyro_extras import register_parameters
37+
38+
config.update("jax_enable_x64", True)
39+
40+
key = jr.key(42)
41+
42+
# %% [markdown]
43+
# ## Data Generation
44+
#
45+
# We generate a synthetic dataset that consists of a linear trend, a periodic component, and some noise.
46+
47+
# %%
48+
N = 100
49+
x = jnp.sort(jr.uniform(key, shape=(N, 1), minval=0.0, maxval=10.0), axis=0)
50+
51+
# True parameters
52+
true_slope = 0.5
53+
true_intercept = 2.0
54+
true_period = 2.0
55+
true_lengthscale = 1.0
56+
true_noise = 0.1
57+
58+
# Signal
59+
linear_trend = true_slope * x + true_intercept
60+
periodic_signal = jnp.sin(2 * jnp.pi * x / true_period)
61+
y_clean = linear_trend + periodic_signal
62+
63+
# Observations
64+
y = y_clean + true_noise * jr.normal(key, shape=x.shape)
65+
66+
plt.figure(figsize=(10, 5))
67+
plt.scatter(x, y, label="Data", alpha=0.6)
68+
plt.plot(x, y_clean, "k--", label="True Signal")
69+
plt.legend()
70+
plt.show()
71+
72+
# %% [markdown]
73+
# ## Model Definition
74+
#
75+
# We define a GP model with a generic mean function (zero for now, as we will handle the linear trend explicitly in the Numpyro model) and a kernel that is the product of a periodic kernel and an RBF kernel. This choice reflects our prior knowledge that the signal is locally periodic.
76+
77+
# %%
78+
kernel = gpx.kernels.RBF() * gpx.kernels.Periodic()
79+
meanf = gpx.mean_functions.Zero()
80+
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
81+
82+
# We will use a ConjugatePosterior since we assume Gaussian noise
83+
likelihood = gpx.likelihoods.Gaussian(num_datapoints=N)
84+
posterior = prior * likelihood
85+
86+
# We initialise the model parameters.
87+
# Note: These values will be overwritten by Numpyro samples during inference.
88+
D = gpx.Dataset(X=x, y=y)
89+
90+
# %% [markdown]
91+
# ## Joint Inference Loop
92+
#
93+
# We define a Numpyro model function that:
94+
# 1. Samples the parameters for the linear trend.
95+
# 2. Computes the residuals (Data - Linear Trend).
96+
# 3. Samples the GP hyperparameters using `register_parameters`.
97+
# 4. Computes the GP marginal log-likelihood on the residuals.
98+
# 5. Adds the GP log-likelihood to the joint density.
99+
100+
101+
# %%
102+
def model(X, Y):
103+
# 1. Sample linear model parameters
104+
slope = numpyro.sample("slope", dist.Normal(0.0, 2.0))
105+
intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0))
106+
107+
# Calculate residuals
108+
trend = slope * X + intercept
109+
residuals = Y - trend
110+
111+
# 2. Register GP parameters
112+
# This automatically samples parameters from the GPJax model
113+
# and returns a model with updated values.
114+
# We can specify custom priors if needed, but we'll rely on defaults here.
115+
# register_parameters modifies the model in-place (and returns it).
116+
# Since Numpyro re-runs this function, we are overwriting the parameters
117+
# of the same object repeatedly, which is fine as they are completely determined
118+
# by the sample sites.
119+
p_posterior = register_parameters(posterior)
120+
121+
# Create dataset for residuals
122+
D_resid = gpx.Dataset(X=X, y=residuals)
123+
124+
# 3. Compute MLL
125+
# We use conjugate_mll which computes log p(y | X, theta) analytically for Gaussian likelihoods.
126+
mll = gpx.objectives.conjugate_mll(p_posterior, D_resid)
127+
128+
# 4. Add to potential
129+
numpyro.factor("gp_log_lik", mll)
130+
131+
132+
# %% [markdown]
133+
# ## Running MCMC
134+
#
135+
# We use the NUTS sampler to draw samples from the posterior.
136+
137+
# %%
138+
nuts_kernel = NUTS(model)
139+
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
140+
mcmc.run(jr.key(0), x, y)
141+
142+
mcmc.print_summary()
143+
144+
# %% [markdown]
145+
# ## Analysis and Plotting
146+
#
147+
# We extract the samples and plot the predictions.
148+
149+
# %%
150+
samples = mcmc.get_samples()
151+
152+
153+
# Helper to get predictions
154+
def predict(rng_key, sample_idx):
155+
# Reconstruct model with sampled values
156+
157+
# Linear part
158+
slope = samples["slope"][sample_idx]
159+
intercept = samples["intercept"][sample_idx]
160+
trend = slope * x + intercept
161+
162+
# GP part
163+
# We use numpyro.handlers.substitute to inject the sampled values into register_parameters
164+
# to reconstruct the GP model state for this sample.
165+
sample_dict = {k: v[sample_idx] for k, v in samples.items()}
166+
167+
with numpyro.handlers.substitute(data=sample_dict):
168+
# We call register_parameters again to update the posterior object with this sample's values
169+
p_posterior = register_parameters(posterior)
170+
171+
# Now predict on residuals
172+
residuals = y - trend
173+
D_resid = gpx.Dataset(X=x, y=residuals)
174+
175+
latent_dist = p_posterior.predict(x, train_data=D_resid)
176+
predictive_mean = latent_dist.mean
177+
predictive_std = latent_dist.stddev()
178+
179+
return trend + predictive_mean, predictive_std
180+
181+
182+
# Plot
183+
plt.figure(figsize=(12, 6))
184+
plt.scatter(x, y, alpha=0.5, label="Data", color="gray")
185+
plt.plot(x, y_clean, "k--", label="True Signal")
186+
187+
# Compute mean prediction (using mean of samples for efficiency)
188+
mean_slope = jnp.mean(samples["slope"])
189+
mean_intercept = jnp.mean(samples["intercept"])
190+
mean_trend = mean_slope * x + mean_intercept
191+
192+
mean_samples = {k: jnp.mean(v, axis=0) for k, v in samples.items()}
193+
with numpyro.handlers.substitute(data=mean_samples):
194+
p_posterior_mean = register_parameters(posterior)
195+
196+
residuals_mean = y - mean_trend
197+
D_resid_mean = gpx.Dataset(X=x, y=residuals_mean)
198+
latent_dist = p_posterior_mean.predict(x, train_data=D_resid_mean)
199+
pred_mean = latent_dist.mean
200+
pred_std = latent_dist.stddev()
201+
202+
total_mean = mean_trend.flatten() + pred_mean.flatten()
203+
std_flat = pred_std.flatten()
204+
205+
plt.plot(x, total_mean, "b-", label="Posterior Mean")
206+
plt.fill_between(
207+
x.flatten(),
208+
total_mean - 2 * std_flat,
209+
total_mean + 2 * std_flat,
210+
color="b",
211+
alpha=0.2,
212+
label="95% CI (GP Uncertainty)",
213+
)
214+
215+
plt.legend()
216+
plt.show()

0 commit comments

Comments
 (0)