Skip to content

Commit a89122d

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Add InSampleUniformGenerator for model-free in-sample candidate selection (facebook#4987)
Summary: Pull Request resolved: facebook#4987 Add `InSampleUniformGenerator`, a `RandomGenerator` subclass that randomly selects `n` arms (without replacement) from existing experiment arms. This replaces the previous `in_sample` mode on `RandomAdapter` (which bypassed the generator entirely) with a proper generator class, following bletham's review feedback on the original diff. The generator overrides `gen()` to select from the `generated_points` array that the adapter already constructs from in-design, non-failed experiment arms (filtered, transformed, and deduplicated). This reuses existing infrastructure without adding new interface surface. Registered as `Generators.IN_SAMPLE_UNIFORM` in the adapter registry with `RandomAdapter` and `Cont_X_trans`, matching the pattern of other random generators (Sobol, Uniform). The resulting user-facing matrix is clean: | | Out-of-sample | In-sample | |---|---|---| | **Model-free** | `Generators.SOBOL` | `Generators.IN_SAMPLE_UNIFORM` | | **Model-based** | `Generators.BOTORCH_MODULAR` | `Generators.BOTORCH_MODULAR` + `model_gen_options={"in_sample": True}` | The asymmetry (enum swap for model-free, flag for model-based) reflects the real architectural difference: model-based in/out-of-sample share a fitted model; model-free in/out-of-sample share nothing. Reviewed By: bletham, saitcakmak Differential Revision: D94973263 fbshipit-source-id: 1cbaa958cf0514caf6b9d91f738e39f2c2cb516b
1 parent 6ce2e7c commit a89122d

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

ax/adapter/registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from ax.generators.discrete.eb_thompson import EmpiricalBayesThompsonSampler
5757
from ax.generators.discrete.full_factorial import FullFactorialGenerator
5858
from ax.generators.discrete.thompson import ThompsonSampler
59+
from ax.generators.random.in_sample import InSampleUniformGenerator
5960
from ax.generators.random.sobol import SobolGenerator
6061
from ax.generators.random.uniform import UniformGenerator
6162
from ax.generators.torch.botorch_modular.generator import (
@@ -215,6 +216,11 @@ class GeneratorSetup(NamedTuple):
215216
generator_class=UniformGenerator,
216217
transforms=Cont_X_trans,
217218
),
219+
"InSampleUniform": GeneratorSetup(
220+
adapter_class=RandomAdapter,
221+
generator_class=InSampleUniformGenerator,
222+
transforms=Cont_X_trans,
223+
),
218224
"ST_MTGP": GeneratorSetup(
219225
adapter_class=TorchAdapter,
220226
generator_class=ModularBoTorchGenerator,
@@ -454,6 +460,7 @@ class Generators(GeneratorRegistryBase):
454460
EMPIRICAL_BAYES_THOMPSON = "EB"
455461
EB_ASHR = "EB_Ashr"
456462
UNIFORM = "Uniform"
463+
IN_SAMPLE_UNIFORM = "InSampleUniform"
457464
ST_MTGP = "ST_MTGP"
458465
BO_MIXED = "BO_MIXED"
459466

ax/generators/random/in_sample.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from collections.abc import Callable
10+
11+
import numpy as np
12+
import numpy.typing as npt
13+
from ax.core.search_space import SearchSpaceDigest
14+
from ax.generators.random.base import RandomGenerator
15+
from ax.generators.types import TConfig
16+
17+
18+
class InSampleUniformGenerator(RandomGenerator):
19+
"""Randomly select candidates from existing experiment arms.
20+
21+
Selects n arms uniformly at random without replacement from the
22+
``generated_points`` array passed by the adapter. This array contains
23+
the in-design, non-failed arms on the experiment (deduplicated).
24+
25+
Used for model-free candidate selection in use cases like LILO
26+
(Language-in-the-Loop Optimization), where a labeling node needs
27+
to randomly select previously observed configurations without
28+
fitting any surrogate model.
29+
30+
See base ``RandomGenerator`` for a description of model attributes.
31+
"""
32+
33+
def gen(
34+
self,
35+
n: int,
36+
search_space_digest: SearchSpaceDigest,
37+
linear_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
38+
fixed_features: dict[int, float] | None = None,
39+
model_gen_options: TConfig | None = None,
40+
rounding_func: Callable[[npt.NDArray], npt.NDArray] | None = None,
41+
generated_points: npt.NDArray | None = None,
42+
) -> tuple[npt.NDArray, npt.NDArray]:
43+
"""Select n candidates from ``generated_points``.
44+
45+
Args:
46+
n: Number of candidates to select.
47+
search_space_digest: A ``SearchSpaceDigest`` object containing
48+
metadata on the features in the datasets.
49+
linear_constraints: Not used. Accepted for interface compatibility.
50+
fixed_features: Not used. Accepted for interface compatibility.
51+
model_gen_options: Not used. Accepted for interface compatibility.
52+
rounding_func: Not used. Accepted for interface compatibility.
53+
generated_points: A numpy array of shape ``(num_arms, d)`` containing
54+
the existing experiment arms to select from. Constructed by the
55+
adapter from in-design, non-failed arms (deduplicated).
56+
57+
Returns:
58+
2-element tuple containing
59+
60+
- ``(n, d)`` array of selected points.
61+
- Uniform weights, an n-array of ones.
62+
63+
Raises:
64+
ValueError: If ``generated_points`` is None or has fewer than
65+
``n`` rows.
66+
"""
67+
available = 0 if generated_points is None else len(generated_points)
68+
if generated_points is None or available < n:
69+
raise ValueError(
70+
f"Cannot select {n} arms: only {available} eligible "
71+
f"arms available on the experiment."
72+
)
73+
74+
rng = np.random.default_rng(seed=self.seed + self.init_position)
75+
indices = rng.choice(len(generated_points), size=n, replace=False)
76+
self.init_position += n
77+
return generated_points[indices], np.ones(n)
78+
79+
def _gen_samples(self, n: int, tunable_d: int, bounds: npt.NDArray) -> npt.NDArray:
80+
raise NotImplementedError(
81+
"InSampleUniformGenerator selects from existing points "
82+
"and does not generate new samples."
83+
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import numpy as np
10+
from ax.core.search_space import SearchSpaceDigest
11+
from ax.generators.random.in_sample import InSampleUniformGenerator
12+
from ax.utils.common.testutils import TestCase
13+
14+
15+
class InSampleUniformGeneratorTest(TestCase):
16+
def setUp(self) -> None:
17+
super().setUp()
18+
self.generated_points = np.array(
19+
[
20+
[0.1, 0.2],
21+
[0.3, 0.4],
22+
[0.5, 0.6],
23+
[0.7, 0.8],
24+
[0.9, 1.0],
25+
]
26+
)
27+
self.ssd = SearchSpaceDigest(
28+
feature_names=["x0", "x1"],
29+
bounds=[(0.0, 1.0), (0.0, 1.0)],
30+
)
31+
32+
def test_basic_selection(self) -> None:
33+
generator = InSampleUniformGenerator(seed=0)
34+
points, weights = generator.gen(
35+
n=2,
36+
search_space_digest=self.ssd,
37+
generated_points=self.generated_points,
38+
)
39+
self.assertEqual(points.shape, (2, 2))
40+
self.assertTrue(np.all(weights == 1.0))
41+
# Each selected row must be present in the original set.
42+
for row in points:
43+
self.assertTrue(
44+
any(np.array_equal(row, gp) for gp in self.generated_points)
45+
)
46+
47+
def test_selects_all(self) -> None:
48+
"""Selecting all points should return all of them (in some order)."""
49+
generator = InSampleUniformGenerator(seed=0)
50+
points, weights = generator.gen(
51+
n=5,
52+
search_space_digest=self.ssd,
53+
generated_points=self.generated_points,
54+
)
55+
self.assertEqual(points.shape, (5, 2))
56+
self.assertTrue(np.all(weights == 1.0))
57+
# Should be a permutation of the input.
58+
self.assertEqual(
59+
{tuple(row) for row in points.tolist()},
60+
{tuple(row) for row in self.generated_points.tolist()},
61+
)
62+
63+
def test_not_enough_points(self) -> None:
64+
generator = InSampleUniformGenerator(seed=0)
65+
with self.assertRaisesRegex(ValueError, "Cannot select 6 arms"):
66+
generator.gen(
67+
n=6,
68+
search_space_digest=self.ssd,
69+
generated_points=self.generated_points,
70+
)
71+
72+
def test_no_generated_points(self) -> None:
73+
generator = InSampleUniformGenerator(seed=0)
74+
with self.assertRaisesRegex(ValueError, "Cannot select 1 arms: only 0"):
75+
generator.gen(
76+
n=1,
77+
search_space_digest=self.ssd,
78+
generated_points=None,
79+
)
80+
81+
def test_reproducibility(self) -> None:
82+
"""Same seed and init_position produce the same selection."""
83+
gen1 = InSampleUniformGenerator(seed=42)
84+
gen2 = InSampleUniformGenerator(seed=42)
85+
points1, _ = gen1.gen(
86+
n=2,
87+
search_space_digest=self.ssd,
88+
generated_points=self.generated_points,
89+
)
90+
points2, _ = gen2.gen(
91+
n=2,
92+
search_space_digest=self.ssd,
93+
generated_points=self.generated_points,
94+
)
95+
self.assertTrue(np.array_equal(points1, points2))
96+
97+
def test_different_selections_across_calls(self) -> None:
98+
"""Successive calls produce different selections (init_position advances)."""
99+
generator = InSampleUniformGenerator(seed=0)
100+
points1, _ = generator.gen(
101+
n=2,
102+
search_space_digest=self.ssd,
103+
generated_points=self.generated_points,
104+
)
105+
self.assertEqual(generator.init_position, 2)
106+
points2, _ = generator.gen(
107+
n=2,
108+
search_space_digest=self.ssd,
109+
generated_points=self.generated_points,
110+
)
111+
self.assertEqual(generator.init_position, 4)
112+
# With 5 points and n=2, different seeds should (almost surely)
113+
# produce different selections.
114+
self.assertFalse(np.array_equal(points1, points2))
115+
116+
def test_gen_samples_raises(self) -> None:
117+
generator = InSampleUniformGenerator()
118+
with self.assertRaises(NotImplementedError):
119+
generator._gen_samples(
120+
n=1,
121+
tunable_d=2,
122+
bounds=np.array([[0.0, 1.0], [0.0, 1.0]]),
123+
)

0 commit comments

Comments
 (0)