1111from gpytorch .mlls import ExactMarginalLogLikelihood
1212from botorch .models .transforms .input import ReversibleInputTransform
1313from botorch .models .transforms .outcome import OutcomeTransform
14+ from linear_operator .utils .cholesky import psd_safe_cholesky
1415
1516from lume_model .models .prob_model_base import (
1617 ProbModelBaseModel ,
@@ -30,15 +31,13 @@ class GPModel(ProbModelBaseModel):
3031 model: A single task GPyTorch model or BoTorch model.
3132 input_transformers: List of input transformers to apply to the input data. Optional, default is None.
3233 output_transformers: List of output transformers to apply to the output data. Optional, default is None.
33- jitter: Jitter to add to diagonal of covariance matrix for numerical stability, if matrix is not positive definite. Optional, default is 1e-8.
3434 """
3535
3636 model : SingleTaskGP | MultiTaskGP # TODO: any other types?
3737 input_transformers : list [ReversibleInputTransform | torch .nn .Linear ] = None
3838 output_transformers : list [
3939 OutcomeTransform | ReversibleInputTransform | torch .nn .Linear
4040 ] = None
41- jitter : float = 1e-8
4241
4342 def __init__ (self , * args , ** kwargs ):
4443 super ().__init__ (* args , ** kwargs )
@@ -187,10 +186,8 @@ def _get_predictions(
187186 distribution = self ._get_distribution (posterior )
188187 # Take mean and covariance of the distribution
189188 mean , covar = distribution .mean , distribution .covariance_matrix
190- # Transform the output (mean and covariance)
191- if self .output_transformers is not None :
192- mean , covar = self ._transform_outputs (mean ), self ._transform_outputs (covar )
193189 # Return a dictionary of output variable names to distributions
190+ # this untransforms the mean and covariance before returning
194191 return self ._create_output_dict ((mean , covar ))
195192
196193 def _posterior (self , x ):
@@ -257,6 +254,11 @@ def _create_output_dict(
257254 _cov [:, :ss , :ss ] = cov [:, i * ss : (i + 1 ) * ss , i * ss : (i + 1 ) * ss ]
258255
259256 _cov = self ._check_covariance_matrix (_cov )
257+ # Last step is to untransform
258+ if self .output_transformers is not None :
259+ _mean , _cov = self ._transform_outputs (_mean ), self ._transform_outputs (_cov )
260+
261+ # TODO: add a check for final covariance matrix to be positive definite?
260262 output_distributions [name ] = MultivariateNormal (_mean , _cov )
261263
262264 return output_distributions
@@ -290,7 +292,7 @@ def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor:
290292 if isinstance (transformer , ReversibleInputTransform ):
291293 output_tensor = transformer .untransform (output_tensor )
292294 elif isinstance (transformer , OutcomeTransform ):
293- output_tensor = transformer .untransform (output_tensor )
295+ output_tensor = transformer .untransform (output_tensor )[ 0 ]
294296 else :
295297 w , b = transformer .weight , transformer .bias
296298 output_tensor = torch .matmul ((output_tensor - b ), torch .linalg .inv (w .T ))
@@ -302,8 +304,9 @@ def _check_covariance_matrix(self, cov: torch.Tensor) -> torch.Tensor:
302304 torch .linalg .cholesky (cov )
303305 except torch ._C ._LinAlgError :
304306 warnings .warn (
305- f"Covariance matrix is not positive definite. Added jitter of { self . jitter :.1e } to the diagonal."
307+ f"Covariance matrix is not positive definite. Attempting to add jitter the diagonal."
306308 )
307- eps = torch .tensor (self .jitter , ** self ._tkwargs )
308- cov = cov + torch .eye (cov .shape [- 1 ], ** self ._tkwargs ) * eps
309+ l = psd_safe_cholesky (cov ) # determines jitter iteratively
310+ cov = l @ l .transpose (- 1 , - 2 )
311+
309312 return cov
0 commit comments