Skip to content

Conversation

@mathDR
Copy link
Contributor

@mathDR mathDR commented Nov 16, 2025

MultiOutput GP Proposal.

In the spirit of MultiOutput Processes like GPFlow, we will take the same problem statement:

Problem Statement

We will consider a regression problem for functions $f: \mathbb{R}^D \rightarrow \mathbb{R}^Q$. We assume that the dataset is of the form $(X_1, f_1), \dots, (X_Q, f_Q)$, that is, we may observe different inputs for each output dimension.

Here we assume a model of the form:
$$f(x) = W g(x), $$
where $g(x) \in \mathbb{R}^L$, $f(x) \in \mathbb{R}^Q$ and $W \in \mathbb{R}^{Q \times L}$. We assume that the outputs of $g$ are uncorrelated, and that by mixing them with $W$ they become correlated.

Note, we have two options for $g$:

  1. The output dimensions of $g$ share the _same_same kernel.
  2. Each output of $g$ has a separate kernel.

In a following PR we can discuss variational GPs, wherein we need to further suboption for the inducing inputs of $g$:

  1. The instances of $g$ share the same inducing inputs.
  2. Each output of $g$ has its own set of inducing inputs.

The notation is as follows:

  • $X_i \in \mathbb{R}^{N \times D}, i=1,ldots,Q$ denotes the input
  • $Y \in \mathbb{R}^{N \times Q}$ denotes the output
  • $k_{1..L}$, $L$ are kernels on $\mathbb{R}^{N \times D}$
  • $g_{1..L}$, $L$ are independent GPs with $g_l \sim GP(0,k_l)$
  • $f_{1..Q}$, $P$ are correlated GPs with $\bf{f} = \bf{W} \bf{g}.$

Phase 1:

We write a multioutput kernel that initially just takes in a single kernel. Mimicing GPFlow again we could have something like:

"""MultiOutput Kernels."""

from abc import abstractmethod

import jax.numpy as jnp
from jaxtyping import Num

from gpjax.kernels import AbstractKernel
from gpjax.typing import Array


# TODO describe various output shapes
class MultioutputKernel(AbstractKernel):
    """
    Multi Output Kernel class.

    This kernel can represent correlation between outputs of different datapoints.

    The `full_output_cov` argument holds whether the kernel should calculate
    the covariance between the outputs. In case there is no correlation but
    `full_output_cov` is set to True the covariance matrix will be filled with zeros
    until the appropriate size is reached.
    """

    @property
    @abstractmethod
    def num_latent_gps(self) -> int:
        """The number of latent GPs in the multioutput kernel"""
        raise NotImplementedError

    @property
    @abstractmethod
    def latent_kernels(self) -> tuple[AbstractKernel, ...]:
        """The underlying kernels in the multioutput kernel"""
        raise NotImplementedError

    @abstractmethod
    def K(
        self,
        X: Num[Array, "N D"],
        X2: Num[Array, "M D"] | None = None,
        full_output_cov: bool = True,
    ) -> Array:
        """
        Returns the correlation of f(X) and f(X2), where f(.) can be multi-dimensional.

        :param X: data matrix
        :param X2: data matrix
        :param full_output_cov: calculate correlation between outputs.
        :return: cov[f(X), f(X2)]
        """
        raise NotImplementedError

    @abstractmethod
    def K_diag(self, X: Num[Array, "N D"], full_output_cov: bool = True) -> Array:
        """
        Returns the correlation of f(X) and f(X), where f(.) can be multi-dimensional.

        :param X: data matrix
        :param full_output_cov: calculate correlation between outputs.
        :return: var[f(X)]
        """
        raise NotImplementedError

    def __call__(
        self,
        X: Num[Array, "N D"],
        X2: Num[Array, "M D"] | None = None,
        *,
        full_cov: bool = False,
        full_output_cov: bool = True,
    ) -> Array:
        if not full_cov and X2 is not None:
            raise ValueError(
                "Ambiguous inputs: passing in `X2` is not compatible with `full_cov=False`."
            )
        if not full_cov:
            return self.K_diag(X, full_output_cov=full_output_cov)
        return self.K(X, X2, full_output_cov=full_output_cov)


