Skip to content

Commit 218bf55

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Fix InSampleUniformGenerator selecting arms from CANDIDATE trials (#5109)
Summary: Pull Request resolved: #5109 When LILO labeling generates pairwise comparisons, `InSampleUniformGenerator` selects 2 arms uniformly at random from `generated_points` built by `RandomAdapter`. Previously, this pool included arms from CANDIDATE trials (which have no observed data) via two sources: 1. `arms_by_signature_for_deduplication` (all non-FAILED arms) 2. `pending_observations` (CANDIDATE/STAGED/RUNNING arms appended for dedup) When a CANDIDATE trial arm was selected, `LILOPairwiseMetric` could not find source metric data for it, causing `fetch_data` to raise an `ExceptionGroup`. **Fix:** In `RandomAdapter._gen()`, when the generator is `InSampleUniformGenerator`: - Filter `arms_to_deduplicate` to only arms from `expecting_data` trials (COMPLETED, EARLY_STOPPED, RUNNING) - Skip appending `pending_observations` (which re-adds CANDIDATE arms) This is correct for all current and foreseeable use cases of `InSampleUniformGenerator`. The generator is used in two scenarios: LILO labeling (active) and potentially bandit candidate generation (not currently using it). In both cases, we need arms with observed data -- selecting an arm from a CANDIDATE or STAGED trial that has never been evaluated is never meaningful. Excluding pending observations is similarly safe: pending points exist to prevent regular generators (Sobol, BO) from re-suggesting in-flight arms, but for in-sample selection the entire point is to pick from already-observed arms. The change is gated behind an `isinstance` check -- no effect on other generators (Sobol, Uniform, etc.). Reviewed By: saitcakmak Differential Revision: D98627073 fbshipit-source-id: b1bf9f378aac1eb56907d32da687e5874d22b632
1 parent 3bcf8c4 commit 218bf55

3 files changed

Lines changed: 89 additions & 7 deletions

File tree

ax/adapter/random.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ax.core.optimization_config import OptimizationConfig
2727
from ax.core.search_space import SearchSpace
2828
from ax.generators.random.base import RandomGenerator
29+
from ax.generators.random.in_sample import InSampleUniformGenerator
2930
from ax.generators.types import TConfig
3031

3132

@@ -97,8 +98,25 @@ def _gen(
9798
# Exclude out-of-design arms (which can only be manual arms
9899
# instead of adapter-generated arms).
99100
generated_points = None
101+
is_in_sample = isinstance(self.generator, InSampleUniformGenerator)
100102
if self.generator.deduplicate:
101103
arms_to_deduplicate = self._experiment.arms_by_signature_for_deduplication
104+
# For in-sample generators, restrict to arms from trials that
105+
# have or expect observed data (COMPLETED, EARLY_STOPPED,
106+
# RUNNING). This prevents selecting arms from CANDIDATE/STAGED
107+
# trials that have never been evaluated.
108+
if is_in_sample:
109+
expecting_sigs = {
110+
arm.signature
111+
for trial in self._experiment.trials.values()
112+
if trial.status.expecting_data
113+
for arm in trial.arms
114+
}
115+
arms_to_deduplicate = {
116+
sig: arm
117+
for sig, arm in arms_to_deduplicate.items()
118+
if sig in expecting_sigs
119+
}
102120
generated_obs = [
103121
ObservationFeatures.from_arm(arm=arm)
104122
for arm in arms_to_deduplicate.values()
@@ -108,9 +126,16 @@ def _gen(
108126
for t in self.transforms.values():
109127
generated_obs = t.transform_observation_features(generated_obs)
110128
# Add pending observations -- already transformed.
111-
generated_obs.extend(
112-
[obs for obs_list in pending_observations.values() for obs in obs_list]
113-
)
129+
# Skipped for in-sample generators: pending observations include
130+
# CANDIDATE arms that should not enter the selection pool.
131+
if not is_in_sample:
132+
generated_obs.extend(
133+
[
134+
obs
135+
for obs_list in pending_observations.values()
136+
for obs in obs_list
137+
]
138+
)
114139
if len(generated_obs) > 0:
115140
# Extract generated points array (n x d).
116141
generated_points = np.array(

ax/adapter/tests/test_random_adapter.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ax.core.search_space import SearchSpace
2323
from ax.exceptions.core import SearchSpaceExhausted
2424
from ax.generators.random.base import RandomGenerator
25+
from ax.generators.random.in_sample import InSampleUniformGenerator
2526
from ax.generators.random.sobol import SobolGenerator
2627
from ax.generators.types import TConfig
2728
from ax.utils.common.testutils import TestCase
@@ -297,6 +298,60 @@ def test_generated_points(self) -> None:
297298
generated_points_all_out = mock_gen.call_args.kwargs["generated_points"]
298299
self.assertIsNone(generated_points_all_out)
299300

301+
def test_in_sample_excludes_non_data_bearing_trial_arms(self) -> None:
302+
"""For InSampleUniformGenerator, generated_points should only contain
303+
arms from trials with expecting_data status (COMPLETED, RUNNING,
304+
EARLY_STOPPED). Arms from CANDIDATE, FAILED, and ABANDONED trials
305+
must be excluded even though they exist on the experiment."""
306+
search_space = SearchSpace(self.parameters[:2])
307+
exp = Experiment(search_space=search_space)
308+
309+
# Trial 0: COMPLETED -- should be included.
310+
exp.new_trial().add_arm(Arm(parameters={"x": 0.5, "y": 1.5})).mark_running(
311+
no_runner_required=True
312+
)
313+
exp.trials[0].mark_completed()
314+
315+
# Trial 1: CANDIDATE -- should be excluded.
316+
exp.new_trial().add_arm(Arm(parameters={"x": 0.8, "y": 1.8}))
317+
318+
# Trial 2: RUNNING -- should be included.
319+
exp.new_trial().add_arm(Arm(parameters={"x": 0.3, "y": 1.3})).mark_running(
320+
no_runner_required=True
321+
)
322+
323+
# Trial 3: FAILED -- should be excluded.
324+
exp.new_trial().add_arm(Arm(parameters={"x": 0.1, "y": 1.1})).mark_running(
325+
no_runner_required=True
326+
)
327+
exp.trials[3].mark_failed()
328+
329+
# Trial 4: ABANDONED -- should be excluded.
330+
exp.new_trial().add_arm(Arm(parameters={"x": 0.9, "y": 1.9})).mark_running(
331+
no_runner_required=True
332+
)
333+
exp.trials[4].mark_abandoned()
334+
335+
generator = InSampleUniformGenerator(seed=0)
336+
adapter = RandomAdapter(
337+
experiment=exp,
338+
generator=generator,
339+
transforms=Cont_X_trans,
340+
)
341+
342+
with mock.patch.object(
343+
generator,
344+
"gen",
345+
wraps=generator.gen,
346+
) as mock_gen:
347+
adapter.gen(n=2)
348+
349+
# generated_points should have exactly 2 points (COMPLETED + RUNNING).
350+
# CANDIDATE, FAILED, and ABANDONED arms must be excluded.
351+
generated_points = mock_gen.call_args.kwargs["generated_points"]
352+
assert generated_points is not None
353+
self.assertEqual(len(generated_points), 2)
354+
300355
def test_generation_with_all_fixed(self) -> None:
301356
# Make sure candidate generation succeeds and returns correct parameters
302357
# when all parameters are fixed.

ax/generators/random/in_sample.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ class InSampleUniformGenerator(RandomGenerator):
1919
"""Randomly select candidates from existing experiment arms.
2020
2121
Selects n arms uniformly at random without replacement from the
22-
``generated_points`` array passed by the adapter. This array contains
23-
the in-design, non-failed arms on the experiment (deduplicated).
22+
``generated_points`` array passed by the adapter. For this generator,
23+
the adapter restricts ``generated_points`` to arms from trials that
24+
have or expect observed data (``status.expecting_data``), excluding
25+
arms from CANDIDATE or STAGED trials that have never been evaluated.
2426
2527
Used for model-free candidate selection in use cases like LILO
2628
(Language-in-the-Loop Optimization), where a labeling node needs
@@ -51,8 +53,8 @@ def gen(
5153
model_gen_options: Not used. Accepted for interface compatibility.
5254
rounding_func: Not used. Accepted for interface compatibility.
5355
generated_points: A numpy array of shape ``(num_arms, d)`` containing
54-
the existing experiment arms to select from. Constructed by the
55-
adapter from in-design, non-failed arms (deduplicated).
56+
the existing experiment arms to select from. The adapter
57+
filters this to arms from trials with observed data.
5658
5759
Returns:
5860
2-element tuple containing

0 commit comments

Comments
 (0)