From e75b14b36cb6792a9be97fdac4f37df494e86c93 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Fri, 31 Oct 2025 20:36:51 -0600 Subject: [PATCH 01/19] updated linting --- gpjax/gps.py | 646 +++++++++++++++++++---------- gpjax/kernels/base.py | 2 +- gpjax/kernels/computations/base.py | 36 +- tests/test_gps.py | 167 +++++++- 4 files changed, 586 insertions(+), 265 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index fb7a90d22..cddce3ee5 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,16 @@ # limitations under the License. # ============================================================================== # from __future__ import annotations + from abc import abstractmethod +from typing import ( + Literal, + Tuple, +) import beartype.typing as tp from flax import nnx +import jax import jax.numpy as jnp import jax.random as jr from jaxtyping import ( @@ -29,16 +35,20 @@ from gpjax.kernels import RFF from gpjax.kernels.base import AbstractKernel from gpjax.likelihoods import ( - AbstractLikelihood, - Gaussian, - NonGaussian, + AbstractLikelihood, + Gaussian, + NonGaussian, ) from gpjax.linalg import ( Dense, + Diagonal, psd, solve, ) -from gpjax.linalg.operations import lower_cholesky +from gpjax.linalg.operations import ( + LinearOperator, + lower_cholesky, +) from gpjax.linalg.utils import add_jitter from gpjax.mean_functions import AbstractMeanFunction from gpjax.parameters import ( @@ -70,14 +80,19 @@ def __init__( r"""Construct a Gaussian process prior. Args: - kernel: kernel object inheriting from AbstractKernel. - mean_function: mean function object inheriting from AbstractMeanFunction. + kernel: kernel object inheriting from AbstractKernel. + mean_function: mean function object inheriting from AbstractMeanFunction. """ self.kernel = kernel self.mean_function = mean_function self.jitter = jitter - def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: + def __call__( + self, + test_inputs: Num[Array, "N D"], + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", + ) -> GaussianDistribution: r"""Evaluate the Gaussian process at the given points. The output of this function is a @@ -90,16 +105,28 @@ def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: `__call__` method and should instead define a `predict` method. Args: - test_inputs: Input locations where the GP should be evaluated. + test_inputs: Input locations where the GP should be evaluated. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ - return self.predict(test_inputs) + return self.predict( + test_inputs, + return_cov_type=return_cov_type, + ) @abstractmethod - def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: + def predict( + self, + test_inputs: Num[Array, "N D"], + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", + ) -> GaussianDistribution: r"""Evaluate the predictive distribution. Compute the latent function's multivariate normal distribution for a @@ -107,11 +134,15 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: this method must be implemented. Args: - test_inputs: Input locations where the GP should be evaluated. + test_inputs: Input locations where the GP should be evaluated. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ raise NotImplementedError @@ -135,10 +166,10 @@ class Prior(AbstractPrior[M, K]): Example: ```python - >>> import gpjax as gpx - >>> kernel = gpx.kernels.RBF() - >>> meanf = gpx.mean_functions.Zero() - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> import gpjax as gpx + >>> kernel = gpx.kernels.RBF() + >>> meanf = gpx.mean_functions.Zero() + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) ``` """ @@ -148,16 +179,16 @@ class Prior(AbstractPrior[M, K]): def __mul__(self, other: GL) -> "ConjugatePosterior[Prior[M, K], GL]": ... @tp.overload - def __mul__( # noqa: F811 + def __mul__( # noqa: F811 self, other: NGL ) -> "NonConjugatePosterior[Prior[M, K], NGL]": ... @tp.overload - def __mul__( # noqa: F811 + def __mul__( # noqa: F811 self, other: L ) -> "AbstractPosterior[Prior[M, K], L]": ... - def __mul__(self, other): # noqa: F811 + def __mul__(self, other): # noqa: F811 r"""Combine the prior with a likelihood to form a posterior distribution. The product of a prior and likelihood is proportional to the posterior @@ -171,20 +202,20 @@ def __mul__(self, other): # noqa: F811 Example: ```pycon - >>> import gpjax as gpx - >>> meanf = gpx.mean_functions.Zero() - >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) - >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) - >>> prior * likelihood + >>> import gpjax as gpx + >>> meanf = gpx.mean_functions.Zero() + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) + >>> prior * likelihood ``` Args: - other (Likelihood): The likelihood distribution of the observed dataset. + other (Likelihood): The likelihood distribution of the observed dataset. Returns - Posterior: The relevant GP posterior for the given prior and - likelihood. Special cases are accounted for where the model - is conjugate. + Posterior: The relevant GP posterior for the given prior and + likelihood. Special cases are accounted for where the model + is conjugate. """ return construct_posterior(prior=self, likelihood=other) @@ -194,33 +225,38 @@ def __mul__(self, other): # noqa: F811 def __rmul__(self, other: GL) -> "ConjugatePosterior[Prior[M, K], GL]": ... @tp.overload - def __rmul__( # noqa: F811 + def __rmul__( # noqa: F811 self, other: NGL ) -> "NonConjugatePosterior[Prior[M, K], NGL]": ... @tp.overload - def __rmul__( # noqa: F811 + def __rmul__( # noqa: F811 self, other: L ) -> "AbstractPosterior[Prior[M, K], L]": ... - def __rmul__(self, other): # noqa: F811 + def __rmul__(self, other): # noqa: F811 r"""Combine the prior with a likelihood to form a posterior distribution. Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior. Args: - other (Likelihood): The likelihood distribution of the observed - dataset. + other (Likelihood): The likelihood distribution of the observed + dataset. Returns - Posterior: The relevant GP posterior for the given prior and - likelihood. Special cases are accounted for where the model - is conjugate. + Posterior: The relevant GP posterior for the given prior and + likelihood. Special cases are accounted for where the model + is conjugate. """ return self.__mul__(other) - def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: + def predict( + self, + test_inputs: Num[Array, "N D"], + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", + ) -> GaussianDistribution: r"""Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes a TFP distribution for a given set of inputs. @@ -230,28 +266,53 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> kernel = gpx.kernels.RBF() - >>> mean_function = gpx.mean_functions.Zero() - >>> prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel) - >>> prior.predict(jnp.linspace(0, 1, 100)[:, None]) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> kernel = gpx.kernels.RBF() + >>> mean_function = gpx.mean_functions.Zero() + >>> prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel) + >>> prior.predict(jnp.linspace(0, 1, 100)[:, None]) ``` Args: - test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the - prior distribution. + test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the + prior distribution. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ - mean_at_test = self.mean_function(test_inputs) - Kxx = self.kernel.gram(test_inputs) - Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) - Kxx = psd(Dense(Kxx_dense)) - return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx) + def _ret_full_cov( + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], LinearOperator]: + mean_at_test = self.mean_function(t) + Kxx = self.kernel.gram(t) + Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) + Kxx = psd(Dense(Kxx_dense)) + return jnp.atleast_1d(mean_at_test.squeeze()), Kxx + + def _ret_diag_cov( + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], LinearOperator]: + mean_at_test = self.mean_function(t) + Kxx = self.kernel.diagonal(t).diagonal + Kxx += self.jitter + Kxx = psd(Dense(Diagonal(Kxx).to_dense())) + return jnp.atleast_1d(mean_at_test.squeeze()), Kxx + + mu, cov = jax.lax.cond( + return_cov_type == "dense", + _ret_full_cov, + _ret_diag_cov, + test_inputs, + ) + + return GaussianDistribution(loc=mu, scale=cov) def sample_approx( self, @@ -285,28 +346,28 @@ def sample_approx( Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> key = jr.PRNGKey(123) - >>> - >>> meanf = gpx.mean_functions.Zero() - >>> kernel = gpx.kernels.RBF(n_dims=1) - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) - >>> - >>> sample_fn = prior.sample_approx(10, key) - >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> key = jr.PRNGKey(123) + >>> + >>> meanf = gpx.mean_functions.Zero() + >>> kernel = gpx.kernels.RBF(n_dims=1) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> + >>> sample_fn = prior.sample_approx(10, key) + >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) ``` Args: - num_samples (int): The desired number of samples. - key (KeyArray): The random seed used for the sample(s). - num_features (int): The number of features used when approximating the - kernel. + num_samples (int): The desired number of samples. + key (KeyArray): The random seed used for the sample(s). + num_features (int): The number of features used when approximating the + kernel. Returns: - FunctionalSample: A function representing an approximate sample from the - Gaussian process prior. + FunctionalSample: A function representing an approximate sample from the + Gaussian process prior. """ if (not isinstance(num_samples, int)) or num_samples <= 0: @@ -329,7 +390,7 @@ def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: ####################### # GP Posteriors -####################### +#######################from gpjax.linalg.operators import LinearOperator class AbstractPosterior(nnx.Module, tp.Generic[P, L]): r"""Abstract Gaussian process posterior. @@ -346,17 +407,21 @@ def __init__( r"""Construct a Gaussian process posterior. Args: - prior (AbstractPrior): The prior distribution. - likelihood (AbstractLikelihood): The likelihood distribution. - jitter (float): A small constant added to the diagonal of the - covariance matrix to ensure numerical stability. + prior (AbstractPrior): The prior distribution. + likelihood (AbstractLikelihood): The likelihood distribution. + jitter (float): A small constant added to the diagonal of the + covariance matrix to ensure numerical stability. """ self.prior = prior self.likelihood = likelihood self.jitter = jitter def __call__( - self, test_inputs: Num[Array, "N D"], train_data: Dataset + self, + test_inputs: Num[Array, "N D"], + train_data: Dataset, + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Evaluate the Gaussian process posterior at the given points. @@ -370,30 +435,46 @@ def __call__( `__call__` method and should instead define a `predict` method. Args: - test_inputs: Input locations where the GP should be evaluated. - train_data: Training dataset to condition on. + test_inputs: Input locations where the GP should be evaluated. + train_data: Training dataset to condition on. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ - return self.predict(test_inputs, train_data) + return self.predict( + test_inputs, + train_data, + return_cov_type=return_cov_type, + ) @abstractmethod def predict( - self, test_inputs: Num[Array, "N D"], train_data: Dataset + self, + test_inputs: Num[Array, "N D"], + train_data: Dataset, + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Compute the latent function's multivariate normal distribution for a given set of parameters. For any class inheriting the `AbstractPosterior` class, this method must be implemented. Args: - test_inputs: Input locations where the GP should be evaluated. - train_data: Training dataset to condition on. + test_inputs: Input locations where the GP should be evaluated. + train_data: Training dataset to condition on. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ raise NotImplementedError @@ -414,29 +495,29 @@ class ConjugatePosterior(AbstractPosterior[P, GL]): ```math \begin{align} p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star}, \mathbf{f} \mid \mathbf{y})\\ - & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} + & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} \end{align} ``` where ```math \begin{align} - \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ + \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). \end{align} ``` Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> - >>> prior = gpx.gps.Prior( - mean_function = gpx.mean_functions.Zero(), - kernel = gpx.kernels.RBF() - ) - >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) - >>> - >>> posterior = prior * likelihood + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> prior = gpx.gps.Prior( + mean_function = gpx.mean_functions.Zero(), + kernel = gpx.kernels.RBF() + ) + >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) + >>> + >>> posterior = prior * likelihood ``` """ @@ -444,6 +525,8 @@ def predict( self, test_inputs: Num[Array, "N D"], train_data: Dataset, + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Query the predictive posterior distribution. @@ -454,13 +537,13 @@ def predict( The predictive distribution of a conjugate GP is given by $$ - p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\ - & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} + p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\ + & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} $$ where $$ - \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ - \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). + \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ + \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). $$ The conditioning set is a GPJax `Dataset` object, whilst predictions @@ -468,62 +551,110 @@ def predict( Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> - >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) - >>> ytrain = jnp.sin(xtrain) - >>> D = gpx.Dataset(X=xtrain, y=ytrain) - >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) - >>> - >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) - >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) - >>> predictive_dist = posterior(xtest, D) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) + >>> ytrain = jnp.sin(xtrain) + >>> D = gpx.Dataset(X=xtrain, y=ytrain) + >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) + >>> + >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) + >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) + >>> predictive_dist = posterior(xtest, D) ``` - Args: - test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the - predictive distribution is evaluated. - train_data (Dataset): A `gpx.Dataset` object that contains the input and - output data used for training dataset. - - Returns: - GaussianDistribution: A function that accepts an input array and - returns the predictive distribution as a `GaussianDistribution`. - """ - # Unpack training data - x, y = train_data.X, train_data.y - - # Unpack test inputs - t = test_inputs - - # Observation noise o² - obs_noise = self.likelihood.obs_stddev.value**2 - mx = self.prior.mean_function(x) - - # Precompute Gram matrix, Kxx, at training inputs, x - Kxx = self.prior.kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) - Kxx = Dense(Kxx_dense) - - Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise - Sigma = psd(Dense(Sigma_dense)) - L_sigma = lower_cholesky(Sigma) - - mean_t = self.prior.mean_function(t) - Ktt = self.prior.kernel.gram(t) - Kxt = self.prior.kernel.cross_covariance(x, t) - - L_inv_Kxt = solve(L_sigma, Kxt) - L_inv_y_diff = solve(L_sigma, y - mx) - - mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) - - covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) - covariance = add_jitter(covariance, self.prior.jitter) - covariance = psd(Dense(covariance)) - - return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) + Args: + test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the + predictive distribution is evaluated. + train_data (Dataset): A `gpx.Dataset` object that contains the input and + output data used for training dataset. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. + + Returns: + GaussianDistribution: A function that accepts an input array and + returns the predictive distribution as a `GaussianDistribution`. + """ + + def _ret_full_cov( + x: Num[Array, "N D"], + y: Num[Array, "N Q"], + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], LinearOperator]: + # Observation noise o² + obs_noise = jnp.square(self.likelihood.obs_stddev.value) + mx = self.prior.mean_function(x) + + # Precompute Gram matrix, Kxx, at training inputs, x + Kxx = self.prior.kernel.gram(x) + Kxx = add_jitter(Kxx.to_dense(), self.jitter) + + Sigma_dense = Kxx + jnp.eye(Kxx.shape[0]) * obs_noise + Sigma = psd(Dense(Sigma_dense)) + L_sigma = lower_cholesky(Sigma) + + mean_t = self.prior.mean_function(t) + Ktt = self.prior.kernel.gram(t) + Kxt = self.prior.kernel.cross_covariance(x, t) + + L_inv_Kxt = solve(L_sigma, Kxt) + L_inv_y_diff = solve(L_sigma, y - mx) + + mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) + mean = jnp.atleast_1d(mean.squeeze()) + + covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) + covariance = add_jitter(covariance, self.prior.jitter) + covariance = psd(Dense(covariance)) + return mean, covariance + + def _ret_diag_cov( + x: Num[Array, "N D"], + y: Num[Array, "N Q"], + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], LinearOperator]: + # Observation noise o² + obs_noise = jnp.square(self.likelihood.obs_stddev.value) + mx = self.prior.mean_function(x) + + # Precompute Gram matrix, Kxx, at training inputs, x + Kxx = self.prior.kernel.diagonal(x).diagonal + Kxx += self.jitter + + Sigma_dense = Kxx + obs_noise + Sigma = psd(Diagonal(Sigma_dense)) + L_sigma = lower_cholesky(Sigma) + + mean_t = self.prior.mean_function(t) + Ktt = self.prior.kernel.diagonal(t).diagonal[:, jnp.newaxis] + Kxt = self.prior.kernel.cross_covariance(x, t) + + # TODO: The following are all diagonal solves, so we can just + # do vector addition as needed. We should furthermore return + # a Diagonal covariance and not a Dense. + L_inv_Kxt_diag = jnp.diag(solve(L_sigma, Kxt))[:, jnp.newaxis] + L_inv_y_diff_diag = jnp.diag(solve(L_sigma, y - mx))[:, jnp.newaxis] + + mean = mean_t + L_inv_Kxt_diag * L_inv_y_diff_diag + mean = jnp.atleast_1d(mean.squeeze()) + covariance = Ktt - jnp.square(L_inv_Kxt_diag) + covariance += self.prior.jitter + covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) + return mean, covariance + + mu, cov = jax.lax.cond( + return_cov_type == "dense", + _ret_full_cov, + _ret_diag_cov, + train_data.X, + train_data.y, + test_inputs, + ) + + return GaussianDistribution(loc=mu, scale=cov) def sample_approx( self, @@ -560,14 +691,14 @@ def sample_approx( can be evaluated with constant cost regardless of the required number of queries. Args: - num_samples (int): The desired number of samples. - key (KeyArray): The random seed used for the sample(s). - num_features (int): The number of features used when approximating the - kernel. + num_samples (int): The desired number of samples. + key (KeyArray): The random seed used for the sample(s). + num_features (int): The number of features used when approximating the + kernel. Returns: - FunctionalSample: A function representing an approximate sample from the Gaussian - process prior. + FunctionalSample: A function representing an approximate sample from the Gaussian + process prior. """ if (not isinstance(num_samples, int)) or num_samples <= 0: raise ValueError("num_samples must be a positive integer") @@ -584,9 +715,9 @@ def sample_approx( y = train_data.y - self.prior.mean_function(train_data.X) Phi = fourier_feature_fn(train_data.X) canonical_weights = solve( - Sigma, - y + eps - jnp.inner(Phi, fourier_weights), - ) # [N, B] + Sigma, + y + eps - jnp.inner(Phi, fourier_weights), + ) # [N, B] def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: fourier_features = fourier_feature_fn(test_inputs) @@ -633,10 +764,10 @@ def __init__( r"""Construct a non-conjugate Gaussian process posterior. Args: - prior (AbstractPrior): The prior distribution. - likelihood (AbstractLikelihood): The likelihood distribution. - jitter (float): A small constant added to the diagonal of the - covariance matrix to ensure numerical stability. + prior (AbstractPrior): The prior distribution. + likelihood (AbstractLikelihood): The likelihood distribution. + jitter (float): A small constant added to the diagonal of the + covariance matrix to ensure numerical stability. """ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter) @@ -647,63 +778,124 @@ def __init__( self.latent = latent if isinstance(latent, Parameter) else Real(latent) self.key = key - def predict( - self, test_inputs: Num[Array, "N D"], train_data: Dataset - ) -> GaussianDistribution: - r"""Query the predictive posterior distribution. - - Conditional on a set of training data, compute the GP's posterior - predictive distribution for a given set of parameters. The returned - function can be evaluated at a set of test inputs to compute the - corresponding predictive density. Note, to gain predictions on the scale - of the original data, the returned distribution will need to be - transformed through the likelihood function's inverse link function. - - Args: + def predict( + self, + test_inputs: Num[Array, "N D"], + train_data: Dataset, + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", + ) -> GaussianDistribution: + r"""Query the predictive posterior distribution. + + Conditional on a set of training data, compute the GP's posterior + predictive distribution for a given set of parameters. The returned + function can be evaluated at a set of test inputs to compute the + corresponding predictive density. Note, to gain predictions on the scale + of the original data, the returned distribution will need to be + transformed through the likelihood function's inverse link function. + + Args: + test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the + predictive distribution is evaluated. train_data (Dataset): A `gpx.Dataset` object that contains the input - and output data used for training dataset. + and output data used for training dataset. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. - Returns: + Returns: GaussianDistribution: A function that accepts an - input array and returns the predictive distribution as - a `dx.Distribution`. - """ - # Unpack training data - x = train_data.X - - # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel - - # Precompute lower triangular of Gram matrix, Lx, at training inputs, x - Kxx = kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter) - Kxx = psd(Dense(Kxx_dense)) - Lx = lower_cholesky(Kxx) - - # Unpack test inputs - t = test_inputs - - # Compute terms of the posterior predictive distribution - Ktx = kernel.cross_covariance(t, x) - Ktt = kernel.gram(t) - mean_t = mean_function(t) - - # Lx⁻¹ Kxt - Lx_inv_Kxt = solve(Lx, Ktx.T) - - # Whitened function values, wx, corresponding to the inputs, x - wx = self.latent.value - - # μt + Ktx Lx⁻¹ wx - mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx) - - # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. - covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) - covariance = add_jitter(covariance, self.prior.jitter) - covariance = psd(Dense(covariance)) + input array and returns the predictive distribution as + a `dx.Distribution`. + """ + + def _ret_full_cov( + x: Num[Array, "N D"], + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], Dense]: + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Precompute lower triangular of Gram matrix + Kxx = kernel.gram(x) + Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter) + Kxx = psd(Dense(Kxx_dense)) + Lx = lower_cholesky(Kxx) + + # Compute terms of the posterior predictive distribution + Ktx = kernel.cross_covariance(t, x) + Ktt = kernel.gram(t) + mean_t = mean_function(t) + + # Lx⁻¹ Kxt + Lx_inv_Kxt = solve(Lx, Ktx.T) + + # Whitened function values, wx, corresponding to the inputs, x + wx = self.latent.value + + # μt + Ktx Lx⁻¹ wx + mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx) + mean = jnp.atleast_1d(mean.squeeze()) + + # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure + # to compute Schur complement more efficiently. + covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) + covariance = add_jitter(covariance, self.prior.jitter) + covariance = psd(Dense(covariance)) + + return mean, covariance + + def _ret_diag_cov( + x: Num[Array, "N D"], + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], Dense]: + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Precompute lower triangular of Gram matrix + Kxx = kernel.diagonal(x).diagonal + Kxx += self.prior.jitter + Kxx = psd(Diagonal(Kxx)) + Lx = lower_cholesky(Kxx) + + # Compute terms of the posterior predictive distribution + Ktx = kernel.cross_covariance(t, x) + Ktt = kernel.diagonal(t).diagonal[:, jnp.newaxis] + mean_t = mean_function(t) + + # Lx⁻¹ Kxt + Lx_inv_Kxt_diag = jnp.diag(solve(Lx, Ktx.T))[:, jnp.newaxis] + + # Whitened function values, wx, corresponding to the inputs, x + wx = self.latent.value + + # μt + Ktx Lx⁻¹ wx + mean = mean_t + Lx_inv_Kxt_diag * wx + mean = jnp.atleast_1d(mean.squeeze()) + + # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure + # to compute Schur complement more efficiently. + covariance = Ktt - jnp.square(Lx_inv_Kxt_diag) + covariance += self.prior.jitter + # It would be nice to return a Diagonal here, but the pytree needs + # to be the same for both cond branches and the other branch needs + # to return a Dense. + # They are both LinearOperators, but they inherit from that class + # and hence are not the same pytree anymore. + covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) + + return mean, covariance + + mu, cov = jax.lax.cond( + return_cov_type == "dense", + _ret_full_cov, + _ret_diag_cov, + train_data.X, + test_inputs, + ) - return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) + return GaussianDistribution(mu, cov) ####################### @@ -781,4 +973,4 @@ def eval_fourier_features(test_inputs: Float[Array, "N D"]) -> Float[Array, "N L "ConjugatePosterior", "NonConjugatePosterior", "construct_posterior", -] +] \ No newline at end of file diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 7d3e40b40..432aeeb74 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -123,7 +123,7 @@ def gram(self, x: Num[Array, "N D"]) -> LinearOperator: """ return self.compute_engine.gram(self, x) - def diagonal(self, x: Num[Array, "N D"]) -> Float[Array, " N"]: + def diagonal(self, x: Num[Array, "N D"]) -> LinearOperator: r"""Compute the diagonal of the gram matrix of the kernel. Args: diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index e0312017f..46c73a6fd 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -17,6 +17,7 @@ import typing as tp from jax import vmap +import jax.numpy as jnp from jaxtyping import ( Float, Num, @@ -48,52 +49,58 @@ class AbstractKernelComputation: def _gram( self, kernel: K, - x: Num[Array, "N D"], + inputs: Num[Array, "N D"], ) -> Float[Array, "N N"]: - Kxx = self.cross_covariance(kernel, x, x) - return Kxx + return self.cross_covariance(kernel, inputs, inputs) def gram( self, kernel: K, - x: Num[Array, "N D"], + inputs: Num[Array, "N D"], ) -> Dense: r"""For a given kernel, compute Gram covariance operator of the kernel function on an input matrix of shape `(N, D)`. Args: kernel: the kernel function. - x: the inputs to the kernel function of shape `(N, D)`. + inputs: the inputs to the kernel function of shape `(N, D)`. Returns: The Gram covariance of the kernel function as a linear operator. """ - Kxx = self.cross_covariance(kernel, x, x) + Kxx = self._gram(kernel, inputs) + # Kxx = self.cross_covariance(kernel, inputs, inputs) return psd(Dense(Kxx)) @abc.abstractmethod def _cross_covariance( - self, kernel: K, x: Num[Array, "N D"], y: Num[Array, "M D"] + self, + kernel: K, + first_inputs: Num[Array, "N D"], + second_inputs: Num[Array, "M D"], ) -> Float[Array, "N M"]: ... def cross_covariance( - self, kernel: K, x: Num[Array, "N D"], y: Num[Array, "M D"] + self, + kernel: K, + first_inputs: Num[Array, "N D"], + second_inputs: Num[Array, "M D"], ) -> Float[Array, "N M"]: r"""For a given kernel, compute the cross-covariance matrix on an a pair of input matrices with shape `(N, D)` and `(M, D)`. Args: kernel: the kernel function. - x: the first input matrix of shape `(N, D)`. - y: the second input matrix of shape `(M, D)`. + first_inputs: the first input matrix of shape `(N, D)`. + second_inputs: the second input matrix of shape `(M, D)`. Returns: The computed cross-covariance of shape `(N, M)`. """ - return self._cross_covariance(kernel, x, y) + return self._cross_covariance(kernel, first_inputs, second_inputs) - def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal: - return psd(Diagonal(vmap(lambda x: kernel(x, x))(inputs))) + def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Float[Array, "N N"]: + return jnp.diag(vmap(lambda x: kernel(x, x))(inputs)) def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal: r"""For a given kernel, compute the elementwise diagonal of the @@ -106,4 +113,5 @@ def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal: Returns: The computed diagonal variance as a `Diagonal` linear operator. """ - return self._diagonal(kernel, inputs) + Kxx = self._diagonal(kernel, inputs) + return psd(Diagonal(Kxx)) diff --git a/tests/test_gps.py b/tests/test_gps.py index e40394c75..254dca393 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -75,6 +75,37 @@ def test_abstract_posterior(): AbstractPosterior() +@pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("kernel", [RBF, Matern52]) +@pytest.mark.parametrize("mean_function", [Zero, Constant]) +def test_prior_with_diag( + num_datapoints: int, + kernel: type[AbstractKernel], + mean_function: Type[AbstractMeanFunction], +) -> None: + # Create prior. + prior = Prior(mean_function=mean_function(), kernel=kernel()) + + # Check types. + assert isinstance(prior, Prior) + assert isinstance(prior, AbstractPrior) + + # Query a marginal distribution at some inputs. + inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + marginal_distribution = prior(inputs, return_cov_type="diagonal") + + # Ensure that the marginal distribution is a Gaussian. + assert isinstance(marginal_distribution, GaussianDistribution) + assert isinstance(marginal_distribution, NumpyroDistribution) + + # Ensure that the marginal distribution has the correct shape. + mu = marginal_distribution.mean + sigma = marginal_distribution.covariance() + assert mu.shape == (num_datapoints,) + assert sigma.shape == (num_datapoints, num_datapoints) + assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) + + @pytest.mark.parametrize("num_datapoints", [1, 10]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) @@ -105,6 +136,48 @@ def test_prior( assert sigma.shape == (num_datapoints, num_datapoints) +@pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("kernel", [RBF, Matern52]) +@pytest.mark.parametrize("mean_function", [Zero, Constant]) +def test_conjugate_posterior_with_diag( + num_datapoints: int, + kernel: type[AbstractKernel], + mean_function: type[AbstractMeanFunction], +) -> None: + # Create a dataset. + key = jr.key(123) + x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + + # Define prior. + prior = Prior(mean_function=mean_function(), kernel=kernel()) + + # Define a likelihood. + likelihood = Gaussian(num_datapoints=num_datapoints) + + # Construct the posterior via the class. + posterior = ConjugatePosterior(prior=prior, likelihood=likelihood) + + # Check types. + assert isinstance(posterior, ConjugatePosterior) + + # Query a marginal distribution of the posterior at some inputs. + inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + marginal_distribution = posterior(inputs, D, return_cov_type="diagonal") + + # Ensure that the marginal distribution is a Gaussian. + assert isinstance(marginal_distribution, GaussianDistribution) + assert isinstance(marginal_distribution, NumpyroDistribution) + + # Ensure that the marginal distribution has the correct shape. + mu = marginal_distribution.mean + sigma = marginal_distribution.covariance() + assert mu.shape == (num_datapoints,) + assert sigma.shape == (num_datapoints, num_datapoints) + assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) + + @pytest.mark.parametrize("num_datapoints", [1, 10]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) @@ -146,6 +219,54 @@ def test_conjugate_posterior( assert sigma.shape == (num_datapoints, num_datapoints) +@pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("kernel", [RBF, Matern52]) +@pytest.mark.parametrize("mean_function", [Zero, Constant]) +def test_nonconjugate_posterior_with_diag( + num_datapoints: int, + kernel: type[AbstractKernel], + mean_function: type[AbstractMeanFunction], +) -> None: + # Create a dataset. + key = jr.key(123) + x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + + # Define prior. + prior = Prior(mean_function=mean_function(), kernel=kernel()) + + # Define a likelihood. + likelihood = Bernoulli(num_datapoints=num_datapoints) + + # Construct the posterior via the class. + posterior = NonConjugatePosterior(prior=prior, likelihood=likelihood) + + # Check types. + assert isinstance(posterior, NonConjugatePosterior) + + # Check latent values. + latent_values = jr.normal(posterior.key, (num_datapoints, 1)) + assert (posterior.latent.value == latent_values).all() + + # Query a marginal distribution of the posterior at some inputs. + inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + marginal_distribution = posterior(inputs, D, return_cov_type="diagonal") + + # Ensure that the marginal distribution is a Gaussian. + assert isinstance(marginal_distribution, GaussianDistribution) + assert isinstance(marginal_distribution, NumpyroDistribution) + + # Ensure that the marginal distribution has the correct shape. + mu = marginal_distribution.mean + sigma = marginal_distribution.covariance() + assert mu.shape == (num_datapoints,) + # We are still returning a full covariance, even though the off diagonal + # should all be zeros... + assert sigma.shape == (num_datapoints, num_datapoints) + assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) + + @pytest.mark.parametrize("num_datapoints", [1, 10]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) @@ -249,22 +370,22 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): p.sample_approx(1, key, 0.5) sampled_fn = p.sample_approx(1, key, 100) - assert isinstance(sampled_fn, Callable) # check type + assert isinstance(sampled_fn, Callable) # check type x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) evals = sampled_fn(x) - assert evals.shape == (num_datapoints, 1.0) # check shape + assert evals.shape == (num_datapoints, 1.0) # check shape sampled_fn_2 = p.sample_approx(1, key, 100) evals_2 = sampled_fn_2(x) max_delta = jnp.max(jnp.abs(evals - evals_2)) - assert max_delta == 0.0 # samples same for same seed + assert max_delta == 0.0 # samples same for same seed new_key = jr.key(12345) sampled_fn_3 = p.sample_approx(1, new_key, num_features=100) evals_3 = sampled_fn_3(x) max_delta = jnp.max(jnp.abs(evals - evals_3)) - assert max_delta > 0.01 # samples different for different seed + assert max_delta > 0.01 # samples different for different seed # Check validty of samples using Monte-Carlo sampled_fn = p.sample_approx(10_000, key, 100) @@ -276,8 +397,8 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): true_var = jnp.diagonal(true_predictive.covariance()) max_error_in_mean = jnp.max(jnp.abs(approx_mean - true_mean)) max_error_in_var = jnp.max(jnp.abs(approx_var - true_var)) - assert max_error_in_mean < 0.02 # check that samples are correct - assert max_error_in_var < 0.05 # check that samples are correct + assert max_error_in_mean < 0.02 # check that samples are correct + assert max_error_in_var < 0.05 # check that samples are correct @pytest.mark.parametrize("num_datapoints", [1, 5]) @@ -286,47 +407,47 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function): kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1) p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian( - num_datapoints=num_datapoints + num_datapoints=num_datapoints ) key = jr.key(123) x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) y = ( - jnp.mean(jnp.sin(x), 1, keepdims=True) - + jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1 + jnp.mean(jnp.sin(x), 1, keepdims=True) + + jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1 ) D = Dataset(X=x, y=y) # with pytest.raises(ValueError): - # p.sample_approx(-1, D, key) + # p.sample_approx(-1, D, key) # with pytest.raises(ValueError): - # p.sample_approx(0, D, key) + # p.sample_approx(0, D, key) # with pytest.raises(ValidationErrors): - # p.sample_approx(0.5, D, key) + # p.sample_approx(0.5, D, key) # with pytest.raises(ValueError): - # p.sample_approx(1, D, key, -10) + # p.sample_approx(1, D, key, -10) # with pytest.raises(ValueError): - # p.sample_approx(1, D, key, 0) + # p.sample_approx(1, D, key, 0) # with pytest.raises(ValidationErrors): - # p.sample_approx(1, D, key, 0.5) + # p.sample_approx(1, D, key, 0.5) sampled_fn = p.sample_approx(1, D, key, num_features=100) - assert isinstance(sampled_fn, Callable) # check type + assert isinstance(sampled_fn, Callable) # check type x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) evals = sampled_fn(x) - assert evals.shape == (num_datapoints, 1.0) # check shape + assert evals.shape == (num_datapoints, 1.0) # check shape sampled_fn_2 = p.sample_approx(1, D, key, num_features=100) evals_2 = sampled_fn_2(x) max_delta = jnp.max(jnp.abs(evals - evals_2)) - assert max_delta == 0.0 # samples same for same seed + assert max_delta == 0.0 # samples same for same seed new_key = jr.key(12345) sampled_fn_3 = p.sample_approx(1, D, new_key, num_features=100) evals_3 = sampled_fn_3(x) max_delta = jnp.max(jnp.abs(evals - evals_3)) - assert max_delta > 0.01 # samples different for different seed + assert max_delta > 0.01 # samples different for different seed # Check validty of samples using Monte-Carlo sampled_fn = p.sample_approx(10_000, D, key, num_features=100) @@ -338,9 +459,9 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function true_var = jnp.diagonal(true_predictive.covariance()) max_error_in_mean = jnp.max(jnp.abs(approx_mean - true_mean)) max_error_in_var = jnp.max(jnp.abs(approx_var - true_var)) - assert max_error_in_mean < 0.02 # check that samples are correct - assert max_error_in_var < 0.05 # check that samples are correct + assert max_error_in_mean < 0.02 # check that samples are correct + assert max_error_in_var < 0.05 # check that samples are correct if __name__ == "__main__": - test_conjugate_posterior_sample_approx(10, RBF, Zero) + test_conjugate_posterior_sample_approx(10, RBF, Zero) \ No newline at end of file From 8636b6278c9b6113907dd41af25d1c5a3a412b0b Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sat, 1 Nov 2025 09:01:21 -0600 Subject: [PATCH 02/19] linting --- gpjax/gps.py | 32 +++++++++++++++--------------- gpjax/kernels/computations/base.py | 7 ++----- tests/test_gps.py | 32 +++++++++++++++--------------- 3 files changed, 34 insertions(+), 37 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index cddce3ee5..2f4dbbbee 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -35,9 +35,9 @@ from gpjax.kernels import RFF from gpjax.kernels.base import AbstractKernel from gpjax.likelihoods import ( - AbstractLikelihood, - Gaussian, - NonGaussian, + AbstractLikelihood, + Gaussian, + NonGaussian, ) from gpjax.linalg import ( Dense, @@ -116,9 +116,9 @@ def __call__( of the Gaussian process. """ return self.predict( - test_inputs, - return_cov_type=return_cov_type, - ) + test_inputs, + return_cov_type=return_cov_type, + ) @abstractmethod def predict( @@ -179,16 +179,16 @@ class Prior(AbstractPrior[M, K]): def __mul__(self, other: GL) -> "ConjugatePosterior[Prior[M, K], GL]": ... @tp.overload - def __mul__( # noqa: F811 + def __mul__( # noqa: F811 self, other: NGL ) -> "NonConjugatePosterior[Prior[M, K], NGL]": ... @tp.overload - def __mul__( # noqa: F811 + def __mul__( # noqa: F811 self, other: L ) -> "AbstractPosterior[Prior[M, K], L]": ... - def __mul__(self, other): # noqa: F811 + def __mul__(self, other): # noqa: F811 r"""Combine the prior with a likelihood to form a posterior distribution. The product of a prior and likelihood is proportional to the posterior @@ -225,16 +225,16 @@ def __mul__(self, other): # noqa: F811 def __rmul__(self, other: GL) -> "ConjugatePosterior[Prior[M, K], GL]": ... @tp.overload - def __rmul__( # noqa: F811 + def __rmul__( # noqa: F811 self, other: NGL ) -> "NonConjugatePosterior[Prior[M, K], NGL]": ... @tp.overload - def __rmul__( # noqa: F811 + def __rmul__( # noqa: F811 self, other: L ) -> "AbstractPosterior[Prior[M, K], L]": ... - def __rmul__(self, other): # noqa: F811 + def __rmul__(self, other): # noqa: F811 r"""Combine the prior with a likelihood to form a posterior distribution. Reimplement the multiplication operator to allow for order-invariant @@ -715,9 +715,9 @@ def sample_approx( y = train_data.y - self.prior.mean_function(train_data.X) Phi = fourier_feature_fn(train_data.X) canonical_weights = solve( - Sigma, - y + eps - jnp.inner(Phi, fourier_weights), - ) # [N, B] + Sigma, + y + eps - jnp.inner(Phi, fourier_weights), + ) # [N, B] def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: fourier_features = fourier_feature_fn(test_inputs) @@ -973,4 +973,4 @@ def eval_fourier_features(test_inputs: Float[Array, "N D"]) -> Float[Array, "N L "ConjugatePosterior", "NonConjugatePosterior", "construct_posterior", -] \ No newline at end of file +] diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index 46c73a6fd..f32a1df1d 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -17,7 +17,6 @@ import typing as tp from jax import vmap -import jax.numpy as jnp from jaxtyping import ( Float, Num, @@ -69,7 +68,6 @@ def gram( The Gram covariance of the kernel function as a linear operator. """ Kxx = self._gram(kernel, inputs) - # Kxx = self.cross_covariance(kernel, inputs, inputs) return psd(Dense(Kxx)) @abc.abstractmethod @@ -100,7 +98,7 @@ def cross_covariance( return self._cross_covariance(kernel, first_inputs, second_inputs) def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Float[Array, "N N"]: - return jnp.diag(vmap(lambda x: kernel(x, x))(inputs)) + return psd(Diagonal(vmap(lambda x: kernel(x, x))(inputs))) def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal: r"""For a given kernel, compute the elementwise diagonal of the @@ -113,5 +111,4 @@ def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal: Returns: The computed diagonal variance as a `Diagonal` linear operator. """ - Kxx = self._diagonal(kernel, inputs) - return psd(Diagonal(Kxx)) + return self._diagonal(kernel, inputs) diff --git a/tests/test_gps.py b/tests/test_gps.py index 254dca393..4925a681c 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -370,22 +370,22 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): p.sample_approx(1, key, 0.5) sampled_fn = p.sample_approx(1, key, 100) - assert isinstance(sampled_fn, Callable) # check type + assert isinstance(sampled_fn, Callable) # check type x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) evals = sampled_fn(x) - assert evals.shape == (num_datapoints, 1.0) # check shape + assert evals.shape == (num_datapoints, 1.0) # check shape sampled_fn_2 = p.sample_approx(1, key, 100) evals_2 = sampled_fn_2(x) max_delta = jnp.max(jnp.abs(evals - evals_2)) - assert max_delta == 0.0 # samples same for same seed + assert max_delta == 0.0 # samples same for same seed new_key = jr.key(12345) sampled_fn_3 = p.sample_approx(1, new_key, num_features=100) evals_3 = sampled_fn_3(x) max_delta = jnp.max(jnp.abs(evals - evals_3)) - assert max_delta > 0.01 # samples different for different seed + assert max_delta > 0.01 # samples different for different seed # Check validty of samples using Monte-Carlo sampled_fn = p.sample_approx(10_000, key, 100) @@ -397,8 +397,8 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): true_var = jnp.diagonal(true_predictive.covariance()) max_error_in_mean = jnp.max(jnp.abs(approx_mean - true_mean)) max_error_in_var = jnp.max(jnp.abs(approx_var - true_var)) - assert max_error_in_mean < 0.02 # check that samples are correct - assert max_error_in_var < 0.05 # check that samples are correct + assert max_error_in_mean < 0.02 # check that samples are correct + assert max_error_in_var < 0.05 # check that samples are correct @pytest.mark.parametrize("num_datapoints", [1, 5]) @@ -407,14 +407,14 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function): kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1) p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian( - num_datapoints=num_datapoints + num_datapoints=num_datapoints ) key = jr.key(123) x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) y = ( - jnp.mean(jnp.sin(x), 1, keepdims=True) - + jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1 + jnp.mean(jnp.sin(x), 1, keepdims=True) + + jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1 ) D = Dataset(X=x, y=y) @@ -432,22 +432,22 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function # p.sample_approx(1, D, key, 0.5) sampled_fn = p.sample_approx(1, D, key, num_features=100) - assert isinstance(sampled_fn, Callable) # check type + assert isinstance(sampled_fn, Callable) # check type x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) evals = sampled_fn(x) - assert evals.shape == (num_datapoints, 1.0) # check shape + assert evals.shape == (num_datapoints, 1.0) # check shape sampled_fn_2 = p.sample_approx(1, D, key, num_features=100) evals_2 = sampled_fn_2(x) max_delta = jnp.max(jnp.abs(evals - evals_2)) - assert max_delta == 0.0 # samples same for same seed + assert max_delta == 0.0 # samples same for same seed new_key = jr.key(12345) sampled_fn_3 = p.sample_approx(1, D, new_key, num_features=100) evals_3 = sampled_fn_3(x) max_delta = jnp.max(jnp.abs(evals - evals_3)) - assert max_delta > 0.01 # samples different for different seed + assert max_delta > 0.01 # samples different for different seed # Check validty of samples using Monte-Carlo sampled_fn = p.sample_approx(10_000, D, key, num_features=100) @@ -459,9 +459,9 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function true_var = jnp.diagonal(true_predictive.covariance()) max_error_in_mean = jnp.max(jnp.abs(approx_mean - true_mean)) max_error_in_var = jnp.max(jnp.abs(approx_var - true_var)) - assert max_error_in_mean < 0.02 # check that samples are correct - assert max_error_in_var < 0.05 # check that samples are correct + assert max_error_in_mean < 0.02 # check that samples are correct + assert max_error_in_var < 0.05 # check that samples are correct if __name__ == "__main__": - test_conjugate_posterior_sample_approx(10, RBF, Zero) \ No newline at end of file + test_conjugate_posterior_sample_approx(10, RBF, Zero) From bb1f47b548a95ea745347c370309a731f1bba481 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sat, 1 Nov 2025 09:10:48 -0600 Subject: [PATCH 03/19] reverted computations base --- gpjax/kernels/computations/base.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index f32a1df1d..e0312017f 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -48,56 +48,51 @@ class AbstractKernelComputation: def _gram( self, kernel: K, - inputs: Num[Array, "N D"], + x: Num[Array, "N D"], ) -> Float[Array, "N N"]: - return self.cross_covariance(kernel, inputs, inputs) + Kxx = self.cross_covariance(kernel, x, x) + return Kxx def gram( self, kernel: K, - inputs: Num[Array, "N D"], + x: Num[Array, "N D"], ) -> Dense: r"""For a given kernel, compute Gram covariance operator of the kernel function on an input matrix of shape `(N, D)`. Args: kernel: the kernel function. - inputs: the inputs to the kernel function of shape `(N, D)`. + x: the inputs to the kernel function of shape `(N, D)`. Returns: The Gram covariance of the kernel function as a linear operator. """ - Kxx = self._gram(kernel, inputs) + Kxx = self.cross_covariance(kernel, x, x) return psd(Dense(Kxx)) @abc.abstractmethod def _cross_covariance( - self, - kernel: K, - first_inputs: Num[Array, "N D"], - second_inputs: Num[Array, "M D"], + self, kernel: K, x: Num[Array, "N D"], y: Num[Array, "M D"] ) -> Float[Array, "N M"]: ... def cross_covariance( - self, - kernel: K, - first_inputs: Num[Array, "N D"], - second_inputs: Num[Array, "M D"], + self, kernel: K, x: Num[Array, "N D"], y: Num[Array, "M D"] ) -> Float[Array, "N M"]: r"""For a given kernel, compute the cross-covariance matrix on an a pair of input matrices with shape `(N, D)` and `(M, D)`. Args: kernel: the kernel function. - first_inputs: the first input matrix of shape `(N, D)`. - second_inputs: the second input matrix of shape `(M, D)`. + x: the first input matrix of shape `(N, D)`. + y: the second input matrix of shape `(M, D)`. Returns: The computed cross-covariance of shape `(N, M)`. """ - return self._cross_covariance(kernel, first_inputs, second_inputs) + return self._cross_covariance(kernel, x, y) - def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Float[Array, "N N"]: + def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal: return psd(Diagonal(vmap(lambda x: kernel(x, x))(inputs))) def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal: From b5c49dce454bd769d2b4b4f70a4d7a5be55b7070 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sat, 1 Nov 2025 11:10:50 -0600 Subject: [PATCH 04/19] fixed docstrings --- gpjax/gps.py | 362 +++++++++++++++++++++++++-------------------------- 1 file changed, 181 insertions(+), 181 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 2f4dbbbee..f524646c1 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -166,10 +166,10 @@ class Prior(AbstractPrior[M, K]): Example: ```python - >>> import gpjax as gpx - >>> kernel = gpx.kernels.RBF() - >>> meanf = gpx.mean_functions.Zero() - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> import gpjax as gpx + >>> kernel = gpx.kernels.RBF() + >>> meanf = gpx.mean_functions.Zero() + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) ``` """ @@ -202,12 +202,12 @@ def __mul__(self, other): # noqa: F811 Example: ```pycon - >>> import gpjax as gpx - >>> meanf = gpx.mean_functions.Zero() - >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) - >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) - >>> prior * likelihood + >>> import gpjax as gpx + >>> meanf = gpx.mean_functions.Zero() + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) + >>> prior * likelihood ``` Args: other (Likelihood): The likelihood distribution of the observed dataset. @@ -266,12 +266,12 @@ def predict( Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> kernel = gpx.kernels.RBF() - >>> mean_function = gpx.mean_functions.Zero() - >>> prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel) - >>> prior.predict(jnp.linspace(0, 1, 100)[:, None]) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> kernel = gpx.kernels.RBF() + >>> mean_function = gpx.mean_functions.Zero() + >>> prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel) + >>> prior.predict(jnp.linspace(0, 1, 100)[:, None]) ``` Args: @@ -346,17 +346,17 @@ def sample_approx( Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> key = jr.PRNGKey(123) - >>> - >>> meanf = gpx.mean_functions.Zero() - >>> kernel = gpx.kernels.RBF(n_dims=1) - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) - >>> - >>> sample_fn = prior.sample_approx(10, key) - >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> key = jr.PRNGKey(123) + >>> + >>> meanf = gpx.mean_functions.Zero() + >>> kernel = gpx.kernels.RBF(n_dims=1) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> + >>> sample_fn = prior.sample_approx(10, key) + >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) ``` Args: @@ -495,29 +495,29 @@ class ConjugatePosterior(AbstractPosterior[P, GL]): ```math \begin{align} p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star}, \mathbf{f} \mid \mathbf{y})\\ - & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} + & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} \end{align} ``` where ```math \begin{align} - \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ + \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). \end{align} ``` Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> - >>> prior = gpx.gps.Prior( - mean_function = gpx.mean_functions.Zero(), - kernel = gpx.kernels.RBF() - ) - >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) - >>> - >>> posterior = prior * likelihood + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> prior = gpx.gps.Prior( + mean_function = gpx.mean_functions.Zero(), + kernel = gpx.kernels.RBF() + ) + >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) + >>> + >>> posterior = prior * likelihood ``` """ @@ -551,33 +551,33 @@ def predict( Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> - >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) - >>> ytrain = jnp.sin(xtrain) - >>> D = gpx.Dataset(X=xtrain, y=ytrain) - >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) - >>> - >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) - >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) - >>> predictive_dist = posterior(xtest, D) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) + >>> ytrain = jnp.sin(xtrain) + >>> D = gpx.Dataset(X=xtrain, y=ytrain) + >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) + >>> + >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) + >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) + >>> predictive_dist = posterior(xtest, D) ``` - Args: - test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the - predictive distribution is evaluated. - train_data (Dataset): A `gpx.Dataset` object that contains the input and - output data used for training dataset. - return_cov_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. - - Returns: - GaussianDistribution: A function that accepts an input array and - returns the predictive distribution as a `GaussianDistribution`. - """ + Args: + test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the + predictive distribution is evaluated. + train_data (Dataset): A `gpx.Dataset` object that contains the input and + output data used for training dataset. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. + + Returns: + GaussianDistribution: A function that accepts an input array and + returns the predictive distribution as a `GaussianDistribution`. + """ def _ret_full_cov( x: Num[Array, "N D"], @@ -778,124 +778,124 @@ def __init__( self.latent = latent if isinstance(latent, Parameter) else Real(latent) self.key = key - def predict( - self, - test_inputs: Num[Array, "N D"], - train_data: Dataset, - *, - return_cov_type: Literal["dense", "diagonal"] = "dense", - ) -> GaussianDistribution: - r"""Query the predictive posterior distribution. - - Conditional on a set of training data, compute the GP's posterior - predictive distribution for a given set of parameters. The returned - function can be evaluated at a set of test inputs to compute the - corresponding predictive density. Note, to gain predictions on the scale - of the original data, the returned distribution will need to be - transformed through the likelihood function's inverse link function. - - Args: - test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the - predictive distribution is evaluated. - train_data (Dataset): A `gpx.Dataset` object that contains the input - and output data used for training dataset. - return_cov_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. - - Returns: - GaussianDistribution: A function that accepts an - input array and returns the predictive distribution as - a `dx.Distribution`. - """ - - def _ret_full_cov( - x: Num[Array, "N D"], - t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], Dense]: - mean_function = self.prior.mean_function - kernel = self.prior.kernel - - # Precompute lower triangular of Gram matrix - Kxx = kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter) - Kxx = psd(Dense(Kxx_dense)) - Lx = lower_cholesky(Kxx) - - # Compute terms of the posterior predictive distribution - Ktx = kernel.cross_covariance(t, x) - Ktt = kernel.gram(t) - mean_t = mean_function(t) - - # Lx⁻¹ Kxt - Lx_inv_Kxt = solve(Lx, Ktx.T) - - # Whitened function values, wx, corresponding to the inputs, x - wx = self.latent.value - - # μt + Ktx Lx⁻¹ wx - mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx) - mean = jnp.atleast_1d(mean.squeeze()) - - # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure - # to compute Schur complement more efficiently. - covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) - covariance = add_jitter(covariance, self.prior.jitter) - covariance = psd(Dense(covariance)) - - return mean, covariance - - def _ret_diag_cov( - x: Num[Array, "N D"], - t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], Dense]: - mean_function = self.prior.mean_function - kernel = self.prior.kernel - - # Precompute lower triangular of Gram matrix - Kxx = kernel.diagonal(x).diagonal - Kxx += self.prior.jitter - Kxx = psd(Diagonal(Kxx)) - Lx = lower_cholesky(Kxx) - - # Compute terms of the posterior predictive distribution - Ktx = kernel.cross_covariance(t, x) - Ktt = kernel.diagonal(t).diagonal[:, jnp.newaxis] - mean_t = mean_function(t) - - # Lx⁻¹ Kxt - Lx_inv_Kxt_diag = jnp.diag(solve(Lx, Ktx.T))[:, jnp.newaxis] - - # Whitened function values, wx, corresponding to the inputs, x - wx = self.latent.value - - # μt + Ktx Lx⁻¹ wx - mean = mean_t + Lx_inv_Kxt_diag * wx - mean = jnp.atleast_1d(mean.squeeze()) - - # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure - # to compute Schur complement more efficiently. - covariance = Ktt - jnp.square(Lx_inv_Kxt_diag) - covariance += self.prior.jitter - # It would be nice to return a Diagonal here, but the pytree needs - # to be the same for both cond branches and the other branch needs - # to return a Dense. - # They are both LinearOperators, but they inherit from that class - # and hence are not the same pytree anymore. - covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) - - return mean, covariance - - mu, cov = jax.lax.cond( - return_cov_type == "dense", - _ret_full_cov, - _ret_diag_cov, - train_data.X, - test_inputs, - ) + def predict( + self, + test_inputs: Num[Array, "N D"], + train_data: Dataset, + *, + return_cov_type: Literal["dense", "diagonal"] = "dense", + ) -> GaussianDistribution: + r"""Query the predictive posterior distribution. + + Conditional on a set of training data, compute the GP's posterior + predictive distribution for a given set of parameters. The returned + function can be evaluated at a set of test inputs to compute the + corresponding predictive density. Note, to gain predictions on the scale + of the original data, the returned distribution will need to be + transformed through the likelihood function's inverse link function. + + Args: + test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the + predictive distribution is evaluated. + train_data (Dataset): A `gpx.Dataset` object that contains the input + and output data used for training dataset. + return_cov_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. + + Returns: + GaussianDistribution: A function that accepts an + input array and returns the predictive distribution as + a `dx.Distribution`. + """ + + def _ret_full_cov( + x: Num[Array, "N D"], + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], Dense]: + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Precompute lower triangular of Gram matrix + Kxx = kernel.gram(x) + Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter) + Kxx = psd(Dense(Kxx_dense)) + Lx = lower_cholesky(Kxx) + + # Compute terms of the posterior predictive distribution + Ktx = kernel.cross_covariance(t, x) + Ktt = kernel.gram(t) + mean_t = mean_function(t) + + # Lx⁻¹ Kxt + Lx_inv_Kxt = solve(Lx, Ktx.T) + + # Whitened function values, wx, corresponding to the inputs, x + wx = self.latent.value + + # μt + Ktx Lx⁻¹ wx + mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx) + mean = jnp.atleast_1d(mean.squeeze()) + + # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure + # to compute Schur complement more efficiently. + covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) + covariance = add_jitter(covariance, self.prior.jitter) + covariance = psd(Dense(covariance)) + + return mean, covariance + + def _ret_diag_cov( + x: Num[Array, "N D"], + t: Num[Array, "N D"], + ) -> Tuple[Float[Array, " N"], Dense]: + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Precompute lower triangular of Gram matrix + Kxx = kernel.diagonal(x).diagonal + Kxx += self.prior.jitter + Kxx = psd(Diagonal(Kxx)) + Lx = lower_cholesky(Kxx) + + # Compute terms of the posterior predictive distribution + Ktx = kernel.cross_covariance(t, x) + Ktt = kernel.diagonal(t).diagonal[:, jnp.newaxis] + mean_t = mean_function(t) + + # Lx⁻¹ Kxt + Lx_inv_Kxt_diag = jnp.diag(solve(Lx, Ktx.T))[:, jnp.newaxis] + + # Whitened function values, wx, corresponding to the inputs, x + wx = self.latent.value + + # μt + Ktx Lx⁻¹ wx + mean = mean_t + Lx_inv_Kxt_diag * wx + mean = jnp.atleast_1d(mean.squeeze()) + + # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure + # to compute Schur complement more efficiently. + covariance = Ktt - jnp.square(Lx_inv_Kxt_diag) + covariance += self.prior.jitter + # It would be nice to return a Diagonal here, but the pytree needs + # to be the same for both cond branches and the other branch needs + # to return a Dense. + # They are both LinearOperators, but they inherit from that class + # and hence are not the same pytree anymore. + covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) + + return mean, covariance + + mu, cov = jax.lax.cond( + return_cov_type == "dense", + _ret_full_cov, + _ret_diag_cov, + train_data.X, + test_inputs, + ) - return GaussianDistribution(mu, cov) + return GaussianDistribution(mu, cov) ####################### From 54172d546def977f6527f800e6e6422ffbc91ba1 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sat, 1 Nov 2025 16:22:34 -0600 Subject: [PATCH 05/19] ran poe format --- examples/regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/regression.py b/examples/regression.py index 0136e025a..610c79cde 100644 --- a/examples/regression.py +++ b/examples/regression.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.2 +# jupytext_version: 1.17.3 # kernelspec: # display_name: .venv # language: python From 9190024edece6a07130b1fdb187b471053ae16f3 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 2 Nov 2025 10:53:11 -0700 Subject: [PATCH 06/19] changed Tuple to tuple --- gpjax/gps.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index f524646c1..d75a5bec8 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -15,10 +15,7 @@ # from __future__ import annotations from abc import abstractmethod -from typing import ( - Literal, - Tuple, -) +from typing import Literal import beartype.typing as tp from flax import nnx @@ -289,7 +286,7 @@ def predict( def _ret_full_cov( t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], LinearOperator]: + ) -> tuple[Float[Array, " N"], LinearOperator]: mean_at_test = self.mean_function(t) Kxx = self.kernel.gram(t) Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) @@ -298,7 +295,7 @@ def _ret_full_cov( def _ret_diag_cov( t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], LinearOperator]: + ) -> tuple[Float[Array, " N"], LinearOperator]: mean_at_test = self.mean_function(t) Kxx = self.kernel.diagonal(t).diagonal Kxx += self.jitter @@ -583,7 +580,7 @@ def _ret_full_cov( x: Num[Array, "N D"], y: Num[Array, "N Q"], t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], LinearOperator]: + ) -> tuple[Float[Array, " N"], LinearOperator]: # Observation noise o² obs_noise = jnp.square(self.likelihood.obs_stddev.value) mx = self.prior.mean_function(x) @@ -615,7 +612,7 @@ def _ret_diag_cov( x: Num[Array, "N D"], y: Num[Array, "N Q"], t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], LinearOperator]: + ) -> tuple[Float[Array, " N"], LinearOperator]: # Observation noise o² obs_noise = jnp.square(self.likelihood.obs_stddev.value) mx = self.prior.mean_function(x) @@ -813,7 +810,7 @@ def predict( def _ret_full_cov( x: Num[Array, "N D"], t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], Dense]: + ) -> tuple[Float[Array, " N"], Dense]: mean_function = self.prior.mean_function kernel = self.prior.kernel @@ -849,7 +846,7 @@ def _ret_full_cov( def _ret_diag_cov( x: Num[Array, "N D"], t: Num[Array, "N D"], - ) -> Tuple[Float[Array, " N"], Dense]: + ) -> tuple[Float[Array, " N"], Dense]: mean_function = self.prior.mean_function kernel = self.prior.kernel From 2ca8c2da80d71a883ed5cdcf5471bddde2af810a Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 2 Nov 2025 10:59:51 -0700 Subject: [PATCH 07/19] changed return_cov_type to return_covariance_type --- gpjax/gps.py | 39 +++++++++++++++++++-------------------- tests/test_gps.py | 6 +++--- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index d75a5bec8..7d6f9585f 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -15,7 +15,6 @@ # from __future__ import annotations from abc import abstractmethod -from typing import Literal import beartype.typing as tp from flax import nnx @@ -88,7 +87,7 @@ def __call__( self, test_inputs: Num[Array, "N D"], *, - return_cov_type: Literal["dense", "diagonal"] = "dense", + return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Evaluate the Gaussian process at the given points. @@ -103,7 +102,7 @@ def __call__( Args: test_inputs: Input locations where the GP should be evaluated. - return_cov_type: Literal denoting whether to return the full covariance + return_covariance_type: tp.Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -114,7 +113,7 @@ def __call__( """ return self.predict( test_inputs, - return_cov_type=return_cov_type, + return_covariance_type=return_covariance_type, ) @abstractmethod @@ -122,7 +121,7 @@ def predict( self, test_inputs: Num[Array, "N D"], *, - return_cov_type: Literal["dense", "diagonal"] = "dense", + return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Evaluate the predictive distribution. @@ -132,7 +131,7 @@ def predict( Args: test_inputs: Input locations where the GP should be evaluated. - return_cov_type: Literal denoting whether to return the full covariance + return_covariance_type: tp.Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -252,7 +251,7 @@ def predict( self, test_inputs: Num[Array, "N D"], *, - return_cov_type: Literal["dense", "diagonal"] = "dense", + return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes @@ -274,7 +273,7 @@ def predict( Args: test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the prior distribution. - return_cov_type: Literal denoting whether to return the full covariance + return_covariance_type: tp.Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -303,7 +302,7 @@ def _ret_diag_cov( return jnp.atleast_1d(mean_at_test.squeeze()), Kxx mu, cov = jax.lax.cond( - return_cov_type == "dense", + return_covariance_type == "dense", _ret_full_cov, _ret_diag_cov, test_inputs, @@ -418,7 +417,7 @@ def __call__( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_cov_type: Literal["dense", "diagonal"] = "dense", + return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Evaluate the Gaussian process posterior at the given points. @@ -434,7 +433,7 @@ def __call__( Args: test_inputs: Input locations where the GP should be evaluated. train_data: Training dataset to condition on. - return_cov_type: Literal denoting whether to return the full covariance + return_covariance_type: tp.Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -446,7 +445,7 @@ def __call__( return self.predict( test_inputs, train_data, - return_cov_type=return_cov_type, + return_covariance_type=return_covariance_type, ) @abstractmethod @@ -455,7 +454,7 @@ def predict( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_cov_type: Literal["dense", "diagonal"] = "dense", + return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Compute the latent function's multivariate normal distribution for a given set of parameters. For any class inheriting the `AbstractPosterior` class, @@ -464,7 +463,7 @@ def predict( Args: test_inputs: Input locations where the GP should be evaluated. train_data: Training dataset to condition on. - return_cov_type: Literal denoting whether to return the full covariance + return_covariance_type: tp.Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -523,7 +522,7 @@ def predict( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_cov_type: Literal["dense", "diagonal"] = "dense", + return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Query the predictive posterior distribution. @@ -566,7 +565,7 @@ def predict( predictive distribution is evaluated. train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - return_cov_type: Literal denoting whether to return the full covariance + return_covariance_type: tp.Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -643,7 +642,7 @@ def _ret_diag_cov( return mean, covariance mu, cov = jax.lax.cond( - return_cov_type == "dense", + return_covariance_type == "dense", _ret_full_cov, _ret_diag_cov, train_data.X, @@ -780,7 +779,7 @@ def predict( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_cov_type: Literal["dense", "diagonal"] = "dense", + return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Query the predictive posterior distribution. @@ -796,7 +795,7 @@ def predict( predictive distribution is evaluated. train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - return_cov_type: Literal denoting whether to return the full covariance + return_covariance_type: tp.Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -885,7 +884,7 @@ def _ret_diag_cov( return mean, covariance mu, cov = jax.lax.cond( - return_cov_type == "dense", + return_covariance_type == "dense", _ret_full_cov, _ret_diag_cov, train_data.X, diff --git a/tests/test_gps.py b/tests/test_gps.py index 4925a681c..380a0cebe 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -92,7 +92,7 @@ def test_prior_with_diag( # Query a marginal distribution at some inputs. inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - marginal_distribution = prior(inputs, return_cov_type="diagonal") + marginal_distribution = prior(inputs, return_covariance_type="diagonal") # Ensure that the marginal distribution is a Gaussian. assert isinstance(marginal_distribution, GaussianDistribution) @@ -164,7 +164,7 @@ def test_conjugate_posterior_with_diag( # Query a marginal distribution of the posterior at some inputs. inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - marginal_distribution = posterior(inputs, D, return_cov_type="diagonal") + marginal_distribution = posterior(inputs, D, return_covariance_type="diagonal") # Ensure that the marginal distribution is a Gaussian. assert isinstance(marginal_distribution, GaussianDistribution) @@ -251,7 +251,7 @@ def test_nonconjugate_posterior_with_diag( # Query a marginal distribution of the posterior at some inputs. inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - marginal_distribution = posterior(inputs, D, return_cov_type="diagonal") + marginal_distribution = posterior(inputs, D, return_covariance_type="diagonal") # Ensure that the marginal distribution is a Gaussian. assert isinstance(marginal_distribution, GaussianDistribution) From d9ed2d8229c6190001a9a0201dc434fb3138f6ef Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 2 Nov 2025 11:04:46 -0700 Subject: [PATCH 08/19] rescoped prior lax cond and renamed functions --- gpjax/gps.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 7d6f9585f..69260ad26 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -283,32 +283,31 @@ def predict( of the Gaussian process. """ - def _ret_full_cov( + def _return_full_covariance( t: Num[Array, "N D"], - ) -> tuple[Float[Array, " N"], LinearOperator]: - mean_at_test = self.mean_function(t) + ) -> LinearOperator: Kxx = self.kernel.gram(t) Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) Kxx = psd(Dense(Kxx_dense)) - return jnp.atleast_1d(mean_at_test.squeeze()), Kxx + return Kxx - def _ret_diag_cov( + def _return_diagonal_covariance( t: Num[Array, "N D"], - ) -> tuple[Float[Array, " N"], LinearOperator]: - mean_at_test = self.mean_function(t) + ) -> LinearOperator: Kxx = self.kernel.diagonal(t).diagonal Kxx += self.jitter Kxx = psd(Dense(Diagonal(Kxx).to_dense())) return jnp.atleast_1d(mean_at_test.squeeze()), Kxx - mu, cov = jax.lax.cond( + mean_at_test = self.mean_function(t) + cov = jax.lax.cond( return_covariance_type == "dense", - _ret_full_cov, - _ret_diag_cov, + _return_full_covariance, + _return_diagonal_covariance, test_inputs, ) - return GaussianDistribution(loc=mu, scale=cov) + return GaussianDistribution(loc=jnp.atleast_1d(mean_at_test.squeeze()), scale=cov) def sample_approx( self, @@ -575,7 +574,7 @@ def predict( returns the predictive distribution as a `GaussianDistribution`. """ - def _ret_full_cov( + def _return_mean_and_full_covariance( x: Num[Array, "N D"], y: Num[Array, "N Q"], t: Num[Array, "N D"], @@ -607,7 +606,7 @@ def _ret_full_cov( covariance = psd(Dense(covariance)) return mean, covariance - def _ret_diag_cov( + def _return_mean_and_diagonal_covariance( x: Num[Array, "N D"], y: Num[Array, "N Q"], t: Num[Array, "N D"], @@ -643,8 +642,8 @@ def _ret_diag_cov( mu, cov = jax.lax.cond( return_covariance_type == "dense", - _ret_full_cov, - _ret_diag_cov, + _return_mean_and_full_covariance, + _return_mean_and_diagonal_covariance, train_data.X, train_data.y, test_inputs, @@ -806,7 +805,7 @@ def predict( a `dx.Distribution`. """ - def _ret_full_cov( + def _return_mean_and_full_covariance( x: Num[Array, "N D"], t: Num[Array, "N D"], ) -> tuple[Float[Array, " N"], Dense]: @@ -842,7 +841,7 @@ def _ret_full_cov( return mean, covariance - def _ret_diag_cov( + def _return_mean_and_diagonal_covariance( x: Num[Array, "N D"], t: Num[Array, "N D"], ) -> tuple[Float[Array, " N"], Dense]: @@ -885,7 +884,7 @@ def _ret_diag_cov( mu, cov = jax.lax.cond( return_covariance_type == "dense", - _ret_full_cov, + _return_mean_and_diagonal_covariance, _ret_diag_cov, train_data.X, test_inputs, From ff96020c648029a50c34513ae6aa28ad4af15a47 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 2 Nov 2025 11:09:41 -0700 Subject: [PATCH 09/19] revert typing Literal as beartype doesnt support it --- gpjax/gps.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 69260ad26..bf5c60bc5 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -15,6 +15,7 @@ # from __future__ import annotations from abc import abstractmethod +from typing import Literal import beartype.typing as tp from flax import nnx @@ -87,7 +88,7 @@ def __call__( self, test_inputs: Num[Array, "N D"], *, - return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", + return_covariance_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Evaluate the Gaussian process at the given points. @@ -102,7 +103,7 @@ def __call__( Args: test_inputs: Input locations where the GP should be evaluated. - return_covariance_type: tp.Literal denoting whether to return the full covariance + return_covariance_type: Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -121,7 +122,7 @@ def predict( self, test_inputs: Num[Array, "N D"], *, - return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", + return_covariance_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Evaluate the predictive distribution. @@ -131,7 +132,7 @@ def predict( Args: test_inputs: Input locations where the GP should be evaluated. - return_covariance_type: tp.Literal denoting whether to return the full covariance + return_covariance_type: Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -251,7 +252,7 @@ def predict( self, test_inputs: Num[Array, "N D"], *, - return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", + return_covariance_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes @@ -273,7 +274,7 @@ def predict( Args: test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the prior distribution. - return_covariance_type: tp.Literal denoting whether to return the full covariance + return_covariance_type: Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -416,7 +417,7 @@ def __call__( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", + return_covariance_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Evaluate the Gaussian process posterior at the given points. @@ -432,7 +433,7 @@ def __call__( Args: test_inputs: Input locations where the GP should be evaluated. train_data: Training dataset to condition on. - return_covariance_type: tp.Literal denoting whether to return the full covariance + return_covariance_type: Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -453,7 +454,7 @@ def predict( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", + return_covariance_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Compute the latent function's multivariate normal distribution for a given set of parameters. For any class inheriting the `AbstractPosterior` class, @@ -462,7 +463,7 @@ def predict( Args: test_inputs: Input locations where the GP should be evaluated. train_data: Training dataset to condition on. - return_covariance_type: tp.Literal denoting whether to return the full covariance + return_covariance_type: Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -521,7 +522,7 @@ def predict( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", + return_covariance_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Query the predictive posterior distribution. @@ -564,7 +565,7 @@ def predict( predictive distribution is evaluated. train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - return_covariance_type: tp.Literal denoting whether to return the full covariance + return_covariance_type: Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. @@ -778,7 +779,7 @@ def predict( test_inputs: Num[Array, "N D"], train_data: Dataset, *, - return_covariance_type: tp.Literal["dense", "diagonal"] = "dense", + return_covariance_type: Literal["dense", "diagonal"] = "dense", ) -> GaussianDistribution: r"""Query the predictive posterior distribution. @@ -794,7 +795,7 @@ def predict( predictive distribution is evaluated. train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - return_covariance_type: tp.Literal denoting whether to return the full covariance + return_covariance_type: Literal denoting whether to return the full covariance of the joint predictive distribution at the test_inputs (dense) or just the the standard-deviation of the predictive distribution at the test_inputs. From ee3f8c2bca82c00b840ba3d47765be3965c95181 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 2 Nov 2025 11:10:26 -0700 Subject: [PATCH 10/19] ran poe format --- gpjax/gps.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index bf5c60bc5..2b7737500 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -308,7 +308,9 @@ def _return_diagonal_covariance( test_inputs, ) - return GaussianDistribution(loc=jnp.atleast_1d(mean_at_test.squeeze()), scale=cov) + return GaussianDistribution( + loc=jnp.atleast_1d(mean_at_test.squeeze()), scale=cov + ) def sample_approx( self, From 66bc54ac1911bf79722a7190e7ca01da2d74cfeb Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 2 Nov 2025 11:13:46 -0700 Subject: [PATCH 11/19] fixed name of return function in nonconjugate posterior --- gpjax/gps.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 2b7737500..2e046c923 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -286,7 +286,7 @@ def predict( def _return_full_covariance( t: Num[Array, "N D"], - ) -> LinearOperator: + ) -> Dense: Kxx = self.kernel.gram(t) Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) Kxx = psd(Dense(Kxx_dense)) @@ -294,13 +294,13 @@ def _return_full_covariance( def _return_diagonal_covariance( t: Num[Array, "N D"], - ) -> LinearOperator: + ) -> Dense: Kxx = self.kernel.diagonal(t).diagonal Kxx += self.jitter Kxx = psd(Dense(Diagonal(Kxx).to_dense())) - return jnp.atleast_1d(mean_at_test.squeeze()), Kxx + return Kxx - mean_at_test = self.mean_function(t) + mean_at_test = self.mean_function(test_inputs) cov = jax.lax.cond( return_covariance_type == "dense", _return_full_covariance, @@ -887,8 +887,8 @@ def _return_mean_and_diagonal_covariance( mu, cov = jax.lax.cond( return_covariance_type == "dense", + _return_mean_and_full_covariance, _return_mean_and_diagonal_covariance, - _ret_diag_cov, train_data.X, test_inputs, ) From 1d784430dbe823a8dfacd638c9fa6ed026138a3a Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 3 Nov 2025 07:48:31 -0700 Subject: [PATCH 12/19] reformatted docstring args --- gpjax/gps.py | 176 +++++++++++++++++++++++++-------------------------- 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 2e046c923..118a3ac78 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -77,8 +77,8 @@ def __init__( r"""Construct a Gaussian process prior. Args: - kernel: kernel object inheriting from AbstractKernel. - mean_function: mean function object inheriting from AbstractMeanFunction. + kernel: kernel object inheriting from AbstractKernel. + mean_function: mean function object inheriting from AbstractMeanFunction. """ self.kernel = kernel self.mean_function = mean_function @@ -100,17 +100,17 @@ def __call__( Under the hood, `__call__` is calling the objects `predict` method. For this reasons, classes inheriting the `AbstractPrior` class, should not overwrite the `__call__` method and should instead define a `predict` method. - + Args: - test_inputs: Input locations where the GP should be evaluated. - return_covariance_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. + test_inputs: Input locations where the GP should be evaluated. + return_covariance_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ return self.predict( test_inputs, @@ -131,15 +131,15 @@ def predict( this method must be implemented. Args: - test_inputs: Input locations where the GP should be evaluated. - return_covariance_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. + test_inputs: Input locations where the GP should be evaluated. + return_covariance_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ raise NotImplementedError @@ -207,12 +207,12 @@ def __mul__(self, other): # noqa: F811 >>> prior * likelihood ``` Args: - other (Likelihood): The likelihood distribution of the observed dataset. + other (Likelihood): The likelihood distribution of the observed dataset. Returns - Posterior: The relevant GP posterior for the given prior and - likelihood. Special cases are accounted for where the model - is conjugate. + Posterior: The relevant GP posterior for the given prior and + likelihood. Special cases are accounted for where the model + is conjugate. """ return construct_posterior(prior=self, likelihood=other) @@ -238,13 +238,13 @@ def __rmul__(self, other): # noqa: F811 product of a likelihood and a prior i.e., likelihood * prior. Args: - other (Likelihood): The likelihood distribution of the observed - dataset. + other (Likelihood): The likelihood distribution of the observed + dataset. Returns - Posterior: The relevant GP posterior for the given prior and - likelihood. Special cases are accounted for where the model - is conjugate. + Posterior: The relevant GP posterior for the given prior and + likelihood. Special cases are accounted for where the model + is conjugate. """ return self.__mul__(other) @@ -272,16 +272,16 @@ def predict( ``` Args: - test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the - prior distribution. - return_covariance_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. + test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the + prior distribution. + return_covariance_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ def _return_full_covariance( @@ -358,14 +358,14 @@ def sample_approx( ``` Args: - num_samples (int): The desired number of samples. - key (KeyArray): The random seed used for the sample(s). - num_features (int): The number of features used when approximating the - kernel. + num_samples (int): The desired number of samples. + key (KeyArray): The random seed used for the sample(s). + num_features (int): The number of features used when approximating the + kernel. Returns: - FunctionalSample: A function representing an approximate sample from the - Gaussian process prior. + FunctionalSample: A function representing an approximate sample from the + Gaussian process prior. """ if (not isinstance(num_samples, int)) or num_samples <= 0: @@ -405,10 +405,10 @@ def __init__( r"""Construct a Gaussian process posterior. Args: - prior (AbstractPrior): The prior distribution. - likelihood (AbstractLikelihood): The likelihood distribution. - jitter (float): A small constant added to the diagonal of the - covariance matrix to ensure numerical stability. + prior (AbstractPrior): The prior distribution. + likelihood (AbstractLikelihood): The likelihood distribution. + jitter (float): A small constant added to the diagonal of the + covariance matrix to ensure numerical stability. """ self.prior = prior self.likelihood = likelihood @@ -433,16 +433,16 @@ def __call__( `__call__` method and should instead define a `predict` method. Args: - test_inputs: Input locations where the GP should be evaluated. - train_data: Training dataset to condition on. - return_covariance_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. + test_inputs: Input locations where the GP should be evaluated. + train_data: Training dataset to condition on. + return_covariance_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ return self.predict( test_inputs, @@ -463,16 +463,16 @@ def predict( this method must be implemented. Args: - test_inputs: Input locations where the GP should be evaluated. - train_data: Training dataset to condition on. - return_covariance_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. + test_inputs: Input locations where the GP should be evaluated. + train_data: Training dataset to condition on. + return_covariance_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A multivariate normal random variable representation - of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ raise NotImplementedError @@ -563,18 +563,18 @@ def predict( ``` Args: - test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the - predictive distribution is evaluated. - train_data (Dataset): A `gpx.Dataset` object that contains the input and - output data used for training dataset. - return_covariance_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. + test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the + predictive distribution is evaluated. + train_data (Dataset): A `gpx.Dataset` object that contains the input and + output data used for training dataset. + return_covariance_type: Literal denoting whether to return the full covariance + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A function that accepts an input array and - returns the predictive distribution as a `GaussianDistribution`. + GaussianDistribution: A function that accepts an input array and + returns the predictive distribution as a `GaussianDistribution`. """ def _return_mean_and_full_covariance( @@ -689,14 +689,14 @@ def sample_approx( can be evaluated with constant cost regardless of the required number of queries. Args: - num_samples (int): The desired number of samples. - key (KeyArray): The random seed used for the sample(s). - num_features (int): The number of features used when approximating the - kernel. + num_samples (int): The desired number of samples. + key (KeyArray): The random seed used for the sample(s). + num_features (int): The number of features used when approximating the + kernel. Returns: - FunctionalSample: A function representing an approximate sample from the Gaussian - process prior. + FunctionalSample: A function representing an approximate sample from the Gaussian + process prior. """ if (not isinstance(num_samples, int)) or num_samples <= 0: raise ValueError("num_samples must be a positive integer") @@ -762,10 +762,10 @@ def __init__( r"""Construct a non-conjugate Gaussian process posterior. Args: - prior (AbstractPrior): The prior distribution. - likelihood (AbstractLikelihood): The likelihood distribution. - jitter (float): A small constant added to the diagonal of the - covariance matrix to ensure numerical stability. + prior (AbstractPrior): The prior distribution. + likelihood (AbstractLikelihood): The likelihood distribution. + jitter (float): A small constant added to the diagonal of the + covariance matrix to ensure numerical stability. """ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter) @@ -794,18 +794,18 @@ def predict( Args: test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the - predictive distribution is evaluated. + predictive distribution is evaluated. train_data (Dataset): A `gpx.Dataset` object that contains the input - and output data used for training dataset. + and output data used for training dataset. return_covariance_type: Literal denoting whether to return the full covariance - of the joint predictive distribution at the test_inputs (dense) - or just the the standard-deviation of the predictive distribution at - the test_inputs. + of the joint predictive distribution at the test_inputs (dense) + or just the the standard-deviation of the predictive distribution at + the test_inputs. Returns: - GaussianDistribution: A function that accepts an - input array and returns the predictive distribution as - a `dx.Distribution`. + GaussianDistribution: A function that accepts an + input array and returns the predictive distribution as + a `dx.Distribution`. """ def _return_mean_and_full_covariance( From 77094027982167ae440847206766a75ce0c7b072 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 3 Nov 2025 07:53:17 -0700 Subject: [PATCH 13/19] fixed linting error --- gpjax/gps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 118a3ac78..78c747146 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -100,7 +100,7 @@ def __call__( Under the hood, `__call__` is calling the objects `predict` method. For this reasons, classes inheriting the `AbstractPrior` class, should not overwrite the `__call__` method and should instead define a `predict` method. - + Args: test_inputs: Input locations where the GP should be evaluated. return_covariance_type: Literal denoting whether to return the full covariance From 333c1ccfa0ba842a895eea478643fe86b6faf8c7 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Tue, 4 Nov 2025 08:02:55 -0700 Subject: [PATCH 14/19] refactoring --- gpjax/gps.py | 45 +++++++++++++++++++--------------------- tests/test_gps.py | 53 ++++++++++++++++++++++++++++++----------------- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 78c747146..46e9d55f3 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -521,7 +521,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]): def predict( self, - test_inputs: Num[Array, "N D"], + test_inputs: Num[Array, "M D"], train_data: Dataset, *, return_covariance_type: Literal["dense", "diagonal"] = "dense", @@ -576,15 +576,17 @@ def predict( GaussianDistribution: A function that accepts an input array and returns the predictive distribution as a `GaussianDistribution`. """ + # Observation noise o² + obs_noise = jnp.square(self.likelihood.obs_stddev.value) + mx = self.prior.mean_function(x) + Kxt = self.prior.kernel.cross_covariance(x, t) + mean_t = self.prior.mean_function(t) - def _return_mean_and_full_covariance( + def _return_full_covariance( x: Num[Array, "N D"], y: Num[Array, "N Q"], - t: Num[Array, "N D"], - ) -> tuple[Float[Array, " N"], LinearOperator]: - # Observation noise o² - obs_noise = jnp.square(self.likelihood.obs_stddev.value) - mx = self.prior.mean_function(x) + t: Num[Array, "M D"], + ) -> Dense: # Precompute Gram matrix, Kxx, at training inputs, x Kxx = self.prior.kernel.gram(x) @@ -594,9 +596,9 @@ def _return_mean_and_full_covariance( Sigma = psd(Dense(Sigma_dense)) L_sigma = lower_cholesky(Sigma) - mean_t = self.prior.mean_function(t) + Ktt = self.prior.kernel.gram(t) - Kxt = self.prior.kernel.cross_covariance(x, t) + L_inv_Kxt = solve(L_sigma, Kxt) L_inv_y_diff = solve(L_sigma, y - mx) @@ -609,14 +611,12 @@ def _return_mean_and_full_covariance( covariance = psd(Dense(covariance)) return mean, covariance - def _return_mean_and_diagonal_covariance( + def _return_diagonal_covariance( x: Num[Array, "N D"], y: Num[Array, "N Q"], - t: Num[Array, "N D"], - ) -> tuple[Float[Array, " N"], LinearOperator]: + t: Num[Array, "M D"], + ) -> Dense: # Observation noise o² - obs_noise = jnp.square(self.likelihood.obs_stddev.value) - mx = self.prior.mean_function(x) # Precompute Gram matrix, Kxx, at training inputs, x Kxx = self.prior.kernel.diagonal(x).diagonal @@ -627,18 +627,15 @@ def _return_mean_and_diagonal_covariance( L_sigma = lower_cholesky(Sigma) mean_t = self.prior.mean_function(t) - Ktt = self.prior.kernel.diagonal(t).diagonal[:, jnp.newaxis] - Kxt = self.prior.kernel.cross_covariance(x, t) - - # TODO: The following are all diagonal solves, so we can just - # do vector addition as needed. We should furthermore return - # a Diagonal covariance and not a Dense. - L_inv_Kxt_diag = jnp.diag(solve(L_sigma, Kxt))[:, jnp.newaxis] - L_inv_y_diff_diag = jnp.diag(solve(L_sigma, y - mx))[:, jnp.newaxis] + Ktt = self.prior.kernel.diagonal(t).diagonal + + L_inv_Kxt = solve(L_sigma, Kxt) + L_inv_y_diff = solve(L_sigma, y - mx) - mean = mean_t + L_inv_Kxt_diag * L_inv_y_diff_diag + mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) mean = jnp.atleast_1d(mean.squeeze()) - covariance = Ktt - jnp.square(L_inv_Kxt_diag) + + covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt) covariance += self.prior.jitter covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) return mean, covariance diff --git a/tests/test_gps.py b/tests/test_gps.py index 380a0cebe..2d91809d8 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -92,18 +92,22 @@ def test_prior_with_diag( # Query a marginal distribution at some inputs. inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - marginal_distribution = prior(inputs, return_covariance_type="diagonal") + marginal_distribution_diag = prior(inputs, return_covariance_type="diagonal") + marginal_distribution_full = prior(inputs, return_covariance_type="dense") # Ensure that the marginal distribution is a Gaussian. - assert isinstance(marginal_distribution, GaussianDistribution) - assert isinstance(marginal_distribution, NumpyroDistribution) + assert isinstance(marginal_distribution_diag, GaussianDistribution) + assert isinstance(marginal_distribution_diag, NumpyroDistribution) # Ensure that the marginal distribution has the correct shape. - mu = marginal_distribution.mean - sigma = marginal_distribution.covariance() + mu = marginal_distribution_diag.mean + sigma = marginal_distribution_diag.covariance() assert mu.shape == (num_datapoints,) assert sigma.shape == (num_datapoints, num_datapoints) + # test that off diagonal elements are zero assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) + # test that we return exactly the diagonal of the full covariance + assert jnp.all(jnp.diag(sigma)==jnp.diag(marginal_distribution_full.covariance())) @pytest.mark.parametrize("num_datapoints", [1, 10]) @@ -137,10 +141,12 @@ def test_prior( @pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("num_test_datapoints", [1, 10, 200]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) def test_conjugate_posterior_with_diag( num_datapoints: int, + num_test_datapoints: int, kernel: type[AbstractKernel], mean_function: type[AbstractMeanFunction], ) -> None: @@ -163,26 +169,32 @@ def test_conjugate_posterior_with_diag( assert isinstance(posterior, ConjugatePosterior) # Query a marginal distribution of the posterior at some inputs. - inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - marginal_distribution = posterior(inputs, D, return_covariance_type="diagonal") + inputs = jnp.linspace(-3.0, 3.0, num_test_datapoints).reshape(-1, 1) + marginal_distribution_diag = posterior(inputs, D, return_covariance_type="diagonal") + marginal_distribution_full = posterior(inputs, D, return_covariance_type="dense") # Ensure that the marginal distribution is a Gaussian. - assert isinstance(marginal_distribution, GaussianDistribution) - assert isinstance(marginal_distribution, NumpyroDistribution) + assert isinstance(marginal_distribution_diag, GaussianDistribution) + assert isinstance(marginal_distribution_diag, NumpyroDistribution) # Ensure that the marginal distribution has the correct shape. - mu = marginal_distribution.mean - sigma = marginal_distribution.covariance() - assert mu.shape == (num_datapoints,) - assert sigma.shape == (num_datapoints, num_datapoints) + mu = marginal_distribution_diag.mean + sigma = marginal_distribution_diag.covariance() + assert mu.shape == (num_test_datapoints,) + assert sigma.shape == (num_test_datapoints, num_test_datapoints) + # test that off diagonal elements are zero assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) + # test that we return exactly the diagonal of the full covariance + assert jnp.all(jnp.diag(sigma)==jnp.diag(marginal_distribution_full.covariance())) @pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("num_test_datapoints", [1, 10, 200]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) def test_conjugate_posterior( num_datapoints: int, + num_test_datapoints: int, kernel: type[AbstractKernel], mean_function: type[AbstractMeanFunction], ) -> None: @@ -205,7 +217,7 @@ def test_conjugate_posterior( assert isinstance(posterior, ConjugatePosterior) # Query a marginal distribution of the posterior at some inputs. - inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + inputs = jnp.linspace(-3.0, 3.0, num_test_datapoints).reshape(-1, 1) marginal_distribution = posterior(inputs, D) # Ensure that the marginal distribution is a Gaussian. @@ -215,15 +227,17 @@ def test_conjugate_posterior( # Ensure that the marginal distribution has the correct shape. mu = marginal_distribution.mean sigma = marginal_distribution.covariance() - assert mu.shape == (num_datapoints,) - assert sigma.shape == (num_datapoints, num_datapoints) + assert mu.shape == (num_test_datapoints,) + assert sigma.shape == (num__test_datapoints, num_test_datapoints) @pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("num_test_datapoints", [1, 10, 200]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) def test_nonconjugate_posterior_with_diag( num_datapoints: int, + num_test_datapoints: int, kernel: type[AbstractKernel], mean_function: type[AbstractMeanFunction], ) -> None: @@ -250,7 +264,7 @@ def test_nonconjugate_posterior_with_diag( assert (posterior.latent.value == latent_values).all() # Query a marginal distribution of the posterior at some inputs. - inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + inputs = jnp.linspace(-3.0, 3.0, num_test_datapoints).reshape(-1, 1) marginal_distribution = posterior(inputs, D, return_covariance_type="diagonal") # Ensure that the marginal distribution is a Gaussian. @@ -260,10 +274,10 @@ def test_nonconjugate_posterior_with_diag( # Ensure that the marginal distribution has the correct shape. mu = marginal_distribution.mean sigma = marginal_distribution.covariance() - assert mu.shape == (num_datapoints,) + assert mu.shape == (num_test_datapoints,) # We are still returning a full covariance, even though the off diagonal # should all be zeros... - assert sigma.shape == (num_datapoints, num_datapoints) + assert sigma.shape == (num_test_datapoints, num_test_datapoints) assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) @@ -272,6 +286,7 @@ def test_nonconjugate_posterior_with_diag( @pytest.mark.parametrize("mean_function", [Zero, Constant]) def test_nonconjugate_posterior( num_datapoints: int, + num_test_datapoints: int, kernel: type[AbstractKernel], mean_function: type[AbstractMeanFunction], ) -> None: From 66b112a2cf8adf057b528af177db949a32fdf3ce Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Wed, 5 Nov 2025 08:00:13 -0700 Subject: [PATCH 15/19] refactored and added tests --- gpjax/gps.py | 178 ++++++++++++++++------------------------------ tests/test_gps.py | 32 ++++++--- 2 files changed, 83 insertions(+), 127 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 46e9d55f3..34a86e5d5 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -43,7 +43,6 @@ solve, ) from gpjax.linalg.operations import ( - LinearOperator, lower_cholesky, ) from gpjax.linalg.utils import add_jitter @@ -576,80 +575,56 @@ def predict( GaussianDistribution: A function that accepts an input array and returns the predictive distribution as a `GaussianDistribution`. """ + x = train_data.X + y = train_data.y # Observation noise o² obs_noise = jnp.square(self.likelihood.obs_stddev.value) mx = self.prior.mean_function(x) - Kxt = self.prior.kernel.cross_covariance(x, t) - mean_t = self.prior.mean_function(t) + # Precompute Gram matrix, Kxx, at training inputs, x + Kxx = self.prior.kernel.gram(x) + Kxx = add_jitter(Kxx.to_dense(), self.jitter) - def _return_full_covariance( - x: Num[Array, "N D"], - y: Num[Array, "N Q"], - t: Num[Array, "M D"], - ) -> Dense: - - # Precompute Gram matrix, Kxx, at training inputs, x - Kxx = self.prior.kernel.gram(x) - Kxx = add_jitter(Kxx.to_dense(), self.jitter) - - Sigma_dense = Kxx + jnp.eye(Kxx.shape[0]) * obs_noise - Sigma = psd(Dense(Sigma_dense)) - L_sigma = lower_cholesky(Sigma) + Sigma_dense = Kxx + jnp.eye(Kxx.shape[0]) * obs_noise + Sigma = psd(Dense(Sigma_dense)) + L_sigma = lower_cholesky(Sigma) - - Ktt = self.prior.kernel.gram(t) - + Kxt = self.prior.kernel.cross_covariance(x, test_inputs) - L_inv_Kxt = solve(L_sigma, Kxt) - L_inv_y_diff = solve(L_sigma, y - mx) + L_inv_Kxt = solve(L_sigma, Kxt) + L_inv_y_diff = solve(L_sigma, y - mx) - mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) - mean = jnp.atleast_1d(mean.squeeze()) + mean_t = self.prior.mean_function(test_inputs) + mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) + def _return_full_covariance( + L_inv_Kxt: Num[Array, "N M"], + t: Num[Array, "M D"], + ) -> Dense: + Ktt = self.prior.kernel.gram(t) covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) covariance = add_jitter(covariance, self.prior.jitter) covariance = psd(Dense(covariance)) - return mean, covariance + return covariance def _return_diagonal_covariance( - x: Num[Array, "N D"], - y: Num[Array, "N Q"], + L_inv_Kxt: Num[Array, "N M"], t: Num[Array, "M D"], ) -> Dense: - # Observation noise o² - - # Precompute Gram matrix, Kxx, at training inputs, x - Kxx = self.prior.kernel.diagonal(x).diagonal - Kxx += self.jitter - - Sigma_dense = Kxx + obs_noise - Sigma = psd(Diagonal(Sigma_dense)) - L_sigma = lower_cholesky(Sigma) - - mean_t = self.prior.mean_function(t) Ktt = self.prior.kernel.diagonal(t).diagonal - - L_inv_Kxt = solve(L_sigma, Kxt) - L_inv_y_diff = solve(L_sigma, y - mx) - - mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) - mean = jnp.atleast_1d(mean.squeeze()) - covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt) covariance += self.prior.jitter covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) - return mean, covariance + return covariance - mu, cov = jax.lax.cond( + cov = jax.lax.cond( return_covariance_type == "dense", - _return_mean_and_full_covariance, - _return_mean_and_diagonal_covariance, - train_data.X, - train_data.y, + _return_full_covariance, + _return_diagonal_covariance, + L_inv_Kxt, test_inputs, ) - return GaussianDistribution(loc=mu, scale=cov) + return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov) def sample_approx( self, @@ -775,7 +750,7 @@ def __init__( def predict( self, - test_inputs: Num[Array, "N D"], + test_inputs: Num[Array, "M D"], train_data: Dataset, *, return_covariance_type: Literal["dense", "diagonal"] = "dense", @@ -804,74 +779,45 @@ def predict( input array and returns the predictive distribution as a `dx.Distribution`. """ + x = train_data.X + t = test_inputs + mean_function = self.prior.mean_function + kernel = self.prior.kernel - def _return_mean_and_full_covariance( - x: Num[Array, "N D"], - t: Num[Array, "N D"], - ) -> tuple[Float[Array, " N"], Dense]: - mean_function = self.prior.mean_function - kernel = self.prior.kernel + # Precompute lower triangular of Gram matrix + Kxx = kernel.gram(x) + Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter) + Kxx = psd(Dense(Kxx_dense)) + Lx = lower_cholesky(Kxx) - # Precompute lower triangular of Gram matrix - Kxx = kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter) - Kxx = psd(Dense(Kxx_dense)) - Lx = lower_cholesky(Kxx) + Kxt = kernel.cross_covariance(x, t) + # Lx⁻¹ Kxt + Lx_inv_Kxt = solve(Lx, Kxt) - # Compute terms of the posterior predictive distribution - Ktx = kernel.cross_covariance(t, x) - Ktt = kernel.gram(t) - mean_t = mean_function(t) + mean_t = mean_function(t) + # Whitened function values, wx, corresponding to the inputs, x + wx = self.latent.value - # Lx⁻¹ Kxt - Lx_inv_Kxt = solve(Lx, Ktx.T) + # μt + Ktx Lx⁻¹ wx + mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx) - # Whitened function values, wx, corresponding to the inputs, x - wx = self.latent.value - - # μt + Ktx Lx⁻¹ wx - mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx) - mean = jnp.atleast_1d(mean.squeeze()) - - # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure - # to compute Schur complement more efficiently. + def _return_full_covariance( + Lx_inv_Kxt: Num[Array, "N M"], + t: Num[Array, "M D"], + ) -> Dense: + Ktt = kernel.gram(t) covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) covariance = add_jitter(covariance, self.prior.jitter) covariance = psd(Dense(covariance)) - return mean, covariance + return covariance - def _return_mean_and_diagonal_covariance( - x: Num[Array, "N D"], - t: Num[Array, "N D"], - ) -> tuple[Float[Array, " N"], Dense]: - mean_function = self.prior.mean_function - kernel = self.prior.kernel - - # Precompute lower triangular of Gram matrix - Kxx = kernel.diagonal(x).diagonal - Kxx += self.prior.jitter - Kxx = psd(Diagonal(Kxx)) - Lx = lower_cholesky(Kxx) - - # Compute terms of the posterior predictive distribution - Ktx = kernel.cross_covariance(t, x) - Ktt = kernel.diagonal(t).diagonal[:, jnp.newaxis] - mean_t = mean_function(t) - - # Lx⁻¹ Kxt - Lx_inv_Kxt_diag = jnp.diag(solve(Lx, Ktx.T))[:, jnp.newaxis] - - # Whitened function values, wx, corresponding to the inputs, x - wx = self.latent.value - - # μt + Ktx Lx⁻¹ wx - mean = mean_t + Lx_inv_Kxt_diag * wx - mean = jnp.atleast_1d(mean.squeeze()) - - # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure - # to compute Schur complement more efficiently. - covariance = Ktt - jnp.square(Lx_inv_Kxt_diag) + def _return_diagonal_covariance( + Lx_inv_Kxt: Num[Array, "N M"], + t: Num[Array, "M D"], + ) -> Dense: + Ktt = kernel.diagonal(t).diagonal + covariance = Ktt - jnp.einsum("ij, ji->i", Lx_inv_Kxt.T, Lx_inv_Kxt) covariance += self.prior.jitter # It would be nice to return a Diagonal here, but the pytree needs # to be the same for both cond branches and the other branch needs @@ -880,17 +826,17 @@ def _return_mean_and_diagonal_covariance( # and hence are not the same pytree anymore. covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) - return mean, covariance + return covariance - mu, cov = jax.lax.cond( + cov = jax.lax.cond( return_covariance_type == "dense", - _return_mean_and_full_covariance, - _return_mean_and_diagonal_covariance, - train_data.X, + _return_full_covariance, + _return_diagonal_covariance, + Lx_inv_Kxt, test_inputs, ) - return GaussianDistribution(mu, cov) + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov) ####################### diff --git a/tests/test_gps.py b/tests/test_gps.py index 2d91809d8..f815c04e7 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -107,7 +107,9 @@ def test_prior_with_diag( # test that off diagonal elements are zero assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) # test that we return exactly the diagonal of the full covariance - assert jnp.all(jnp.diag(sigma)==jnp.diag(marginal_distribution_full.covariance())) + assert jnp.allclose( + jnp.diag(sigma), jnp.diag(marginal_distribution_full.covariance()) + ) @pytest.mark.parametrize("num_datapoints", [1, 10]) @@ -185,7 +187,9 @@ def test_conjugate_posterior_with_diag( # test that off diagonal elements are zero assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) # test that we return exactly the diagonal of the full covariance - assert jnp.all(jnp.diag(sigma)==jnp.diag(marginal_distribution_full.covariance())) + assert jnp.allclose( + jnp.diag(sigma), jnp.diag(marginal_distribution_full.covariance()) + ) @pytest.mark.parametrize("num_datapoints", [1, 10]) @@ -228,7 +232,7 @@ def test_conjugate_posterior( mu = marginal_distribution.mean sigma = marginal_distribution.covariance() assert mu.shape == (num_test_datapoints,) - assert sigma.shape == (num__test_datapoints, num_test_datapoints) + assert sigma.shape == (num_test_datapoints, num_test_datapoints) @pytest.mark.parametrize("num_datapoints", [1, 10]) @@ -265,23 +269,29 @@ def test_nonconjugate_posterior_with_diag( # Query a marginal distribution of the posterior at some inputs. inputs = jnp.linspace(-3.0, 3.0, num_test_datapoints).reshape(-1, 1) - marginal_distribution = posterior(inputs, D, return_covariance_type="diagonal") + marginal_distribution_diag = posterior(inputs, D, return_covariance_type="diagonal") + marginal_distribution_full = posterior(inputs, D, return_covariance_type="dense") # Ensure that the marginal distribution is a Gaussian. - assert isinstance(marginal_distribution, GaussianDistribution) - assert isinstance(marginal_distribution, NumpyroDistribution) + assert isinstance(marginal_distribution_diag, GaussianDistribution) + assert isinstance(marginal_distribution_diag, NumpyroDistribution) # Ensure that the marginal distribution has the correct shape. - mu = marginal_distribution.mean - sigma = marginal_distribution.covariance() + mu = marginal_distribution_diag.mean + sigma = marginal_distribution_diag.covariance() assert mu.shape == (num_test_datapoints,) # We are still returning a full covariance, even though the off diagonal # should all be zeros... assert sigma.shape == (num_test_datapoints, num_test_datapoints) assert jnp.all((sigma - jnp.diag(jnp.diag(sigma))) == 0) + # test that we return exactly the diagonal of the full covariance + assert jnp.allclose( + jnp.diag(sigma), jnp.diag(marginal_distribution_full.covariance()) + ) @pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("num_test_datapoints", [1, 10, 200]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) def test_nonconjugate_posterior( @@ -313,7 +323,7 @@ def test_nonconjugate_posterior( assert (posterior.latent.value == latent_values).all() # Query a marginal distribution of the posterior at some inputs. - inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + inputs = jnp.linspace(-3.0, 3.0, num_test_datapoints).reshape(-1, 1) marginal_distribution = posterior(inputs, D) # Ensure that the marginal distribution is a Gaussian. @@ -323,8 +333,8 @@ def test_nonconjugate_posterior( # Ensure that the marginal distribution has the correct shape. mu = marginal_distribution.mean sigma = marginal_distribution.covariance() - assert mu.shape == (num_datapoints,) - assert sigma.shape == (num_datapoints, num_datapoints) + assert mu.shape == (num_test_datapoints,) + assert sigma.shape == (num_test_datapoints, num_test_datapoints) @pytest.mark.parametrize("likelihood", [Bernoulli, Gaussian]) From 57c603937b29e7b63b163f4180600189764c486c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sat, 8 Nov 2025 13:58:30 +0100 Subject: [PATCH 16/19] Fix graph kernel shapes --- gpjax/kernels/non_euclidean/graph.py | 35 +++++++++---- tests/test_kernels/test_non_euclidean.py | 63 +++++++++++++++++++++++- 2 files changed, 88 insertions(+), 10 deletions(-) diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 90aad5dee..d6065244b 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -15,11 +15,7 @@ import beartype.typing as tp import jax.numpy as jnp -from jaxtyping import ( - Float, - Int, - Num, -) +from jaxtyping import Float, Integer, Num from gpjax.kernels.computations import ( AbstractKernelComputation, @@ -103,11 +99,32 @@ def __init__( def __call__( self, - x: Int[Array, "N 1"], - y: Int[Array, "M 1"], + x: tp.Union[ScalarInt, Integer[Array, " N"], Integer[Array, "N 1"]], + y: tp.Union[ScalarInt, Integer[Array, " M"], Integer[Array, "M 1"]], ): + x_idx = self._prepare_indices(x) + y_idx = self._prepare_indices(y) S = calculate_heat_semigroup(self) - Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose( - jax_gather_nd(self.eigenvectors, y) + Kxx = (jax_gather_nd(self.eigenvectors, x_idx) * S.squeeze()) @ jnp.transpose( + jax_gather_nd(self.eigenvectors, y_idx) ) # shape (n,n) return Kxx.squeeze() + + def _prepare_indices( + self, + indices: tp.Union[ScalarInt, Integer[Array, " N"], Integer[Array, "N 1"]], + ) -> Integer[Array, "N 1"]: + """Ensure index arrays are integer column vectors regardless of caller shape.""" + + idx = jnp.asarray(indices, dtype=jnp.int32) + idx = jnp.atleast_1d(idx) + + if idx.ndim > 2: + raise ValueError( + "GraphKernel expects indices shaped (N,) or (N, 1). " + f"Received {idx.shape}." + ) + if idx.ndim == 2 and idx.shape[-1] != 1: + raise ValueError("GraphKernel expects index arrays with a single column.") + + return idx.reshape(-1, 1) diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index a4f12e609..1d4d94429 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -10,7 +10,7 @@ # # See the License for the specific language governing permissions and # # limitations under the License. -from jax import config +from jax import config, jit, vmap import jax.numpy as jnp import networkx as nx @@ -48,3 +48,64 @@ def test_graph_kernel(): Kxx += Identity(Kxx.shape[0]) * 1e-6 eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert all(eigen_values > 0) + + +def _build_test_kernel(n_vertices: int = 10) -> GraphKernel: + graph = nx.path_graph(n_vertices) + laplacian = nx.laplacian_matrix(graph).toarray() + jnp.eye(n_vertices) * 1e-12 + return GraphKernel(laplacian=laplacian) + + +def test_graph_kernel_accepts_vector_indices(): + kernel = _build_test_kernel() + col_indices = jnp.arange(5).reshape(-1, 1) + vector_indices = col_indices.squeeze() + + matrix_eval = kernel(col_indices, col_indices) + vector_eval = kernel(vector_indices, vector_indices) + + assert matrix_eval.shape == (5, 5) + assert vector_eval.shape == (5, 5) + assert jnp.allclose(matrix_eval, vector_eval) + + +def test_graph_kernel_vmappable_over_single_indices(): + kernel = _build_test_kernel() + idx = jnp.arange(4) + + diag_entries = vmap(lambda z: kernel(z, z))(idx) + assert diag_entries.shape == (4,) + assert jnp.all(diag_entries >= 0.0) + + +def test_graph_kernel_vmappable_over_pairs(): + kernel = _build_test_kernel() + x = jnp.arange(5) + y = jnp.array([4, 3, 2, 1, 0]) + + vectorised_eval = vmap(lambda a, b: kernel(a, b))(x, y) + baseline = jnp.asarray([kernel(a, b) for a, b in zip(x, y)]) + + assert vectorised_eval.shape == (5,) + assert jnp.allclose(vectorised_eval, baseline) + + +def test_graph_kernel_is_jittable(): + kernel = _build_test_kernel() + jit_kernel = jit(lambda a, b: kernel(a, b)) + + column = jnp.arange(5).reshape(-1, 1) + vector = jnp.arange(5) + pairs = [ + (0, 0), + (0, column), + (vector, 1), + (vector, vector), + (column, column), + ] + + for x, y in pairs: + expected = kernel(x, y) + result = jit_kernel(x, y) + assert result.shape == expected.shape + assert jnp.allclose(result, expected) From 41c51c6fff3397a4f27819c4e3eae9338e3e471a Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sat, 8 Nov 2025 14:02:15 +0100 Subject: [PATCH 17/19] Fix graph kernel shapes --- gpjax/kernels/non_euclidean/graph.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index d6065244b..97c7a5d5e 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -118,13 +118,4 @@ def _prepare_indices( idx = jnp.asarray(indices, dtype=jnp.int32) idx = jnp.atleast_1d(idx) - - if idx.ndim > 2: - raise ValueError( - "GraphKernel expects indices shaped (N,) or (N, 1). " - f"Received {idx.shape}." - ) - if idx.ndim == 2 and idx.shape[-1] != 1: - raise ValueError("GraphKernel expects index arrays with a single column.") - return idx.reshape(-1, 1) From 02413518495833b50eb247244d39cb6cc5372996 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sat, 8 Nov 2025 14:09:48 +0100 Subject: [PATCH 18/19] Tidy typing --- gpjax/kernels/non_euclidean/graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 97c7a5d5e..79e8af4cb 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -99,8 +99,8 @@ def __init__( def __call__( self, - x: tp.Union[ScalarInt, Integer[Array, " N"], Integer[Array, "N 1"]], - y: tp.Union[ScalarInt, Integer[Array, " M"], Integer[Array, "M 1"]], + x: ScalarInt | Integer[Array, " N"] | Integer[Array, "N 1"], + y: ScalarInt | Integer[Array, " M"] | Integer[Array, "M 1"], ): x_idx = self._prepare_indices(x) y_idx = self._prepare_indices(y) @@ -112,7 +112,7 @@ def __call__( def _prepare_indices( self, - indices: tp.Union[ScalarInt, Integer[Array, " N"], Integer[Array, "N 1"]], + indices: ScalarInt | Integer[Array, " N"] | Integer[Array, "N 1"], ) -> Integer[Array, "N 1"]: """Ensure index arrays are integer column vectors regardless of caller shape.""" From 9687c8f2a4d7ed4494b91c5761013568ae3481ec Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sat, 8 Nov 2025 15:45:38 +0100 Subject: [PATCH 19/19] Linting --- gpjax/kernels/non_euclidean/graph.py | 6 +++++- tests/test_kernels/test_non_euclidean.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 79e8af4cb..30a416a4e 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -15,7 +15,11 @@ import beartype.typing as tp import jax.numpy as jnp -from jaxtyping import Float, Integer, Num +from jaxtyping import ( + Float, + Integer, + Num, +) from gpjax.kernels.computations import ( AbstractKernelComputation, diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index 1d4d94429..e09717d03 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -10,7 +10,11 @@ # # See the License for the specific language governing permissions and # # limitations under the License. -from jax import config, jit, vmap +from jax import ( + config, + jit, + vmap, +) import jax.numpy as jnp import networkx as nx @@ -84,7 +88,7 @@ def test_graph_kernel_vmappable_over_pairs(): y = jnp.array([4, 3, 2, 1, 0]) vectorised_eval = vmap(lambda a, b: kernel(a, b))(x, y) - baseline = jnp.asarray([kernel(a, b) for a, b in zip(x, y)]) + baseline = jnp.asarray([kernel(a, b) for a, b in zip(x, y, strict=False)]) assert vectorised_eval.shape == (5,) assert jnp.allclose(vectorised_eval, baseline)