1515
1616# from __future__ import annotations
1717from abc import abstractmethod
18- from dataclasses import dataclass , field
18+ from dataclasses import (
19+ dataclass ,
20+ field ,
21+ )
1922from typing import overload
2023
2124from beartype .typing import (
2528)
2629import cola
2730from cola .ops import Dense
28-
2931import jax .numpy as jnp
3032from jax .random import (
3133 PRNGKey ,
4749 ReshapedDistribution ,
4850 ReshapedGaussianDistribution ,
4951)
50- from gpjax .kernels import RFF , White
52+ from gpjax .kernels import (
53+ RFF ,
54+ White ,
55+ )
5156from gpjax .kernels .base import AbstractKernel
5257from gpjax .likelihoods import (
5358 AbstractLikelihood ,
@@ -503,7 +508,7 @@ def predict(
503508 n_test = len (test_inputs )
504509
505510 # Observation noise o²
506- obs_noise = self .likelihood .obs_noise
511+ obs_var = self .likelihood .obs_stddev ** 2
507512 mx = self .prior .mean_function (x )
508513
509514 # Precompute Gram matrix, Kxx, at training inputs, x
@@ -512,7 +517,7 @@ def predict(
512517
513518 # Σ = Kxx + Io²
514519 Sigma = cola .ops .Kronecker (Kxx , Kyy )
515- Sigma += cola .ops .I_like (Sigma ) * (obs_noise + self .jitter )
520+ Sigma += cola .ops .I_like (Sigma ) * (obs_var + self .jitter )
516521 Sigma = cola .PSD (Sigma )
517522
518523 if mask is not None :
@@ -606,13 +611,10 @@ def sample_approx(
606611
607612 # sample weights v for canonical features
608613 # v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
614+ obs_var = self .likelihood .obs_stddev ** 2
609615 Kxx = self .prior .kernel .gram (train_data .X ) # [N, N]
610- Sigma = Kxx + cola .ops .I_like (Kxx ) * (
611- self .likelihood .obs_noise + self .jitter
612- ) # [N, N]
613- eps = jnp .sqrt (self .likelihood .obs_noise ) * normal (
614- key , [train_data .n , num_samples ]
615- ) # [N, B]
616+ Sigma = Kxx + cola .ops .I_like (Kxx ) * (obs_var + self .jitter ) # [N, N]
617+ eps = jnp .sqrt (obs_var ) * normal (key , [train_data .n , num_samples ]) # [N, B]
616618 y = train_data .y - self .prior .mean_function (train_data .X ) # account for mean
617619 Phi = fourier_feature_fn (train_data .X )
618620 canonical_weights = cola .solve (
0 commit comments