From 10b284852f60079e7298bc2eb6d865d261bb3a3b Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 24 Nov 2025 22:12:28 +0100 Subject: [PATCH 1/5] Add Numpyro MWE --- examples/heteroscedastic_inference.py | 36 ++- examples/lgcp_numpyro.py | 119 ++++++++++ examples/numpyro_integration.py | 216 ++++++++++++++++++ gpjax/numpyro_extras.py | 188 +++++++--------- gpjax/parameters.py | 101 ++++++++- mkdocs.yml | 2 + tests/test_numpyro_extras.py | 304 +++++++++----------------- 7 files changed, 655 insertions(+), 311 deletions(-) create mode 100644 examples/lgcp_numpyro.py create mode 100644 examples/numpyro_integration.py diff --git a/examples/heteroscedastic_inference.py b/examples/heteroscedastic_inference.py index 8c7542246..6650a6a83 100644 --- a/examples/heteroscedastic_inference.py +++ b/examples/heteroscedastic_inference.py @@ -357,8 +357,22 @@ alpha=0.3, label="One std. dev.", ) -ax.plot(xtest.squeeze(), predictive_mean - predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75) -ax.plot(xtest.squeeze(), predictive_mean + predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75) +ax.plot( + xtest.squeeze(), + predictive_mean - predictive_std, + "--", + color=cols[1], + alpha=0.5, + linewidth=0.75, +) +ax.plot( + xtest.squeeze(), + predictive_mean + predictive_std, + "--", + color=cols[1], + alpha=0.5, + linewidth=0.75, +) ax.fill_between( xtest.squeeze(), predictive_mean - 2 * predictive_std, @@ -367,8 +381,22 @@ alpha=0.1, label="Two std. dev.", ) -ax.plot(xtest.squeeze(), predictive_mean - 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75) -ax.plot(xtest.squeeze(), predictive_mean + 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75) +ax.plot( + xtest.squeeze(), + predictive_mean - 2 * predictive_std, + "--", + color=cols[1], + alpha=0.5, + linewidth=0.75, +) +ax.plot( + xtest.squeeze(), + predictive_mean + 2 * predictive_std, + "--", + color=cols[1], + alpha=0.5, + linewidth=0.75, +) ax.set_title("Sparse Heteroscedastic Regression") ax.legend(loc="best", fontsize="small") diff --git a/examples/lgcp_numpyro.py b/examples/lgcp_numpyro.py new file mode 100644 index 000000000..96f0a3749 --- /dev/null +++ b/examples/lgcp_numpyro.py @@ -0,0 +1,119 @@ +# %% +import jax.numpy as jnp +from jax import random +from jax import config +import numpy as np + +import gpjax as gpx +from gpjax import numpyro_extras +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC, NUTS +import arviz as az + +import matplotlib.pyplot as plt + +# Enable x64 support for JAX +config.update("jax_enable_x64", True) + +# Set random seed +key = random.PRNGKey(42) + +# Configure MCMC +num_warmup = 1000 +num_samples = 1000 +num_chains = 4 + +# Set device count for numpyro for parallel chains +numpyro.set_host_device_count(num_chains) + +# %% +# 1. Data: Coal Mining Disasters (1851-1962) +# Counts of disasters per year +counts = jnp.array([ + 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, 2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0, 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, 3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1 +], dtype=jnp.float64) + +years = jnp.arange(1851, 1851 + len(counts), dtype=jnp.float64).reshape(-1, 1) +# Normalize years for better numerical stability in GP +years_norm = (years - years.min()) / (years.max() - years.min()) + +# %% +# 2. Model Definition +# We model the log-intensity log(lambda(t)) as a Gaussian Process. +# lambda(t) = exp(f(t)) +# y_i ~ Poisson(lambda(t_i)) + +# Mean function: Constant mean +mean_f = gpx.mean_functions.Constant(constant=jnp.array([0.0])) + +# Kernel: Matern52 +# We expect changes over decades, so lengthscale should be non-trivial. +# Since x is normalized to [0, 1], a lengthscale of 0.1 corresponds to ~11 years. +kernel = gpx.kernels.Matern52(lengthscale=0.2, variance=0.5) + +prior = gpx.gps.Prior(mean_function=mean_f, kernel=kernel) + +def model(x, y): + # Register GPJax parameters (lengthscale, variance, mean_constant) with Numpyro + gp = numpyro_extras.register_parameters(prior) + + # Sample the latent function f at the input locations x + f = numpyro.sample("f", gp(x)) + + # The intensity is exp(f) + rate = jnp.exp(f) + + # Observation model: Poisson + numpyro.sample("y", dist.Poisson(rate), obs=y) + +# %% +# 3. Inference +rng_key, rng_key_ = random.split(key) + +kernel_nuts = NUTS(model, target_accept_prob=0.9) +mcmc = MCMC( + kernel_nuts, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + progress_bar=True, + jit_model_args=True, +) + +# Run MCMC +# Note: We pass years_norm for stability, but we'll plot against original years +mcmc.run(rng_key_, x=years_norm, y=counts) + +# %% +# 4. Analysis & Plotting +mcmc.print_summary() + +# Extract samples +samples = mcmc.get_samples() +f_samples = samples["f"] +intensity_samples = jnp.exp(f_samples) + +# Compute statistics +mean_intensity = jnp.mean(intensity_samples, axis=0) +lower_ci = jnp.percentile(intensity_samples, 2.5, axis=0) +upper_ci = jnp.percentile(intensity_samples, 97.5, axis=0) + +# Plot +plt.figure(figsize=(12, 6)) +plt.bar(years.flatten(), counts, color="gray", alpha=0.5, label="Observed Counts", width=1.0) +plt.plot(years.flatten(), mean_intensity, color="C0", label="Posterior Mean Intensity", linewidth=2) +plt.fill_between(years.flatten(), lower_ci, upper_ci, color="C0", alpha=0.3, label="95% CI") + +plt.xlabel("Year") +plt.ylabel("Number of Disasters") +plt.title("Coal Mining Disasters: Log-Gaussian Cox Process (GPJax + Numpyro)") +plt.legend() +plt.grid(True, alpha=0.3) +plt.tight_layout() +plt.savefig("lgcp_coal_mining.png") +# plt.show() + +# Trace plot for diagnostics +az.plot_trace(mcmc, var_names=["kernel.lengthscale", "kernel.variance"]) +plt.tight_layout() \ No newline at end of file diff --git a/examples/numpyro_integration.py b/examples/numpyro_integration.py new file mode 100644 index 000000000..bed7b0f0f --- /dev/null +++ b/examples/numpyro_integration.py @@ -0,0 +1,216 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.17.3 +# kernelspec: +# display_name: python3 +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Joint Inference with Numpyro +# +# In this notebook, we demonstrate how to use [Numpyro](https://num.pyro.ai/) to perform fully Bayesian inference over the hyperparameters of a Gaussian process model. +# We will look at a scenario where we have a structured mean function (a linear model) and a GP capturing the residuals. We will infer the parameters of both the linear model and the GP jointly. + +# %% +from jax import config +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt +import numpyro +import numpyro.distributions as dist +from numpyro.infer import ( + MCMC, + NUTS, +) + +import gpjax as gpx +from gpjax.numpyro_extras import register_parameters + +config.update("jax_enable_x64", True) + +key = jr.key(42) + +# %% [markdown] +# ## Data Generation +# +# We generate a synthetic dataset that consists of a linear trend, a periodic component, and some noise. + +# %% +N = 100 +x = jnp.sort(jr.uniform(key, shape=(N, 1), minval=0.0, maxval=10.0), axis=0) + +# True parameters +true_slope = 0.5 +true_intercept = 2.0 +true_period = 2.0 +true_lengthscale = 1.0 +true_noise = 0.1 + +# Signal +linear_trend = true_slope * x + true_intercept +periodic_signal = jnp.sin(2 * jnp.pi * x / true_period) +y_clean = linear_trend + periodic_signal + +# Observations +y = y_clean + true_noise * jr.normal(key, shape=x.shape) + +plt.figure(figsize=(10, 5)) +plt.scatter(x, y, label="Data", alpha=0.6) +plt.plot(x, y_clean, "k--", label="True Signal") +plt.legend() +plt.show() + +# %% [markdown] +# ## Model Definition +# +# We define a GP model with a generic mean function (zero for now, as we will handle the linear trend explicitly in the Numpyro model) and a kernel that is the product of a periodic kernel and an RBF kernel. This choice reflects our prior knowledge that the signal is locally periodic. + +# %% +kernel = gpx.kernels.RBF() * gpx.kernels.Periodic() +meanf = gpx.mean_functions.Zero() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) + +# We will use a ConjugatePosterior since we assume Gaussian noise +likelihood = gpx.likelihoods.Gaussian(num_datapoints=N) +posterior = prior * likelihood + +# We initialise the model parameters. +# Note: These values will be overwritten by Numpyro samples during inference. +D = gpx.Dataset(X=x, y=y) + +# %% [markdown] +# ## Joint Inference Loop +# +# We define a Numpyro model function that: +# 1. Samples the parameters for the linear trend. +# 2. Computes the residuals (Data - Linear Trend). +# 3. Samples the GP hyperparameters using `register_parameters`. +# 4. Computes the GP marginal log-likelihood on the residuals. +# 5. Adds the GP log-likelihood to the joint density. + + +# %% +def model(X, Y): + # 1. Sample linear model parameters + slope = numpyro.sample("slope", dist.Normal(0.0, 2.0)) + intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0)) + + # Calculate residuals + trend = slope * X + intercept + residuals = Y - trend + + # 2. Register GP parameters + # This automatically samples parameters from the GPJax model + # and returns a model with updated values. + # We can specify custom priors if needed, but we'll rely on defaults here. + # register_parameters modifies the model in-place (and returns it). + # Since Numpyro re-runs this function, we are overwriting the parameters + # of the same object repeatedly, which is fine as they are completely determined + # by the sample sites. + p_posterior = register_parameters(posterior) + + # Create dataset for residuals + D_resid = gpx.Dataset(X=X, y=residuals) + + # 3. Compute MLL + # We use conjugate_mll which computes log p(y | X, theta) analytically for Gaussian likelihoods. + mll = gpx.objectives.conjugate_mll(p_posterior, D_resid) + + # 4. Add to potential + numpyro.factor("gp_log_lik", mll) + + +# %% [markdown] +# ## Running MCMC +# +# We use the NUTS sampler to draw samples from the posterior. + +# %% +nuts_kernel = NUTS(model) +mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1) +mcmc.run(jr.key(0), x, y) + +mcmc.print_summary() + +# %% [markdown] +# ## Analysis and Plotting +# +# We extract the samples and plot the predictions. + +# %% +samples = mcmc.get_samples() + + +# Helper to get predictions +def predict(rng_key, sample_idx): + # Reconstruct model with sampled values + + # Linear part + slope = samples["slope"][sample_idx] + intercept = samples["intercept"][sample_idx] + trend = slope * x + intercept + + # GP part + # We use numpyro.handlers.substitute to inject the sampled values into register_parameters + # to reconstruct the GP model state for this sample. + sample_dict = {k: v[sample_idx] for k, v in samples.items()} + + with numpyro.handlers.substitute(data=sample_dict): + # We call register_parameters again to update the posterior object with this sample's values + p_posterior = register_parameters(posterior) + + # Now predict on residuals + residuals = y - trend + D_resid = gpx.Dataset(X=x, y=residuals) + + latent_dist = p_posterior.predict(x, train_data=D_resid) + predictive_mean = latent_dist.mean + predictive_std = latent_dist.stddev() + + return trend + predictive_mean, predictive_std + + +# Plot +plt.figure(figsize=(12, 6)) +plt.scatter(x, y, alpha=0.5, label="Data", color="gray") +plt.plot(x, y_clean, "k--", label="True Signal") + +# Compute mean prediction (using mean of samples for efficiency) +mean_slope = jnp.mean(samples["slope"]) +mean_intercept = jnp.mean(samples["intercept"]) +mean_trend = mean_slope * x + mean_intercept + +mean_samples = {k: jnp.mean(v, axis=0) for k, v in samples.items()} +with numpyro.handlers.substitute(data=mean_samples): + p_posterior_mean = register_parameters(posterior) + +residuals_mean = y - mean_trend +D_resid_mean = gpx.Dataset(X=x, y=residuals_mean) +latent_dist = p_posterior_mean.predict(x, train_data=D_resid_mean) +pred_mean = latent_dist.mean +pred_std = latent_dist.stddev() + +total_mean = mean_trend.flatten() + pred_mean.flatten() +std_flat = pred_std.flatten() + +plt.plot(x, total_mean, "b-", label="Posterior Mean") +plt.fill_between( + x.flatten(), + total_mean - 2 * std_flat, + total_mean + 2 * std_flat, + color="b", + alpha=0.2, + label="95% CI (GP Uncertainty)", +) + +plt.legend() +plt.show() diff --git a/gpjax/numpyro_extras.py b/gpjax/numpyro_extras.py index 846c39af1..d0845d497 100644 --- a/gpjax/numpyro_extras.py +++ b/gpjax/numpyro_extras.py @@ -1,106 +1,88 @@ -import math - -import jax -import jax.numpy as jnp -from numpyro.distributions.transforms import Transform - -# ----------------------------------------------------------------------------- -# Implementation: FillTriangularTransform -# ----------------------------------------------------------------------------- +import typing as tp + +from flax import nnx +import jax.tree_util as jtu +import numpyro +import numpyro.distributions as dist + +from gpjax.parameters import ( + FillTriangularTransform, + Parameter, +) + + +def _get_default_prior(tag, shape, ndim): + if tag in ("positive", "non_negative"): + return dist.LogNormal(0.0, 1.0).expand(shape).to_event(ndim) + if tag == "real": + return dist.Normal(0.0, 1.0).expand(shape).to_event(ndim) + if tag == "sigmoid": + return dist.Uniform(0.0, 1.0).expand(shape).to_event(ndim) + if tag == "lower_triangular": + N = shape[-1] + K = N * (N + 1) // 2 + batch_shape = shape[:-2] + base_shape = batch_shape + (K,) + base_dist = dist.Normal(0.0, 1.0).expand(base_shape).to_event(1) + td = dist.TransformedDistribution(base_dist, FillTriangularTransform()) + return td.to_event(len(batch_shape)) + return dist.Normal(0.0, 1.0).expand(shape).to_event(ndim) + + +def register_parameters( + model: nnx.Module, + priors: tp.Dict[str, dist.Distribution] | None = None, + prefix: str = "", +) -> nnx.Module: + """ + Register GPJax parameters with Numpyro. + Args: + model: The GPJax model (flax.nnx.Module). + priors: Optional dictionary mapping parameter names to Numpyro distributions. + prefix: Optional prefix for parameter names. -class FillTriangularTransform(Transform): + Returns: + The model with parameters updated from Numpyro samples. """ - Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix. - The ordering is assumed to be: - (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1) - """ - - # Note: The base class provides `inv` through _InverseTransform wrapping _inverse. - - def __call__(self, x): - """ - Forward transformation. - - Parameters - ---------- - x : array_like, shape (..., L) - Input vector with L = n(n+1)/2 for some integer n. - - Returns - ------- - y : array_like, shape (..., n, n) - Lower-triangular matrix (with zeros in the upper triangle) filled in - row-major order (i.e. [ (0,0), (1,0), (1,1), ... ]). - """ - L = x.shape[-1] - # Use static (Python) math.sqrt to compute n. This avoids tracer issues. - n = int((-1 + math.sqrt(1 + 8 * L)) // 2) - if n * (n + 1) // 2 != L: - raise ValueError("Last dimension must equal n(n+1)/2 for some integer n.") - - def fill_single(vec): - out = jnp.zeros((n, n), dtype=vec.dtype) - row, col = jnp.tril_indices(n) - return out.at[row, col].set(vec) - - if x.ndim == 1: - return fill_single(x) - else: - batch_shape = x.shape[:-1] - flat_x = x.reshape((-1, L)) - out = jax.vmap(fill_single)(flat_x) - return out.reshape(batch_shape + (n, n)) - - def _inverse(self, y): - """ - Inverse transformation. - - Parameters - ---------- - y : array_like, shape (..., n, n) - Lower triangular matrix. - - Returns - ------- - x : array_like, shape (..., n(n+1)/2) - The vector containing the elements from the lower-triangular portion of y. - """ - if y.ndim < 2: - raise ValueError("Input to inverse must be at least two-dimensional.") - n = y.shape[-1] - if y.shape[-2] != n: - raise ValueError( - "Input matrix must be square; got shape %s" % str(y.shape[-2:]) - ) - - row, col = jnp.tril_indices(n) - - def inv_single(mat): - return mat[row, col] - - if y.ndim == 2: - return inv_single(y) - else: - batch_shape = y.shape[:-2] - flat_y = y.reshape((-1, n, n)) - out = jax.vmap(inv_single)(flat_y) - return out.reshape(batch_shape + (n * (n + 1) // 2,)) - - def log_abs_det_jacobian(self, x, y, intermediates=None): - # Since the transform simply reorders the vector into a matrix, the Jacobian determinant is 1. - return jnp.zeros(x.shape[:-1]) - - @property - def sign(self): - # The reordering transformation has a positive derivative everywhere. - return 1.0 - - # Implement tree_flatten and tree_unflatten because base Transform expects them. - def tree_flatten(self): - # This transform is stateless. - return (), {} - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls() + if priors is None: + priors = {} + + def _param_callback(path, param): + if not isinstance(param, Parameter): + return param + + # Construct name + name_parts = [] + for p in path: + if isinstance(p, jtu.DictKey): + name_parts.append(str(p.key)) + elif isinstance(p, jtu.SequenceKey): + name_parts.append(str(p.idx)) + elif isinstance(p, jtu.GetAttrKey): + name_parts.append(str(p.name)) + else: + name_parts.append(str(p)) + + name = ".".join(name_parts) + if prefix: + name = f"{prefix}.{name}" + + # Determine prior + prior = priors.get(name) + if prior is None: + prior = _get_default_prior(param.tag, param.value.shape, param.value.ndim) + + # Sample + value = numpyro.sample(name, prior) + + # Update parameter + return param.replace(value) + + graphdef, state = nnx.split(model) + + new_state = jtu.tree_map_with_path( + _param_callback, state, is_leaf=lambda x: isinstance(x, Parameter) + ) + + return nnx.merge(graphdef, new_state) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 71b587c5e..0bdcb9957 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -1,18 +1,115 @@ +import math import typing as tp from flax import nnx +import jax from jax.experimental import checkify import jax.numpy as jnp import jax.tree_util as jtu from jax.typing import ArrayLike import numpyro.distributions.transforms as npt -from gpjax.numpyro_extras import FillTriangularTransform - T = tp.TypeVar("T", bound=tp.Union[ArrayLike, list[float]]) ParameterTag = str +class FillTriangularTransform(npt.Transform): + """ + Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix. + The ordering is assumed to be: + (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1) + """ + + # Note: The base class provides `inv` through _InverseTransform wrapping _inverse. + + def __call__(self, x): + """ + Forward transformation. + + Parameters + ---------- + x : array_like, shape (..., L) + Input vector with L = n(n+1)/2 for some integer n. + + Returns + ------- + y : array_like, shape (..., n, n) + Lower-triangular matrix (with zeros in the upper triangle) filled in + row-major order (i.e. [ (0,0), (1,0), (1,1), ... ]). + """ + L = x.shape[-1] + # Use static (Python) math.sqrt to compute n. This avoids tracer issues. + n = int((-1 + math.sqrt(1 + 8 * L)) // 2) + if n * (n + 1) // 2 != L: + raise ValueError("Last dimension must equal n(n+1)/2 for some integer n.") + + def fill_single(vec): + out = jnp.zeros((n, n), dtype=vec.dtype) + row, col = jnp.tril_indices(n) + return out.at[row, col].set(vec) + + if x.ndim == 1: + return fill_single(x) + else: + batch_shape = x.shape[:-1] + flat_x = x.reshape((-1, L)) + out = jax.vmap(fill_single)(flat_x) + return out.reshape(batch_shape + (n, n)) + + def _inverse(self, y): + """ + Inverse transformation. + + Parameters + ---------- + y : array_like, shape (..., n, n) + Lower triangular matrix. + + Returns + ------- + x : array_like, shape (..., n(n+1)/2) + The vector containing the elements from the lower-triangular portion of y. + """ + if y.ndim < 2: + raise ValueError("Input to inverse must be at least two-dimensional.") + n = y.shape[-1] + if y.shape[-2] != n: + raise ValueError( + "Input matrix must be square; got shape %s" % str(y.shape[-2:]) + ) + + row, col = jnp.tril_indices(n) + + def inv_single(mat): + return mat[row, col] + + if y.ndim == 2: + return inv_single(y) + else: + batch_shape = y.shape[:-2] + flat_y = y.reshape((-1, n, n)) + out = jax.vmap(inv_single)(flat_y) + return out.reshape(batch_shape + (n * (n + 1) // 2,)) + + def log_abs_det_jacobian(self, x, y, intermediates=None): + # Since the transform simply reorders the vector into a matrix, the Jacobian determinant is 1. + return jnp.zeros(x.shape[:-1]) + + @property + def sign(self): + # The reordering transformation has a positive derivative everywhere. + return 1.0 + + # Implement tree_flatten and tree_unflatten because base Transform expects them. + def tree_flatten(self): + # This transform is stateless. + return (), {} + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls() + + def transform( params: nnx.State, params_bijection: tp.Dict[str, npt.Transform], diff --git a/mkdocs.yml b/mkdocs.yml index 8e70d9d49..0526e19bd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,8 @@ nav: - Stochastic sparse GPs: _examples/uncollapsed_vi.md - Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md - Heteroscedastic Inference: _examples/heteroscedastic_inference.md + - Numpyro Integration: _examples/numpyro_integration.md + - Log-Gaussian Cox Process: _examples/lgcp_numpyro.md - ๐Ÿ“– Guides for customisation: - Kernels: _examples/constructing_new_kernels.md - Likelihoods: _examples/likelihoods_guide.md diff --git a/tests/test_numpyro_extras.py b/tests/test_numpyro_extras.py index ce7764352..d7a526a45 100644 --- a/tests/test_numpyro_extras.py +++ b/tests/test_numpyro_extras.py @@ -1,203 +1,103 @@ -from jax import ( - grad, - jit, -) +from flax import nnx import jax.numpy as jnp -import numpy as np -import pytest - -from gpjax.numpyro_extras import FillTriangularTransform - - -# Helper function to generate a test input vector for a given matrix size. -def generate_test_vector(n): - """ - Generate a sequential vector of shape (n(n+1)/2,) with values [1, 2, ..., n(n+1)/2]. - """ - L = n * (n + 1) // 2 - return jnp.arange(1, L + 1, dtype=jnp.float32) - - -# ----------------- Unit tests using PyTest ----------------- - - -@pytest.mark.parametrize("n", [1, 2, 3, 4]) -def test_forward_inverse(n): - """ - Test that for a range of input sizes the forward transform correctly fills - an n x n lower triangular matrix and that the inverse recovers the original vector. - """ - ft = FillTriangularTransform() - vec = generate_test_vector(n) - L = ft(vec) - - # Construct the expected n x n lower triangular matrix - expected = jnp.zeros((n, n), dtype=vec.dtype) - row, col = jnp.tril_indices(n) - expected = expected.at[row, col].set(vec) - - np.testing.assert_allclose(L, expected, rtol=1e-6) - - # Check that the inverse recovers the original vector - vec_rec = ft.inv(L) - np.testing.assert_allclose(vec, vec_rec, rtol=1e-6) - - -@pytest.mark.parametrize("n", [1, 2, 3, 4]) -def test_batched_forward_inverse(n): - """ - Test that the transform correctly handles batched inputs. - """ - ft = FillTriangularTransform() - batch_size = 5 - vec = jnp.stack([generate_test_vector(n) for _ in range(batch_size)], axis=0) - L = ft(vec) # Expected shape: (batch_size, n, n) - assert L.shape == (batch_size, n, n) - - vec_rec = ft.inv(L) # Expected shape: (batch_size, n(n+1)/2) - assert vec_rec.shape == (batch_size, n * (n + 1) // 2) - np.testing.assert_allclose(vec, vec_rec, rtol=1e-6) - - -def test_jit_forward(): - """ - Test that the forward transformation works correctly when compiled with JIT. - """ - ft = FillTriangularTransform() - n = 3 - vec = generate_test_vector(n) - - jit_forward = jit(ft) - L = ft(vec) - L_jit = jit_forward(vec) - np.testing.assert_allclose(L, L_jit, rtol=1e-6) - - -def test_jit_inverse(): - """ - Test that the inverse transformation works correctly when compiled with JIT. - """ - ft = FillTriangularTransform() - n = 3 - vec = generate_test_vector(n) - L_mat = ft(vec) - - # Wrap the inverse call in a lambda to avoid hashing the unhashable _InverseTransform. - jit_inverse = jit(lambda y: ft.inv(y)) - vec_rec = ft.inv(L_mat) - vec_rec_jit = jit_inverse(L_mat) - np.testing.assert_allclose(vec_rec, vec_rec_jit, rtol=1e-6) - - -def test_grad_forward(): - """ - Test that JAX gradients can be computed for the forward transform. - We define a simple function that sums the output matrix. - Since the forward transform is just a reordering, the gradient should be 1 - for every element in the input vector. - """ - ft = FillTriangularTransform() - n = 3 - vec = generate_test_vector(n) - - # Define a scalar function f(x) = sum(forward(x)) - f = lambda x: jnp.sum(ft(x)) - grad_f = grad(f)(vec) - np.testing.assert_allclose(grad_f, jnp.ones_like(vec), rtol=1e-6) - - -def test_grad_inverse(): - """ - Test that gradients flow through the inverse transformation. - Define a simple scalar function on the inverse such that g(y) = sum(inv(y)). - The gradient with respect to y should be one on the lower triangular indices. - """ - ft = FillTriangularTransform() - n = 3 - vec = generate_test_vector(n) - L = ft(vec) - - g = lambda y: jnp.sum(ft.inv(y)) - grad_g = grad(g)(L) - - # Construct the expected gradient matrix: zeros everywhere except ones on the lower triangle. - grad_expected = jnp.zeros_like(L) - row, col = jnp.tril_indices(n) - grad_expected = grad_expected.at[row, col].set(1.0) - np.testing.assert_allclose(grad_g, grad_expected, rtol=1e-6) - - -def test_invalid_dimension_error(): - """ - Test that the FillTriangularTransform correctly raises a ValueError when - the last dimension doesn't equal n(n+1)/2 for some integer n. - """ - ft = FillTriangularTransform() - - # Create vectors with invalid dimensions that aren't n(n+1)/2 for any integer n - invalid_dims = [2, 4, 5, 7, 8, 11, 13, 14, 17, 19, 20] - - for dim in invalid_dims: - vec = jnp.ones(dim) - with pytest.raises( - ValueError, - match="Last dimension must equal n\\(n\\+1\\)/2 for some integer n\\.", - ): - ft(vec) - - # Verify that valid dimensions don't raise errors - valid_dims = [1, 3, 6, 10, 15, 21] # n(n+1)/2 for n=1,2,3,4,5,6 - - for dim in valid_dims: - vec = jnp.ones(dim) - try: - ft(vec) - except ValueError: - pytest.fail( - f"FillTriangularTransform raised ValueError for valid dimension {dim}" - ) - - -def test_inverse_dimension_error(): - """ - Test that the FillTriangularTransform.inv correctly raises a ValueError when - the input has less than two dimensions. - """ - ft = FillTriangularTransform() - - # Create a one-dimensional array - vec = jnp.ones(3) # 1D array with 3 elements - - # Try to call inverse on the 1D array, should fail - with pytest.raises( - ValueError, match="Input to inverse must be at least two-dimensional." - ): - ft.inv(vec) - - -def test_inverse_non_square_error(): - """ - Test that the FillTriangularTransform.inv correctly raises a ValueError when - the input matrix is not square. - """ - ft = FillTriangularTransform() - - # Create non-square matrices of different shapes - non_square_matrices = [ - jnp.ones((3, 4)), # 3x4 matrix - jnp.ones((5, 2)), # 5x2 matrix - jnp.ones((1, 3)), # 1x3 matrix - ] - - for matrix in non_square_matrices: - # Extract dimensions - dim1, dim2 = matrix.shape[-2:] - # Use a simpler regex pattern that doesn't include parentheses - error_pattern = "Input matrix must be square; got shape" - with pytest.raises(ValueError, match=error_pattern): - ft.inv(matrix) - - # Test with batched non-square matrices - batched_non_square = jnp.ones((2, 3, 4)) # Batch of 2 matrices of shape 3x4 - with pytest.raises(ValueError, match="Input matrix must be square"): - ft.inv(batched_non_square) +import numpyro.distributions as dist +from numpyro.handlers import ( + seed, + trace, +) + +from gpjax.numpyro_extras import register_parameters +from gpjax.parameters import ( + PositiveReal, + Real, +) + + +class MockSubModule(nnx.Module): + def __init__(self): + self.c = Real(jnp.array(3.0)) + + +class MockModel(nnx.Module): + def __init__(self): + self.a = PositiveReal(jnp.array(1.0)) + self.b = Real(jnp.array(2.0)) + self.submodule = MockSubModule() + + +def test_register_parameters_default_priors(): + model = MockModel() + + def model_fn(): + return register_parameters(model) + + with seed(rng_seed=0): + tr = trace(model_fn).get_trace() + + # Check sites exist + assert "a" in tr + assert "b" in tr + assert "submodule.c" in tr + + # Check distributions + # a: PositiveReal -> LogNormal + # LogNormal is a TransformedDistribution. + assert isinstance(tr["a"]["fn"], dist.LogNormal) + + # b: Real -> Normal + # If scalar, to_event(0) returns Normal. If vector, to_event(1) returns Independent. + if isinstance(tr["b"]["fn"], dist.Independent): + assert isinstance(tr["b"]["fn"].base_dist, dist.Normal) + else: + assert isinstance(tr["b"]["fn"], dist.Normal) + + # submodule.c: Real -> Normal + if isinstance(tr["submodule.c"]["fn"], dist.Independent): + assert isinstance(tr["submodule.c"]["fn"].base_dist, dist.Normal) + else: + assert isinstance(tr["submodule.c"]["fn"], dist.Normal) + + # Check values in updated model + with seed(rng_seed=0): + updated_model = model_fn() + + assert jnp.allclose(updated_model.a.value, tr["a"]["value"]) + assert jnp.allclose(updated_model.b.value, tr["b"]["value"]) + assert jnp.allclose(updated_model.submodule.c.value, tr["submodule.c"]["value"]) + + # Verify original values were different (random sample != 1.0) + assert not jnp.allclose(updated_model.a.value, 1.0) + + +def test_register_parameters_custom_priors(): + model = MockModel() + + priors = {"a": dist.Gamma(2.0, 2.0), "submodule.c": dist.Cauchy(0.0, 1.0)} + + def model_fn(): + return register_parameters(model, priors=priors) + + with seed(rng_seed=0): + tr = trace(model_fn).get_trace() + + assert isinstance(tr["a"]["fn"], dist.Gamma) + # b should use default (Normal wrapped in Independent or Normal) + if isinstance(tr["b"]["fn"], dist.Independent): + assert isinstance(tr["b"]["fn"].base_dist, dist.Normal) + else: + assert isinstance(tr["b"]["fn"], dist.Normal) + assert isinstance(tr["submodule.c"]["fn"], dist.Cauchy) + + +def test_register_parameters_prefix(): + model = MockModel() + + def model_fn(): + return register_parameters(model, prefix="foo") + + with seed(rng_seed=0): + tr = trace(model_fn).get_trace() + + assert "foo.a" in tr + assert "foo.b" in tr + assert "foo.submodule.c" in tr From 9d90f32c7ed2cc33e93d348650309c8c2b63148a Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 24 Nov 2025 23:37:44 +0100 Subject: [PATCH 2/5] Add explicit prior specification --- .github/workflows/test_docs.yml | 28 ++++ examples/lgcp_numpyro.py | 119 -------------- examples/numpyro_integration.py | 106 +++++++------ gpjax/distributions.py | 3 + gpjax/numpyro_extras.py | 26 +--- gpjax/parameters.py | 12 +- tests/test_heteroscedastic.py | 2 +- tests/test_numpyro_extras.py | 267 +++++++++++++++++++++++++------- tests/test_parameters.py | 259 ++++++++++++++++++++++--------- 9 files changed, 501 insertions(+), 321 deletions(-) delete mode 100644 examples/lgcp_numpyro.py diff --git a/.github/workflows/test_docs.yml b/.github/workflows/test_docs.yml index 3da520b65..77dd9ff17 100644 --- a/.github/workflows/test_docs.yml +++ b/.github/workflows/test_docs.yml @@ -3,6 +3,11 @@ name: Test documentation on: pull_request: +permissions: + contents: read + pages: write + id-token: write + jobs: test-docs: # Functionality for testing documentation builds on multiple OSes and Python versions @@ -42,3 +47,26 @@ jobs: uv sync --extra docs uv run python docs/scripts/gen_examples.py --execute uv run mkdocs build + + - name: Upload built docs artifact + uses: actions/upload-artifact@v4 + with: + name: docs-site-html + path: site + + - name: Upload Pages artifact + uses: actions/upload-pages-artifact@v3 + with: + path: site + + deploy-docs-preview: + needs: test-docs + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - name: Deploy MkDocs preview + id: deployment + uses: actions/deploy-pages@v4 diff --git a/examples/lgcp_numpyro.py b/examples/lgcp_numpyro.py deleted file mode 100644 index 96f0a3749..000000000 --- a/examples/lgcp_numpyro.py +++ /dev/null @@ -1,119 +0,0 @@ -# %% -import jax.numpy as jnp -from jax import random -from jax import config -import numpy as np - -import gpjax as gpx -from gpjax import numpyro_extras -import numpyro -import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS -import arviz as az - -import matplotlib.pyplot as plt - -# Enable x64 support for JAX -config.update("jax_enable_x64", True) - -# Set random seed -key = random.PRNGKey(42) - -# Configure MCMC -num_warmup = 1000 -num_samples = 1000 -num_chains = 4 - -# Set device count for numpyro for parallel chains -numpyro.set_host_device_count(num_chains) - -# %% -# 1. Data: Coal Mining Disasters (1851-1962) -# Counts of disasters per year -counts = jnp.array([ - 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, 2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0, 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, 3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1 -], dtype=jnp.float64) - -years = jnp.arange(1851, 1851 + len(counts), dtype=jnp.float64).reshape(-1, 1) -# Normalize years for better numerical stability in GP -years_norm = (years - years.min()) / (years.max() - years.min()) - -# %% -# 2. Model Definition -# We model the log-intensity log(lambda(t)) as a Gaussian Process. -# lambda(t) = exp(f(t)) -# y_i ~ Poisson(lambda(t_i)) - -# Mean function: Constant mean -mean_f = gpx.mean_functions.Constant(constant=jnp.array([0.0])) - -# Kernel: Matern52 -# We expect changes over decades, so lengthscale should be non-trivial. -# Since x is normalized to [0, 1], a lengthscale of 0.1 corresponds to ~11 years. -kernel = gpx.kernels.Matern52(lengthscale=0.2, variance=0.5) - -prior = gpx.gps.Prior(mean_function=mean_f, kernel=kernel) - -def model(x, y): - # Register GPJax parameters (lengthscale, variance, mean_constant) with Numpyro - gp = numpyro_extras.register_parameters(prior) - - # Sample the latent function f at the input locations x - f = numpyro.sample("f", gp(x)) - - # The intensity is exp(f) - rate = jnp.exp(f) - - # Observation model: Poisson - numpyro.sample("y", dist.Poisson(rate), obs=y) - -# %% -# 3. Inference -rng_key, rng_key_ = random.split(key) - -kernel_nuts = NUTS(model, target_accept_prob=0.9) -mcmc = MCMC( - kernel_nuts, - num_warmup=num_warmup, - num_samples=num_samples, - num_chains=num_chains, - progress_bar=True, - jit_model_args=True, -) - -# Run MCMC -# Note: We pass years_norm for stability, but we'll plot against original years -mcmc.run(rng_key_, x=years_norm, y=counts) - -# %% -# 4. Analysis & Plotting -mcmc.print_summary() - -# Extract samples -samples = mcmc.get_samples() -f_samples = samples["f"] -intensity_samples = jnp.exp(f_samples) - -# Compute statistics -mean_intensity = jnp.mean(intensity_samples, axis=0) -lower_ci = jnp.percentile(intensity_samples, 2.5, axis=0) -upper_ci = jnp.percentile(intensity_samples, 97.5, axis=0) - -# Plot -plt.figure(figsize=(12, 6)) -plt.bar(years.flatten(), counts, color="gray", alpha=0.5, label="Observed Counts", width=1.0) -plt.plot(years.flatten(), mean_intensity, color="C0", label="Posterior Mean Intensity", linewidth=2) -plt.fill_between(years.flatten(), lower_ci, upper_ci, color="C0", alpha=0.3, label="95% CI") - -plt.xlabel("Year") -plt.ylabel("Number of Disasters") -plt.title("Coal Mining Disasters: Log-Gaussian Cox Process (GPJax + Numpyro)") -plt.legend() -plt.grid(True, alpha=0.3) -plt.tight_layout() -plt.savefig("lgcp_coal_mining.png") -# plt.show() - -# Trace plot for diagnostics -az.plot_trace(mcmc, var_names=["kernel.lengthscale", "kernel.variance"]) -plt.tight_layout() \ No newline at end of file diff --git a/examples/numpyro_integration.py b/examples/numpyro_integration.py index bed7b0f0f..f643b33ef 100644 --- a/examples/numpyro_integration.py +++ b/examples/numpyro_integration.py @@ -30,6 +30,7 @@ from numpyro.infer import ( MCMC, NUTS, + Predictive, ) import gpjax as gpx @@ -67,7 +68,7 @@ plt.scatter(x, y, label="Data", alpha=0.6) plt.plot(x, y_clean, "k--", label="True Signal") plt.legend() -plt.show() +# plt.show() # %% [markdown] # ## Model Definition @@ -75,12 +76,30 @@ # We define a GP model with a generic mean function (zero for now, as we will handle the linear trend explicitly in the Numpyro model) and a kernel that is the product of a periodic kernel and an RBF kernel. This choice reflects our prior knowledge that the signal is locally periodic. # %% -kernel = gpx.kernels.RBF() * gpx.kernels.Periodic() +# Define priors +lengthscale_prior = dist.LogNormal(0.0, 1.0) +variance_prior = dist.LogNormal(0.0, 1.0) +period_prior = dist.LogNormal(0.0, 0.5) +noise_prior = dist.LogNormal(0.0, 1.0) + +# Define Kernel with priors +# We can explicitly attach priors to the parameters +kernel = gpx.kernels.RBF( + lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior), + variance=gpx.parameters.PositiveReal(1.0, prior=variance_prior), +) * gpx.kernels.Periodic( + lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior), + period=gpx.parameters.PositiveReal(1.0, prior=period_prior), +) + meanf = gpx.mean_functions.Zero() prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) # We will use a ConjugatePosterior since we assume Gaussian noise -likelihood = gpx.likelihoods.Gaussian(num_datapoints=N) +likelihood = gpx.likelihoods.Gaussian( + num_datapoints=N, + obs_stddev=gpx.parameters.NonNegativeReal(1.0, prior=noise_prior), +) posterior = prior * likelihood # We initialise the model parameters. @@ -111,7 +130,8 @@ def model(X, Y): # 2. Register GP parameters # This automatically samples parameters from the GPJax model # and returns a model with updated values. - # We can specify custom priors if needed, but we'll rely on defaults here. + # We attached priors to the parameters during model definition, + # so register_parameters will use those. # register_parameters modifies the model in-place (and returns it). # Since Numpyro re-runs this function, we are overwriting the parameters # of the same object repeatedly, which is fine as they are completely determined @@ -150,67 +170,59 @@ def model(X, Y): samples = mcmc.get_samples() -# Helper to get predictions -def predict(rng_key, sample_idx): - # Reconstruct model with sampled values - - # Linear part - slope = samples["slope"][sample_idx] - intercept = samples["intercept"][sample_idx] - trend = slope * x + intercept +def predict_fn(X_new, Y_train): + # 1. Sample linear model parameters + slope = numpyro.sample("slope", dist.Normal(0.0, 2.0)) + intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0)) - # GP part - # We use numpyro.handlers.substitute to inject the sampled values into register_parameters - # to reconstruct the GP model state for this sample. - sample_dict = {k: v[sample_idx] for k, v in samples.items()} + # Calculate residuals + trend_train = slope * x + intercept + residuals = Y_train - trend_train - with numpyro.handlers.substitute(data=sample_dict): - # We call register_parameters again to update the posterior object with this sample's values - p_posterior = register_parameters(posterior) + # 2. Register GP parameters + p_posterior = register_parameters(posterior) - # Now predict on residuals - residuals = y - trend + # Create dataset for residuals D_resid = gpx.Dataset(X=x, y=residuals) - latent_dist = p_posterior.predict(x, train_data=D_resid) - predictive_mean = latent_dist.mean - predictive_std = latent_dist.stddev() + # 3. Compute latent GP distribution + latent_dist = p_posterior.predict(X_new, train_data=D_resid) - return trend + predictive_mean, predictive_std + # 4. Sample latent function values + f = numpyro.sample("f", latent_dist) + f = f.reshape((-1, 1)) + # 5. Compute and return total prediction + total_prediction = slope * X_new + intercept + f + numpyro.deterministic("y_pred", total_prediction) + return total_prediction -# Plot -plt.figure(figsize=(12, 6)) -plt.scatter(x, y, alpha=0.5, label="Data", color="gray") -plt.plot(x, y_clean, "k--", label="True Signal") -# Compute mean prediction (using mean of samples for efficiency) -mean_slope = jnp.mean(samples["slope"]) -mean_intercept = jnp.mean(samples["intercept"]) -mean_trend = mean_slope * x + mean_intercept +# Create predictive utility +predictive = Predictive(predict_fn, posterior_samples=samples) -mean_samples = {k: jnp.mean(v, axis=0) for k, v in samples.items()} -with numpyro.handlers.substitute(data=mean_samples): - p_posterior_mean = register_parameters(posterior) +# Generate predictions +predictions = predictive(jr.key(1), X_new=x, Y_train=y) +y_pred = predictions["y_pred"] -residuals_mean = y - mean_trend -D_resid_mean = gpx.Dataset(X=x, y=residuals_mean) -latent_dist = p_posterior_mean.predict(x, train_data=D_resid_mean) -pred_mean = latent_dist.mean -pred_std = latent_dist.stddev() +# Compute statistics +mean_prediction = jnp.mean(y_pred, axis=0) +std_prediction = jnp.std(y_pred, axis=0) -total_mean = mean_trend.flatten() + pred_mean.flatten() -std_flat = pred_std.flatten() +# Plot +plt.figure(figsize=(12, 6)) +plt.scatter(x, y, alpha=0.5, label="Data", color="gray") +plt.plot(x, y_clean, "k--", label="True Signal") -plt.plot(x, total_mean, "b-", label="Posterior Mean") +plt.plot(x, mean_prediction, "b-", label="Posterior Mean") plt.fill_between( x.flatten(), - total_mean - 2 * std_flat, - total_mean + 2 * std_flat, + mean_prediction.flatten() - 2 * std_prediction.flatten(), + mean_prediction.flatten() + 2 * std_prediction.flatten(), color="b", alpha=0.2, label="95% CI (GP Uncertainty)", ) plt.legend() -plt.show() +# plt.show() diff --git a/gpjax/distributions.py b/gpjax/distributions.py index 0bb00967f..28fe41294 100644 --- a/gpjax/distributions.py +++ b/gpjax/distributions.py @@ -68,6 +68,9 @@ def sample(self, key, sample_shape=()): def affine_transformation(_x): return self.loc + covariance_root @ _x + if not sample_shape: + return affine_transformation(white_noise) + return vmap(affine_transformation)(white_noise) @property diff --git a/gpjax/numpyro_extras.py b/gpjax/numpyro_extras.py index d0845d497..0df6f6eda 100644 --- a/gpjax/numpyro_extras.py +++ b/gpjax/numpyro_extras.py @@ -6,29 +6,10 @@ import numpyro.distributions as dist from gpjax.parameters import ( - FillTriangularTransform, Parameter, ) -def _get_default_prior(tag, shape, ndim): - if tag in ("positive", "non_negative"): - return dist.LogNormal(0.0, 1.0).expand(shape).to_event(ndim) - if tag == "real": - return dist.Normal(0.0, 1.0).expand(shape).to_event(ndim) - if tag == "sigmoid": - return dist.Uniform(0.0, 1.0).expand(shape).to_event(ndim) - if tag == "lower_triangular": - N = shape[-1] - K = N * (N + 1) // 2 - batch_shape = shape[:-2] - base_shape = batch_shape + (K,) - base_dist = dist.Normal(0.0, 1.0).expand(base_shape).to_event(1) - td = dist.TransformedDistribution(base_dist, FillTriangularTransform()) - return td.to_event(len(batch_shape)) - return dist.Normal(0.0, 1.0).expand(shape).to_event(ndim) - - def register_parameters( model: nnx.Module, priors: tp.Dict[str, dist.Distribution] | None = None, @@ -71,7 +52,12 @@ def _param_callback(path, param): # Determine prior prior = priors.get(name) if prior is None: - prior = _get_default_prior(param.tag, param.value.shape, param.value.ndim) + # Check for attached prior + numpyro_props = getattr(param, "numpyro_properties", {}) + prior = numpyro_props.get("prior") + + if prior is None: + return param # Sample value = numpyro.sample(name, prior) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 0bdcb9957..ef815e48f 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import jax.tree_util as jtu from jax.typing import ArrayLike +import numpyro.distributions as dist import numpyro.distributions.transforms as npt T = tp.TypeVar("T", bound=tp.Union[ArrayLike, list[float]]) @@ -170,13 +171,22 @@ class Parameter(nnx.Variable[T]): """ - def __init__(self, value: T, tag: ParameterTag, **kwargs): + def __init__( + self, + value: T, + tag: ParameterTag, + prior: tp.Optional[dist.Distribution] = None, + **kwargs, + ): _check_is_arraylike(value) super().__init__(value=jnp.asarray(value), **kwargs) # nnx.Variable metadata must be set via set_metadata (direct setattr is disallowed). self.set_metadata(tag=tag) + self.numpyro_properties: tp.Dict[str, tp.Any] = {} + if prior is not None: + self.numpyro_properties["prior"] = prior @property def tag(self) -> ParameterTag: diff --git a/tests/test_heteroscedastic.py b/tests/test_heteroscedastic.py index 6f1bbfae1..3d269bdec 100644 --- a/tests/test_heteroscedastic.py +++ b/tests/test_heteroscedastic.py @@ -178,7 +178,7 @@ def test_softplus_transform_numerical_accuracy(mean: float, variance: float, see mc_inv_variance = jnp.mean(1.0 / transformed_samples) # Allow for some MC error and quadrature approximation error - rtol = 0.1 + rtol = 0.15 assert jnp.allclose(moments.variance, mc_variance, rtol=rtol) assert jnp.allclose(moments.log_variance, mc_log_variance, rtol=rtol) assert jnp.allclose(moments.inv_variance, mc_inv_variance, rtol=rtol) diff --git a/tests/test_numpyro_extras.py b/tests/test_numpyro_extras.py index d7a526a45..8afb56be6 100644 --- a/tests/test_numpyro_extras.py +++ b/tests/test_numpyro_extras.py @@ -1,32 +1,97 @@ from flax import nnx +from hypothesis import ( + given, + strategies as st, +) +from hypothesis.extra.numpy import arrays import jax.numpy as jnp +import numpy as np import numpyro.distributions as dist from numpyro.handlers import ( seed, trace, ) -from gpjax.numpyro_extras import register_parameters +from gpjax.numpyro_extras import ( + register_parameters, +) from gpjax.parameters import ( + LowerTriangular, + NonNegativeReal, PositiveReal, Real, + SigmoidBounded, ) +# --- Strategies --- + + +def valid_shapes(min_dims=0, max_dims=2): + return st.integers(min_dims, max_dims).flatmap( + lambda d: st.lists(st.integers(1, 3), min_size=d, max_size=d).map(tuple) + ) -class MockSubModule(nnx.Module): - def __init__(self): - self.c = Real(jnp.array(3.0)) +def real_arrays(shape=None, min_value=None, max_value=None): + return arrays( + dtype=np.float64, + shape=shape if shape is not None else valid_shapes(), + elements=st.floats( + min_value=min_value, + max_value=max_value, + allow_nan=False, + allow_infinity=False, + width=64, + ), + ).map(jnp.array) -class MockModel(nnx.Module): - def __init__(self): - self.a = PositiveReal(jnp.array(1.0)) - self.b = Real(jnp.array(2.0)) - self.submodule = MockSubModule() +def lower_triangular_matrices(n=2): + return arrays( + dtype=np.float64, + shape=(n, n), + elements=st.floats(min_value=-2.0, max_value=2.0, width=64), + ).map(lambda x: jnp.tril(jnp.array(x))) -def test_register_parameters_default_priors(): - model = MockModel() + +class FlexibleMockModel(nnx.Module): + def __init__( + self, + pos_val, + real_val, + non_neg_val, + sigmoid_val, + lower_val, + vec_val, + pos_prior=None, + real_prior=None, + non_neg_prior=None, + sigmoid_prior=None, + lower_prior=None, + vec_prior=None, + ): + self.pos = PositiveReal(pos_val, prior=pos_prior) + self.real = Real(real_val, prior=real_prior) + self.non_neg = NonNegativeReal(non_neg_val, prior=non_neg_prior) + self.sigmoid = SigmoidBounded(sigmoid_val, prior=sigmoid_prior) + self.lower = LowerTriangular(lower_val, prior=lower_prior) + self.vec = Real(vec_val, prior=vec_prior) + + +@given( + pos_val=real_arrays(shape=(1,), min_value=1e-3, max_value=10.0), + real_val=real_arrays(shape=(1,), min_value=-10.0, max_value=10.0), + non_neg_val=real_arrays(shape=(1,), min_value=0.0, max_value=10.0), + sigmoid_val=real_arrays(shape=(1,), min_value=1e-3, max_value=0.999), + lower_val=lower_triangular_matrices(n=2), + vec_val=real_arrays(shape=(2,), min_value=-10.0, max_value=10.0), +) +def test_no_priors_no_sampling( + pos_val, real_val, non_neg_val, sigmoid_val, lower_val, vec_val +): + model = FlexibleMockModel( + pos_val, real_val, non_neg_val, sigmoid_val, lower_val, vec_val + ) def model_fn(): return register_parameters(model) @@ -34,45 +99,121 @@ def model_fn(): with seed(rng_seed=0): tr = trace(model_fn).get_trace() - # Check sites exist - assert "a" in tr - assert "b" in tr - assert "submodule.c" in tr - - # Check distributions - # a: PositiveReal -> LogNormal - # LogNormal is a TransformedDistribution. - assert isinstance(tr["a"]["fn"], dist.LogNormal) - - # b: Real -> Normal - # If scalar, to_event(0) returns Normal. If vector, to_event(1) returns Independent. - if isinstance(tr["b"]["fn"], dist.Independent): - assert isinstance(tr["b"]["fn"].base_dist, dist.Normal) - else: - assert isinstance(tr["b"]["fn"], dist.Normal) - - # submodule.c: Real -> Normal - if isinstance(tr["submodule.c"]["fn"], dist.Independent): - assert isinstance(tr["submodule.c"]["fn"].base_dist, dist.Normal) - else: - assert isinstance(tr["submodule.c"]["fn"], dist.Normal) - - # Check values in updated model + # Should be empty because no priors were attached or passed + assert len(tr) == 0 + + +@given( + pos_val=real_arrays(shape=(1,), min_value=1e-3, max_value=10.0), + real_val=real_arrays(shape=(1,), min_value=-10.0, max_value=10.0), + non_neg_val=real_arrays(shape=(1,), min_value=0.0, max_value=10.0), + sigmoid_val=real_arrays(shape=(1,), min_value=1e-3, max_value=0.999), + lower_val=lower_triangular_matrices(n=2), + vec_val=real_arrays(shape=(2,), min_value=-10.0, max_value=10.0), +) +def test_explicit_priors_sampling( + pos_val, real_val, non_neg_val, sigmoid_val, lower_val, vec_val +): + model = FlexibleMockModel( + pos_val, real_val, non_neg_val, sigmoid_val, lower_val, vec_val + ) + + # Define priors compatible with shapes + priors = { + "pos": dist.LogNormal(0.0, 1.0).expand(pos_val.shape).to_event(pos_val.ndim), + "real": dist.Normal(0.0, 1.0).expand(real_val.shape).to_event(real_val.ndim), + "non_neg": dist.LogNormal(0.0, 1.0) + .expand(non_neg_val.shape) + .to_event(non_neg_val.ndim), + "sigmoid": dist.Uniform(0.0, 1.0) + .expand(sigmoid_val.shape) + .to_event(sigmoid_val.ndim), + # For LowerTriangular, user must provide a prior over the full matrix shape + # OR a transformed prior. Here we simulate providing a prior over the full shape + # just to ensure the site is registered. + "lower": dist.Normal(0.0, 1.0).expand(lower_val.shape).to_event(lower_val.ndim), + "vec": dist.Normal(0.0, 1.0).expand(vec_val.shape).to_event(vec_val.ndim), + } + + def model_fn(): + return register_parameters(model, priors=priors) + with seed(rng_seed=0): - updated_model = model_fn() + tr = trace(model_fn).get_trace() + + assert "pos" in tr + assert "real" in tr + assert "non_neg" in tr + assert "sigmoid" in tr + assert "lower" in tr + assert "vec" in tr - assert jnp.allclose(updated_model.a.value, tr["a"]["value"]) - assert jnp.allclose(updated_model.b.value, tr["b"]["value"]) - assert jnp.allclose(updated_model.submodule.c.value, tr["submodule.c"]["value"]) - # Verify original values were different (random sample != 1.0) - assert not jnp.allclose(updated_model.a.value, 1.0) +@given( + pos_val=real_arrays(shape=(1,), min_value=1e-3, max_value=10.0), + real_val=real_arrays(shape=(1,), min_value=-10.0, max_value=10.0), + non_neg_val=real_arrays(shape=(1,), min_value=0.0, max_value=10.0), + sigmoid_val=real_arrays(shape=(1,), min_value=1e-3, max_value=0.999), + lower_val=lower_triangular_matrices(n=2), + vec_val=real_arrays(shape=(2,), min_value=-10.0, max_value=10.0), +) +def test_attached_priors_sampling( + pos_val, real_val, non_neg_val, sigmoid_val, lower_val, vec_val +): + # Create priors + pos_prior = dist.LogNormal(0.0, 1.0).expand(pos_val.shape).to_event(pos_val.ndim) + real_prior = dist.Normal(0.0, 1.0).expand(real_val.shape).to_event(real_val.ndim) + # Attach only to a subset to verify mixed behavior + model = FlexibleMockModel( + pos_val, + real_val, + non_neg_val, + sigmoid_val, + lower_val, + vec_val, + pos_prior=pos_prior, + real_prior=real_prior, + ) + + def model_fn(): + return register_parameters(model) + with seed(rng_seed=0): + tr = trace(model_fn).get_trace() -def test_register_parameters_custom_priors(): - model = MockModel() + assert "pos" in tr + assert "real" in tr + assert "non_neg" not in tr + assert "vec" not in tr - priors = {"a": dist.Gamma(2.0, 2.0), "submodule.c": dist.Cauchy(0.0, 1.0)} + +@given( + pos_val=real_arrays(shape=(1,), min_value=1e-3, max_value=10.0), +) +def test_prior_precedence(pos_val): + # Attached prior + attached_prior = dist.Gamma(2.0, 1.0).expand(pos_val.shape).to_event(pos_val.ndim) + + # Explicit prior (different) + explicit_prior = dist.Exponential(1.0).expand(pos_val.shape).to_event(pos_val.ndim) + + # Model with attached prior + # We need dummy values for others + dummy_real = jnp.array([0.0]) + dummy_lower = jnp.eye(2) + dummy_vec = jnp.zeros(2) + + model = FlexibleMockModel( + pos_val, + dummy_real, + dummy_real, + dummy_real, + dummy_lower, + dummy_vec, + pos_prior=attached_prior, + ) + + priors = {"pos": explicit_prior} def model_fn(): return register_parameters(model, priors=priors) @@ -80,24 +221,36 @@ def model_fn(): with seed(rng_seed=0): tr = trace(model_fn).get_trace() - assert isinstance(tr["a"]["fn"], dist.Gamma) - # b should use default (Normal wrapped in Independent or Normal) - if isinstance(tr["b"]["fn"], dist.Independent): - assert isinstance(tr["b"]["fn"].base_dist, dist.Normal) - else: - assert isinstance(tr["b"]["fn"], dist.Normal) - assert isinstance(tr["submodule.c"]["fn"], dist.Cauchy) + # Check that the sampled site corresponds to the explicit prior + # We can check the distribution object + # Structure might be Independent(Expanded(Exponential)) or Independent(Exponential) + d = tr["pos"]["fn"] + while hasattr(d, "base_dist"): + d = d.base_dist + assert isinstance(d, dist.Exponential) + +def test_register_parameters_nested_prefix(): + class NestedModel(nnx.Module): + def __init__(self): + self.inner = FlexibleMockModel( + jnp.array([1.0]), + jnp.array([0.0]), + jnp.array([1.0]), + jnp.array([0.5]), + jnp.eye(2), + jnp.zeros(2), + ) -def test_register_parameters_prefix(): - model = MockModel() + model = NestedModel() + # Explicit prior for nested + priors = {"outer.inner.pos": dist.LogNormal(0.0, 1.0).expand((1,)).to_event(1)} def model_fn(): - return register_parameters(model, prefix="foo") + return register_parameters(model, prefix="outer", priors=priors) with seed(rng_seed=0): tr = trace(model_fn).get_trace() - assert "foo.a" in tr - assert "foo.b" in tr - assert "foo.submodule.c" in tr + assert "outer.inner.pos" in tr + assert "outer.inner.real" not in tr diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 0cff50b35..ff24912a1 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,111 +1,218 @@ from flax import nnx -from jax import jit -from jax.experimental import checkify +from hypothesis import ( + given, + strategies as st, +) +from hypothesis.extra.numpy import arrays import jax.numpy as jnp +import numpy as np import pytest from gpjax.parameters import ( DEFAULT_BIJECTION, + FillTriangularTransform, LowerTriangular, NonNegativeReal, - Parameter, PositiveReal, Real, SigmoidBounded, - _check_in_bounds, - _check_is_lower_triangular, - _check_is_positive, - _check_is_square, - _safe_assert, transform, ) +# --- Strategies --- -@pytest.mark.parametrize( - "param, value", - [ - (NonNegativeReal, 0.0), - (NonNegativeReal, 1.0), - (PositiveReal, 1.0), - (Real, 2.0), - (SigmoidBounded, 0.5), - ], -) -def test_transform(param, value): - # Create mock parameters and bijectors - params = nnx.State( - { - "param1": param(value), - "param2": Parameter(2.0, tag="real"), - } + +def valid_shapes(min_dims=0, max_dims=2): + return st.integers(min_dims, max_dims).flatmap( + lambda d: st.lists(st.integers(1, 5), min_size=d, max_size=d).map(tuple) ) - # Test forward transformation - t_params = transform(params, DEFAULT_BIJECTION) - t_param1_expected = DEFAULT_BIJECTION[params["param1"].tag](value) - assert jnp.allclose(t_params["param1"].value, t_param1_expected) - assert jnp.allclose(t_params["param2"].value, 2.0) - - -@pytest.mark.parametrize( - "param, tag", - [ - (NonNegativeReal(0.0), "non_negative"), - (PositiveReal(1.0), "positive"), - (Real(2.0), "real"), - (SigmoidBounded(0.5), "sigmoid"), - (LowerTriangular(jnp.eye(2)), "lower_triangular"), - ], -) -def test_default_tags(param, tag): - assert param.tag == tag +def real_arrays(shape_strategy=valid_shapes(), min_value=None, max_value=None): + return arrays( + dtype=np.float64, + shape=shape_strategy, + elements=st.floats( + min_value=min_value, + max_value=max_value, + allow_nan=False, + allow_infinity=False, + width=64, + ), + ).map(jnp.array) + + +# --- Parameter Tests --- + + +@given(value=real_arrays()) +def test_real_parameter(value): + # Should accept any real value + p = Real(value) + assert jnp.array_equal(p.value, value) + assert p.tag == "real" -def test_check_is_positive(): - # Check singleton - _safe_assert(_check_is_positive, jnp.array(3.0)) - # Check array - _safe_assert(_check_is_positive, jnp.array([3.0, 4.0])) - # Check negative singleton +@given(value=real_arrays(min_value=1e-6, max_value=1e6)) +def test_positive_real_valid(value): + p = PositiveReal(value) + assert jnp.array_equal(p.value, value) + assert p.tag == "positive" + + +@given(value=real_arrays(max_value=-1e-6)) +def test_positive_real_invalid(value): with pytest.raises(ValueError): - _safe_assert(_check_is_positive, jnp.array(-3.0)) + PositiveReal(value) + + +@given(value=real_arrays(min_value=0.0, max_value=1e6)) +def test_non_negative_real_valid(value): + p = NonNegativeReal(value) + assert jnp.array_equal(p.value, value) + assert p.tag == "non_negative" - # Check negative array + +@given(value=real_arrays(max_value=-1e-6)) +def test_non_negative_real_invalid(value): with pytest.raises(ValueError): - _safe_assert(_check_is_positive, jnp.array([-3.0, 4.0])) + NonNegativeReal(value) - # Test that functions wrapping _check_is_positive are jittable - def _dummy_fn(value): - _safe_assert(_check_is_positive, value) - jitted_fn = jit(checkify.checkify(_dummy_fn)) - jitted_fn(jnp.array(3.0)) +@given(value=real_arrays(min_value=0.0, max_value=1.0)) +def test_sigmoid_bounded_valid(value): + p = SigmoidBounded(value) + assert jnp.array_equal(p.value, value) + assert p.tag == "sigmoid" -def test_check_is_square(): - # Check square matrix - _safe_assert(_check_is_square, jnp.full((2, 2), 1.0)) - # Check non-square matrix +@given(value=real_arrays(min_value=1.001, max_value=1e6)) +def test_sigmoid_bounded_invalid_high(value): with pytest.raises(ValueError): - _safe_assert(_check_is_square, jnp.full((2, 3), 1.0)) + SigmoidBounded(value) -def test_check_is_lower_triangular(): - # Check lower triangular matrix - _safe_assert(_check_is_lower_triangular, jnp.tril(jnp.eye(2))) - # Check non-lower triangular matrix +@given(value=real_arrays(max_value=-0.001)) +def test_sigmoid_bounded_invalid_low(value): with pytest.raises(ValueError): - _safe_assert(_check_is_lower_triangular, jnp.linspace(0.0, 1.0, 4)) + SigmoidBounded(value) + + +# Strategy for lower triangular matrices +def lower_triangular_matrices(n_min=1, n_max=5): + return st.integers(n_min, n_max).flatmap( + lambda n: arrays( + dtype=np.float64, + shape=(n, n), + elements=st.floats(min_value=-10, max_value=10, width=64), + ).map(lambda x: jnp.tril(jnp.array(x))) + ) -def test_check_in_bounds(): - # Check in bounds - _safe_assert( - _check_in_bounds, jnp.array(0.5), low=jnp.array(0.0), high=jnp.array(1.0) +@given(value=lower_triangular_matrices()) +def test_lower_triangular_valid(value): + p = LowerTriangular(value) + assert jnp.array_equal(p.value, value) + assert p.tag == "lower_triangular" + + +@given( + n=st.integers(2, 5), + data=st.data(), +) +def test_lower_triangular_invalid(n, data): + # Generate a square matrix + mat = data.draw( + arrays( + dtype=np.float64, + shape=(n, n), + elements=st.floats(min_value=-10, max_value=10, width=64), + ).map(jnp.array) + ) + # Ensure it's NOT lower triangular by setting an upper element + row, col = np.triu_indices(n, 1) + if len(row) > 0: + # Pick a random upper triangular index + idx = data.draw(st.integers(0, len(row) - 1)) + r, c = row[idx], col[idx] + # Set to non-zero + mat = mat.at[r, c].set(1.0) + + with pytest.raises(ValueError): + LowerTriangular(mat) + + +# --- Transform Tests --- + + +@given( + param_class=st.sampled_from([NonNegativeReal, PositiveReal, Real, SigmoidBounded]), + data=st.data(), +) +def test_transform_roundtrip(param_class, data): + # Generate valid value for the parameter type + if param_class == NonNegativeReal: + val = data.draw(real_arrays(min_value=0.0, max_value=10.0)) + elif param_class == PositiveReal: + val = data.draw(real_arrays(min_value=1e-3, max_value=10.0)) + elif param_class == Real: + val = data.draw(real_arrays(min_value=-10.0, max_value=10.0)) + elif param_class == SigmoidBounded: + val = data.draw(real_arrays(min_value=1e-3, max_value=1.0 - 1e-3)) + else: + return # Should not happen + + params = nnx.State({"p": param_class(val)}) + + # Forward + t_params = transform(params, DEFAULT_BIJECTION, inverse=False) + + # Inverse + inv_params = transform(t_params, DEFAULT_BIJECTION, inverse=True) + + # Check + assert jnp.allclose(inv_params["p"].value, val, atol=1e-5, rtol=1e-5) + + +# --- FillTriangularTransform Tests --- + + +@given(n=st.integers(1, 10)) +def test_fill_triangular_shapes(n): + k = n * (n + 1) // 2 + vec = jnp.zeros(k) + ft = FillTriangularTransform() + + mat = ft(vec) + assert mat.shape == (n, n) + assert jnp.allclose(mat, jnp.tril(mat)) + + +@given(n=st.integers(1, 5), data=st.data()) +def test_fill_triangular_roundtrip_hypothesis(n, data): + k = n * (n + 1) // 2 + vec = data.draw( + arrays( + dtype=np.float64, + shape=(k,), + elements=st.floats(min_value=-5.0, max_value=5.0, width=64), + ).map(jnp.array) ) - # Check out of bounds + + ft = FillTriangularTransform() + + # Forward + mat = ft(vec) + assert mat.shape == (n, n) + + # Inverse + vec_recon = ft.inv(mat) + + assert jnp.allclose(vec, vec_recon) + + +def test_fill_triangular_errors(): + ft = FillTriangularTransform() + # Invalid size with pytest.raises(ValueError): - _safe_assert( - _check_in_bounds, jnp.array(1.5), low=jnp.array(0.0), high=jnp.array(1.0) - ) + ft(jnp.zeros(4)) # n=2 -> k=3. n=3 -> k=6. 4 is invalid. From d9ebd756f5afe68402108f55412dff8bfc5c8fc8 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 24 Nov 2025 23:54:36 +0100 Subject: [PATCH 3/5] Correct preview workflow --- .github/workflows/test_docs.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_docs.yml b/.github/workflows/test_docs.yml index 77dd9ff17..d04d0bf08 100644 --- a/.github/workflows/test_docs.yml +++ b/.github/workflows/test_docs.yml @@ -64,9 +64,11 @@ jobs: if: github.event_name == 'pull_request' runs-on: ubuntu-latest environment: - name: github-pages + name: docs-preview url: ${{ steps.deployment.outputs.page_url }} steps: - name: Deploy MkDocs preview id: deployment uses: actions/deploy-pages@v4 + with: + preview: true From af8e55d6494bcf766ba99f42b8ca20bd280a05af Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 25 Nov 2025 00:38:57 +0100 Subject: [PATCH 4/5] Add more complicated DGP --- examples/numpyro_integration.py | 122 ++++++++++++++++---------------- mkdocs.yml | 3 +- 2 files changed, 63 insertions(+), 62 deletions(-) diff --git a/examples/numpyro_integration.py b/examples/numpyro_integration.py index f643b33ef..ed5979526 100644 --- a/examples/numpyro_integration.py +++ b/examples/numpyro_integration.py @@ -21,11 +21,13 @@ # We will look at a scenario where we have a structured mean function (a linear model) and a GP capturing the residuals. We will infer the parameters of both the linear model and the GP jointly. # %% +import numpyro +numpyro.set_host_device_count(4) + from jax import config import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt -import numpyro import numpyro.distributions as dist from numpyro.infer import ( MCMC, @@ -38,31 +40,37 @@ config.update("jax_enable_x64", True) -key = jr.key(42) +key = jr.key(123) # %% [markdown] # ## Data Generation # -# We generate a synthetic dataset that consists of a linear trend, a periodic component, and some noise. +# We generate a synthetic dataset that consists of a linear trend together with a locally periodic residual signal whose amplitude varies over time, an additional high-frequency component, and a local bump. This richer structure highlights how a GP can capture deviations from the explicit linear model. # %% -N = 100 -x = jnp.sort(jr.uniform(key, shape=(N, 1), minval=0.0, maxval=10.0), axis=0) - -# True parameters -true_slope = 0.5 -true_intercept = 2.0 -true_period = 2.0 -true_lengthscale = 1.0 -true_noise = 0.1 +N = 200 +key_x, key_y = jr.split(key) +x = jnp.sort(jr.uniform(key_x, shape=(N, 1), minval=0.0, maxval=10.0), axis=0) + +# True parameters for the linear trend +true_slope = 0.45 +true_intercept = 1.5 + +# Structured residual signal captured by the GP +slow_period = 6.0 +fast_period = 0.8 +amplitude_envelope = 1.0 + 0.5 * jnp.sin(2 * jnp.pi * x / slow_period) +modulated_periodic = amplitude_envelope * jnp.sin(2 * jnp.pi * x / fast_period) +high_frequency_component = 0.3 * jnp.cos(2 * jnp.pi * x / 0.35) +localised_bump = 1.2 * jnp.exp(-0.5 * ((x - 7.0) / 0.45) ** 2) -# Signal linear_trend = true_slope * x + true_intercept -periodic_signal = jnp.sin(2 * jnp.pi * x / true_period) -y_clean = linear_trend + periodic_signal +residual_signal = modulated_periodic + high_frequency_component + localised_bump +y_clean = linear_trend + residual_signal -# Observations -y = y_clean + true_noise * jr.normal(key, shape=x.shape) +# Observations with homoscedastic noise +observation_noise = 0.3 +y = y_clean + observation_noise * jr.normal(key_y, shape=x.shape) plt.figure(figsize=(10, 5)) plt.scatter(x, y, label="Data", alpha=0.6) @@ -82,23 +90,30 @@ period_prior = dist.LogNormal(0.0, 0.5) noise_prior = dist.LogNormal(0.0, 1.0) -# Define Kernel with priors # We can explicitly attach priors to the parameters -kernel = gpx.kernels.RBF( - lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior), - variance=gpx.parameters.PositiveReal(1.0, prior=variance_prior), -) * gpx.kernels.Periodic( - lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior), - period=gpx.parameters.PositiveReal(1.0, prior=period_prior), +lengthscale = gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior) +variance = gpx.parameters.PositiveReal(1.0, prior=variance_prior) +period = gpx.parameters.PositiveReal(1.0, prior=period_prior) +noise = gpx.parameters.NonNegativeReal(1.0, prior=noise_prior) + +# Define Kernel with priors +stationary_component = gpx.kernels.RBF( + lengthscale=lengthscale, + variance=variance, ) +periodic_component = gpx.kernels.Periodic( + lengthscale=lengthscale, + period=period, +) +kernel = stationary_component * periodic_component -meanf = gpx.mean_functions.Zero() +meanf = gpx.mean_functions.Constant() prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) # We will use a ConjugatePosterior since we assume Gaussian noise likelihood = gpx.likelihoods.Gaussian( num_datapoints=N, - obs_stddev=gpx.parameters.NonNegativeReal(1.0, prior=noise_prior), + obs_stddev=gpx.parameters.NonNegativeReal(1.0, prior=dist.LogNormal(0.0, 1.0)), ) posterior = prior * likelihood @@ -118,7 +133,7 @@ # %% -def model(X, Y): +def model(X, Y, X_new=None): # 1. Sample linear model parameters slope = numpyro.sample("slope", dist.Normal(0.0, 2.0)) intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0)) @@ -148,6 +163,15 @@ def model(X, Y): # 4. Add to potential numpyro.factor("gp_log_lik", mll) + # Optional prediction branch for use with Predictive + if X_new is not None: + latent_dist = p_posterior.predict(X_new, train_data=D_resid) + f_new = numpyro.sample("f_new", latent_dist) + f_new = f_new.reshape((-1, 1)) + total_prediction = slope * X_new + intercept + f_new + numpyro.deterministic("y_pred", total_prediction) + return total_prediction + # %% [markdown] # ## Running MCMC @@ -156,8 +180,8 @@ def model(X, Y): # %% nuts_kernel = NUTS(model) -mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1) -mcmc.run(jr.key(0), x, y) +mcmc = MCMC(nuts_kernel, num_warmup=1500, num_samples=2000, num_chains=4, chain_method="parallel") +mcmc.run(jr.key(123), x, y) mcmc.print_summary() @@ -167,42 +191,18 @@ def model(X, Y): # We extract the samples and plot the predictions. # %% +# Draw posterior samples for downstream use samples = mcmc.get_samples() - -def predict_fn(X_new, Y_train): - # 1. Sample linear model parameters - slope = numpyro.sample("slope", dist.Normal(0.0, 2.0)) - intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0)) - - # Calculate residuals - trend_train = slope * x + intercept - residuals = Y_train - trend_train - - # 2. Register GP parameters - p_posterior = register_parameters(posterior) - - # Create dataset for residuals - D_resid = gpx.Dataset(X=x, y=residuals) - - # 3. Compute latent GP distribution - latent_dist = p_posterior.predict(X_new, train_data=D_resid) - - # 4. Sample latent function values - f = numpyro.sample("f", latent_dist) - f = f.reshape((-1, 1)) - - # 5. Compute and return total prediction - total_prediction = slope * X_new + intercept + f - numpyro.deterministic("y_pred", total_prediction) - return total_prediction - - -# Create predictive utility -predictive = Predictive(predict_fn, posterior_samples=samples) +# Create predictive utility that reuses the original model +predictive = Predictive( + model, + posterior_samples=samples, + return_sites=["y_pred"], +) # Generate predictions -predictions = predictive(jr.key(1), X_new=x, Y_train=y) +predictions = predictive(jr.key(1), x, y, X_new=x) y_pred = predictions["y_pred"] # Compute statistics diff --git a/mkdocs.yml b/mkdocs.yml index 0526e19bd..5cf03b86d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,8 +28,9 @@ nav: - Stochastic sparse GPs: _examples/uncollapsed_vi.md - Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md - Heteroscedastic Inference: _examples/heteroscedastic_inference.md - - Numpyro Integration: _examples/numpyro_integration.md - Log-Gaussian Cox Process: _examples/lgcp_numpyro.md + - ๐Ÿงช Experimental: + - Numpyro Integration: _examples/numpyro_integration.md - ๐Ÿ“– Guides for customisation: - Kernels: _examples/constructing_new_kernels.md - Likelihoods: _examples/likelihoods_guide.md From b6269a7643aafd36fad2e68c05d33d29c8f34dfd Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 25 Nov 2025 21:16:23 +0100 Subject: [PATCH 5/5] Add spatial GP --- examples/spatial_linear_gp.py | 283 ++++++++++++++++++++++++++++++++++ mkdocs.yml | 1 + 2 files changed, 284 insertions(+) create mode 100644 examples/spatial_linear_gp.py diff --git a/examples/spatial_linear_gp.py b/examples/spatial_linear_gp.py new file mode 100644 index 000000000..b07403516 --- /dev/null +++ b/examples/spatial_linear_gp.py @@ -0,0 +1,283 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Spatial Modelling: Linear Regression vs. Gaussian Processes +# +# In this notebook, we explore the benefits of combining structured mean functions with Gaussian +# Processes (GPs) for modelling spatial data. We will compare two approaches: +# 1. **A Baseline Linear Model**: A standard Bayesian linear regression that assumes the target +# variable is a linear combination of the inputs. +# 2. **A Joint Linear + GP Model**: A semi-parametric model that captures the global linear trend +# while using a GP to model the non-linear spatial residuals. +# +# Crucially, this example demonstrates the seamless integration between **GPJax** and **NumPyro**. +# We will show how `GPJax` defines the GP prior and likelihood, while `NumPyro` handles the +# Hamiltonian Monte Carlo (HMC) inference for both the linear coefficients and the GP +# hyperparameters simultaneously. + +# %% [markdown] +# ## 1. Setup and Data Simulation +# +# First, we import the necessary libraries. We enable 64-bit precision in JAX to ensure numerical +# stability during matrix decompositions. +# +# We simulate a 2D spatial dataset ($N=200$) on a domain $[0, 5] imes [0, 5]$. The true +# generating process consists of: +# * **A Linear Trend**: $y_{\text{lin}} = 2x_1 - 1x_2 + 1.5$ +# * **A Spatial Residual**: $y_{\text{res}} = \sin(x_1) \cos(x_2)$ +# * **Observation Noise**: $\epsilon \sim \mathcal{N}(0, 0.1^2)$ +# +# This structure effectively masks the non-linear signal within a dominant linear trend, posing a +# challenge for simple linear models. + +# %% +import jax +import jax.numpy as jnp +import jax.random as jr +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC, NUTS, Predictive +import gpjax as gpx +from gpjax.numpyro_extras import register_parameters +import matplotlib.pyplot as plt + +# Enable x64 precision for better stability +jax.config.update("jax_enable_x64", True) + +print("Spatial Linear GP Comparison Example") + +# --- Step 2: Data Simulation --- +N = 200 +key = jr.key(42) +key_x, key_noise = jr.split(key) + +# Simulate X in [0, 5] x [0, 5] +X = jr.uniform(key_x, shape=(N, 2), minval=0.0, maxval=5.0) + +# True Linear Trend +true_slope = jnp.array([2.0, -1.0]) +true_intercept = 1.5 +y_lin = X @ true_slope + true_intercept + +# Non-linear Spatial Residual +y_res = jnp.sin(X[:, 0]) * jnp.cos(X[:, 1]) + +# Total Signal + Noise +y_clean = y_lin + y_res +noise_std = 0.1 +y = y_clean + noise_std * jr.normal(key_noise, shape=y_clean.shape) + +print(f"Generated {N} data points.") + +# %% [markdown] +# ## 2. Baseline Linear Model +# +# We begin by defining a standard Bayesian linear regression model in NumPyro. This model assumes +# that the data can be fully explained by a hyperplane and Gaussian noise. +# +# $$\begin{aligned} \mathbf{w} &\sim \mathcal{N}(\mathbf{0}, 5\mathbf{I}) \\ +# b &\sim \mathcal{N}(0, 5) \\ +# \sigma &\sim \text{LogNormal}(0, 1) \\ +# \mathbf{y} &\sim \mathcal{N}(\mathbf{X}\mathbf{w} + b, \sigma^2 \mathbf{I}) \end{aligned} $$ +# +# We use the No-U-Turn Sampler (NUTS) to estimate the posterior distributions of the slope +# $\mathbf{w}$, intercept $b$, and noise $\sigma$. + +# %% +def linear_model(X, Y=None): + # Priors + slope = numpyro.sample("slope", dist.Normal(0.0, 5.0).expand([2])) + intercept = numpyro.sample("intercept", dist.Normal(0.0, 5.0)) + obs_noise = numpyro.sample("obs_noise", dist.LogNormal(0.0, 1.0)) + + # Mean function + mu = X @ slope + intercept + numpyro.deterministic("mu", mu) + + # Likelihood + numpyro.sample("obs", dist.Normal(mu, obs_noise), obs=Y) + +# Run MCMC for Linear Model +print("\nRunning MCMC for Baseline Linear Model...") +nuts_kernel_lin = NUTS(linear_model) +mcmc_lin = MCMC(nuts_kernel_lin, num_warmup=500, num_samples=1000, num_chains=1) +mcmc_lin.run(key, X, y) +mcmc_lin.print_summary() + +# %% [markdown] +# ## 3. Joint Linear + GP Model +# +# Now we define the joint model. Here, the GP accounts for the residuals that the linear model +# cannot explain. +# +# $$ y(\mathbf{x}) = \underbrace{\mathbf{w}^T \mathbf{x} + b}_{\text{Linear Mean}} + +# \underbrace{f(\mathbf{x})}_{\text{GP Residual}} + \epsilon $$ +# +# ### GPJax and NumPyro Integration +# +# This section highlights the interoperability between GPJax and NumPyro. +# +# 1. **GP Definition**: We define the GP prior in `GPJax` using an RBF kernel and a zero mean +# function (since the linear trend is handled explicitly). We attach `dist.LogNormal` priors to +# the kernel's hyperparameters (lengthscale and variance) directly within the GPJax object. +# 2. **`register_parameters`**: Inside the NumPyro model, we call +# `gpx.numpyro_extras.register_parameters(gp_posterior)`. This function traverses the GPJax +# object, identifies parameters with attached priors, and registers them as NumPyro sample +# sites. It returns a new GPJax object where the parameters have been replaced by the values +# sampled by NumPyro. +# 3. **Conjugate Marginal Log-Likelihood**: We compute the exact marginal log-likelihood (MLL) of +# the residuals under the GP prior using `gpx.objectives.conjugate_mll`. This term is added to +# the potential function using `numpyro.factor`, guiding the sampler. + +# %% +# GP Definition +lengthscale = gpx.parameters.PositiveReal(1.0, prior=dist.LogNormal(0.0, 1.0)) +variance = gpx.parameters.PositiveReal(1.0, prior=dist.LogNormal(0.0, 1.0)) + +# active_dims=[0, 1] ensures the kernel operates on both spatial dimensions +kernel = gpx.kernels.RBF(active_dims=[0, 1], lengthscale=lengthscale, variance=variance) +meanf = gpx.mean_functions.Zero() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) + +obs_stddev = gpx.parameters.NonNegativeReal(0.1, prior=dist.LogNormal(0.0, 1.0)) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=N, obs_stddev=obs_stddev) +gp_posterior = prior * likelihood + +def joint_model(X, Y, gp_posterior, X_new=None): + # 1. Sample Linear Model Parameters + slope = numpyro.sample("slope", dist.Normal(0.0, 5.0).expand([2])) + intercept = numpyro.sample("intercept", dist.Normal(0.0, 5.0)) + + trend = X @ slope + intercept + + # 2. Register GP Parameters with NumPyro + # This draws samples for lengthscale, variance, and obs_noise from their priors + p_posterior = register_parameters(gp_posterior) + + if Y is not None: + # Calculate residuals for the GP to model + residuals = Y - trend + # Reshape residuals to (N, 1) for GPJax Dataset + residuals = residuals.reshape(-1, 1) + D_resid = gpx.Dataset(X=X, y=residuals) + + # 3. Compute GP Marginal Log-Likelihood + mll = gpx.objectives.conjugate_mll(p_posterior, D_resid) + numpyro.factor("gp_log_lik", mll) + + if X_new is not None: + # Prediction logic + if Y is not None: + residuals = Y - trend + residuals = residuals.reshape(-1, 1) + D_resid = gpx.Dataset(X=X, y=residuals) + + # Compute predictive distribution for the GP component + latent_dist = p_posterior.predict(X_new, train_data=D_resid) + f_new = numpyro.sample("f_new", latent_dist) + f_new = f_new.reshape((-1, 1)) + + # Combine Linear Trend + GP Residual + total_prediction = (X_new @ slope + intercept).reshape(-1, 1) + f_new + numpyro.deterministic("y_pred", total_prediction) + +# Run MCMC for Joint Model +print("\nRunning MCMC for Joint Linear + GP Model...") +# Use a closure to pass the static gp_posterior object to the model +def joint_model_wrapper(X, Y, X_new=None): + joint_model(X, Y, gp_posterior, X_new) + +nuts_kernel_joint = NUTS(joint_model_wrapper) +mcmc_joint = MCMC(nuts_kernel_joint, num_warmup=500, num_samples=1000, num_chains=1) +mcmc_joint.run(key, X, y) +mcmc_joint.print_summary() + +# %% [markdown] +# ## 4. Comparison and Visualization +# +# We evaluate both models by comparing their Root Mean Squared Error (RMSE) against the true +# noise-free signal. We also visualise the predictions over the 2D domain. +# +# We expect the **Joint Model** to significantly outperform the linear baseline because it can +# capture the spatial correlations ($\sin(x_1)\cos(x_2)$) that the linear model ignores. + +# %% +# --- Step 6: Comparison & Visualization --- + +# 1. Prediction on Training Data (for RMSE) +# Linear Model +samples_lin = mcmc_lin.get_samples() +predictive_lin = Predictive(linear_model, samples_lin, return_sites=["mu"]) +preds_lin = predictive_lin(jr.key(1), X=X)["mu"] +mean_pred_lin = jnp.mean(preds_lin, axis=0) + +# Joint Model +samples_joint = mcmc_joint.get_samples() +predictive_joint = Predictive(joint_model_wrapper, samples_joint, return_sites=["y_pred"]) +preds_joint = predictive_joint(jr.key(2), X=X, Y=y, X_new=X)["y_pred"] +mean_pred_joint = jnp.mean(preds_joint, axis=0) + +# Calculate RMSE +rmse_lin = jnp.sqrt(jnp.mean((mean_pred_lin.flatten() - y_clean.flatten())**2)) +rmse_joint = jnp.sqrt(jnp.mean((mean_pred_joint.flatten() - y_clean.flatten())**2)) + +print(f"\nRMSE Comparison (vs True Signal):") +print(f"Linear Model: {rmse_lin:.4f}") +print(f"Joint Model: {rmse_joint:.4f}") + +# 2. Visualization on a Grid +n_grid = 30 +x1 = jnp.linspace(0, 5, n_grid) +x2 = jnp.linspace(0, 5, n_grid) +X1, X2 = jnp.meshgrid(x1, x2) +X_grid = jnp.column_stack([X1.ravel(), X2.ravel()]) + +# True Signal on Grid +y_grid_true = (X_grid @ true_slope + true_intercept) + (jnp.sin(X_grid[:, 0]) * jnp.cos(X_grid[:, 1])) + +# Linear Prediction on Grid +preds_lin_grid = predictive_lin(jr.key(3), X=X_grid)["mu"] +mean_pred_lin_grid = jnp.mean(preds_lin_grid, axis=0) + +# Joint Prediction on Grid +preds_joint_grid = predictive_joint(jr.key(4), X=X, Y=y, X_new=X_grid)["y_pred"] +mean_pred_joint_grid = jnp.mean(preds_joint_grid, axis=0) + +# Plotting +fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True) + +# Truth +c0 = axes[0].tricontourf(X_grid[:,0], X_grid[:,1], y_grid_true, levels=20, cmap='viridis') +axes[0].set_title("True Signal") +plt.colorbar(c0, ax=axes[0]) + +# Linear +c1 = axes[1].tricontourf(X_grid[:,0], X_grid[:,1], mean_pred_lin_grid.flatten(), levels=20, cmap='viridis') +axes[1].set_title(f"Linear Model (RMSE: {rmse_lin:.2f})") +plt.colorbar(c1, ax=axes[1]) + +# Joint +c2 = axes[2].tricontourf(X_grid[:,0], X_grid[:,1], mean_pred_joint_grid.flatten(), levels=20, cmap='viridis') +axes[2].set_title(f"Joint Model (RMSE: {rmse_joint:.2f})") +plt.colorbar(c2, ax=axes[2]) + +for ax in axes: + ax.set_xlabel("x1") + ax.set_ylabel("x2") + ax.scatter(X[:,0], X[:,1], c='k', s=10, alpha=0.3, label="Data") + +plt.tight_layout() diff --git a/mkdocs.yml b/mkdocs.yml index 5cf03b86d..6622abbff 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,6 +31,7 @@ nav: - Log-Gaussian Cox Process: _examples/lgcp_numpyro.md - ๐Ÿงช Experimental: - Numpyro Integration: _examples/numpyro_integration.md + - Spatial Linear GP: _examples/spatial_linear_gp.md - ๐Ÿ“– Guides for customisation: - Kernels: _examples/constructing_new_kernels.md - Likelihoods: _examples/likelihoods_guide.md