Skip to content

Commit

Permalink
Fix sample_points_around_best when using 20 dimensional inputs or `…
Browse files Browse the repository at this point in the history
…prob_perturb` (#1290)

Summary:
Pull Request resolved: #1290

When using 20 dimensional inputs, `sample_points_around_best` would return `n // 2` samples, which would lead to errors when reshaping it in `gen_batch_initial_conditions`. Similarly, if one were to use `prob_perturb` with <20 dimensions, they would get `n + n // 2` samples. This changes the two `if` statements that control this to use the same conditional, eliminating the bug.

Reviewed By: Balandat

Differential Revision: D37705879

fbshipit-source-id: b6716dc934c3760cc00bd9d0c74305d6bf05ea27
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jul 11, 2022
1 parent 50b6d61 commit e7ec084
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,17 @@ def sample_points_around_best(
# the view() is to ensure that best_idcs is not a scalar tensor
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
best_X = X[best_idcs]
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
n_trunc_normal_points = (
n_discrete_points // 2 if best_X.shape[-1] >= 20 else n_discrete_points
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
)
perturbed_X = sample_truncated_normal_perturbations(
X=best_X,
n_discrete_points=n_trunc_normal_points,
sigma=sigma,
bounds=bounds,
)
if best_X.shape[-1] > 20 or prob_perturb is not None:
if use_perturbed_sampling:
perturbed_subset_dims_X = sample_perturbed_subset_dims(
X=best_X,
bounds=bounds,
Expand Down
8 changes: 4 additions & 4 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,12 +710,12 @@ def test_sample_points_around_best(self):
self.assertEqual(eq_mask[idcs].sum(), 4)
self.assertTrue((X_rnd >= 1).all())
self.assertTrue((X_rnd <= 2).all())
# test that subset_dims is called if d>=21
X_train = 1 + torch.rand(20, 21, **tkwargs)
# test that subset_dims is called if d>=20
X_train = 1 + torch.rand(10, 20, **tkwargs)
model = MockModel(
MockPosterior(mean=(2 * X_train + 1).sum(dim=-1, keepdim=True))
)
bounds = torch.ones(2, 21, **tkwargs)
bounds = torch.ones(2, 20, **tkwargs)
bounds[1] = 2
# test NEI with X_baseline
acqf = qNoisyExpectedImprovement(
Expand All @@ -728,7 +728,7 @@ def test_sample_points_around_best(self):
X_rnd = sample_points_around_best(
acq_function=acqf, n_discrete_points=5, sigma=1e-3, bounds=bounds
)
self.assertTrue(X_rnd.shape, torch.Size([5, 2]))
self.assertEqual(X_rnd.shape, torch.Size([5, 20]))
self.assertTrue((X_rnd >= 1).all())
self.assertTrue((X_rnd <= 2).all())
mock_subset_dims.assert_called_once()
Expand Down

0 comments on commit e7ec084

Please sign in to comment.