Skip to content

Commit e7ec084

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix sample_points_around_best when using 20 dimensional inputs or prob_perturb (#1290)
Summary: Pull Request resolved: #1290 When using 20 dimensional inputs, `sample_points_around_best` would return `n // 2` samples, which would lead to errors when reshaping it in `gen_batch_initial_conditions`. Similarly, if one were to use `prob_perturb` with <20 dimensions, they would get `n + n // 2` samples. This changes the two `if` statements that control this to use the same conditional, eliminating the bug. Reviewed By: Balandat Differential Revision: D37705879 fbshipit-source-id: b6716dc934c3760cc00bd9d0c74305d6bf05ea27
1 parent 50b6d61 commit e7ec084

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

botorch/optim/initializers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,16 +683,17 @@ def sample_points_around_best(
683683
# the view() is to ensure that best_idcs is not a scalar tensor
684684
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
685685
best_X = X[best_idcs]
686+
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
686687
n_trunc_normal_points = (
687-
n_discrete_points // 2 if best_X.shape[-1] >= 20 else n_discrete_points
688+
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
688689
)
689690
perturbed_X = sample_truncated_normal_perturbations(
690691
X=best_X,
691692
n_discrete_points=n_trunc_normal_points,
692693
sigma=sigma,
693694
bounds=bounds,
694695
)
695-
if best_X.shape[-1] > 20 or prob_perturb is not None:
696+
if use_perturbed_sampling:
696697
perturbed_subset_dims_X = sample_perturbed_subset_dims(
697698
X=best_X,
698699
bounds=bounds,

test/optim/test_initializers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,12 +710,12 @@ def test_sample_points_around_best(self):
710710
self.assertEqual(eq_mask[idcs].sum(), 4)
711711
self.assertTrue((X_rnd >= 1).all())
712712
self.assertTrue((X_rnd <= 2).all())
713-
# test that subset_dims is called if d>=21
714-
X_train = 1 + torch.rand(20, 21, **tkwargs)
713+
# test that subset_dims is called if d>=20
714+
X_train = 1 + torch.rand(10, 20, **tkwargs)
715715
model = MockModel(
716716
MockPosterior(mean=(2 * X_train + 1).sum(dim=-1, keepdim=True))
717717
)
718-
bounds = torch.ones(2, 21, **tkwargs)
718+
bounds = torch.ones(2, 20, **tkwargs)
719719
bounds[1] = 2
720720
# test NEI with X_baseline
721721
acqf = qNoisyExpectedImprovement(
@@ -728,7 +728,7 @@ def test_sample_points_around_best(self):
728728
X_rnd = sample_points_around_best(
729729
acq_function=acqf, n_discrete_points=5, sigma=1e-3, bounds=bounds
730730
)
731-
self.assertTrue(X_rnd.shape, torch.Size([5, 2]))
731+
self.assertEqual(X_rnd.shape, torch.Size([5, 20]))
732732
self.assertTrue((X_rnd >= 1).all())
733733
self.assertTrue((X_rnd <= 2).all())
734734
mock_subset_dims.assert_called_once()

0 commit comments

Comments
 (0)