Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ax/adapter/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from ax.generators.discrete.eb_thompson import EmpiricalBayesThompsonSampler
from ax.generators.discrete.full_factorial import FullFactorialGenerator
from ax.generators.discrete.thompson import ThompsonSampler
from ax.generators.random.in_sample import InSampleUniformGenerator
from ax.generators.random.sobol import SobolGenerator
from ax.generators.random.uniform import UniformGenerator
from ax.generators.torch.botorch_modular.generator import (
Expand Down Expand Up @@ -215,6 +216,11 @@ class GeneratorSetup(NamedTuple):
generator_class=UniformGenerator,
transforms=Cont_X_trans,
),
"InSampleUniform": GeneratorSetup(
adapter_class=RandomAdapter,
generator_class=InSampleUniformGenerator,
transforms=Cont_X_trans,
),
"ST_MTGP": GeneratorSetup(
adapter_class=TorchAdapter,
generator_class=ModularBoTorchGenerator,
Expand Down Expand Up @@ -454,6 +460,7 @@ class Generators(GeneratorRegistryBase):
EMPIRICAL_BAYES_THOMPSON = "EB"
EB_ASHR = "EB_Ashr"
UNIFORM = "Uniform"
IN_SAMPLE_UNIFORM = "InSampleUniform"
ST_MTGP = "ST_MTGP"
BO_MIXED = "BO_MIXED"

Expand Down
4 changes: 1 addition & 3 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,11 +1880,9 @@ def supports_trial_type(self, trial_type: str | None) -> bool:
"""
return (
trial_type is None
# We temporarily allow "short run" and "long run" trial
# types in single-type experiments during development of
# a new ``GenerationStrategy`` that needs them.
or trial_type == Keys.SHORT_RUN
or trial_type == Keys.LONG_RUN
or trial_type == Keys.LILO_LABELING
)

def attach_trial(
Expand Down
12 changes: 12 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ def test_basic_batch_creation(self) -> None:
new_exp = get_experiment()
new_exp._attach_trial(batch)

def test_supports_trial_type(self) -> None:
exp = get_experiment()
self.assertTrue(exp.supports_trial_type(None))
self.assertTrue(exp.supports_trial_type(Keys.SHORT_RUN))
self.assertTrue(exp.supports_trial_type(Keys.LONG_RUN))
self.assertTrue(exp.supports_trial_type(Keys.LILO_LABELING))
self.assertFalse(exp.supports_trial_type("unsupported_type"))

# Verify LILO_LABELING trial type works with new_batch_trial
batch = exp.new_batch_trial(trial_type=Keys.LILO_LABELING)
self.assertEqual(batch.trial_type, Keys.LILO_LABELING)

def test_repr(self) -> None:
self.assertEqual("Experiment(test)", str(self.experiment))

Expand Down
10 changes: 6 additions & 4 deletions ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ def __init__(
)
if len(generator_specs) > 1 and best_model_selector is None:
raise UserInputError(MISSING_MODEL_SELECTOR_MESSAGE)
if trial_type is not None and (
trial_type != Keys.SHORT_RUN and trial_type != Keys.LONG_RUN
if trial_type is not None and trial_type not in (
Keys.SHORT_RUN,
Keys.LONG_RUN,
Keys.LILO_LABELING,
):
raise NotImplementedError(
f"Trial type must be either {Keys.SHORT_RUN} or {Keys.LONG_RUN},"
f" got {trial_type}."
f"Trial type must be one of {Keys.SHORT_RUN}, {Keys.LONG_RUN},"
f" or {Keys.LILO_LABELING}, got {trial_type}."
)
# If possible, assign `_generator_spec_to_gen_from` right away, for use in
# `__repr__`
Expand Down
8 changes: 7 additions & 1 deletion ax/generation_strategy/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_input_constructor_none(self) -> None:
self.assertEqual(self.sobol_generation_node.input_constructors, {})

def test_incorrect_trial_type(self) -> None:
with self.assertRaisesRegex(NotImplementedError, "Trial type must be either"):
with self.assertRaisesRegex(NotImplementedError, "Trial type must be one of"):
GenerationNode(
name="test",
generator_specs=[self.sobol_generator_spec],
Expand All @@ -147,12 +147,18 @@ def test_init_with_trial_type(self) -> None:
generator_specs=[self.sobol_generator_spec],
trial_type=Keys.LONG_RUN,
)
node_lilo = GenerationNode(
name="test",
generator_specs=[self.sobol_generator_spec],
trial_type=Keys.LILO_LABELING,
)
node_default = GenerationNode(
name="test",
generator_specs=[self.sobol_generator_spec],
)
self.assertEqual(self.node_short._trial_type, Keys.SHORT_RUN)
self.assertEqual(node_long._trial_type, Keys.LONG_RUN)
self.assertEqual(node_lilo._trial_type, Keys.LILO_LABELING)
self.assertIsNone(node_default._trial_type)

def test_input_constructor(self) -> None:
Expand Down
83 changes: 83 additions & 0 deletions ax/generators/random/in_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections.abc import Callable

import numpy as np
import numpy.typing as npt
from ax.core.search_space import SearchSpaceDigest
from ax.generators.random.base import RandomGenerator
from ax.generators.types import TConfig


class InSampleUniformGenerator(RandomGenerator):
"""Randomly select candidates from existing experiment arms.

Selects n arms uniformly at random without replacement from the
``generated_points`` array passed by the adapter. This array contains
the in-design, non-failed arms on the experiment (deduplicated).

Used for model-free candidate selection in use cases like LILO
(Language-in-the-Loop Optimization), where a labeling node needs
to randomly select previously observed configurations without
fitting any surrogate model.

See base ``RandomGenerator`` for a description of model attributes.
"""

def gen(
self,
n: int,
search_space_digest: SearchSpaceDigest,
linear_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
fixed_features: dict[int, float] | None = None,
model_gen_options: TConfig | None = None,
rounding_func: Callable[[npt.NDArray], npt.NDArray] | None = None,
generated_points: npt.NDArray | None = None,
) -> tuple[npt.NDArray, npt.NDArray]:
"""Select n candidates from ``generated_points``.

Args:
n: Number of candidates to select.
search_space_digest: A ``SearchSpaceDigest`` object containing
metadata on the features in the datasets.
linear_constraints: Not used. Accepted for interface compatibility.
fixed_features: Not used. Accepted for interface compatibility.
model_gen_options: Not used. Accepted for interface compatibility.
rounding_func: Not used. Accepted for interface compatibility.
generated_points: A numpy array of shape ``(num_arms, d)`` containing
the existing experiment arms to select from. Constructed by the
adapter from in-design, non-failed arms (deduplicated).

Returns:
2-element tuple containing

- ``(n, d)`` array of selected points.
- Uniform weights, an n-array of ones.

Raises:
ValueError: If ``generated_points`` is None or has fewer than
``n`` rows.
"""
available = 0 if generated_points is None else len(generated_points)
if generated_points is None or available < n:
raise ValueError(
f"Cannot select {n} arms: only {available} eligible "
f"arms available on the experiment."
)

rng = np.random.default_rng(seed=self.seed + self.init_position)
indices = rng.choice(len(generated_points), size=n, replace=False)
self.init_position += n
return generated_points[indices], np.ones(n)

def _gen_samples(self, n: int, tunable_d: int, bounds: npt.NDArray) -> npt.NDArray:
raise NotImplementedError(
"InSampleUniformGenerator selects from existing points "
"and does not generate new samples."
)
123 changes: 123 additions & 0 deletions ax/generators/tests/test_in_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import numpy as np
from ax.core.search_space import SearchSpaceDigest
from ax.generators.random.in_sample import InSampleUniformGenerator
from ax.utils.common.testutils import TestCase


class InSampleUniformGeneratorTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.generated_points = np.array(
[
[0.1, 0.2],
[0.3, 0.4],
[0.5, 0.6],
[0.7, 0.8],
[0.9, 1.0],
]
)
self.ssd = SearchSpaceDigest(
feature_names=["x0", "x1"],
bounds=[(0.0, 1.0), (0.0, 1.0)],
)

def test_basic_selection(self) -> None:
generator = InSampleUniformGenerator(seed=0)
points, weights = generator.gen(
n=2,
search_space_digest=self.ssd,
generated_points=self.generated_points,
)
self.assertEqual(points.shape, (2, 2))
self.assertTrue(np.all(weights == 1.0))
# Each selected row must be present in the original set.
for row in points:
self.assertTrue(
any(np.array_equal(row, gp) for gp in self.generated_points)
)

def test_selects_all(self) -> None:
"""Selecting all points should return all of them (in some order)."""
generator = InSampleUniformGenerator(seed=0)
points, weights = generator.gen(
n=5,
search_space_digest=self.ssd,
generated_points=self.generated_points,
)
self.assertEqual(points.shape, (5, 2))
self.assertTrue(np.all(weights == 1.0))
# Should be a permutation of the input.
self.assertEqual(
{tuple(row) for row in points.tolist()},
{tuple(row) for row in self.generated_points.tolist()},
)

def test_not_enough_points(self) -> None:
generator = InSampleUniformGenerator(seed=0)
with self.assertRaisesRegex(ValueError, "Cannot select 6 arms"):
generator.gen(
n=6,
search_space_digest=self.ssd,
generated_points=self.generated_points,
)

def test_no_generated_points(self) -> None:
generator = InSampleUniformGenerator(seed=0)
with self.assertRaisesRegex(ValueError, "Cannot select 1 arms: only 0"):
generator.gen(
n=1,
search_space_digest=self.ssd,
generated_points=None,
)

def test_reproducibility(self) -> None:
"""Same seed and init_position produce the same selection."""
gen1 = InSampleUniformGenerator(seed=42)
gen2 = InSampleUniformGenerator(seed=42)
points1, _ = gen1.gen(
n=2,
search_space_digest=self.ssd,
generated_points=self.generated_points,
)
points2, _ = gen2.gen(
n=2,
search_space_digest=self.ssd,
generated_points=self.generated_points,
)
self.assertTrue(np.array_equal(points1, points2))

def test_different_selections_across_calls(self) -> None:
"""Successive calls produce different selections (init_position advances)."""
generator = InSampleUniformGenerator(seed=0)
points1, _ = generator.gen(
n=2,
search_space_digest=self.ssd,
generated_points=self.generated_points,
)
self.assertEqual(generator.init_position, 2)
points2, _ = generator.gen(
n=2,
search_space_digest=self.ssd,
generated_points=self.generated_points,
)
self.assertEqual(generator.init_position, 4)
# With 5 points and n=2, different seeds should (almost surely)
# produce different selections.
self.assertFalse(np.array_equal(points1, points2))

def test_gen_samples_raises(self) -> None:
generator = InSampleUniformGenerator()
with self.assertRaises(NotImplementedError):
generator._gen_samples(
n=1,
tunable_d=2,
bounds=np.array([[0.0, 1.0], [0.0, 1.0]]),
)
1 change: 1 addition & 0 deletions ax/utils/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Keys(StrEnum):
FRAC_RANDOM = "frac_random"
FULL_PARAMETERIZATION = "full_parameterization"
IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = "immutable_search_space_and_opt_config"
LILO_LABELING = "lilo_labeling"
LLM_MESSAGES = "llm_messages"
LONG_RUN = "long_run"
MAXIMIZE = "maximize"
Expand Down