Skip to content

Commit 274bd47

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Fix inconsistent output shapes when all features are fixed in optimize_acqf (#3241)
Summary: When all features are fixed, _optimize_acqf_all_features_fixed now returns shapes consistent with the normal _optimize_acqf_batch path: - return_best_only=True: candidates (q, d), acq_value scalar - return_best_only=False: candidates (1, q, d), acq_value (1,) Closes #2740 Pull Request resolved: #3241 Reviewed By: dme65 Differential Revision: D97315513 Pulled By: saitcakmak fbshipit-source-id: 28a44486f72706ccdf77d3a77faa5cce6ec9ef69
1 parent 1d7bbdb commit 274bd47

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

botorch/optim/optimize.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def _optimize_acqf_all_features_fixed(
198198
fixed_features: dict[int, float],
199199
q: int,
200200
acq_function: AcquisitionFunction,
201+
return_best_only: bool = True,
201202
return_acq_values: bool = True,
202203
) -> tuple[Tensor, Tensor | None]:
203204
"""
@@ -210,10 +211,21 @@ def _optimize_acqf_all_features_fixed(
210211
dtype=bounds.dtype,
211212
)
212213
X = X.expand(q, *X.shape)
214+
if not return_best_only:
215+
# When return_best_only=False, candidates have shape
216+
# `num_restarts x q x d`. With all features fixed there is only one
217+
# candidate, so num_restarts=1.
218+
X = X.unsqueeze(0)
213219
if not return_acq_values:
214220
return X, None
215221
with torch.no_grad():
216-
acq_value = acq_function(X)
222+
acq_value = acq_function(X if return_best_only else X.squeeze(0))
223+
# Ensure acq_value is a scalar (consistent with return_best_only=True)
224+
# or 1-d with shape `(1,)` (consistent with return_best_only=False).
225+
if return_best_only:
226+
acq_value = acq_value.squeeze()
227+
else:
228+
acq_value = acq_value.view(1)
217229
return X, acq_value
218230

219231

@@ -810,6 +822,7 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor | Non
810822
fixed_features=opt_inputs.fixed_features,
811823
q=opt_inputs.q,
812824
acq_function=opt_inputs.acq_function,
825+
return_best_only=opt_inputs.return_best_only,
813826
return_acq_values=opt_inputs.return_acq_values,
814827
)
815828

test/optim/test_optimize.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,36 @@ def test_optimize_acqf_return_acq_values(
407407
)
408408
)
409409

410+
# All features fixed: return shape consistency with normal path.
411+
# return_best_only=True (default): candidates (q, d), acq_value scalar
412+
candidates_rbo, acq_rbo = optimize_acqf(
413+
acq_function=mock_acq_function,
414+
bounds=bounds,
415+
q=1,
416+
num_restarts=num_restarts,
417+
raw_samples=raw_samples,
418+
options=options,
419+
fixed_features=fixed_all,
420+
return_best_only=True,
421+
)
422+
self.assertEqual(candidates_rbo.shape, (1, 3))
423+
self.assertEqual(acq_rbo.shape, torch.Size([]))
424+
425+
# return_best_only=False: candidates (num_restarts, q, d),
426+
# acq_value (num_restarts,)
427+
candidates_all, acq_all = optimize_acqf(
428+
acq_function=mock_acq_function,
429+
bounds=bounds,
430+
q=1,
431+
num_restarts=num_restarts,
432+
raw_samples=raw_samples,
433+
options=options,
434+
fixed_features=fixed_all,
435+
return_best_only=False,
436+
)
437+
self.assertEqual(candidates_all.shape, (1, 1, 3))
438+
self.assertEqual(acq_all.shape, (1,))
439+
410440
# Sequential path: return_acq_values=True and return_acq_values=False
411441
mock_gen_candidates_scipy.return_value = (
412442
torch.rand(1, 1, 3, device=self.device, dtype=torch.double),

0 commit comments

Comments
 (0)