Skip to content

Commit 88a0e21

Browse files
committed
fix how mean and covar are transformed
1 parent 9388662 commit 88a0e21

File tree

1 file changed

+44
-11
lines changed

1 file changed

+44
-11
lines changed

lume_model/models/gp_model.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from botorch.models.transforms.input import ReversibleInputTransform
1313
from botorch.models.transforms.outcome import OutcomeTransform
1414
from linear_operator.utils.cholesky import psd_safe_cholesky
15+
from linear_operator.operators import DiagLinearOperator
1516

1617
from lume_model.models.prob_model_base import (
1718
ProbModelBaseModel,
@@ -114,6 +115,8 @@ def check_transforms(self):
114115
# Either way, the internal transform should be removed
115116
# to avoid double transformations
116117
# TODO: should this get saved somewhere though, e.g. as metadata?
118+
# TODO: is this how we want to handle this? Wouldn't results be inaccurate if the GP was trained w/
119+
# TODO: different outcome_transform?
117120
delattr(self.model, "outcome_transform")
118121

119122
if self.output_transformers is None:
@@ -253,12 +256,14 @@ def _create_output_dict(
253256
_cov = torch.zeros(batch, ss, ss, **self._tkwargs)
254257
_cov[:, :ss, :ss] = cov[:, i * ss : (i + 1) * ss, i * ss : (i + 1) * ss]
255258

259+
# Check that the covariance matrix is positive definite
256260
_cov = self._check_covariance_matrix(_cov)
261+
257262
# Last step is to untransform
258263
if self.output_transformers is not None:
259-
_mean, _cov = self._transform_outputs(_mean), self._transform_outputs(_cov)
264+
_mean = self._transform_mean(_mean)
265+
_cov = self._transform_covar(_cov, _mean)
260266

261-
# TODO: add a check for final covariance matrix to be positive definite?
262267
output_distributions[name] = MultivariateNormal(_mean, _cov)
263268

264269
return output_distributions
@@ -279,24 +284,52 @@ def _transform_inputs(self, input_tensor: torch.Tensor) -> torch.Tensor:
279284
input_tensor = transformer(input_tensor)
280285
return input_tensor
281286

282-
def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor:
283-
"""(Un-)Transforms the model output tensor.
287+
def _transform_mean(self, mean: torch.Tensor) -> torch.Tensor:
288+
"""(Un-)Transforms the model output mean.
289+
290+
Args:
291+
mean: Output mean tensor from the model.
292+
293+
Returns:
294+
(Un-)Transformed output mean tensor.
295+
"""
296+
for transformer in self.output_transformers:
297+
if isinstance(transformer, ReversibleInputTransform):
298+
mean = transformer.untransform(mean)
299+
elif isinstance(transformer, OutcomeTransform):
300+
scale_fac = transformer.stdvs.squeeze(0)
301+
offset = transformer.means.squeeze(0)
302+
mean = offset + scale_fac * mean
303+
else:
304+
raise NotImplementedError(
305+
f"Output transformer {type(transformer)} is not supported."
306+
)
307+
return mean
308+
309+
def _transform_covar(self, cov: torch.Tensor) -> torch.Tensor:
310+
"""(Un-)Transforms the model output covariance matrix.
284311
285312
Args:
286-
output_tensor: Output tensor from the model.
313+
cov: Output covariance matrix tensor from the model.
287314
288315
Returns:
289-
(Un-)Transformed output tensor.
316+
(Un-)Transformed output covariance matrix tensor.
290317
"""
291318
for transformer in self.output_transformers:
292319
if isinstance(transformer, ReversibleInputTransform):
293-
output_tensor = transformer.untransform(output_tensor)
320+
scale_fac = transformer.coefficient.expand(cov.shape[:-1])
321+
scale_mat = DiagLinearOperator(scale_fac)
322+
cov = scale_mat @ cov @ scale_mat
294323
elif isinstance(transformer, OutcomeTransform):
295-
output_tensor = transformer.untransform(output_tensor)[0]
324+
scale_fac = transformer.stdvs.squeeze(0)
325+
scale_fac = scale_fac.expand(cov.shape[:-1])
326+
scale_mat = DiagLinearOperator(scale_fac)
327+
cov = scale_mat @ cov @ scale_mat
296328
else:
297-
w, b = transformer.weight, transformer.bias
298-
output_tensor = torch.matmul((output_tensor - b), torch.linalg.inv(w.T))
299-
return output_tensor
329+
raise NotImplementedError(
330+
f"Output transformer {type(transformer)} is not supported."
331+
)
332+
return cov
300333

301334
def _check_covariance_matrix(self, cov: torch.Tensor) -> torch.Tensor:
302335
"""Checks that the covariance matrix is positive definite, and adds jitter if not."""

0 commit comments

Comments
 (0)