diff --git a/botorch/acquisition/analytic.py b/botorch/acquisition/analytic.py index 7743d18652..cf4517fe31 100644 --- a/botorch/acquisition/analytic.py +++ b/botorch/acquisition/analytic.py @@ -1078,12 +1078,20 @@ def forward(self, X: Tensor) -> Tensor: t-batches of ``d``-dim design points each. Returns: - A ``(b1 x ... x bk)``-dim Tensor of Posterior Mean values at the given - design points ``X``. + A ``(b1 x ... x bk)``-dim Tensor of scalarized Posterior Mean values + at the given design points ``X``. """ - # (b1 x ... x bk) x q x 1 - mean, _ = self._mean_and_sigma(X, compute_sigma=False) - return mean.squeeze(-1) @ self.weights + # ScalarizedPosteriorMean cannot use self._mean_and_sigma, since that squeezes + # the q-dim. + self.to(X) # Sync weights buffer to X's device/dtype + posterior = self.model.posterior( + X=X, posterior_transform=self.posterior_transform + ) + # posterior.mean has shape (b1 x ... x bk) x q x m + # squeeze(-1) removes m (should be 1), giving (b1 x ... x bk) x q + mean = posterior.mean.squeeze(-1) + # @ self.weights: (b1 x ... x bk) x q @ q -> (b1 x ... x bk) + return mean @ self.weights class PosteriorStandardDeviation(AnalyticAcquisitionFunction):