File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1078,12 +1078,19 @@ def forward(self, X: Tensor) -> Tensor:
10781078 t-batches of ``d``-dim design points each.
10791079
10801080 Returns:
1081- A ``(b1 x ... x bk)``-dim Tensor of Posterior Mean values at the given
1082- design points ``X``.
1081+ A ``(b1 x ... x bk)``-dim Tensor of scalarized Posterior Mean values
1082+ at the given design points ``X``.
10831083 """
1084- # (b1 x ... x bk) x q x 1
1085- mean , _ = self ._mean_and_sigma (X , compute_sigma = False )
1086- return mean .squeeze (- 1 ) @ self .weights
1084+ # ScalarizedPosteriorMean cannot use self._mean_and_sigma, since that squeezes
1085+ # the q-dim.
1086+ posterior = self .model .posterior (
1087+ X = X , posterior_transform = self .posterior_transform
1088+ )
1089+ # posterior.mean has shape (b1 x ... x bk) x q x m
1090+ # squeeze(-1) removes m (should be 1), giving (b1 x ... x bk) x q
1091+ mean = posterior .mean .squeeze (- 1 )
1092+ # @ self.weights: (b1 x ... x bk) x q @ q -> (b1 x ... x bk)
1093+ return mean @ self .weights
10871094
10881095
10891096class PosteriorStandardDeviation (AnalyticAcquisitionFunction ):
You can’t perform that action at this time.
0 commit comments