Skip to content

Commit efd72a6

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Clean up sampler usage in pick_best_out_of_sample_point_acqf_class (facebook#5095)
Summary: Pull Request resolved: facebook#5095 This was generating a SobolQMCSampler regardless of the model it is used with. SobolQMCSampler does not support PosteriorList, which was leading to issues with LILO integration. This diff cleans up the generated options in `pick_best_out_of_sample_point_acqf_class`, since BoTorch (in `MCSamplerMixin.get_posterior_samples`) will dispatch to proper defaults without them. This led to the `options` input becoming unused, so I removed those rather than having them get ignored silently. Reviewed By: hvarfner Differential Revision: D97956036 fbshipit-source-id: e363c433d4769488f9225fc2ecca6395442638c2
1 parent 09eadfe commit efd72a6

4 files changed

Lines changed: 10 additions & 42 deletions

File tree

ax/generators/torch/botorch_modular/surrogate.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,8 @@
5050
from ax.generators.types import TConfig
5151
from ax.generators.utils import best_in_sample_point
5252
from ax.utils.common.base import Base
53-
from ax.utils.common.constants import Keys
5453
from ax.utils.common.logger import get_logger
55-
from ax.utils.common.typeutils import (
56-
_argparse_type_encoder,
57-
assert_is_instance_optional,
58-
)
54+
from ax.utils.common.typeutils import _argparse_type_encoder
5955
from ax.utils.stats.model_fit_stats import (
6056
DIAGNOSTIC_FN_DIRECTIONS,
6157
DIAGNOSTIC_FNS,
@@ -1023,7 +1019,6 @@ def best_out_of_sample_point(
10231019
self,
10241020
search_space_digest: SearchSpaceDigest,
10251021
torch_opt_config: TorchOptConfig,
1026-
options: TConfig | None = None,
10271022
) -> tuple[Tensor, Tensor]:
10281023
"""Finds the best predicted point and the corresponding value of the
10291024
appropriate best point acquisition function.
@@ -1032,9 +1027,6 @@ def best_out_of_sample_point(
10321027
search_space_digest: A `SearchSpaceDigest`.
10331028
torch_opt_config: A `TorchOptConfig`; none-None `fixed_features` is
10341029
not supported.
1035-
options: Optional. If present, `seed_inner` (default None) and `qmc`
1036-
(default True) will be parsed from `options`; any other keys
1037-
will be ignored.
10381030
10391031
Returns:
10401032
A two-tuple (`candidate`, `acqf_value`), where `candidate` is a 1d
@@ -1048,18 +1040,8 @@ def best_out_of_sample_point(
10481040
# TODO (ref: https://fburl.com/diff/uneqb3n9)
10491041
raise NotImplementedError("Fixed features not yet supported.")
10501042

1051-
options = options or {}
1052-
botorch_acqf_class, botorch_acqf_options = (
1053-
pick_best_out_of_sample_point_acqf_class(
1054-
outcome_constraints=torch_opt_config.outcome_constraints,
1055-
seed_inner=assert_is_instance_optional(
1056-
options.get(Keys.SEED_INNER, None), int
1057-
),
1058-
qmc=assert_is_instance(
1059-
options.get(Keys.QMC, True),
1060-
bool,
1061-
),
1062-
)
1043+
botorch_acqf_class = pick_best_out_of_sample_point_acqf_class(
1044+
outcome_constraints=torch_opt_config.outcome_constraints,
10631045
)
10641046

10651047
# Avoiding circular import between `Surrogate` and `Acquisition`.
@@ -1070,7 +1052,6 @@ def best_out_of_sample_point(
10701052
botorch_acqf_class=botorch_acqf_class,
10711053
search_space_digest=search_space_digest,
10721054
torch_opt_config=torch_opt_config,
1073-
botorch_acqf_options=botorch_acqf_options,
10741055
)
10751056
candidates, acqf_value, _ = acqf.optimize(
10761057
n=1,

ax/generators/torch/tests/test_surrogate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1652,7 +1652,6 @@ def test_best_out_of_sample_point(self) -> None:
16521652
candidate, acqf_value = surrogate.best_out_of_sample_point(
16531653
search_space_digest=self.search_space_digest,
16541654
torch_opt_config=torch_opt_config,
1655-
options=self.options,
16561655
)
16571656
candidate_in_bounds = all(
16581657
((x >= b[0]) & (x <= b[1]) for x, b in zip(candidate, self.bounds))

ax/generators/torch/tests/test_utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,17 +1471,14 @@ def test_choose_botorch_acqf_class_with_map_key_fixing(self) -> None:
14711471
self.assertEqual(acqf_class, qLogNoisyExpectedImprovement)
14721472

14731473
def test_pick_best_out_of_sample_point_acqf_class(self) -> None:
1474-
# Unconstrained: PosteriorMean with no options.
1475-
acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class(
1474+
# Unconstrained: PosteriorMean.
1475+
acqf_class = pick_best_out_of_sample_point_acqf_class(
14761476
outcome_constraints=None,
14771477
)
14781478
self.assertEqual(acqf_class, PosteriorMean)
1479-
self.assertEqual(acqf_options, {})
14801479

1481-
# Constrained: qSimpleRegret with no explicit sampler (get_sampler
1482-
# auto-dispatches, which handles PosteriorList correctly).
1483-
acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class(
1480+
# Constrained: qSimpleRegret.
1481+
acqf_class = pick_best_out_of_sample_point_acqf_class(
14841482
outcome_constraints=(torch.tensor([[1.0, 0.0]]), torch.tensor([[0.5]])),
14851483
)
14861484
self.assertEqual(acqf_class, qSimpleRegret)
1487-
self.assertEqual(acqf_options, {})

ax/generators/torch/utils.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import logging
1010
from collections.abc import Callable, Sequence
1111
from dataclasses import dataclass
12-
from typing import Any, cast
1312

1413
import numpy.typing as npt
1514
import torch
@@ -404,18 +403,10 @@ def objective(samples: Tensor, X: Tensor | None = None) -> Tensor:
404403

405404
def pick_best_out_of_sample_point_acqf_class(
406405
outcome_constraints: tuple[Tensor, Tensor] | None = None,
407-
mc_samples: int = 512,
408-
qmc: bool = True,
409-
seed_inner: int | None = None,
410-
) -> tuple[type[AcquisitionFunction], dict[str, Any]]:
406+
) -> type[AcquisitionFunction]:
411407
if outcome_constraints is None:
412-
acqf_class = PosteriorMean
413-
acqf_options = {}
414-
else:
415-
acqf_class = qSimpleRegret
416-
acqf_options = {}
417-
418-
return cast(type[AcquisitionFunction], acqf_class), acqf_options
408+
return PosteriorMean
409+
return qSimpleRegret
419410

420411

421412
def predict_from_model(

0 commit comments

Comments
 (0)