Skip to content

Commit 9facfc4

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Add MinTrialsWithLILOInputHashCheck transition criterion (#4994)
Summary: Pull Request resolved: #4994 Add a hash-aware transition criterion for LILO GS loops. Unlike plain MinTrials which counts all completed trials from a node, MinTrialsWithLILOInputHashCheck only counts trials whose LILO input hash matches the current experiment state. This ensures the GS correctly transitions from LILO labeling → MBG only when enough *fresh* labels exist (labels produced under the current experiment data + LLM messages). Trials without a LILO input hash (non-LILO trials) are always counted, preserving backward compatibility. Changes: - Add `MinTrialsWithLILOInputHashCheck` class to `transition_criterion.py` that delegates hash computation to `get_current_lilo_hash` from `hash_utils` (replacing a private `_compute_current_hash` static method) - Remove redundant pass-through `__init__` — the parent class handles all args - Register in JSON encoder/decoder registries for serialization support - Add tests verifying fresh/stale counting behavior Reviewed By: saitcakmak Differential Revision: D95284285
1 parent b8e31a3 commit 9facfc4

File tree

3 files changed

+180
-0
lines changed

3 files changed

+180
-0
lines changed

ax/generation_strategy/tests/test_transition_criterion.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77

88

99
from logging import Logger
10+
from unittest.mock import MagicMock
1011

1112
import pandas as pd
1213
from ax.adapter.registry import Generators
14+
from ax.core.arm import Arm
1315
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
1416
from ax.core.data import Data
17+
from ax.core.derived_metric import DerivedMetric
18+
from ax.core.experiment import Experiment
1519
from ax.core.trial_status import TrialStatus
1620
from ax.exceptions.core import DataRequiredError, UserInputError
1721
from ax.exceptions.generation_strategy import MaxParallelismReachedException
@@ -28,7 +32,10 @@
2832
MaxGenerationParallelism,
2933
MaxTrialsAwaitingData,
3034
MinTrials,
35+
MinTrialsWithLILOInputHashCheck,
3136
)
37+
from ax.utils.common.constants import Keys
38+
from ax.utils.common.hash_utils import compute_lilo_input_hash
3239
from ax.utils.common.logger import get_logger
3340
from ax.utils.common.testutils import TestCase
3441
from ax.utils.testing.core_stubs import (
@@ -41,6 +48,13 @@
4148
logger: Logger = get_logger(__name__)
4249

4350

51+
def _mock_node(trials_from_node: set[int]) -> MagicMock:
52+
"""Create a mock GenerationNode with a specified trials_from_node set."""
53+
node = MagicMock()
54+
node.trials_from_node = trials_from_node
55+
return node
56+
57+
4458
class TestTransitionCriterion(TestCase):
4559
def setUp(self) -> None:
4660
super().setUp()
@@ -614,3 +628,93 @@ def test_max_generation_parallelism_block_error(self) -> None:
614628
experiment=self.experiment,
615629
trials_from_node={0, 1, 2},
616630
)
631+
632+
def test_min_trials_with_lilo_input_hash_check(self) -> None:
633+
"""Verify MinTrialsWithLILOInputHashCheck counts only hash-fresh trials."""
634+
exp = get_branin_experiment()
635+
636+
# Register a DerivedMetric with pairwise name.
637+
pairwise_metric = DerivedMetric(
638+
name=Keys.PAIRWISE_PREFERENCE_QUERY.value,
639+
input_metric_names=["branin"],
640+
)
641+
exp.add_tracking_metric(pairwise_metric)
642+
643+
criterion = MinTrialsWithLILOInputHashCheck(
644+
threshold=2,
645+
transition_to="next_node",
646+
only_in_statuses=[TrialStatus.COMPLETED],
647+
)
648+
649+
# Helper to create and complete a trial with data.
650+
def _add_trial(idx: int, exp: Experiment = exp) -> None:
651+
trial = exp.new_batch_trial()
652+
trial.add_arm(
653+
Arm(name=f"{idx}_0", parameters={"x1": float(idx), "x2": 0.0})
654+
)
655+
trial.mark_running(no_runner_required=True)
656+
trial.mark_completed()
657+
exp.attach_data(
658+
Data(
659+
df=pd.DataFrame(
660+
[
661+
{
662+
"trial_index": idx,
663+
"arm_name": f"{idx}_0",
664+
"metric_name": "branin",
665+
"metric_signature": "branin",
666+
"mean": float(idx),
667+
"sem": 0.1,
668+
}
669+
]
670+
)
671+
)
672+
)
673+
674+
# Create 3 trials, stamp first 2 with current hash.
675+
for i in range(3):
676+
_add_trial(i)
677+
678+
current_hash = compute_lilo_input_hash(exp, ["branin"])
679+
trials_from_node = {0, 1, 2}
680+
681+
with self.subTest("no_hashes_all_count"):
682+
# No hash stamps → all counted (fallback behavior).
683+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
684+
self.assertEqual(count, 3)
685+
686+
# Stamp trials 0 and 1 with the current hash.
687+
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash
688+
exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash
689+
690+
with self.subTest("fresh_hashes_count"):
691+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
692+
# Trials 0, 1 (fresh hash) + trial 2 (no hash → included).
693+
self.assertEqual(count, 3)
694+
695+
# Make trial 1 stale.
696+
exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash"
697+
698+
with self.subTest("stale_hash_excluded"):
699+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
700+
# Trial 0 (fresh) + trial 2 (no hash) = 2. Trial 1 excluded.
701+
self.assertEqual(count, 2)
702+
self.assertTrue(criterion.is_met(exp, _mock_node(trials_from_node)))
703+
704+
# Make trial 0 stale too.
705+
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "another_stale"
706+
707+
with self.subTest("not_enough_fresh"):
708+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
709+
# Only trial 2 (no hash) counts.
710+
self.assertEqual(count, 1)
711+
self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node)))
712+
713+
with self.subTest("data_change_invalidates"):
714+
# Add new data — changes the current hash, making ALL stamped
715+
# trials stale.
716+
_add_trial(3)
717+
trials_from_node.add(3)
718+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
719+
# Trials 0, 1 stale. Trials 2, 3 have no hash → included.
720+
self.assertEqual(count, 2)

