Skip to content

bug: Heteroscedastic noise gaussian processes #524

@VadimBim

Description

@VadimBim

I'm not sure if it's a bug; most probably, I am doing something wrong.

GPJax version: 0.11.2

Hi ! I want to fit a model on some noisy measurements and decided that GPs are a natural framework for this task. I started adapting the regression tutorial to my data. Here is an MVP:

import gpjax as gpx

from jax import config
import jax.numpy as jnp

config.update("jax_enable_x64", True)

x = jnp.array([0.1, 0.2, 0.5, 2.0]).reshape(-1, 1)
y = jnp.array([1.48, 2.2, 5.18, 5.62,]).reshape(-1, 1)
y_error = jnp.array([1e-3, 1e-2, 1e-1, 1.0]).reshape(-1, 1)
x_test = jnp.linspace(0, 2.0, 100)

dataset = gpx.Dataset(X=x, y=y)

#construct gp
kernel = gpx.kernels.RBF(lengthscale=0.25, variance=0.8)  # 1-dimensional input
meanf = gpx.mean_functions.Constant(jnp.max(dataset.y))
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=dataset.n, obs_stddev=y_error)
posterior = prior * likelihood

print(f"obs_stddev before optimization: {likelihood.obs_stddev.value}")

# fit using bfgs
opt_posterior, history = gpx.fit_scipy(
    model=posterior,
    # we use the negative mll as we are minimising
    objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
    train_data=dataset,
)

print(f"obs_stddev after optimization: {opt_posterior.likelihood.obs_stddev.value}")

latent_dist = opt_posterior.predict(x_test.reshape(-1, 1), dataset)
predictive_dist = opt_posterior.likelihood(latent_dist)

Traceback:

obs_stddev before optimization: [[0.001]
 [0.01 ]
 [0.1  ]
 [1.   ]]
Optimization terminated successfully.
         Current function value: 6.641228
         Iterations: 61
         Function evaluations: 62
         Gradient evaluations: 62
obs_stddev after optimization: [[1.39191897e-005]
 [5.68286249e-261]
 [0.00000000e+000]
 [0.00000000e+000]]
Traceback (most recent call last):
  File "mvp.py", line 44, in <module>
    predictive_dist = posterior.likelihood(latent_dist)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/envs/default/lib/python3.11/site-packages/gpjax/likelihoods.py", line 73, in __call__
    return self.predict(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/envs/default/lib/python3.11/site-packages/gpjax/likelihoods.py", line 187, in predict
    noisy_cov = cov.at[jnp.diag_indices(n_data)].add(self.obs_stddev.value**2)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/default/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 834, in add
    return scatter._scatter_update(self.array, self.index, values,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/default/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 77, in _scatter_update
    return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/envs/default/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 112, in _scatter_impl
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/envs/default/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 3138, in broadcast_to
    return util._broadcast_to(array, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/envs/default/lib/python3.11/site-packages/jax/_src/numpy/util.py", line 271, in _broadcast_to
    raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(4, 1) shape=(100,)

Expected behavior:

  1. I was expecting the observational noise not to be affected by the optimizer as it is "ground truth" from the measurements. Am I missing something?
  2. Also, the mismatch in dimension that happens here could be resolved by creating an array of the same size as the diagonal and keep the value it it is close to the x points. A naive solution that worked for my 1D case:
    def predict(
        self,
        train_ds: Dataset, 
        x_test: Array,
        dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
    ) -> npd.MultivariateNormal:
        r"""Evaluate the Gaussian likelihood.

        Evaluate the Gaussian likelihood function at a given predictive
        distribution. Computationally, this is equivalent to summing the
        observation noise term to the diagonal elements of the predictive
        distribution's covariance matrix.

        Args:
            dist (npd.Distribution): The Gaussian process posterior,
                evaluated at a finite set of test points.

        Returns:
            npd.Distribution: The predictive distribution.
        """
        n_data = dist.event_shape[0]
        cov = dist.covariance_matrix
        diag_mask = jnp.squeeze(vmap(lambda train_x: jnp.isclose(x_test - train_x, 0.0, atol=1e-2))(train_ds.X).sum(axis=0, dtype=bool))
        assert diag_mask.sum() == jnp.size(self.obs_stddev)
        diag_noise = jnp.zeros_like(diag_mask, dtype=self.obs_stddev.value.dtype).at[diag_mask].set(jnp.squeeze(self.obs_stddev.value**2))
        noisy_cov = cov.at[jnp.diag_indices(n_data)].add(diag_noise)

        return npd.MultivariateNormal(dist.mean, noisy_cov)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions