Skip to content

Commit 9388662

Browse files
committed
adjust cov matrix handling
1 parent 51f0fc7 commit 9388662

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

lume_model/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def recursive_serialize(
127127
and any(isinstance(ele, torch.nn.Module) for ele in value)
128128
):
129129
# List of transformers
130-
print(v[key])
131130
v[key] = [
132131
process_torch_module(
133132
value[i], base_key, f"{key}_{i}", file_prefix, save_models, False

lume_model/models/gp_model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from gpytorch.mlls import ExactMarginalLogLikelihood
1212
from botorch.models.transforms.input import ReversibleInputTransform
1313
from botorch.models.transforms.outcome import OutcomeTransform
14+
from linear_operator.utils.cholesky import psd_safe_cholesky
1415

1516
from lume_model.models.prob_model_base import (
1617
ProbModelBaseModel,
@@ -30,15 +31,13 @@ class GPModel(ProbModelBaseModel):
3031
model: A single task GPyTorch model or BoTorch model.
3132
input_transformers: List of input transformers to apply to the input data. Optional, default is None.
3233
output_transformers: List of output transformers to apply to the output data. Optional, default is None.
33-
jitter: Jitter to add to diagonal of covariance matrix for numerical stability, if matrix is not positive definite. Optional, default is 1e-8.
3434
"""
3535

3636
model: SingleTaskGP | MultiTaskGP # TODO: any other types?
3737
input_transformers: list[ReversibleInputTransform | torch.nn.Linear] = None
3838
output_transformers: list[
3939
OutcomeTransform | ReversibleInputTransform | torch.nn.Linear
4040
] = None
41-
jitter: float = 1e-8
4241

4342
def __init__(self, *args, **kwargs):
4443
super().__init__(*args, **kwargs)
@@ -187,10 +186,8 @@ def _get_predictions(
187186
distribution = self._get_distribution(posterior)
188187
# Take mean and covariance of the distribution
189188
mean, covar = distribution.mean, distribution.covariance_matrix
190-
# Transform the output (mean and covariance)
191-
if self.output_transformers is not None:
192-
mean, covar = self._transform_outputs(mean), self._transform_outputs(covar)
193189
# Return a dictionary of output variable names to distributions
190+
# this untransforms the mean and covariance before returning
194191
return self._create_output_dict((mean, covar))
195192

196193
def _posterior(self, x):
@@ -257,6 +254,11 @@ def _create_output_dict(
257254
_cov[:, :ss, :ss] = cov[:, i * ss : (i + 1) * ss, i * ss : (i + 1) * ss]
258255

259256
_cov = self._check_covariance_matrix(_cov)
257+
# Last step is to untransform
258+
if self.output_transformers is not None:
259+
_mean, _cov = self._transform_outputs(_mean), self._transform_outputs(_cov)
260+
261+
# TODO: add a check for final covariance matrix to be positive definite?
260262
output_distributions[name] = MultivariateNormal(_mean, _cov)
261263

262264
return output_distributions
@@ -290,7 +292,7 @@ def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor:
290292
if isinstance(transformer, ReversibleInputTransform):
291293
output_tensor = transformer.untransform(output_tensor)
292294
elif isinstance(transformer, OutcomeTransform):
293-
output_tensor = transformer.untransform(output_tensor)
295+
output_tensor = transformer.untransform(output_tensor)[0]
294296
else:
295297
w, b = transformer.weight, transformer.bias
296298
output_tensor = torch.matmul((output_tensor - b), torch.linalg.inv(w.T))
@@ -302,8 +304,9 @@ def _check_covariance_matrix(self, cov: torch.Tensor) -> torch.Tensor:
302304
torch.linalg.cholesky(cov)
303305
except torch._C._LinAlgError:
304306
warnings.warn(
305-
f"Covariance matrix is not positive definite. Added jitter of {self.jitter:.1e} to the diagonal."
307+
f"Covariance matrix is not positive definite. Attempting to add jitter the diagonal."
306308
)
307-
eps = torch.tensor(self.jitter, **self._tkwargs)
308-
cov = cov + torch.eye(cov.shape[-1], **self._tkwargs) * eps
309+
l = psd_safe_cholesky(cov) # determines jitter iteratively
310+
cov = l @ l.transpose(-1, -2)
311+
309312
return cov

0 commit comments

Comments
 (0)