1212from botorch .models .transforms .input import ReversibleInputTransform
1313from botorch .models .transforms .outcome import OutcomeTransform
1414from linear_operator .utils .cholesky import psd_safe_cholesky
15+ from linear_operator .operators import DiagLinearOperator
1516
1617from lume_model .models .prob_model_base import (
1718 ProbModelBaseModel ,
@@ -114,6 +115,8 @@ def check_transforms(self):
114115 # Either way, the internal transform should be removed
115116 # to avoid double transformations
116117 # TODO: should this get saved somewhere though, e.g. as metadata?
118+ # TODO: is this how we want to handle this? Wouldn't results be inaccurate if the GP was trained w/
119+ # TODO: different outcome_transform?
117120 delattr (self .model , "outcome_transform" )
118121
119122 if self .output_transformers is None :
@@ -253,12 +256,14 @@ def _create_output_dict(
253256 _cov = torch .zeros (batch , ss , ss , ** self ._tkwargs )
254257 _cov [:, :ss , :ss ] = cov [:, i * ss : (i + 1 ) * ss , i * ss : (i + 1 ) * ss ]
255258
259+ # Check that the covariance matrix is positive definite
256260 _cov = self ._check_covariance_matrix (_cov )
261+
257262 # Last step is to untransform
258263 if self .output_transformers is not None :
259- _mean , _cov = self ._transform_outputs (_mean ), self ._transform_outputs (_cov )
264+ _mean = self ._transform_mean (_mean )
265+ _cov = self ._transform_covar (_cov , _mean )
260266
261- # TODO: add a check for final covariance matrix to be positive definite?
262267 output_distributions [name ] = MultivariateNormal (_mean , _cov )
263268
264269 return output_distributions
@@ -279,24 +284,52 @@ def _transform_inputs(self, input_tensor: torch.Tensor) -> torch.Tensor:
279284 input_tensor = transformer (input_tensor )
280285 return input_tensor
281286
282- def _transform_outputs (self , output_tensor : torch .Tensor ) -> torch .Tensor :
283- """(Un-)Transforms the model output tensor.
287+ def _transform_mean (self , mean : torch .Tensor ) -> torch .Tensor :
288+ """(Un-)Transforms the model output mean.
289+
290+ Args:
291+ mean: Output mean tensor from the model.
292+
293+ Returns:
294+ (Un-)Transformed output mean tensor.
295+ """
296+ for transformer in self .output_transformers :
297+ if isinstance (transformer , ReversibleInputTransform ):
298+ mean = transformer .untransform (mean )
299+ elif isinstance (transformer , OutcomeTransform ):
300+ scale_fac = transformer .stdvs .squeeze (0 )
301+ offset = transformer .means .squeeze (0 )
302+ mean = offset + scale_fac * mean
303+ else :
304+ raise NotImplementedError (
305+ f"Output transformer { type (transformer )} is not supported."
306+ )
307+ return mean
308+
309+ def _transform_covar (self , cov : torch .Tensor ) -> torch .Tensor :
310+ """(Un-)Transforms the model output covariance matrix.
284311
285312 Args:
286- output_tensor : Output tensor from the model.
313+ cov : Output covariance matrix tensor from the model.
287314
288315 Returns:
289- (Un-)Transformed output tensor.
316+ (Un-)Transformed output covariance matrix tensor.
290317 """
291318 for transformer in self .output_transformers :
292319 if isinstance (transformer , ReversibleInputTransform ):
293- output_tensor = transformer .untransform (output_tensor )
320+ scale_fac = transformer .coefficient .expand (cov .shape [:- 1 ])
321+ scale_mat = DiagLinearOperator (scale_fac )
322+ cov = scale_mat @ cov @ scale_mat
294323 elif isinstance (transformer , OutcomeTransform ):
295- output_tensor = transformer .untransform (output_tensor )[0 ]
324+ scale_fac = transformer .stdvs .squeeze (0 )
325+ scale_fac = scale_fac .expand (cov .shape [:- 1 ])
326+ scale_mat = DiagLinearOperator (scale_fac )
327+ cov = scale_mat @ cov @ scale_mat
296328 else :
297- w , b = transformer .weight , transformer .bias
298- output_tensor = torch .matmul ((output_tensor - b ), torch .linalg .inv (w .T ))
299- return output_tensor
329+ raise NotImplementedError (
330+ f"Output transformer { type (transformer )} is not supported."
331+ )
332+ return cov
300333
301334 def _check_covariance_matrix (self , cov : torch .Tensor ) -> torch .Tensor :
302335 """Checks that the covariance matrix is positive definite, and adds jitter if not."""
0 commit comments