Skip to content

Commit 3decde5

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
q-dim bugfix in ScalarizedPosteriorMean (#3191)
Summary: Pull Request resolved: #3191 Fixes ScalarizedPosteriorMean to correctly reduce over the q-dimension, ensuring consistent behavior with other analytic acquisition functions. Reviewed By: saitcakmak Differential Revision: D93691048 fbshipit-source-id: e0be83af71cef1cde1474b1e653a0ae8ec1600d2
1 parent 9f269fa commit 3decde5

1 file changed

Lines changed: 13 additions & 5 deletions

File tree

botorch/acquisition/analytic.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,12 +1078,20 @@ def forward(self, X: Tensor) -> Tensor:
10781078
t-batches of ``d``-dim design points each.
10791079
10801080
Returns:
1081-
A ``(b1 x ... x bk)``-dim Tensor of Posterior Mean values at the given
1082-
design points ``X``.
1081+
A ``(b1 x ... x bk)``-dim Tensor of scalarized Posterior Mean values
1082+
at the given design points ``X``.
10831083
"""
1084-
# (b1 x ... x bk) x q x 1
1085-
mean, _ = self._mean_and_sigma(X, compute_sigma=False)
1086-
return mean.squeeze(-1) @ self.weights
1084+
# ScalarizedPosteriorMean cannot use self._mean_and_sigma, since that squeezes
1085+
# the q-dim.
1086+
self.to(X) # Sync weights buffer to X's device/dtype
1087+
posterior = self.model.posterior(
1088+
X=X, posterior_transform=self.posterior_transform
1089+
)
1090+
# posterior.mean has shape (b1 x ... x bk) x q x m
1091+
# squeeze(-1) removes m (should be 1), giving (b1 x ... x bk) x q
1092+
mean = posterior.mean.squeeze(-1)
1093+
# @ self.weights: (b1 x ... x bk) x q @ q -> (b1 x ... x bk)
1094+
return mean @ self.weights
10871095

10881096

10891097
class PosteriorStandardDeviation(AnalyticAcquisitionFunction):

0 commit comments

Comments
 (0)