ax/generation_strategy/transition_criterion.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from ax.core.utils import get_trial_indices_with_required_metrics
1818
from ax.exceptions.core import DataRequiredError, UserInputError
1919
from ax.exceptions.generation_strategy import MaxParallelismReachedException
20+
from ax.utils.common.constants import Keys
21+
from ax.utils.common.hash_utils import get_current_lilo_hash
2022

2123
if TYPE_CHECKING:
2224
from ax.generation_strategy.generation_node import GenerationNode
@@ -644,6 +646,77 @@ def __init__(
644646
)
645647

646648

649+
class MinTrialsWithLILOInputHashCheck(TrialBasedCriterion):
650+
"""Like ``MinTrials``, but only counts trials whose LILO input hash
651+
matches the current experiment state.
652+
653+
LILO (Language-in-the-Loop) trials are stamped with a hash of the
654+
experiment state (metric data + LLM messages) at labeling time.
655+
When the experiment state changes (new data arrives, or the user updates
656+
LLM messages), old labels become stale. This criterion ensures that
657+
the transition fires only when enough *fresh* labels exist — i.e.,
658+
labels produced under the current experiment state.
659+
660+
Freshness is checked against the *current* experiment state (not the
661+
most-recently-stamped LILO hash) because the LLM prompt includes a
662+
full experiment summary, so any change to input metric data alters the
663+
context under which labels would be produced and warrants relabeling.
664+
665+
Trials without a LILO input hash (e.g., Sobol or MBG trials) are always
666+
counted, preserving backward compatibility with non-LILO workflows.
667+
668+
Args:
669+
threshold: Minimum number of fresh trials required.
670+
transition_to: The GenerationNode to transition to when met.
671+
only_in_statuses: Only count trials with these statuses.
672+
not_in_statuses: Exclude trials with these statuses.
673+
use_all_trials_in_exp: Count all experiment trials, not just
674+
those from the current node.
675+
continue_trial_generation: Continue generating arms for the
676+
same trial after transition.
677+
count_only_trials_with_data: Only count trials that have data.
678+
"""
679+
680+
def num_contributing_to_threshold(
681+
self,
682+
experiment: Experiment,
683+
trials_from_node: set[int],
684+
) -> int:
685+
"""Count trials toward threshold, excluding those with stale hashes.
686+
687+
First applies the standard status-based filtering from the base class,
688+
then further filters to only trials whose LILO input hash matches
689+
the current experiment state.
690+
"""
691+
# Get the base count of candidate trial indices (status-filtered).
692+
all_trials = self.all_trials_to_check(experiment)
693+
if self.count_only_trials_with_data:
694+
data_trial_indices = get_trial_indices_with_required_metrics(
695+
experiment=experiment,
696+
df=experiment.lookup_data().df,
697+
require_data_for_all_metrics=False,
698+
)
699+
all_trials = all_trials.intersection(data_trial_indices)
700+
701+
if not bool(self.use_all_trials_in_exp):
702+
all_trials = trials_from_node.intersection(all_trials)
703+
704+
# Further filter by LILO input hash freshness.
705+
current_hash = get_current_lilo_hash(experiment)
706+
if current_hash is None:
707+
# No pairwise DerivedMetric found — fall back to plain count.
708+
return len(all_trials)
709+
710+
fresh_count = 0
711+
for idx in all_trials:
712+
trial = experiment.trials[idx]
713+
trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH)
714+
if trial_hash is None or trial_hash == current_hash:
715+
fresh_count += 1
716+
717+
return fresh_count
718+
719+
647720
class AuxiliaryExperimentCheck(TransitionCriterion):
648721
"""A class to transition from one GenerationNode to another by checking if certain
649722
types of Auxiliary Experiment purposes exists.

ax/storage/json_store/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
MaxGenerationParallelism,
8383
MaxTrialsAwaitingData,
8484
MinTrials,
85+
MinTrialsWithLILOInputHashCheck,
8586
TransitionCriterion,
8687
)
8788
from ax.generators.torch.botorch_modular.acquisition import Acquisition
@@ -222,6 +223,7 @@
222223
MaxTrialsAwaitingData: pausing_criterion_to_dict,
223224
Metric: metric_to_dict,
224225
MinTrials: transition_criterion_to_dict,
226+
MinTrialsWithLILOInputHashCheck: transition_criterion_to_dict,
225227
AuxiliaryExperimentCheck: transition_criterion_to_dict,
226228
GeneratorSpec: generator_spec_to_dict,
227229
MultiObjective: multi_objective_to_dict,
@@ -350,6 +352,7 @@
350352
"MaxTrialsAwaitingData": MaxTrialsAwaitingData,
351353
"Metric": Metric,
352354
"MinTrials": MinTrials,
355+
"MinTrialsWithLILOInputHashCheck": MinTrialsWithLILOInputHashCheck,
353356
# DEPRECATED; backward compatibility for MinimumTrialsInStatus -> MinTrials
354357
"MinimumTrialsInStatus": MinTrials,
355358
"GeneratorRegistryBase": GeneratorRegistryBase,

0 commit comments

Comments
 (0)