Skip to content

Commit 6ce2e7c

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Add Keys.LILO_LABELING trial type for Language-in-the-Loop labeling trials (facebook#4986)
Summary: Pull Request resolved: facebook#4986 Add `Keys.LILO_LABELING` to identify LILO (Language-in-the-Loop) labeling trials in generation strategies. LILO trials collect pairwise preference labels via LLM calls and need a distinct trial type so they can be selectively marked STALE before relabeling rounds. - Add `LILO_LABELING = "lilo_labeling"` to `Keys` enum - Extend `GenerationNode.__init__` trial type validation to accept `LILO_LABELING` - Extend `Experiment.supports_trial_type()` to accept `LILO_LABELING` Reviewed By: saitcakmak Differential Revision: D94743845 fbshipit-source-id: 9526915c193f4ce510a90c9e528290818fc3fa27
1 parent 9f50a93 commit 6ce2e7c

File tree

5 files changed

+27
-8
lines changed

5 files changed

+27
-8
lines changed

ax/core/experiment.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,11 +1880,9 @@ def supports_trial_type(self, trial_type: str | None) -> bool:
18801880
"""
18811881
return (
18821882
trial_type is None
1883-
# We temporarily allow "short run" and "long run" trial
1884-
# types in single-type experiments during development of
1885-
# a new ``GenerationStrategy`` that needs them.
18861883
or trial_type == Keys.SHORT_RUN
18871884
or trial_type == Keys.LONG_RUN
1885+
or trial_type == Keys.LILO_LABELING
18881886
)
18891887

18901888
def attach_trial(

ax/core/tests/test_experiment.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,18 @@ def test_basic_batch_creation(self) -> None:
207207
new_exp = get_experiment()
208208
new_exp._attach_trial(batch)
209209

210+
def test_supports_trial_type(self) -> None:
211+
exp = get_experiment()
212+
self.assertTrue(exp.supports_trial_type(None))
213+
self.assertTrue(exp.supports_trial_type(Keys.SHORT_RUN))
214+
self.assertTrue(exp.supports_trial_type(Keys.LONG_RUN))
215+
self.assertTrue(exp.supports_trial_type(Keys.LILO_LABELING))
216+
self.assertFalse(exp.supports_trial_type("unsupported_type"))
217+
218+
# Verify LILO_LABELING trial type works with new_batch_trial
219+
batch = exp.new_batch_trial(trial_type=Keys.LILO_LABELING)
220+
self.assertEqual(batch.trial_type, Keys.LILO_LABELING)
221+
210222
def test_repr(self) -> None:
211223
self.assertEqual("Experiment(test)", str(self.experiment))
212224

ax/generation_strategy/generation_node.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,14 @@ def __init__(
179179
)
180180
if len(generator_specs) > 1 and best_model_selector is None:
181181
raise UserInputError(MISSING_MODEL_SELECTOR_MESSAGE)
182-
if trial_type is not None and (
183-
trial_type != Keys.SHORT_RUN and trial_type != Keys.LONG_RUN
182+
if trial_type is not None and trial_type not in (
183+
Keys.SHORT_RUN,
184+
Keys.LONG_RUN,
185+
Keys.LILO_LABELING,
184186
):
185187
raise NotImplementedError(
186-
f"Trial type must be either {Keys.SHORT_RUN} or {Keys.LONG_RUN},"
187-
f" got {trial_type}."
188+
f"Trial type must be one of {Keys.SHORT_RUN}, {Keys.LONG_RUN},"
189+
f" or {Keys.LILO_LABELING}, got {trial_type}."
188190
)
189191
# If possible, assign `_generator_spec_to_gen_from` right away, for use in
190192
# `__repr__`

ax/generation_strategy/tests/test_generation_node.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_input_constructor_none(self) -> None:
134134
self.assertEqual(self.sobol_generation_node.input_constructors, {})
135135

136136
def test_incorrect_trial_type(self) -> None:
137-
with self.assertRaisesRegex(NotImplementedError, "Trial type must be either"):
137+
with self.assertRaisesRegex(NotImplementedError, "Trial type must be one of"):
138138
GenerationNode(
139139
name="test",
140140
generator_specs=[self.sobol_generator_spec],
@@ -147,12 +147,18 @@ def test_init_with_trial_type(self) -> None:
147147
generator_specs=[self.sobol_generator_spec],
148148
trial_type=Keys.LONG_RUN,
149149
)
150+
node_lilo = GenerationNode(
151+
name="test",
152+
generator_specs=[self.sobol_generator_spec],
153+
trial_type=Keys.LILO_LABELING,
154+
)
150155
node_default = GenerationNode(
151156
name="test",
152157
generator_specs=[self.sobol_generator_spec],
153158
)
154159
self.assertEqual(self.node_short._trial_type, Keys.SHORT_RUN)
155160
self.assertEqual(node_long._trial_type, Keys.LONG_RUN)
161+
self.assertEqual(node_lilo._trial_type, Keys.LILO_LABELING)
156162
self.assertIsNone(node_default._trial_type)
157163

158164
def test_input_constructor(self) -> None:

ax/utils/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class Keys(StrEnum):
6464
FRAC_RANDOM = "frac_random"
6565
FULL_PARAMETERIZATION = "full_parameterization"
6666
IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = "immutable_search_space_and_opt_config"
67+
LILO_LABELING = "lilo_labeling"
6768
LLM_MESSAGES = "llm_messages"
6869
LONG_RUN = "long_run"
6970
MAXIMIZE = "maximize"

0 commit comments

Comments
 (0)