class SharedIndependent(MultioutputKernel):
    """
    - Shared: we use the same kernel for each latent GP
    - Independent: Latents are uncorrelated a priori.
    """

    def __init__(self, kernel: AbstractKernel, output_dim: int) -> None:
        super().__init__()
        self.kernel = kernel
        self.output_dim = output_dim

    @property
    def num_latent_gps(self) -> int:
        # In this case number of latent GPs (L) == output_dim (P)
        return self.output_dim

    @property
    def latent_kernels(self) -> tuple[AbstractKernel, ...]:
        """The underlying kernels in the multioutput kernel"""
        return (self.kernel,)

    def K(
        self,
        X: Num[Array, "N D"],
        X2: Num[Array, "M D"] | None = None,
        full_output_cov: bool = True,
    ) -> Array:
        if X2 is None:
            K = self.kernel.gram(X)
        else:
            K = self.kernel.cross_covariance(X, X2)

        if full_output_cov:
            return jnp.tile(
                jnp.tile(K, (self.num_latent_gps, 1)), (1, self.num_latent_gps)
            )
        else:
            return jnp.kron(jnp.eye(self.num_latent_gps, dtype=int), K)

    def K_diag(self, X: Num[Array, "N D"], full_output_cov: bool = True) -> Array:
        K = self.kernel.diagonal(X)
        Ks = jnp.tile(K, (self.num_latent_gps, 1))
        return jnp.diag(Ks) if full_output_cov else Ks

we could then add another Independent Kernel where we pass in a Sequence of kernels and apply each to a num_latent_gps dimension.

Phase 2

We add a MultiLatentPrior class (probably just start with Conjugate) where instead of returning a GaussianDistribution class, we develop a MatrixNormalDistribution class that inherits from numpyro's MatrixNormal that will allow us to sample.

I am still fleshing this out, but overall it seems as if we need to add a Multikernel and a multi prior of some types.

Thoughts?

@github-actions
Copy link

Thank you for opening your first PR into GPJax!

If you have not heard from us in a while, please feel free to ping
@gpjax/developers or anyone who has commented on the PR.
Most of our reviewers are volunteers and sometimes things fall
through the cracks.

You can also join us on
Slack
for real-time
discussion.

For details on testing, writing docs, and our review process,
please see the developer
guide

We strive to be a welcoming and open project. Please follow our
Code of
Conduct
.

@mathDR mathDR marked this pull request as draft November 16, 2025 21:21
@thomaspinder
Copy link
Owner

Thanks for the detail here @mathDR - give me a few days to review and provide comments.

@mathDR
Copy link
Contributor Author

mathDR commented Nov 17, 2025

Great @thomaspinder I am also working on an mvp for the multilatent prior. When I get that up I will post it here.

@daniel-dodd any thoughts/comments you want to make?

@mathDR
Copy link
Contributor Author

mathDR commented Nov 18, 2025

Some questions that are arising as I am building this out:
Assume for the following that we have an input dimension of $D$ and the number of latent gps is $L$ and the output dimension of $y$ is $Q.$

Then if we have a single kernel (i.e. SharedIndependent) then we would need $L = Q$. Similarly for SeparateIndependent where each output of $y$ has its own kernel.

When we have a full LMC kernel with a mixing matrix $W$ then $L$ and $Q$ do not have to be equal.

For input data:
We could have a few different APIs:

  1. a single Dataset with $X \in \mathbb{R}^{N \times D}$ and $y \in \mathbb{R}^{N \times Q}$ and the kernel "knows" to stack $y$ appropriately.
  2. a single Dataset with $X \in \mathbb{R}^{N \times D}$ and $y \in \mathbb{R}^{N Q}$ -- that is, y is already stacked, and the Posterior predict() functions "knows" how to reshape predictions to match $N_{predict} \times Q$.
  3. we pass in $Q$ Datasets each having $X_i \in \mathbb{R}^{N_i \times D}$ and $y \in \mathbb{R}^{N_i}$ -- that is, each component can have possibly different input locations (and quantity).

We should allow for all of these in MultiLatent Prior initialization don't you think?

@thomaspinder
Copy link
Owner

