From e7ec08456709a7f9aaa51548bbe90596eff6a295 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 11 Jul 2022 13:20:28 -0700 Subject: [PATCH] Fix `sample_points_around_best` when using 20 dimensional inputs or `prob_perturb` (#1290) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1290 When using 20 dimensional inputs, `sample_points_around_best` would return `n // 2` samples, which would lead to errors when reshaping it in `gen_batch_initial_conditions`. Similarly, if one were to use `prob_perturb` with <20 dimensions, they would get `n + n // 2` samples. This changes the two `if` statements that control this to use the same conditional, eliminating the bug. Reviewed By: Balandat Differential Revision: D37705879 fbshipit-source-id: b6716dc934c3760cc00bd9d0c74305d6bf05ea27 --- botorch/optim/initializers.py | 5 +++-- test/optim/test_initializers.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index c2481ac9ac..baaf99bc1a 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -683,8 +683,9 @@ def sample_points_around_best( # the view() is to ensure that best_idcs is not a scalar tensor best_idcs = torch.topk(f_pred, n_best).indices.view(-1) best_X = X[best_idcs] + use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None n_trunc_normal_points = ( - n_discrete_points // 2 if best_X.shape[-1] >= 20 else n_discrete_points + n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points ) perturbed_X = sample_truncated_normal_perturbations( X=best_X, @@ -692,7 +693,7 @@ def sample_points_around_best( sigma=sigma, bounds=bounds, ) - if best_X.shape[-1] > 20 or prob_perturb is not None: + if use_perturbed_sampling: perturbed_subset_dims_X = sample_perturbed_subset_dims( X=best_X, bounds=bounds, diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 97b8b2bb48..ea94de2caa 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -710,12 +710,12 @@ def test_sample_points_around_best(self): self.assertEqual(eq_mask[idcs].sum(), 4) self.assertTrue((X_rnd >= 1).all()) self.assertTrue((X_rnd <= 2).all()) - # test that subset_dims is called if d>=21 - X_train = 1 + torch.rand(20, 21, **tkwargs) + # test that subset_dims is called if d>=20 + X_train = 1 + torch.rand(10, 20, **tkwargs) model = MockModel( MockPosterior(mean=(2 * X_train + 1).sum(dim=-1, keepdim=True)) ) - bounds = torch.ones(2, 21, **tkwargs) + bounds = torch.ones(2, 20, **tkwargs) bounds[1] = 2 # test NEI with X_baseline acqf = qNoisyExpectedImprovement( @@ -728,7 +728,7 @@ def test_sample_points_around_best(self): X_rnd = sample_points_around_best( acq_function=acqf, n_discrete_points=5, sigma=1e-3, bounds=bounds ) - self.assertTrue(X_rnd.shape, torch.Size([5, 2])) + self.assertEqual(X_rnd.shape, torch.Size([5, 20])) self.assertTrue((X_rnd >= 1).all()) self.assertTrue((X_rnd <= 2).all()) mock_subset_dims.assert_called_once()