From 436233f29dfa628aaa70c9b909f992e95648227e Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 23 Nov 2025 15:23:38 +0100 Subject: [PATCH 1/6] Add heteroscedastic implementation --- .github/workflows/pr_greeting.yml | 62 ---- .gitignore | 5 + README.md | 2 +- examples/heteroscedastic_inference.py | 394 +++++++++++++++++++++++++ examples/regression.py | 47 +-- gpjax/__init__.py | 2 +- gpjax/citation.py | 14 + gpjax/gps.py | 77 +++++ gpjax/likelihoods.py | 233 +++++++++++++++ gpjax/objectives.py | 55 +++- gpjax/parameters.py | 9 +- gpjax/variational_families.py | 129 +++++++++ mkdocs.yml | 1 + pyproject.toml | 2 +- tests/integration_tests.py | 21 +- tests/test_citations.py | 18 ++ tests/test_heteroscedastic.py | 400 ++++++++++++++++++++++++++ 17 files changed, 1370 insertions(+), 101 deletions(-) delete mode 100644 .github/workflows/pr_greeting.yml create mode 100644 examples/heteroscedastic_inference.py create mode 100644 tests/test_heteroscedastic.py diff --git a/.github/workflows/pr_greeting.yml b/.github/workflows/pr_greeting.yml deleted file mode 100644 index d20125ba0..000000000 --- a/.github/workflows/pr_greeting.yml +++ /dev/null @@ -1,62 +0,0 @@ ---- -name: PR Greetings - -on: [pull_request_target] - -permissions: - pull-requests: write - -jobs: - greeting: - runs-on: ubuntu-latest - - steps: - - uses: actions/first-interaction@v3.1.0 - with: - repo_token: ${{ secrets.GITHUB_TOKEN }} - issue_message: | - Thank you for opening your first issue into GPJax! - - If you have not heard from us in a while, please feel free to ping - `@gpjax/developers` or anyone who has commented on the PR. - Most of our reviewers are volunteers and sometimes things fall - through the cracks. - - - You can also join us [on - Slack](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw) for real-time - discussion. - - - For details on testing, writing docs, and our review process, - please see [the developer - guide](https://docs.jaxgaussianprocesses.com/contributing/) - - - We strive to be a welcoming and open project. Please follow our - [Code of - Conduct](https://github.com/thomaspinder/GPJax/blob/main/.github/CODE_OF_CONDUCT.md). - - pr_message: | - Thank you for opening your first PR into GPJax! - - - If you have not heard from us in a while, please feel free to ping - `@gpjax/developers` or anyone who has commented on the PR. - Most of our reviewers are volunteers and sometimes things fall - through the cracks. - - - You can also join us [on - Slack](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw) for real-time - discussion. - - - For details on testing, writing docs, and our review process, - please see [the developer - guide](https://docs.jaxgaussianprocesses.com/contributing/) - - - We strive to be a welcoming and open project. Please follow our - [Code of - Conduct](https://github.com/thomaspinder/GPJax/blob/main/.github/CODE_OF_CONDUCT.md). diff --git a/.gitignore b/.gitignore index e51d8509c..1dcb8bf78 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,8 @@ node_modules/ docs/api docs/_examples +local_libs/ +local_papers/ +GEMINI.md +AGENTS.md +plans/ \ No newline at end of file diff --git a/README.md b/README.md index 2104be23a..e71a792bd 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ GPJax into the package it is today. > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation) > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel) > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/) -> - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/) +> - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/) > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/) > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/) > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/) diff --git a/examples/heteroscedastic_inference.py b/examples/heteroscedastic_inference.py new file mode 100644 index 000000000..0807c7706 --- /dev/null +++ b/examples/heteroscedastic_inference.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.17.3 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Heteroscedastic inference for regression and classification +# +# This notebook shows how to fit heteroscedastic Gaussian processes (GPs) for two +# everyday tasks: +# +# - Regression with input-dependent noise using the tight Lázaro-Gredilla & Titsias +# (2011) bound. +# - Classification with input-dependent label noise using the generic chained bound. +# +# +# ## Background +# A heteroscedastic GP couples two latent functions: +# - A **signal GP** $f(\cdot)$ for the mean response. +# - A **noise GP** $g(\cdot)$ that maps to a positive variance +# $\sigma^2(x) = \phi(g(x))$ via a positivity transform $\phi$ (typically +# ${\rm exp}$ or ${\rm softplus}$). Intuitively, we are introducing a pair of GPs; +# one to model the latent mean, and a second that models the log-noise variance. This +# is in direct contrast a +# [homoscedastic GP](https://docs.jaxgaussianprocesses.com/_examples/regression/) +# where we learn a constant value for the noise. +# +# In the Gaussian case, the observed targets follow +# $$y \mid f, g \sim \mathcal{N}(f, \sigma^2(x)).$$ +# Variational inference works with independent posteriors $q(f)q(g)$, combining the +# moments of each into an ELBO. For non-Gaussian likelihoods the same structure +# remains; only the expected log-likelihood changes. + +# %% +from jax import config +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt +import matplotlib as mpl +import optax as ox + +from examples.utils import use_mpl_style +import gpjax as gpx +from gpjax.likelihoods import ( + HeteroscedasticGaussian, + LogNormalTransform, + SoftplusTransform, +) +from gpjax.objectives import heteroscedastic_elbo +from gpjax.variational_families import ( + HeteroscedasticVariationalFamily, + VariationalGaussianInit, +) + +# Enable Float64 for stable linear algebra. +config.update("jax_enable_x64", True) + + +use_mpl_style() +key = jr.key(0) +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + + +# %% [markdown] +# ## Dataset simulation +# We synthesise a regression dataset whose mean structure and noise level vary with +# the input. Specifically, we sample inputs $x \sim \mathcal{U}(0, 1)$ and define the +# latent signal to be +# $$f(x) = (x - 0.5)^2 + 0.05,$$ +# a smooth bowl-shaped curve. The observation standard deviation is chosen to be +# proportional to the signal, +# $$\sigma(x) = 0.5\,f(x),$$ +# which yields the heteroscedastic generative model +# $$y \mid x \sim \mathcal{N}\!\big(f(x), \sigma^2(x)\big).$$ +# This construction makes the noise small near the minimum of the bowl and much +# larger in the tails. We also create a dense test grid that we shall use later for +# visualising posterior fits and predictive uncertainty. + +# %% +# Create data with input-dependent variance. +key, x_key, noise_key = jr.split(key, 3) +n = 200 +x = jr.uniform(x_key, (n, 1), minval=0.0, maxval=1.0) +signal = (x - 0.5) ** 2 + 0.05 +noise_scale = 0.5 * signal +noise = noise_scale * jr.normal(noise_key, shape=(n, 1)) +y = signal + noise +train = gpx.Dataset(X=x, y=y) + +xtest = jnp.linspace(-0.1, 1.1, 200)[:, None] +signal_test = (xtest - 0.5) ** 2 + 0.05 +noise_scale_test = 0.5 * signal_test +noise_test = noise_scale_test * jr.normal(noise_key, shape=(200, 1)) +ytest = signal_test + noise_test + +fig, ax = plt.subplots() +ax.plot(x, y, "o", label="Observations", alpha=0.7, color=cols[0]) +ax.plot(xtest, signal_test, label="Signal", alpha=0.7, color=cols[1]) +ax.plot(xtest, noise_scale_test, label="Noise scale", alpha=0.7, color=cols[2]) +ax.set_xlabel("$x$") +ax.set_ylabel("$y$") +ax.legend(loc="upper left") + +# %% [markdown] +# For a homoscedastic baseline, compare this figure with the +# [Gaussian process regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/) +# (`examples/regression.py`), where a single latent GP is paired with constant +# observation noise. + +# %% [markdown] +# ## Prior specification +# We place independent Gaussian process priors on the signal and noise processes: +# $$f \sim \mathcal{GP}\big(0, k_f\big), \qquad g \sim \mathcal{GP}\big(0, k_g\big),$$ +# where $k_f$ and $k_g$ are stationary squared-exponential kernels with unit +# variance and lengthscale of one. The noise process $g$ is mapped to the variance +# via the logarithmic transform in `LogNormalTransform`, giving +# $\sigma^2(x) = \exp\big(g(x)\big)$. The joint prior over $(f, g)$ combines with +# the heteroscedastic Gaussian likelihood, +# $$p(\mathbf{y} \mid f, g) = \prod_{i=1}^n +# \mathcal{N}\!\big(y_i \mid f(x_i), \exp(g(x_i))\big),$$ +# to form the posterior target that we shall approximate variationally. The product +# syntax `signal_prior * likelihood` used below constructs this augmented GP model. + +# %% +# Signal and noise priors. +signal_prior = gpx.gps.Prior( + mean_function=gpx.mean_functions.Zero(), + kernel=gpx.kernels.RBF(), +) +noise_prior = gpx.gps.Prior( + mean_function=gpx.mean_functions.Zero(), + kernel=gpx.kernels.RBF(), +) +likelihood = HeteroscedasticGaussian( + num_datapoints=train.n, + noise_prior=noise_prior, + noise_transform=LogNormalTransform(), +) +posterior = signal_prior * likelihood + +# Variational family over both processes. +z = jnp.linspace(-3.2, 3.2, 25)[:, None] +q = HeteroscedasticVariationalFamily( + posterior=posterior, + inducing_inputs=z, + inducing_inputs_g=z, +) + +# %% [markdown] +# The variational family introduces inducing variables for both latent functions, +# located at the set $Z = \{z_m\}_{m=1}^M$. These inducing variables summarise the +# infinite-dimensional GP priors in terms of multivariate Gaussian parameters. +# Optimising the evidence lower bound (ELBO) corresponds to adjusting the means and +# covariances of the variational posteriors $q(f)$ and $q(g)$ so that they best +# explain the observed data whilst remaining close to the prior. For a deeper look at +# these constructions in the homoscedastic setting, refer to the +# [Sparse Gaussian Process Regression](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/) +# (`examples/collapsed_vi.py`) and +# [Sparse Stochastic Variational Inference](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/) +# (`examples/uncollapsed_vi.py`) notebooks. + +# %% [markdown] +# ### Optimisation +# With the model specified, we minimise the negative ELBO, +# $$\mathcal{L} = \mathbb{E}_{q(f)q(g)}\!\big[\log p(\mathbf{y}\mid f, g)\big] +# - \mathrm{KL}\!\left[q(f) \,\|\, p(f)\right] +# - \mathrm{KL}\!\left[q(g) \,\|\, p(g)\right],$$ +# using the Adam optimiser. GPJax automatically selects the tight bound of +# Lázaro-Gredilla & Titsias (2011) when the likelihood is Gaussian, yielding an +# analytically tractable expectation over the latent noise process. The resulting +# optimisation iteratively updates the inducing posteriors for both latent GPs. + +# Optimise the heteroscedastic ELBO (selects LGT bound). +objective = lambda model, data: -heteroscedastic_elbo(model, data) +optimiser = ox.adam(1e-2) +q_trained, history = gpx.fit( + model=q, + objective=objective, + train_data=train, + optim=optimiser, + num_iters=10000, + verbose=False, +) + +loss_trace = jnp.asarray(history) +print(f"Final regression ELBO: {-loss_trace[-1]:.3f}") + +# %% [markdown] +# ## Prediction +# After training we obtain posterior marginals for both latent functions. To make a +# prediction we evaluate two quantities: +# 1. The latent posterior over $f$ (mean and variance), which reflects uncertainty +# in the latent function **prior** to observing noise. +# 2. The marginal predictive over observations, which integrates out both $f$ and +# $g$ to provide predictive intervals for future noisy measurements. +# The helper method `likelihood.predict` performs the second integration for us. + +# %% +# Predict on a dense grid. +xtest = jnp.linspace(-0.1, 1.1, 200)[:, None] +mf, vf, mg, vg = q_trained.predict(xtest) + +signal_pred, noise_pred = q_trained.predict_latents(xtest) +predictive = likelihood.predict(signal_pred, noise_pred) + +fig, ax = plt.subplots(figsize=(7, 4)) +ax.plot(train.X, train.y, "o", label="Observations", alpha=0.5) +ax.plot(xtest, mf, color="C0", label="Posterior mean") +ax.fill_between( + xtest.squeeze(), + (mf.squeeze() - 2 * jnp.sqrt(vf.squeeze())).squeeze(), + (mf.squeeze() + 2 * jnp.sqrt(vf.squeeze())).squeeze(), + color="C0", + alpha=0.15, + label="±2 std (latent)", +) +ax.fill_between( + xtest.squeeze(), + predictive.mean - 2 * jnp.sqrt(jnp.diag(predictive.covariance_matrix)), + predictive.mean + 2 * jnp.sqrt(jnp.diag(predictive.covariance_matrix)), + color="C1", + alpha=0.15, + label="±2 std (observed)", +) +ax.set_xlabel("$x$") +ax.set_ylabel("$y$") +ax.legend(loc="upper left") +ax.set_title("Heteroscedastic regression") +plt.show() + +# %% [markdown] +# The latent intervals quantify epistemic uncertainty about $f$, whereas the broader +# observed band adds the aleatoric noise predicted by $g$. The widening of the orange +# band in the right half matches the ground-truth construction of the dataset. + +# %% [markdown] +# ## Sparse Heteroscedastic Regression +# +# We now demonstrate how the aforementioned heteroscedastic approach can be extended +# into sparse scenarios, thus offering more favourable scalability as the size of our +# dataset grows. To achieve this we defined inducing points for both the signal and +# noise processes. Decoupling these grids allows us to focus modelling +# capacity where each latent function varies the most. The synthetic dataset below +# contains a smooth sinusoidal signal but exhibits a sharply peaked noise shock, +# mimicking the situation where certain regions of the input space are far noisier +# than others. + +# %% +# Generate data +key, x_key, noise_key = jr.split(key, 3) +n = 300 +x = jr.uniform(x_key, (n, 1), minval=-2.0, maxval=2.0) +signal = jnp.sin(2.0 * x) +# Gaussian bump of noise +noise_std = 0.1 + 0.5 * jnp.exp(-0.5 * ((x - 0.5) / 0.4) ** 2) +y = signal + noise_std * jr.normal(noise_key, shape=(n, 1)) +data_adv = gpx.Dataset(X=x, y=y) + +# %% [markdown] +# ### Model components +# We again adopt RBF priors for both processes but now apply a `SoftplusTransform` +# to the noise GP. This alternative map enforces positivity whilst avoiding the +# heavier tails induced by the log-normal transform. The `HeteroscedasticGaussian` +# likelihood seamlessly accepts the new transform. + +# %% +# Define model components +mean_prior = gpx.gps.Prior( + mean_function=gpx.mean_functions.Zero(), + kernel=gpx.kernels.RBF(), +) +noise_prior_adv = gpx.gps.Prior( + mean_function=gpx.mean_functions.Zero(), + kernel=gpx.kernels.RBF(), +) +likelihood_adv = HeteroscedasticGaussian( + num_datapoints=data_adv.n, + noise_prior=noise_prior_adv, + noise_transform=SoftplusTransform(), +) +posterior_adv = mean_prior * likelihood_adv + +# %% +# Configure variational family +# The signal requires a richer inducing set to capture its oscillations, whereas the +# noise process can be summarised with fewer points because the burst is localised. +z_signal = jnp.linspace(-2.0, 2.0, 30)[:, None] +z_noise = jnp.linspace(-2.0, 2.0, 15)[:, None] + +# Use VariationalGaussianInit to pass specific configurations +q_init_f = VariationalGaussianInit(inducing_inputs=z_signal) +q_init_g = VariationalGaussianInit(inducing_inputs=z_noise) + +q_adv = HeteroscedasticVariationalFamily( + posterior=posterior_adv, + signal_init=q_init_f, + noise_init=q_init_g, +) + +# %% [markdown] +# The initialisation objects `VariationalGaussianInit` allow us to prescribe +# different inducing grids and initial covariance structures for $f$ and $g$. This +# flexibility is invaluable when working with large datasets where the latent +# functions have markedly different smoothness properties. + +# %% +# Optimize +objective_adv = lambda model, data: -heteroscedastic_elbo(model, data) +optimiser_adv = ox.adam(1e-2) +q_adv_trained, _ = gpx.fit( + model=q_adv, + objective=objective_adv, + train_data=data_adv, + optim=optimiser_adv, + num_iters=8000, + verbose=False, +) + +# %% +# Plotting +xtest = jnp.linspace(-2.2, 2.2, 200)[:, None] +pred = q_adv_trained.predict(xtest) + +# Unpack the named tuple +mf = pred.mean_f +vf = pred.variance_f +mg = pred.mean_g +vg = pred.variance_g + +# Calculate total predictive variance +# The likelihood expects the *latent* noise distribution to compute the predictive +# but here we can just use the transformed expected variance for plotting. +# For accurate predictive intervals, we should use likelihood.predict. +signal_dist, noise_dist = q_adv_trained.predict_latents(xtest) +predictive_dist = likelihood_adv.predict(signal_dist, noise_dist) +predictive_mean = predictive_dist.mean +predictive_std = jnp.sqrt(jnp.diag(predictive_dist.covariance_matrix)) + +fig, ax = plt.subplots(figsize=(7, 4)) +ax.plot(x, y, "o", color="black", alpha=0.3, label="Data") +ax.plot(xtest, mf, color="C0", label="Signal Mean") +ax.fill_between( + xtest.squeeze(), + mf.squeeze() - 2 * jnp.sqrt(vf.squeeze()), + mf.squeeze() + 2 * jnp.sqrt(vf.squeeze()), + color="C0", + alpha=0.2, + label="Signal Uncertainty", +) + +# Plot total uncertainty (signal + noise) +ax.plot(xtest, predictive_mean, "--", color="C1", alpha=0.5) +ax.fill_between( + xtest.squeeze(), + predictive_mean - 2 * predictive_std, + predictive_mean + 2 * predictive_std, + color="C1", + alpha=0.1, + label="Predictive Uncertainty (95%)", +) + +ax.set_title("Heteroscedastic Regression with Custom Inducing Points") +ax.legend(loc="upper left", fontsize="small") +plt.show() + +# %% [markdown] +# ## Takeaways +# - The heteroscedastic GP model couples two latent GPs, enabling separate control of +# epistemic and aleatoric uncertainties. +# - We support multiple positivity transforms for the noise process; the choice +# affects the implied variance tails and should reflect prior beliefs. +# - Inducing points for the signal and noise processes can be tuned independently to +# balance computational budget against the local complexity of each function. +# - The ELBO implementation automatically selects the tightest analytical bound +# available, streamlining heteroscedastic inference workflows. + +# %% [markdown] +# ## System configuration + +# %% +# %reload_ext watermark +# %watermark -n -u -v -iv -w -a 'Thomas Pinder' diff --git a/examples/regression.py b/examples/regression.py index 733e7dc60..c4c6b7fe5 100644 --- a/examples/regression.py +++ b/examples/regression.py @@ -29,7 +29,6 @@ import matplotlib.pyplot as plt from examples.utils import ( - clean_legend, use_mpl_style, ) @@ -129,26 +128,26 @@ # %% # %% [markdown] -prior_dist = prior.predict(xtest, return_covariance_type="dense") - -prior_mean = prior_dist.mean -prior_std = prior_dist.variance -samples = prior_dist.sample(key=key, sample_shape=(20,)) - - -fig, ax = plt.subplots() -ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples") -ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean") -ax.fill_between( - xtest.flatten(), - prior_mean - prior_std, - prior_mean + prior_std, - alpha=0.3, - color=cols[1], - label="Prior variance", -) -ax.legend(loc="best") -ax = clean_legend(ax) +# prior_dist = prior.predict(xtest, return_covariance_type="dense") +# +# prior_mean = prior_dist.mean +# prior_std = prior_dist.variance +# samples = prior_dist.sample(key=key, sample_shape=(20,)) +# +# +# fig, ax = plt.subplots() +# ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples") +# ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean") +# ax.fill_between( +# xtest.flatten(), +# prior_mean - prior_std, +# prior_mean + prior_std, +# alpha=0.3, +# color=cols[1], +# label="Prior variance", +# ) +# ax.legend(loc="best") +# ax = clean_legend(ax) # %% [markdown] # ## Constructing the posterior @@ -217,13 +216,15 @@ # this, we use our defined `posterior` and `likelihood` at our test inputs to obtain # the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean` # and `stddev` can be used to extract the predictive mean and standard deviatation. -# +# # We are only concerned here about the variance between the test points and themselves, so # we can just copute the diagonal version of the covariance. We enforce this by using # `return_covariance_type = "diagonal"` in the `predict` call. # %% -latent_dist = opt_posterior.predict(xtest, train_data=D, return_covariance_type="diagonal") +latent_dist = opt_posterior.predict( + xtest, train_data=D, return_covariance_type="diagonal" +) predictive_dist = opt_posterior.likelihood(latent_dist) predictive_mean = predictive_dist.mean diff --git a/gpjax/__init__.py b/gpjax/__init__.py index d1023c26d..ce7302e39 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -40,7 +40,7 @@ __description__ = "Gaussian processes in JAX and Flax" __url__ = "https://github.com/thomaspinder/GPJax" __contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors" -__version__ = "0.13.3" +__version__ = "0.13.4" __all__ = [ "gps", diff --git a/gpjax/citation.py b/gpjax/citation.py index 3dc62a971..c42e981c1 100644 --- a/gpjax/citation.py +++ b/gpjax/citation.py @@ -24,6 +24,8 @@ Matern52, ) +from gpjax.likelihoods import HeteroscedasticGaussian + CitationType = Union[None, str, Dict[str, str]] @@ -149,3 +151,15 @@ def _(tree) -> PaperCitation: booktitle="Advances in neural information processing systems", citation_type="article", ) + + +@cite.register(HeteroscedasticGaussian) +def _(tree) -> PaperCitation: + return PaperCitation( + citation_key="lazaro2011variational", + authors="Lázaro-Gredilla, Miguel and Titsias, Michalis", + title="Variational heteroscedastic Gaussian process regression", + year="2011", + booktitle="Proceedings of the 28th International Conference on Machine Learning (ICML)", + citation_type="inproceedings", + ) diff --git a/gpjax/gps.py b/gpjax/gps.py index 2a7e103dc..5690f8c03 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -32,8 +32,10 @@ from gpjax.kernels import RFF from gpjax.kernels.base import AbstractKernel from gpjax.likelihoods import ( + AbstractHeteroscedasticLikelihood, AbstractLikelihood, Gaussian, + HeteroscedasticGaussian, NonGaussian, ) from gpjax.linalg import ( @@ -62,6 +64,7 @@ L = tp.TypeVar("L", bound=AbstractLikelihood) NGL = tp.TypeVar("NGL", bound=NonGaussian) GL = tp.TypeVar("GL", bound=Gaussian) +HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood) class AbstractPrior(nnx.Module, tp.Generic[M, K]): @@ -476,6 +479,22 @@ def predict( raise NotImplementedError +class LatentPosterior(AbstractPosterior[P, L]): + r"""A posterior shell used to expose prior structure without inference.""" + + def predict( + self, + test_inputs: Num[Array, "N D"], + train_data: Dataset, + *, + return_covariance_type: Literal["dense", "diagonal"] = "dense", + ) -> GaussianDistribution: + raise NotImplementedError( + "LatentPosteriors are a lightweight wrapper for priors and do not " + "implement predictive distributions. Use a variational family for inference." + ) + + class ConjugatePosterior(AbstractPosterior[P, GL]): r"""A Conjuate Gaussian process posterior object. @@ -839,6 +858,40 @@ def _return_diagonal_covariance( return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov) +class HeteroscedasticPosterior(LatentPosterior[P, HL]): + r"""Posterior shell for heteroscedastic likelihoods. + + The posterior retains both the signal and noise priors; inference is delegated + to variational families and specialised objectives. + """ + + def __init__( + self, + prior: AbstractPrior[M, K], + likelihood: HL, + jitter: float = 1e-6, + ): + if likelihood.noise_prior is None: + raise ValueError("Heteroscedastic likelihoods require a noise_prior.") + super().__init__(prior=prior, likelihood=likelihood, jitter=jitter) + self.noise_prior = likelihood.noise_prior + self.noise_posterior = LatentPosterior( + prior=self.noise_prior, likelihood=likelihood, jitter=jitter + ) + + +class ChainedPosterior(HeteroscedasticPosterior[P, HL]): + r"""Posterior routed for heteroscedastic likelihoods using chained bounds.""" + + def __init__( + self, + prior: AbstractPrior[M, K], + likelihood: HL, + jitter: float = 1e-6, + ): + super().__init__(prior=prior, likelihood=likelihood, jitter=jitter) + + ####################### # Utils ####################### @@ -854,6 +907,18 @@ def construct_posterior( # noqa: F811 ) -> NonConjugatePosterior[P, NGL]: ... +@tp.overload +def construct_posterior( # noqa: F811 + prior: P, likelihood: HeteroscedasticGaussian +) -> HeteroscedasticPosterior[P, HeteroscedasticGaussian]: ... + + +@tp.overload +def construct_posterior( # noqa: F811 + prior: P, likelihood: AbstractHeteroscedasticLikelihood +) -> ChainedPosterior[P, AbstractHeteroscedasticLikelihood]: ... + + def construct_posterior(prior, likelihood): # noqa: F811 r"""Utility function for constructing a posterior object from a prior and likelihood. The function will automatically select the correct posterior @@ -873,6 +938,15 @@ def construct_posterior(prior, likelihood): # noqa: F811 if isinstance(likelihood, Gaussian): return ConjugatePosterior(prior=prior, likelihood=likelihood) + if ( + isinstance(likelihood, HeteroscedasticGaussian) + and likelihood.supports_tight_bound() + ): + return HeteroscedasticPosterior(prior=prior, likelihood=likelihood) + + if isinstance(likelihood, AbstractHeteroscedasticLikelihood): + return ChainedPosterior(prior=prior, likelihood=likelihood) + return NonConjugatePosterior(prior=prior, likelihood=likelihood) @@ -911,7 +985,10 @@ def eval_fourier_features(test_inputs: Float[Array, "N D"]) -> Float[Array, "N L "AbstractPrior", "Prior", "AbstractPosterior", + "LatentPosterior", "ConjugatePosterior", "NonConjugatePosterior", + "HeteroscedasticPosterior", + "ChainedPosterior", "construct_posterior", ] diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index c6d5fb891..df712db47 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -10,15 +10,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from __future__ import annotations import abc +from dataclasses import dataclass import beartype.typing as tp from flax import nnx +import jax from jax import vmap +import jax.nn as jnn import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Float +import numpy as np import numpyro.distributions as npd from gpjax.distributions import GaussianDistribution @@ -36,6 +41,20 @@ ) +@dataclass +class NoiseMoments: + log_variance: Array + inv_variance: Array + variance: Array + + +jax.tree_util.register_pytree_node( + NoiseMoments, + lambda x: ((x.log_variance, x.inv_variance, x.variance), None), + lambda _, x: NoiseMoments(*x), +) + + class AbstractLikelihood(nnx.Module): r"""Abstract base class for likelihoods. @@ -103,6 +122,9 @@ def expected_log_likelihood( y: Float[Array, "N D"], mean: Float[Array, "N D"], variance: Float[Array, "N D"], + mean_g: tp.Optional[Float[Array, "N D"]] = None, + variance_g: tp.Optional[Float[Array, "N D"]] = None, + **_: tp.Any, ) -> Float[Array, " N"]: r"""Compute the expected log likelihood. @@ -116,6 +138,12 @@ def expected_log_likelihood( y (Float[Array, 'N D']): The observed response variable. mean (Float[Array, 'N D']): The variational mean. variance (Float[Array, 'N D']): The variational variance. + mean_g (Float[Array, 'N D']): Optional moments of the latent noise + process for heteroscedastic likelihoods. + variance_g (Float[Array, 'N D']): Optional moments of the latent noise + process for heteroscedastic likelihoods. + **_: Unused extra arguments for compatibility with specialised + likelihoods. Returns: ScalarFloat: The expected log likelihood. @@ -126,6 +154,142 @@ def expected_log_likelihood( ) +class AbstractNoiseTransform(nnx.Module): + """Abstract base class for noise transformations.""" + + @abc.abstractmethod + def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]: + """Transform the input noise signal.""" + raise NotImplementedError + + @abc.abstractmethod + def moments( + self, mean: Float[Array, "..."], variance: Float[Array, "..."] + ) -> NoiseMoments: + """Compute the moments of the transformed noise signal.""" + raise NotImplementedError + + +class LogNormalTransform(AbstractNoiseTransform): + """Log-normal noise transformation.""" + + def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]: + return jnp.exp(x) + + def moments( + self, mean: Float[Array, "..."], variance: Float[Array, "..."] + ) -> NoiseMoments: + expected_variance = jnp.exp(mean + 0.5 * variance) + expected_log_variance = mean + expected_inv_variance = jnp.exp(-mean + 0.5 * variance) + return NoiseMoments( + log_variance=expected_log_variance, + inv_variance=expected_inv_variance, + variance=expected_variance, + ) + + +class SoftplusTransform(AbstractNoiseTransform): + """Softplus noise transformation.""" + + def __init__(self, num_points: int = 20): + self.num_points = num_points + + def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]: + return jnn.softplus(x) + + def moments( + self, mean: Float[Array, "..."], variance: Float[Array, "..."] + ) -> NoiseMoments: + quad_x, quad_w = np.polynomial.hermite.hermgauss(self.num_points) + quad_w = jnp.asarray(quad_w / jnp.sqrt(jnp.pi)) + quad_x = jnp.asarray(quad_x) + + std = jnp.sqrt(variance) + samples = mean[..., None] + jnp.sqrt(2.0) * std[..., None] * quad_x + sigma2 = self(samples) + log_sigma2 = jnp.log(sigma2) + inv_sigma2 = 1.0 / sigma2 + + expected_variance = jnp.sum(sigma2 * quad_w, axis=-1) + expected_log_variance = jnp.sum(log_sigma2 * quad_w, axis=-1) + expected_inv_variance = jnp.sum(inv_sigma2 * quad_w, axis=-1) + + return NoiseMoments( + log_variance=expected_log_variance, + inv_variance=expected_inv_variance, + variance=expected_variance, + ) + + +class AbstractHeteroscedasticLikelihood(AbstractLikelihood): + r"""Base class for heteroscedastic likelihoods with latent noise processes.""" + + def __init__( + self, + num_datapoints: int, + noise_prior, + noise_transform: tp.Union[ + AbstractNoiseTransform, + tp.Callable[[Float[Array, "..."]], Float[Array, "..."]], + ] = SoftplusTransform(), + integrator: AbstractIntegrator = GHQuadratureIntegrator(), + ): + self.noise_prior = noise_prior + + if isinstance(noise_transform, AbstractNoiseTransform): + self.noise_transform = noise_transform + else: + transform_name = getattr(noise_transform, "__name__", "") + if noise_transform is jnp.exp or transform_name == "exp": + self.noise_transform = LogNormalTransform() + else: + # Default to SoftplusTransform for softplus or unknown callables (legacy behavior used quadrature) + # Note: If an unknown callable is passed, we technically use SoftplusTransform which applies softplus. + # Users should implement AbstractNoiseTransform for custom transforms. + self.noise_transform = SoftplusTransform() + + super().__init__(num_datapoints=num_datapoints, integrator=integrator) + + def __call__( + self, + dist: tp.Union[npd.MultivariateNormal, GaussianDistribution], + noise_dist: tp.Optional[ + tp.Union[npd.MultivariateNormal, GaussianDistribution] + ] = None, + ) -> npd.Distribution: + return self.predict(dist, noise_dist) + + def supports_tight_bound(self) -> bool: + """Return whether the tighter LGT bound is applicable.""" + return False + + def noise_statistics( + self, mean: Float[Array, "N D"], variance: Float[Array, "N D"] + ) -> NoiseMoments: + r"""Moment matching of the transformed noise process. + + Args: + mean: Mean of the latent noise GP. + variance: Variance of the latent noise GP. + + Returns: + NoiseMoments: Expected log variance, inverse variance, and variance. + """ + return self.noise_transform.moments(mean, variance) + + def expected_log_likelihood( + self, + y: Float[Array, "N D"], + mean: Float[Array, "N D"], + variance: Float[Array, "N D"], + mean_g: tp.Optional[Float[Array, "N D"]] = None, + variance_g: tp.Optional[Float[Array, "N D"]] = None, + **kwargs: tp.Any, + ) -> Float[Array, " N"]: + raise NotImplementedError + + class Gaussian(AbstractLikelihood): r"""Gaussian likelihood object.""" @@ -186,6 +350,69 @@ def predict( return npd.MultivariateNormal(dist.mean, noisy_cov) +class HeteroscedasticGaussian(AbstractHeteroscedasticLikelihood): + def predict( + self, + dist: tp.Union[npd.MultivariateNormal, GaussianDistribution], + noise_dist: tp.Optional[ + tp.Union[npd.MultivariateNormal, GaussianDistribution] + ] = None, + ) -> npd.MultivariateNormal: + if noise_dist is None: + raise ValueError( + "noise_dist must be provided for heteroscedastic prediction." + ) + + n_data = dist.event_shape[0] + noise_mean = noise_dist.mean + noise_variance = jnp.diag(noise_dist.covariance_matrix) + noise_stats = self.noise_statistics( + noise_mean[..., None], noise_variance[..., None] + ) + + cov = dist.covariance_matrix + noisy_cov = cov.at[jnp.diag_indices(n_data)].add(noise_stats.variance.squeeze()) + + return npd.MultivariateNormal(dist.mean, noisy_cov) + + def link_function(self, f: Float[Array, "..."]) -> npd.Normal: + sigma2 = self.noise_transform(jnp.zeros_like(f)) + return npd.Normal(loc=f, scale=jnp.sqrt(sigma2)) + + def expected_log_likelihood( + self, + y: Float[Array, "N D"], + mean: Float[Array, "N D"], + variance: Float[Array, "N D"], + mean_g: tp.Optional[Float[Array, "N D"]] = None, + variance_g: tp.Optional[Float[Array, "N D"]] = None, + noise_stats: tp.Optional[NoiseMoments] = None, + return_parts: bool = False, + **_: tp.Any, + ) -> tp.Union[Float[Array, " N"], tuple[Float[Array, " N"], NoiseMoments]]: + if mean_g is None or variance_g is None: + raise ValueError( + "mean_g and variance_g must be provided for heteroscedastic models." + ) + + if noise_stats is None: + noise_stats = self.noise_statistics(mean_g, variance_g) + sq_error = jnp.square(y - mean) + log2pi = jnp.log(2.0 * jnp.pi) + expected = -0.5 * ( + log2pi + + noise_stats.log_variance + + (sq_error + variance) * noise_stats.inv_variance + ) + expected_sum = jnp.sum(expected, axis=1) + if return_parts: + return expected_sum, noise_stats + return expected_sum + + def supports_tight_bound(self) -> bool: + return True + + class Bernoulli(AbstractLikelihood): def link_function(self, f: Float[Array, "..."]) -> npd.BernoulliProbs: r"""The probit link function of the Bernoulli likelihood. @@ -268,7 +495,13 @@ def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]: "AbstractLikelihood", "NonGaussian", "Gaussian", + "AbstractHeteroscedasticLikelihood", + "HeteroscedasticGaussian", "Bernoulli", "Poisson", "inv_probit", + "NoiseMoments", + "AbstractNoiseTransform", + "LogNormalTransform", + "SoftplusTransform", ] diff --git a/gpjax/objectives.py b/gpjax/objectives.py index 2872ff398..41d5d00f4 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -14,6 +14,9 @@ ConjugatePosterior, NonConjugatePosterior, ) +from gpjax.likelihoods import ( + AbstractHeteroscedasticLikelihood, +) from gpjax.linalg import ( Dense, lower_cholesky, @@ -25,9 +28,13 @@ Array, ScalarFloat, ) -from gpjax.variational_families import AbstractVariationalFamily +from gpjax.variational_families import ( + AbstractVariationalFamily, + HeteroscedasticVariationalFamily, +) VF = TypeVar("VF", bound=AbstractVariationalFamily) +HVF = TypeVar("HVF", bound=HeteroscedasticVariationalFamily) Objective = tpe.Callable[[nnx.Module, Dataset], ScalarFloat] @@ -414,3 +421,49 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat: # log N(y; μx, Io² + KxzKzz⁻¹Kzx) - 1/2o² tr(Kxx - KxzKzz⁻¹Kzx) return (two_log_prob - two_trace).squeeze() / 2.0 + + +def heteroscedastic_elbo_lgt(variational_family: HVF, data: Dataset) -> ScalarFloat: + r"""Tight LGT bound for heteroscedastic Gaussian likelihoods.""" + likelihood = variational_family.posterior.likelihood + mean_f, var_f, mean_g, var_g = variational_family.predict(data.X) + + expected_ll, _ = likelihood.expected_log_likelihood( + data.y, + mean_f, + var_f, + mean_g=mean_g, + variance_g=var_g, + return_parts=True, + ) + + scale = likelihood.num_datapoints / data.n + return scale * jnp.sum(expected_ll) - variational_family.prior_kl() + + +def heteroscedastic_elbo_chained(variational_family: HVF, data: Dataset) -> ScalarFloat: + r"""Generic chained bound for heteroscedastic likelihoods.""" + likelihood: AbstractHeteroscedasticLikelihood = ( + variational_family.posterior.likelihood + ) + mean_f, var_f, mean_g, var_g = variational_family.predict(data.X) + noise_stats = likelihood.noise_statistics(mean_g, var_g) + + expected_ll = likelihood.expected_log_likelihood( + data.y, + mean_f, + var_f, + mean_g=mean_g, + variance_g=var_g, + noise_stats=noise_stats, + ) + + scale = likelihood.num_datapoints / data.n + return scale * jnp.sum(expected_ll) - variational_family.prior_kl() + + +def heteroscedastic_elbo(variational_family: HVF, data: Dataset) -> ScalarFloat: + likelihood = variational_family.posterior.likelihood + if likelihood.supports_tight_bound(): + return heteroscedastic_elbo_lgt(variational_family, data) + return heteroscedastic_elbo_chained(variational_family, data) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index e1743ba18..71b587c5e 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -77,7 +77,14 @@ def __init__(self, value: T, tag: ParameterTag, **kwargs): _check_is_arraylike(value) super().__init__(value=jnp.asarray(value), **kwargs) - self.tag = tag + + # nnx.Variable metadata must be set via set_metadata (direct setattr is disallowed). + self.set_metadata(tag=tag) + + @property + def tag(self) -> ParameterTag: + """Return the parameter's constraint tag.""" + return self.metadata.get("tag", "real") class NonNegativeReal(Parameter[T]): diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 8f65258c6..0079f3e89 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -14,6 +14,7 @@ # ============================================================================== import abc +from dataclasses import dataclass import beartype.typing as tp from flax import nnx @@ -29,9 +30,12 @@ from gpjax.gps import ( AbstractPosterior, AbstractPrior, + ChainedPosterior, + HeteroscedasticPosterior, ) from gpjax.kernels.base import AbstractKernel from gpjax.likelihoods import ( + AbstractHeteroscedasticLikelihood, Gaussian, NonGaussian, ) @@ -59,8 +63,10 @@ L = tp.TypeVar("L", Gaussian, NonGaussian) NGL = tp.TypeVar("NGL", bound=NonGaussian) GL = tp.TypeVar("GL", bound=Gaussian) +HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood) P = tp.TypeVar("P", bound=AbstractPrior) PP = tp.TypeVar("PP", bound=AbstractPosterior) +HP = tp.TypeVar("HP", HeteroscedasticPosterior, ChainedPosterior) class AbstractVariationalFamily(nnx.Module, tp.Generic[L]): @@ -870,6 +876,126 @@ def predict( ) +@dataclass +class VariationalGaussianInit: + """Initialization parameters for a variational Gaussian distribution.""" + + inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]] + variational_mean: tp.Union[Float[Array, "N 1"], None] = None + variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None + + +class HeteroscedasticPrediction(tp.NamedTuple): + """Mean and variance of the signal and noise latent processes.""" + + mean_f: Float[Array, "N 1"] + variance_f: Float[Array, "N 1"] + mean_g: Float[Array, "N 1"] + variance_g: Float[Array, "N 1"] + + +class HeteroscedasticVariationalFamily(AbstractVariationalFamily[HL]): + r"""Variational family for two independent latent processes f and g.""" + + def __init__( + self, + posterior: HP, + inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]] = None, + inducing_inputs_g: tp.Union[ + Int[Array, "M D"], Float[Array, "M D"], None + ] = None, + variational_mean_f: tp.Union[Float[Array, "N 1"], None] = None, + variational_root_covariance_f: tp.Union[Float[Array, "N N"], None] = None, + variational_mean_g: tp.Union[Float[Array, "M 1"], None] = None, + variational_root_covariance_g: tp.Union[Float[Array, "M M"], None] = None, + jitter: ScalarFloat = 1e-6, + signal_init: tp.Optional[VariationalGaussianInit] = None, + noise_init: tp.Optional[VariationalGaussianInit] = None, + ): + self.jitter = jitter + + if signal_init is not None: + self.signal_variational = VariationalGaussian( + posterior=posterior, + inducing_inputs=signal_init.inducing_inputs, + variational_mean=signal_init.variational_mean, + variational_root_covariance=signal_init.variational_root_covariance, + jitter=jitter, + ) + elif inducing_inputs is not None: + self.signal_variational = VariationalGaussian( + posterior=posterior, + inducing_inputs=inducing_inputs, + variational_mean=variational_mean_f, + variational_root_covariance=variational_root_covariance_f, + jitter=jitter, + ) + else: + raise ValueError("Either signal_init or inducing_inputs must be provided.") + + if noise_init is not None: + self.noise_variational = VariationalGaussian( + posterior=posterior.noise_posterior, + inducing_inputs=noise_init.inducing_inputs, + variational_mean=noise_init.variational_mean, + variational_root_covariance=noise_init.variational_root_covariance, + jitter=jitter, + ) + else: + noise_inducing = ( + inducing_inputs if inducing_inputs_g is None else inducing_inputs_g + ) + if noise_inducing is None and signal_init is not None: + noise_inducing = signal_init.inducing_inputs + + if noise_inducing is None: + raise ValueError( + "Could not determine inducing inputs for noise process." + ) + + self.noise_variational = VariationalGaussian( + posterior=posterior.noise_posterior, + inducing_inputs=noise_inducing, + variational_mean=variational_mean_g, + variational_root_covariance=variational_root_covariance_g, + jitter=jitter, + ) + super().__init__(posterior) + + def prior_kl(self) -> ScalarFloat: + return self.signal_variational.prior_kl() + self.noise_variational.prior_kl() + + def predict( + self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]] + ) -> HeteroscedasticPrediction: + dist_f = self.signal_variational.predict(test_inputs) + dist_g = self.noise_variational.predict(test_inputs) + + mean_f = dist_f.mean[:, None] if dist_f.mean.ndim == 1 else dist_f.mean + var_f = ( + dist_f.variance[:, None] if dist_f.variance.ndim == 1 else dist_f.variance + ) + mean_g = dist_g.mean[:, None] if dist_g.mean.ndim == 1 else dist_g.mean + var_g = ( + dist_g.variance[:, None] if dist_g.variance.ndim == 1 else dist_g.variance + ) + + return HeteroscedasticPrediction( + mean_f=mean_f, + variance_f=var_f, + mean_g=mean_g, + variance_g=var_g, + ) + + def predict_latents( + self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]] + ) -> tuple[GaussianDistribution, GaussianDistribution]: + return ( + self.signal_variational.predict(test_inputs), + self.noise_variational.predict(test_inputs), + ) + + __all__ = [ "AbstractVariationalFamily", "AbstractVariationalGaussian", @@ -879,4 +1005,7 @@ def predict( "NaturalVariationalGaussian", "ExpectationVariationalGaussian", "CollapsedVariationalGaussian", + "HeteroscedasticVariationalFamily", + "VariationalGaussianInit", + "HeteroscedasticPrediction", ] diff --git a/mkdocs.yml b/mkdocs.yml index 177663be4..8e70d9d49 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -27,6 +27,7 @@ nav: - Sparse GPs: _examples/collapsed_vi.md - Stochastic sparse GPs: _examples/uncollapsed_vi.md - Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md + - Heteroscedastic Inference: _examples/heteroscedastic_inference.md - 📖 Guides for customisation: - Kernels: _examples/constructing_new_kernels.md - Likelihoods: _examples/likelihoods_guide.md diff --git a/pyproject.toml b/pyproject.toml index 26e4bb1b5..f332ee174 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,7 +140,7 @@ help = "Check code formatting and style" # Testing tasks [tool.poe.tasks.test] -cmd = "pytest . -v -n 8 --beartype-packages='gpjax'" +cmd = "pytest tests -v -n 8 --beartype-packages='gpjax'" help = "Run tests with pytest" [tool.poe.tasks.coverage] diff --git a/tests/integration_tests.py b/tests/integration_tests.py index d220633cb..8d07c104c 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -29,9 +29,6 @@ import jax.numpy as jnp # noqa: F401 import jupytext -# %% -import gpjax - # %% get_last = lambda x: x[-1] @@ -80,14 +77,7 @@ def test(self): contents = "\n".join([line for line in lines if not line.startswith("%")]) loc = {} - - # weird bug in interactive interpreter: lambda functions - # don't have access to the global scope of the executed file - # so we need to pass gpjax in the globals explicitly - # since it's used in a lambda function inside the examples - _globals = globals() - _globals["gpx"] = gpjax - exec(contents, _globals, loc) + exec(contents, loc) for k, v in self.comparisons.items(): truth, op = v self._compare( @@ -127,3 +117,12 @@ def test(self): }, ) stochastic.test() + +# %% +heteroscedastic = Result( + path="examples/heteroscedastic_inference.py", + comparisons={ + "history": (251.918, get_last), + }, +) +heteroscedastic.test() diff --git a/tests/test_citations.py b/tests/test_citations.py index 13396032d..016f2e539 100644 --- a/tests/test_citations.py +++ b/tests/test_citations.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import pytest +import gpjax as gpx from gpjax.citation import ( AbstractCitation, NullCitation, @@ -24,6 +25,10 @@ ) +from gpjax.likelihoods import HeteroscedasticGaussian +from gpjax.mean_functions import Zero + + def _check_no_fallback(citation: AbstractCitation): # Check the fallback has not been used assert repr(citation) != repr( @@ -94,3 +99,16 @@ def test_rff(kernel): def test_missing_citation(kernel): citation = cite(kernel) assert isinstance(citation, NullCitation) + + +def test_heteroscedastic_citation(): + noise_prior = gpx.gps.Prior(mean_function=Zero(), kernel=RBF()) + likelihood = HeteroscedasticGaussian(num_datapoints=10, noise_prior=noise_prior) + citation = cite(likelihood) + + assert isinstance(citation, PaperCitation) + assert citation.citation_key == "lazaro2011variational" + assert citation.title == "Variational heteroscedastic Gaussian process regression" + assert citation.authors == "Lázaro-Gredilla, Miguel and Titsias, Michalis" + assert citation.year == "2011" + _check_no_fallback(citation) diff --git a/tests/test_heteroscedastic.py b/tests/test_heteroscedastic.py new file mode 100644 index 000000000..50be2512a --- /dev/null +++ b/tests/test_heteroscedastic.py @@ -0,0 +1,400 @@ +# Copyright 2024 The GPJax Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from flax import nnx +import jax +from jax import config +import jax.numpy as jnp +import jax.random as jr +import pytest + +import gpjax as gpx +from gpjax.dataset import Dataset +from gpjax.gps import ( + ChainedPosterior, + HeteroscedasticPosterior, + Prior, + construct_posterior, +) +from gpjax.kernels import RBF +from gpjax.likelihoods import ( + HeteroscedasticGaussian, + LogNormalTransform, + NoiseMoments, + SoftplusTransform, +) +from gpjax.mean_functions import Zero +from gpjax.objectives import heteroscedastic_elbo +from gpjax.parameters import Parameter +from gpjax.variational_families import ( + HeteroscedasticPrediction, + HeteroscedasticVariationalFamily, + VariationalGaussianInit, +) + +config.update("jax_enable_x64", True) + + +# --- Fixtures --- + + +@pytest.fixture +def prior() -> Prior: + return Prior(kernel=RBF(), mean_function=Zero()) + + +@pytest.fixture +def noise_prior() -> Prior: + return Prior(kernel=RBF(), mean_function=Zero()) + + +@pytest.fixture +def dataset() -> Dataset: + x = jnp.linspace(-2.0, 2.0, 10)[:, None] + y = jnp.sin(x) + return Dataset(X=x, y=y) + + +class SoftplusHeteroscedastic(HeteroscedasticGaussian): + def supports_tight_bound(self) -> bool: + return False + + +# --- Likelihood Tests --- + + +def test_construct_posterior_routing(prior, noise_prior): + likelihood = HeteroscedasticGaussian(num_datapoints=5, noise_prior=noise_prior) + posterior = construct_posterior(prior=prior, likelihood=likelihood) + assert isinstance(posterior, HeteroscedasticPosterior) + assert posterior.noise_prior is noise_prior + + chained_likelihood = SoftplusHeteroscedastic( + num_datapoints=5, noise_prior=noise_prior + ) + chained_posterior = construct_posterior(prior=prior, likelihood=chained_likelihood) + assert isinstance(chained_posterior, ChainedPosterior) + assert chained_posterior.noise_prior is noise_prior + + +def test_likelihood_callable_compatibility(noise_prior): + # Test that passing jnp.exp uses LogNormalTransform + lik_exp = HeteroscedasticGaussian( + num_datapoints=10, noise_prior=noise_prior, noise_transform=jnp.exp + ) + assert isinstance(lik_exp.noise_transform, LogNormalTransform) + + # Test that passing a custom callable uses SoftplusTransform (default fallback logic) + def custom_transform(x): + return jnp.square(x) + + lik_custom = HeteroscedasticGaussian( + num_datapoints=10, noise_prior=noise_prior, noise_transform=custom_transform + ) + assert isinstance(lik_custom.noise_transform, SoftplusTransform) + + +def test_heteroscedastic_gaussian_validation(noise_prior, dataset): + lik = HeteroscedasticGaussian(num_datapoints=10, noise_prior=noise_prior) + # Construct a valid GaussianDistribution to satisfy jaxtyping + scale = gpx.linalg.Dense(jnp.eye(10)) + dist = gpx.distributions.GaussianDistribution(loc=jnp.zeros(10), scale=scale) + + # Test predict raises ValueError if noise_dist is None + with pytest.raises( + ValueError, match="noise_dist must be provided for heteroscedastic prediction" + ): + lik.predict(dist, noise_dist=None) + + # Test expected_log_likelihood raises ValueError if moments are None + with pytest.raises(ValueError, match="mean_g and variance_g must be provided"): + lik.expected_log_likelihood( + dataset.y, dataset.X, dataset.X, mean_g=None, variance_g=None + ) + + +# --- Transform Tests --- + + +def test_log_normal_transform_moments(): + transform = LogNormalTransform() + mean = jnp.array([[0.5], [1.0]]) + variance = jnp.array([[0.1], [0.2]]) + + moments = transform.moments(mean, variance) + + expected_variance = jnp.exp(mean + 0.5 * variance) + expected_log_variance = mean + expected_inv_variance = jnp.exp(-mean + 0.5 * variance) + + assert jnp.allclose(moments.variance, expected_variance) + assert jnp.allclose(moments.log_variance, expected_log_variance) + assert jnp.allclose(moments.inv_variance, expected_inv_variance) + + +def test_softplus_transform_numerical_accuracy(): + # Monte Carlo verification of SoftplusTransform moments + transform = SoftplusTransform(num_points=100) + mean = jnp.array([[0.5]]) + variance = jnp.array([[0.2]]) + + moments = transform.moments(mean, variance) + + key = jr.PRNGKey(42) + samples = mean + jnp.sqrt(variance) * jr.normal(key, (50000, 1)) + transformed_samples = jax.nn.softplus(samples) + + # E[sigma^2] + mc_variance = jnp.mean(transformed_samples) + # E[log(sigma^2)] + mc_log_variance = jnp.mean(jnp.log(transformed_samples)) + # E[1/sigma^2] + mc_inv_variance = jnp.mean(1.0 / transformed_samples) + + # Allow for some MC error and quadrature approximation error + rtol = 2e-2 + assert jnp.allclose(moments.variance, mc_variance, rtol=rtol) + assert jnp.allclose(moments.log_variance, mc_log_variance, rtol=rtol) + assert jnp.allclose(moments.inv_variance, mc_inv_variance, rtol=rtol) + + +# --- Variational Family Tests --- + + +def test_heteroscedastic_variational_predict(prior, noise_prior, dataset): + posterior = prior * HeteroscedasticGaussian( + num_datapoints=dataset.n, noise_prior=noise_prior + ) + variational = HeteroscedasticVariationalFamily( + posterior=posterior, inducing_inputs=dataset.X, inducing_inputs_g=dataset.X[::2] + ) + + mf, vf, mg, vg = variational.predict(dataset.X) + assert mf.shape == (dataset.n, 1) + assert vf.shape == (dataset.n, 1) + assert mg.shape == (dataset.n, 1) + assert vg.shape == (dataset.n, 1) + + kl = variational.prior_kl() + assert jnp.isfinite(kl) + + latent_f, latent_g = variational.predict_latents(dataset.X) + assert latent_f.mean.shape[0] == dataset.n + assert latent_g.mean.shape[0] == dataset.n + + +def test_variational_family_init_structure(prior, noise_prior): + likelihood = HeteroscedasticGaussian(num_datapoints=10, noise_prior=noise_prior) + posterior = HeteroscedasticPosterior(prior=prior, likelihood=likelihood) + + n_inducing = 5 + inducing_inputs = jnp.linspace(0, 1, n_inducing).reshape(-1, 1) + + signal_init = VariationalGaussianInit(inducing_inputs=inducing_inputs) + noise_inducing = jnp.linspace(0, 1, n_inducing).reshape(-1, 1) + 0.1 + noise_init = VariationalGaussianInit(inducing_inputs=noise_inducing) + + q = HeteroscedasticVariationalFamily( + posterior=posterior, signal_init=signal_init, noise_init=noise_init + ) + + assert jnp.allclose(q.signal_variational.inducing_inputs.value, inducing_inputs) + assert jnp.allclose(q.noise_variational.inducing_inputs.value, noise_inducing) + + # Test initialization inference (noise inferred from signal) + q_inferred = HeteroscedasticVariationalFamily( + posterior=posterior, signal_init=signal_init + ) + assert jnp.allclose( + q_inferred.noise_variational.inducing_inputs.value, inducing_inputs + ) + + +def test_variational_family_init_errors(prior, noise_prior): + likelihood = HeteroscedasticGaussian(num_datapoints=10, noise_prior=noise_prior) + posterior = HeteroscedasticPosterior(prior=prior, likelihood=likelihood) + + # Case 1: No inputs provided + with pytest.raises( + ValueError, match="Either signal_init or inducing_inputs must be provided" + ): + HeteroscedasticVariationalFamily(posterior=posterior) + + # Case 2: Cannot infer noise inducing inputs + # This is hard to trigger because if signal_init is provided, it falls back to signal_init.inducing_inputs. + # And if inducing_inputs is provided, it falls back to inducing_inputs. + # We need a case where we supply signal_init, but we want to force a failure in noise inference? + # Actually, the code says: + # if noise_inducing is None and signal_init is not None: noise_inducing = signal_init.inducing_inputs + # if noise_inducing is None: raise ValueError + # So if signal_init is passed, it's always safe. + # If inducing_inputs is passed, it's always safe. + # The only failure mode is if BOTH are None (caught by first check) + # OR if logic flows in a way that misses. + # Wait, if `inducing_inputs` is passed, `noise_inducing` becomes `inducing_inputs`. + # So effectively, if the first check passes, the second check should arguably not be reachable + # unless I pass `inducing_inputs_g=None` explicitly? No, default is None. + + +def test_variational_family_predict_return_type(prior, noise_prior): + likelihood = HeteroscedasticGaussian(num_datapoints=10, noise_prior=noise_prior) + posterior = HeteroscedasticPosterior(prior=prior, likelihood=likelihood) + + n_inducing = 5 + inducing_inputs = jnp.linspace(0, 1, n_inducing).reshape(-1, 1) + q = HeteroscedasticVariationalFamily( + posterior=posterior, inducing_inputs=inducing_inputs + ) + + test_inputs = jnp.linspace(0.5, 0.6, 3).reshape(-1, 1) + prediction = q.predict(test_inputs) + + assert isinstance(prediction, HeteroscedasticPrediction) + assert hasattr(prediction, "mean_f") + assert hasattr(prediction, "variance_f") + assert hasattr(prediction, "mean_g") + assert hasattr(prediction, "variance_g") + + # Check backward compatibility (unpacking) + mf, vf, mg, vg = prediction + assert jnp.allclose(mf, prediction.mean_f) + + +# --- Objective Tests --- + + +def test_heteroscedastic_elbo_gradients(dataset, prior, noise_prior): + def _build_variational(likelihood_cls: type[HeteroscedasticGaussian]): + likelihood = likelihood_cls(num_datapoints=dataset.n, noise_prior=noise_prior) + posterior = prior * likelihood + return HeteroscedasticVariationalFamily( + posterior=posterior, inducing_inputs=dataset.X + ) + + for likelihood_cls in (HeteroscedasticGaussian, SoftplusHeteroscedastic): + variational = _build_variational(likelihood_cls) + graphdef, params, *state = nnx.split(variational, Parameter, ...) + + def loss(p, graphdef=graphdef, state=state): + model = nnx.merge(graphdef, p, *state) + return -heteroscedastic_elbo(model, dataset) + + loss_val = loss(params) + loss_jit = jax.jit(loss)(params) + grads = jax.grad(loss)(params) + + assert jnp.isfinite(loss_val) + assert jnp.isfinite(loss_jit) + assert isinstance(grads, nnx.State) + + # Verify correct bound usage (smoke test via return value logic) + # SoftplusHeteroscedastic forces chained bound. + # HeteroscedasticGaussian uses LGT bound. + # We verify they both run and return finite values. + + +# --- JIT Compatibility Tests --- + + +def test_jit_prediction(prior, noise_prior, dataset): + likelihood = HeteroscedasticGaussian( + num_datapoints=dataset.n, noise_prior=noise_prior + ) + posterior = prior * likelihood + q = HeteroscedasticVariationalFamily(posterior=posterior, inducing_inputs=dataset.X) + + # JIT compile the predict method + predict_jit = jax.jit(q.predict) + mf, vf, mg, vg = predict_jit(dataset.X) + + assert mf.shape == (dataset.n, 1) + assert jnp.isfinite(mf).all() + + # JIT compile transforms (testing low-level JIT) + log_transform = LogNormalTransform() + moments_fn = jax.jit(log_transform.moments) + + mu = jnp.array([[0.0]]) + var = jnp.array([[1.0]]) + moments = moments_fn(mu, var) + assert jnp.isfinite(moments.variance).all() + + # Test SoftplusTransform JIT + softplus_transform = SoftplusTransform(num_points=20) + moments_fn_soft = jax.jit(softplus_transform.moments) + moments_soft = moments_fn_soft(mu, var) + assert jnp.isfinite(moments_soft.variance).all() + + +def test_jit_likelihood_prediction(dataset, prior, noise_prior): + # Separate test for likelihood prediction to keep things clean + likelihood = HeteroscedasticGaussian( + num_datapoints=dataset.n, noise_prior=noise_prior + ) + + # JIT compile likelihood prediction + # We pass arrays and reconstruct distributions inside to ensure Pytree safety + def lik_predict(f_mean, f_cov, g_mean, g_cov): + f = gpx.distributions.GaussianDistribution(f_mean, gpx.linalg.Dense(f_cov)) + g = gpx.distributions.GaussianDistribution(g_mean, gpx.linalg.Dense(g_cov)) + return likelihood.predict(f, g).mean + + lik_predict_jit = jax.jit(lik_predict) + + cov = jnp.eye(dataset.n) + mu = jnp.zeros(dataset.n) + res = lik_predict_jit(mu, cov, mu, cov) + assert res.shape == (dataset.n,) + + +def test_predictive_variance_tracks_noise(prior, noise_prior): + x = jnp.array([[-1.0], [1.0]]) + likelihood = HeteroscedasticGaussian(num_datapoints=2, noise_prior=noise_prior) + posterior = prior * likelihood + + variational = HeteroscedasticVariationalFamily( + posterior=posterior, + inducing_inputs=x, + inducing_inputs_g=x, + variational_mean_g=jnp.array([[-1.0], [1.5]]), + ) + + signal_dist, noise_dist = variational.predict_latents(x) + predictive = likelihood.predict(signal_dist, noise_dist) + diag_cov = jnp.diag(predictive.covariance_matrix) + + assert diag_cov[1] > diag_cov[0] + + +def test_noise_moments_pytree_registration(): + # Explicitly test Pytree registration for NoiseMoments + nm = NoiseMoments( + log_variance=jnp.array([1.0]), + inv_variance=jnp.array([0.5]), + variance=jnp.array([2.0]), + ) + leaves, treedef = jax.tree_util.tree_flatten(nm) + + # Check structure + assert len(leaves) == 3 + assert leaves[0] is nm.log_variance + assert leaves[1] is nm.inv_variance + assert leaves[2] is nm.variance + + # Check unflattening + nm_restored = jax.tree_util.tree_unflatten(treedef, leaves) + assert isinstance(nm_restored, NoiseMoments) + assert jnp.allclose(nm_restored.log_variance, nm.log_variance) + assert jnp.allclose(nm_restored.inv_variance, nm.inv_variance) + assert jnp.allclose(nm_restored.variance, nm.variance) From e88777dd7628b37234721de61819ebcb656b7958 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 23 Nov 2025 19:44:24 +0100 Subject: [PATCH 2/6] Fix integration test --- examples/heteroscedastic_inference.py | 6 ++---- tests/integration_tests.py | 9 ++++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/heteroscedastic_inference.py b/examples/heteroscedastic_inference.py index 0807c7706..6e1160756 100644 --- a/examples/heteroscedastic_inference.py +++ b/examples/heteroscedastic_inference.py @@ -215,7 +215,7 @@ signal_pred, noise_pred = q_trained.predict_latents(xtest) predictive = likelihood.predict(signal_pred, noise_pred) -fig, ax = plt.subplots(figsize=(7, 4)) +fig, ax = plt.subplots() ax.plot(train.X, train.y, "o", label="Observations", alpha=0.5) ax.plot(xtest, mf, color="C0", label="Posterior mean") ax.fill_between( @@ -238,7 +238,6 @@ ax.set_ylabel("$y$") ax.legend(loc="upper left") ax.set_title("Heteroscedastic regression") -plt.show() # %% [markdown] # The latent intervals quantify epistemic uncertainty about $f$, whereas the broader @@ -348,7 +347,7 @@ predictive_mean = predictive_dist.mean predictive_std = jnp.sqrt(jnp.diag(predictive_dist.covariance_matrix)) -fig, ax = plt.subplots(figsize=(7, 4)) +fig, ax = plt.subplots() ax.plot(x, y, "o", color="black", alpha=0.3, label="Data") ax.plot(xtest, mf, color="C0", label="Signal Mean") ax.fill_between( @@ -373,7 +372,6 @@ ax.set_title("Heteroscedastic Regression with Custom Inducing Points") ax.legend(loc="upper left", fontsize="small") -plt.show() # %% [markdown] # ## Takeaways diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 8d07c104c..d251dc8a0 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -28,6 +28,7 @@ ) import jax.numpy as jnp # noqa: F401 import jupytext +import gpjax # %% get_last = lambda x: x[-1] @@ -77,7 +78,13 @@ def test(self): contents = "\n".join([line for line in lines if not line.startswith("%")]) loc = {} - exec(contents, loc) + # weird bug in interactive interpreter: lambda functions + # don't have access to the global scope of the executed file + # so we need to pass gpjax in the globals explicitly + # since it's used in a lambda function inside the examples + _globals = globals() + _globals["gpx"] = gpjax + exec(contents, _globals, loc) for k, v in self.comparisons.items(): truth, op = v self._compare( From d858ad67e1e6f937fba38d84bd1c27ad82e77bf4 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 23 Nov 2025 19:59:55 +0100 Subject: [PATCH 3/6] Fix integration tests in heteroscedastic example --- examples/heteroscedastic_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/heteroscedastic_inference.py b/examples/heteroscedastic_inference.py index 6e1160756..4c8b875e0 100644 --- a/examples/heteroscedastic_inference.py +++ b/examples/heteroscedastic_inference.py @@ -182,8 +182,9 @@ # analytically tractable expectation over the latent noise process. The resulting # optimisation iteratively updates the inducing posteriors for both latent GPs. +# %% # Optimise the heteroscedastic ELBO (selects LGT bound). -objective = lambda model, data: -heteroscedastic_elbo(model, data) +objective = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data) optimiser = ox.adam(1e-2) q_trained, history = gpx.fit( model=q, @@ -316,7 +317,7 @@ # %% # Optimize -objective_adv = lambda model, data: -heteroscedastic_elbo(model, data) +objective_adv = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data) optimiser_adv = ox.adam(1e-2) q_adv_trained, _ = gpx.fit( model=q_adv, From 8abe45f34bb749e785a5f51d7ca9034214b91b84 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 23 Nov 2025 21:44:30 +0100 Subject: [PATCH 4/6] Correct typing --- examples/heteroscedastic_inference.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/heteroscedastic_inference.py b/examples/heteroscedastic_inference.py index 4c8b875e0..434ee55c2 100644 --- a/examples/heteroscedastic_inference.py +++ b/examples/heteroscedastic_inference.py @@ -18,12 +18,9 @@ # %% [markdown] # # Heteroscedastic inference for regression and classification # -# This notebook shows how to fit heteroscedastic Gaussian processes (GPs) for two -# everyday tasks: -# -# - Regression with input-dependent noise using the tight Lázaro-Gredilla & Titsias -# (2011) bound. -# - Classification with input-dependent label noise using the generic chained bound. +# This notebook shows how to fit a heteroscedastic Gaussian processes (GPs) that +# allows one to perform regression where there exists non-constant, or +# input-dependent, noise. # # # ## Background @@ -37,7 +34,7 @@ # [homoscedastic GP](https://docs.jaxgaussianprocesses.com/_examples/regression/) # where we learn a constant value for the noise. # -# In the Gaussian case, the observed targets follow +# In the Gaussian case, the observed response follows # $$y \mid f, g \sim \mathcal{N}(f, \sigma^2(x)).$$ # Variational inference works with independent posteriors $q(f)q(g)$, combining the # moments of each into an ELBO. For non-Gaussian likelihoods the same structure @@ -69,16 +66,16 @@ use_mpl_style() -key = jr.key(0) +key = jr.key(123) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] # ## Dataset simulation -# We synthesise a regression dataset whose mean structure and noise level vary with -# the input. Specifically, we sample inputs $x \sim \mathcal{U}(0, 1)$ and define the +# We simulate whose mean and noise levels vary with +# the input. We sample inputs $x \sim \mathcal{U}(0, 1)$ and define the # latent signal to be -# $$f(x) = (x - 0.5)^2 + 0.05,$$ +# $$f(x) = (x - 0.5)^2 + 0.05;$$ # a smooth bowl-shaped curve. The observation standard deviation is chosen to be # proportional to the signal, # $$\sigma(x) = 0.5\,f(x),$$ From f86a3b943c201a61f709478dc31da5b76137c1e5 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 23 Nov 2025 22:09:30 +0100 Subject: [PATCH 5/6] Improve quality --- examples/heteroscedastic_inference.py | 2 +- gpjax/likelihoods.py | 3 +- gpjax/objectives.py | 8 +-- pyproject.toml | 1 + tests/conftest.py | 7 +++ tests/test_heteroscedastic.py | 74 ++++++++++++++------------- uv.lock | 23 +++++++++ 7 files changed, 78 insertions(+), 40 deletions(-) diff --git a/examples/heteroscedastic_inference.py b/examples/heteroscedastic_inference.py index 434ee55c2..f3aef835c 100644 --- a/examples/heteroscedastic_inference.py +++ b/examples/heteroscedastic_inference.py @@ -180,7 +180,7 @@ # optimisation iteratively updates the inducing posteriors for both latent GPs. # %% -# Optimise the heteroscedastic ELBO (selects LGT bound). +# Optimise the heteroscedastic ELBO (selects tighter bound). objective = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data) optimiser = ox.adam(1e-2) q_trained, history = gpx.fit( diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index df712db47..5fffa9822 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -261,7 +261,8 @@ def __call__( return self.predict(dist, noise_dist) def supports_tight_bound(self) -> bool: - """Return whether the tighter LGT bound is applicable.""" + """Return whether the tighter bound from Lázaro-Gredilla & Titsias (2011) + is applicable.""" return False def noise_statistics( diff --git a/gpjax/objectives.py b/gpjax/objectives.py index 41d5d00f4..050bca1af 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -423,8 +423,10 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat: return (two_log_prob - two_trace).squeeze() / 2.0 -def heteroscedastic_elbo_lgt(variational_family: HVF, data: Dataset) -> ScalarFloat: - r"""Tight LGT bound for heteroscedastic Gaussian likelihoods.""" +def heteroscedastic_elbo_conjugate( + variational_family: HVF, data: Dataset +) -> ScalarFloat: + r"""Tight bound from Lázaro-Gredilla & Titsias (2011) for heteroscedastic Gaussian likelihoods.""" likelihood = variational_family.posterior.likelihood mean_f, var_f, mean_g, var_g = variational_family.predict(data.X) @@ -465,5 +467,5 @@ def heteroscedastic_elbo_chained(variational_family: HVF, data: Dataset) -> Scal def heteroscedastic_elbo(variational_family: HVF, data: Dataset) -> ScalarFloat: likelihood = variational_family.posterior.likelihood if likelihood.supports_tight_bound(): - return heteroscedastic_elbo_lgt(variational_family, data) + return heteroscedastic_elbo_conjugate(variational_family, data) return heteroscedastic_elbo_chained(variational_family, data) diff --git a/pyproject.toml b/pyproject.toml index f332ee174..fad426a35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ dev-dependencies = [ "autoflake", "poethepoet>=0.37.0", "twine>=6.2.0", + "hypothesis>=6.148.2", ] diff --git a/tests/conftest.py b/tests/conftest.py index 4902e7d4c..382bc44ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ from jax import config +from hypothesis import settings from jaxtyping import install_import_hook config.update("jax_enable_x64", True) @@ -6,3 +7,9 @@ # import gpjax within import hook to apply beartype everywhere, before running tests with install_import_hook("gpjax", "beartype.beartype"): import gpjax # noqa: F401 + +settings.register_profile( + "gpjax-default", + settings(deadline=None, max_examples=20), +) +settings.load_profile("gpjax-default") diff --git a/tests/test_heteroscedastic.py b/tests/test_heteroscedastic.py index 50be2512a..5c164b110 100644 --- a/tests/test_heteroscedastic.py +++ b/tests/test_heteroscedastic.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import jax.random as jr import pytest +from hypothesis import given, settings, strategies as st import gpjax as gpx from gpjax.dataset import Dataset @@ -127,10 +128,11 @@ def test_heteroscedastic_gaussian_validation(noise_prior, dataset): # --- Transform Tests --- -def test_log_normal_transform_moments(): +@given(num_data=st.integers(min_value=1, max_value=100)) +def test_log_normal_transform_moments(num_data: int): transform = LogNormalTransform() - mean = jnp.array([[0.5], [1.0]]) - variance = jnp.array([[0.1], [0.2]]) + mean = jnp.array([[0.5] for _ in range(num_data)]) + variance = jnp.array([[0.1] for _ in range(num_data)]) moments = transform.moments(mean, variance) @@ -143,16 +145,26 @@ def test_log_normal_transform_moments(): assert jnp.allclose(moments.inv_variance, expected_inv_variance) -def test_softplus_transform_numerical_accuracy(): - # Monte Carlo verification of SoftplusTransform moments +@given( + mean=st.floats( + min_value=-2.0, + max_value=5.0, + allow_nan=False, + allow_infinity=False, + ), + variance=st.floats(min_value=1e-3, max_value=3.0, allow_nan=False), + seed=st.integers(min_value=0, max_value=2**32 - 1), +) +def test_softplus_transform_numerical_accuracy(mean: float, variance: float, seed: int): + # Monte Carlo verification of SoftplusTransform moments over a range of inputs transform = SoftplusTransform(num_points=100) - mean = jnp.array([[0.5]]) - variance = jnp.array([[0.2]]) + mean_array = jnp.array([[mean]]) + variance_array = jnp.array([[variance]]) - moments = transform.moments(mean, variance) + moments = transform.moments(mean_array, variance_array) - key = jr.PRNGKey(42) - samples = mean + jnp.sqrt(variance) * jr.normal(key, (50000, 1)) + key = jr.PRNGKey(seed) + samples = mean_array + jnp.sqrt(variance_array) * jr.normal(key, (100000, 1)) transformed_samples = jax.nn.softplus(samples) # E[sigma^2] @@ -163,7 +175,7 @@ def test_softplus_transform_numerical_accuracy(): mc_inv_variance = jnp.mean(1.0 / transformed_samples) # Allow for some MC error and quadrature approximation error - rtol = 2e-2 + rtol = 0.1 assert jnp.allclose(moments.variance, mc_variance, rtol=rtol) assert jnp.allclose(moments.log_variance, mc_log_variance, rtol=rtol) assert jnp.allclose(moments.inv_variance, mc_inv_variance, rtol=rtol) @@ -194,15 +206,27 @@ def test_heteroscedastic_variational_predict(prior, noise_prior, dataset): assert latent_g.mean.shape[0] == dataset.n -def test_variational_family_init_structure(prior, noise_prior): +@given( + n_inducing=st.integers(min_value=1, max_value=10), + offset=st.floats( + min_value=-0.5, + max_value=0.5, + allow_nan=False, + allow_infinity=False, + ), +) +def test_variational_family_init_structure(n_inducing: int, offset: float): + prior = Prior(kernel=RBF(), mean_function=Zero()) + noise_prior = Prior(kernel=RBF(), mean_function=Zero()) likelihood = HeteroscedasticGaussian(num_datapoints=10, noise_prior=noise_prior) posterior = HeteroscedasticPosterior(prior=prior, likelihood=likelihood) - n_inducing = 5 - inducing_inputs = jnp.linspace(0, 1, n_inducing).reshape(-1, 1) + inducing_inputs = jnp.linspace(0.0, 1.0, n_inducing, dtype=jnp.float64).reshape( + -1, 1 + ) signal_init = VariationalGaussianInit(inducing_inputs=inducing_inputs) - noise_inducing = jnp.linspace(0, 1, n_inducing).reshape(-1, 1) + 0.1 + noise_inducing = inducing_inputs + jnp.asarray(offset, dtype=jnp.float64) noise_init = VariationalGaussianInit(inducing_inputs=noise_inducing) q = HeteroscedasticVariationalFamily( @@ -231,21 +255,6 @@ def test_variational_family_init_errors(prior, noise_prior): ): HeteroscedasticVariationalFamily(posterior=posterior) - # Case 2: Cannot infer noise inducing inputs - # This is hard to trigger because if signal_init is provided, it falls back to signal_init.inducing_inputs. - # And if inducing_inputs is provided, it falls back to inducing_inputs. - # We need a case where we supply signal_init, but we want to force a failure in noise inference? - # Actually, the code says: - # if noise_inducing is None and signal_init is not None: noise_inducing = signal_init.inducing_inputs - # if noise_inducing is None: raise ValueError - # So if signal_init is passed, it's always safe. - # If inducing_inputs is passed, it's always safe. - # The only failure mode is if BOTH are None (caught by first check) - # OR if logic flows in a way that misses. - # Wait, if `inducing_inputs` is passed, `noise_inducing` becomes `inducing_inputs`. - # So effectively, if the first check passes, the second check should arguably not be reachable - # unless I pass `inducing_inputs_g=None` explicitly? No, default is None. - def test_variational_family_predict_return_type(prior, noise_prior): likelihood = HeteroscedasticGaussian(num_datapoints=10, noise_prior=noise_prior) @@ -298,11 +307,6 @@ def loss(p, graphdef=graphdef, state=state): assert jnp.isfinite(loss_jit) assert isinstance(grads, nnx.State) - # Verify correct bound usage (smoke test via return value logic) - # SoftplusHeteroscedastic forces chained bound. - # HeteroscedasticGaussian uses LGT bound. - # We verify they both run and return finite values. - # --- JIT Compatibility Tests --- diff --git a/uv.lock b/uv.lock index b81b9dd19..08b3dbc1a 100644 --- a/uv.lock +++ b/uv.lock @@ -845,6 +845,7 @@ dev = [ { name = "black" }, { name = "codespell" }, { name = "coverage" }, + { name = "hypothesis" }, { name = "interrogate" }, { name = "isort" }, { name = "jupytext" }, @@ -922,6 +923,7 @@ dev = [ { name = "black" }, { name = "codespell", specifier = ">=2.2.4" }, { name = "coverage", specifier = ">=7.2.2" }, + { name = "hypothesis", specifier = ">=6.148.2" }, { name = "interrogate", specifier = ">=1.5.0" }, { name = "isort" }, { name = "jupytext" }, @@ -960,6 +962,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/c7/316e7ca04d26695ef0635dc81683d628350810eb8e9b2299fc08ba49f366/humanize-4.13.0-py3-none-any.whl", hash = "sha256:b810820b31891813b1673e8fec7f1ed3312061eab2f26e3fa192c393d11ed25f", size = 128869, upload-time = "2025-08-25T09:39:18.54Z" }, ] +[[package]] +name = "hypothesis" +version = "6.148.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4a/99/a3c6eb3fdd6bfa01433d674b0f12cd9102aa99630689427422d920aea9c6/hypothesis-6.148.2.tar.gz", hash = "sha256:07e65d34d687ddff3e92a3ac6b43966c193356896813aec79f0a611c5018f4b1", size = 469984, upload-time = "2025-11-18T20:21:17.047Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/d2/c2673aca0127e204965e0e9b3b7a0e91e9b12993859ac8758abd22669b89/hypothesis-6.148.2-py3-none-any.whl", hash = "sha256:bf8ddc829009da73b321994b902b1964bcc3e5c3f0ed9a1c1e6a1631ab97c5fa", size = 536986, upload-time = "2025-11-18T20:21:15.212Z" }, +] + [[package]] name = "id" version = "1.5.0" @@ -3198,6 +3212,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "soupsieve" version = "2.8" From 83a98a29d561a3b70a9ce2ea8a86a875dc891dde Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 23 Nov 2025 22:11:54 +0100 Subject: [PATCH 6/6] Run linters --- examples/heteroscedastic_inference.py | 3 +-- gpjax/citation.py | 1 - tests/conftest.py | 2 +- tests/integration_tests.py | 1 + tests/test_citations.py | 2 -- tests/test_heteroscedastic.py | 5 ++++- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/heteroscedastic_inference.py b/examples/heteroscedastic_inference.py index f3aef835c..7b54d3533 100644 --- a/examples/heteroscedastic_inference.py +++ b/examples/heteroscedastic_inference.py @@ -44,8 +44,8 @@ from jax import config import jax.numpy as jnp import jax.random as jr -import matplotlib.pyplot as plt import matplotlib as mpl +import matplotlib.pyplot as plt import optax as ox from examples.utils import use_mpl_style @@ -55,7 +55,6 @@ LogNormalTransform, SoftplusTransform, ) -from gpjax.objectives import heteroscedastic_elbo from gpjax.variational_families import ( HeteroscedasticVariationalFamily, VariationalGaussianInit, diff --git a/gpjax/citation.py b/gpjax/citation.py index c42e981c1..3b23eb42f 100644 --- a/gpjax/citation.py +++ b/gpjax/citation.py @@ -23,7 +23,6 @@ Matern32, Matern52, ) - from gpjax.likelihoods import HeteroscedasticGaussian CitationType = Union[None, str, Dict[str, str]] diff --git a/tests/conftest.py b/tests/conftest.py index 382bc44ff..ece7dc188 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ -from jax import config from hypothesis import settings +from jax import config from jaxtyping import install_import_hook config.update("jax_enable_x64", True) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index d251dc8a0..9c2eeeb14 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -28,6 +28,7 @@ ) import jax.numpy as jnp # noqa: F401 import jupytext + import gpjax # %% diff --git a/tests/test_citations.py b/tests/test_citations.py index 016f2e539..761ee75e1 100644 --- a/tests/test_citations.py +++ b/tests/test_citations.py @@ -23,8 +23,6 @@ Matern32, Matern52, ) - - from gpjax.likelihoods import HeteroscedasticGaussian from gpjax.mean_functions import Zero diff --git a/tests/test_heteroscedastic.py b/tests/test_heteroscedastic.py index 5c164b110..6f1bbfae1 100644 --- a/tests/test_heteroscedastic.py +++ b/tests/test_heteroscedastic.py @@ -13,12 +13,15 @@ # limitations under the License. # ============================================================================== from flax import nnx +from hypothesis import ( + given, + strategies as st, +) import jax from jax import config import jax.numpy as jnp import jax.random as jr import pytest -from hypothesis import given, settings, strategies as st import gpjax as gpx from gpjax.dataset import Dataset