Skip to content

Commit 11bd2f6

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
q-dim Bugfix in ScalarizedPosteriorMean
Differential Revision: D92409564
1 parent 1855320 commit 11bd2f6

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

botorch/acquisition/analytic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,12 +1078,19 @@ def forward(self, X: Tensor) -> Tensor:
10781078
t-batches of ``d``-dim design points each.
10791079
10801080
Returns:
1081-
A ``(b1 x ... x bk)``-dim Tensor of Posterior Mean values at the given
1082-
design points ``X``.
1081+
A ``(b1 x ... x bk)``-dim Tensor of scalarized Posterior Mean values
1082+
at the given design points ``X``.
10831083
"""
1084-
# (b1 x ... x bk) x q x 1
1085-
mean, _ = self._mean_and_sigma(X, compute_sigma=False)
1086-
return mean.squeeze(-1) @ self.weights
1084+
# ScalarizedPosteriorMean cannot use self._mean_and_sigma, since that squeezes
1085+
# the q-dim.
1086+
posterior = self.model.posterior(
1087+
X=X, posterior_transform=self.posterior_transform
1088+
)
1089+
# posterior.mean has shape (b1 x ... x bk) x q x m
1090+
# squeeze(-1) removes m (should be 1), giving (b1 x ... x bk) x q
1091+
mean = posterior.mean.squeeze(-1)
1092+
# @ self.weights: (b1 x ... x bk) x q @ q -> (b1 x ... x bk)
1093+
return mean @ self.weights
10871094

10881095

10891096
class PosteriorStandardDeviation(AnalyticAcquisitionFunction):

0 commit comments

Comments
 (0)