diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index c2481ac9ac..baaf99bc1a 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -683,8 +683,9 @@ def sample_points_around_best( # the view() is to ensure that best_idcs is not a scalar tensor best_idcs = torch.topk(f_pred, n_best).indices.view(-1) best_X = X[best_idcs] + use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None n_trunc_normal_points = ( - n_discrete_points // 2 if best_X.shape[-1] >= 20 else n_discrete_points + n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points ) perturbed_X = sample_truncated_normal_perturbations( X=best_X, @@ -692,7 +693,7 @@ def sample_points_around_best( sigma=sigma, bounds=bounds, ) - if best_X.shape[-1] > 20 or prob_perturb is not None: + if use_perturbed_sampling: perturbed_subset_dims_X = sample_perturbed_subset_dims( X=best_X, bounds=bounds, diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 97b8b2bb48..ea94de2caa 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -710,12 +710,12 @@ def test_sample_points_around_best(self): self.assertEqual(eq_mask[idcs].sum(), 4) self.assertTrue((X_rnd >= 1).all()) self.assertTrue((X_rnd <= 2).all()) - # test that subset_dims is called if d>=21 - X_train = 1 + torch.rand(20, 21, **tkwargs) + # test that subset_dims is called if d>=20 + X_train = 1 + torch.rand(10, 20, **tkwargs) model = MockModel( MockPosterior(mean=(2 * X_train + 1).sum(dim=-1, keepdim=True)) ) - bounds = torch.ones(2, 21, **tkwargs) + bounds = torch.ones(2, 20, **tkwargs) bounds[1] = 2 # test NEI with X_baseline acqf = qNoisyExpectedImprovement( @@ -728,7 +728,7 @@ def test_sample_points_around_best(self): X_rnd = sample_points_around_best( acq_function=acqf, n_discrete_points=5, sigma=1e-3, bounds=bounds ) - self.assertTrue(X_rnd.shape, torch.Size([5, 2])) + self.assertEqual(X_rnd.shape, torch.Size([5, 20])) self.assertTrue((X_rnd >= 1).all()) self.assertTrue((X_rnd <= 2).all()) mock_subset_dims.assert_called_once()