thomaspinder commented Nov 20, 2025

Thanks for putting this together @mathDR . Some unstructured thoughts.

  1. It might be clear and well-aligned with the maths if we allow the user to directly specify/initialise the mixing matrix (W). Something like
  mixing = Real(jnp.eye(num_outputs, rank)) 
  latent = [gpx.kernels.RBF(lengthscale=...) for _ in range(rank)]
  kernel = gpx.kernels.MultiOutputKernel(
      latent_kernels=latent,
      mixing=mixing,
      num_outputs=num_outputs,
  )
  meanf = gpx.mean_functions.Zero(output_dims=num_outputs)
  prior = gpx.gps.MultiOutputPrior(mean_function=meanf, kernel=kernel)

What do you think? I think this allows us to repurpose a lot of your code whilst keeping things very close to the maths.
2. We should be very mindful of how we create the data. In fact, based on previous efforts, it's equally important that we get this right alongside the kernel. We need to be able to correctly align the data with the kernels and leverage the Kronecker linear operator, otherwise operations will be unnecessarily expensive.

@mathDR
Copy link
Contributor Author

mathDR commented Nov 20, 2025

Okay for this simple case (one kernel repeated for each output), we can use the following:
$$\bf{vec}(y) \sim \mathbb{N}(\bf{vec}(0), \Sigma + \sigma^2 \bf{I}_{NQ}), \Sigma = K \otimes W W^T$$

If our covariance to the above multivariate normal was just $K \otimes W W^T$ then we could use a matrix normal $\mathbb{MN}(0,K,W W^T),$ but IDK if that still holds. In any case we can use a numpyro MultiVariateNormal distribution

@mathDR
Copy link
Contributor Author

mathDR commented Nov 21, 2025

In the general case with $L$ latent kernels, we need to compute:
$$\bf{vec}(y) \sim \mathbb{N}(\bf{vec}(0), \Sigma + \sigma^2 \bf{I}_{NQ}), \Sigma = \sum_l K_l \otimes w_l w_l^T$$

@mathDR
Copy link
Contributor Author

mathDR commented Nov 21, 2025

Okay @thomaspinder I like your idea, but I would say we should add the mixing matrix (W) to the prior:

  mixing = Real(jnp.eye(num_outputs, rank)) 
  latent = [gpx.kernels.RBF(lengthscale=...) for _ in range(rank)]
  kernel = gpx.kernels.MultiOutputKernel(
      latent_kernels=latent,
      num_outputs=num_outputs,
  )
  meanf = gpx.mean_functions.Zero(output_dims=num_outputs)
  prior = gpx.gps.MultiOutputPrior(mean_function=meanf, kernel=kernel, mixing=mixing,)

then the "kernel" produces $g(x)$ and the prior produces $f(x) = W g(x).$. Does that make sense?

@mathDR
Copy link
Contributor Author

mathDR commented Nov 21, 2025

Also, not for this PR, but when we allow for variational inference, we will have a new get_posterior() method that (probably) matches something like what GPFlow has in this area:

So we would have a SharedIndendent, a SeperateIndependent, and a LinearCoregionalization multioutput set of kernels. And we would have a matching prior for each.

For the variational model(s) we would need to specify for the latter two kernels if we want the inducing points "shared" among all of the kernels, or each kernel having its own set of inducing points.

But as I said, that is for a later PR.

@mathDR
Copy link
Contributor Author

mathDR commented Nov 21, 2025

For this PR: we would implement:

  1. MultiOutputKernels
    a. SharedIndependent
    b. SeperateIndependent
    c. LinearCoregionalization
  2. Prior/Posteriors for the above. We assume only conjugate priors for this PR (a later PR would add the nonconjugate case). There would be a one-to-one match between them, so:
    a. SharedIndependentPrior; SharedIndependentConjugatePosterior
    b. SeperateIndependentPrior; SeperateIndependentConjugatePosterior
    c. LinearCoregionalizationPrior; LinearCoregionalizationPosterior

Thoughts?

@thomaspinder
Copy link
Owner

OK. I'm good with leaving the mixing component in the prior, this is fine. I also align with your proposal for this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants