From 79d9b427514ba68a70803fa8d206e454b59f0ff1 Mon Sep 17 00:00:00 2001 From: Zhiyuan Jerry Lin Date: Fri, 13 Mar 2026 10:11:15 -0700 Subject: [PATCH 1/3] LILO hash computation and stamping for data freshness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add hash-based data freshness tracking for LILO (Language-in-the-Loop) pairwise preference labels. When LILOPairwiseMetric produces labels, it now stamps a SHA-256 hash of the experiment's LILO inputs (metric data for input_metric_names + LLM messages) onto the trial's _properties. If any of these inputs change (new data arrives, data is updated, or the user modifies LLM messages), the hash changes, indicating that existing LILO labels are stale. Changes: - Add `LILO_INPUT_HASH` key to `Keys` enum in `constants.py` - Create `ax/utils/common/hash_utils.py` with `compute_lilo_input_hash` (standalone hash function) and `get_current_lilo_hash` (convenience helper that looks up the pairwise `DerivedMetric` on an experiment, extracts `input_metric_names`, and computes the hash — returns `None` if no pairwise metric is registered) - Stamp hash in `LILOPairwiseMetric._compute_derived_values` after producing labels - Add tests for hash determinism, sensitivity to data/message changes, stamping, and `get_current_lilo_hash` helper Differential Revision: D95284287 --- ax/utils/common/constants.py | 1 + ax/utils/common/hash_utils.py | 95 +++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 ax/utils/common/hash_utils.py diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index b0612ebf4fa..34ec22ba068 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -64,6 +64,7 @@ class Keys(StrEnum): FRAC_RANDOM = "frac_random" FULL_PARAMETERIZATION = "full_parameterization" IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = "immutable_search_space_and_opt_config" + LILO_INPUT_HASH = "lilo_input_hash" LILO_LABELING = "lilo_labeling" LLM_MESSAGES = "llm_messages" LONG_RUN = "long_run" diff --git a/ax/utils/common/hash_utils.py b/ax/utils/common/hash_utils.py new file mode 100644 index 00000000000..e4f51bea682 --- /dev/null +++ b/ax/utils/common/hash_utils.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""Hash utilities for LILO (Language-in-the-Loop) data freshness tracking.""" + +from __future__ import annotations + +import hashlib +from typing import TYPE_CHECKING + +from ax.core.derived_metric import DerivedMetric +from ax.utils.common.constants import Keys + +if TYPE_CHECKING: + from ax.core.experiment import Experiment + + +def compute_lilo_input_hash( + experiment: Experiment, + input_metric_names: list[str], +) -> str: + """Compute a hash of the experiment state relevant to LILO labeling. + + The hash captures two components: + 1. The experiment's LLM messages (user preferences that guide labeling). + 2. The observed metric data for ``input_metric_names`` across all trials. + + If any of these inputs change, the hash changes, indicating that existing + LILO labels are stale and should be excluded from model fitting. + + Args: + experiment: The experiment whose state to hash. + input_metric_names: Names of the base metrics whose observed values + are shown to the LLM for pairwise comparison. + + Returns: + An SHA-256 hex digest string representing the current LILO input state. + """ + parts: list[str] = [] + + # Component 1: LLM messages (canonical serialization). + for msg in experiment.llm_messages: + parts.append(f"{msg.role}:{msg.content}") + + parts.append("---") # Separator between components. + + # Component 2: Metric data for input_metric_names. + data = experiment.data + if not data.empty: + df = data.df + metric_df = df[df["metric_name"].isin(input_metric_names)] + if not metric_df.empty: + # Sort deterministically and serialize key columns. + sorted_df = metric_df.sort_values( + ["trial_index", "arm_name", "metric_name"] + ) + for _, row in sorted_df.iterrows(): + parts.append( + f"{row['trial_index']}|{row['arm_name']}|" + f"{row['metric_name']}|{row['mean']}|{row['sem']}" + ) + + content = "\n".join(parts) + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + +def get_current_lilo_hash(experiment: Experiment) -> str | None: + """Compute the current LILO input hash, or ``None`` if not applicable. + + Looks up the pairwise preference metric on the experiment by name + (``Keys.PAIRWISE_PREFERENCE_QUERY``), checks that it is a + ``DerivedMetric`` (which provides ``input_metric_names``), and computes + the hash. In practice only ``LILOPairwiseMetric`` satisfies both + conditions; we check ``DerivedMetric`` rather than ``LILOPairwiseMetric`` + directly because the latter lives in ``ax.fb`` and cannot be imported + from this OSS module without creating a circular dependency. + + Returns: + The SHA-256 hex digest of the current LILO input state, or ``None`` + if no suitable pairwise ``DerivedMetric`` is registered. + """ + pairwise_metric_name = Keys.PAIRWISE_PREFERENCE_QUERY.value + metric = experiment.metrics.get(pairwise_metric_name) + # TODO: Replace `DerivedMetric` with `LILOPairwiseMetric` here. + if metric is None or not isinstance(metric, DerivedMetric): + return None + return compute_lilo_input_hash( + experiment=experiment, + input_metric_names=metric.input_metric_names, + ) From 5ecbbfbb2fed6cfb99ad559e0071e4c07b6a026e Mon Sep 17 00:00:00 2001 From: Zhiyuan Jerry Lin Date: Fri, 13 Mar 2026 10:11:15 -0700 Subject: [PATCH 2/3] Hash-based filtering of stale LILO data in adapter Summary: When building the RankingDataset for PairwiseGP model fitting, exclude LILO trial data whose input hash doesn't match the current experiment state. This ensures PairwiseGP is only fitted on labels that are consistent with the current metric data and LLM messages. Changes: - Add `_get_fresh_pairwise_trial_indices` helper to `adapter_utils.py`: uses `get_current_lilo_hash` from `hash_utils` to compute the current hash and returns trial indices whose stamped hash matches, or `None` if not a LILO experiment (preserving BOPE compatibility) - Filter pairwise data in `TorchAdapter._convert_experiment_data` before calling `prep_pairwise_data`, ensuring stale rows are excluded - Add tests for hash-based filtering logic Differential Revision: D95284286 --- ax/adapter/adapter_utils.py | 53 ++++++++++++++++ ax/adapter/tests/test_adapter_utils.py | 85 ++++++++++++++++++++++++++ ax/adapter/torch.py | 21 +++++++ 3 files changed, 159 insertions(+) diff --git a/ax/adapter/adapter_utils.py b/ax/adapter/adapter_utils.py index ba90e972a6a..c42889dca82 100644 --- a/ax/adapter/adapter_utils.py +++ b/ax/adapter/adapter_utils.py @@ -44,6 +44,8 @@ get_weighted_mc_objective_and_objective_thresholds, pareto_frontier_evaluator, ) +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import get_current_lilo_hash from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import ( assert_is_instance_of_tuple, @@ -1270,6 +1272,57 @@ def process_contextual_datasets( return contextual_datasets +def _get_fresh_pairwise_trial_indices( + experiment: Experiment, +) -> set[int] | None: + """Return trial indices whose pairwise labels match current experiment state. + + LILO (Language-in-the-Loop) trials are stamped with a hash of the + experiment state (metric data + LLM messages) at labeling time. When + the experiment state changes (new data, updated LLM messages), old labels + become stale and should be excluded from PairwiseGP model fitting. + + Design note: we intentionally compare each trial's stamped hash against + the *current* experiment state rather than the most-recently-stamped LILO + hash. This is because the LLM prompt includes a full experiment summary + (via ``get_llm_messages_with_experiment_summary``), so any change to + input metric data -- even from non-LILO trials -- alters the context + under which labels would be produced and warrants relabeling. + + Returns: + A set of trial indices whose LILO input hash matches the current + experiment state, or ``None`` if hash-based filtering is not + applicable (e.g., no trials have a LILO input hash -- the experiment + uses BOPE or another non-LILO pairwise workflow). + """ + # Collect trials that have been stamped with a LILO input hash. + stamped_trials = { + idx: trial + for idx, trial in experiment.trials.items() + if Keys.LILO_INPUT_HASH in trial._properties + } + if not stamped_trials: + # Not a LILO experiment -- no filtering needed. + return None + + current_hash = get_current_lilo_hash(experiment) + if current_hash is None: + return None + + fresh_indices: set[int] = set() + for idx, trial in experiment.trials.items(): + trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH) + if trial_hash is None: + # Trial without hash (non-LILO trial) -- always include. + fresh_indices.add(idx) + elif trial_hash == current_hash: + # Hash matches -- labels are fresh. + fresh_indices.add(idx) + # else: stale hash -- excluded. + + return fresh_indices + + def prep_pairwise_data( X: Tensor, Y: Tensor, diff --git a/ax/adapter/tests/test_adapter_utils.py b/ax/adapter/tests/test_adapter_utils.py index c30b2605703..4b618db2981 100644 --- a/ax/adapter/tests/test_adapter_utils.py +++ b/ax/adapter/tests/test_adapter_utils.py @@ -8,9 +8,11 @@ import numpy as np +import pandas as pd import torch from ax.adapter.adapter_utils import ( _get_adapter_training_data, + _get_fresh_pairwise_trial_indices, arm_to_np_array, can_map_to_binary, extract_objective_weight_matrix, @@ -25,6 +27,9 @@ from ax.adapter.torch import TorchAdapter from ax.adapter.transforms.choice_encode import ChoiceToNumericChoice from ax.core.arm import Arm +from ax.core.data import Data +from ax.core.derived_metric import DerivedMetric +from ax.core.experiment import Experiment from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import MultiObjectiveOptimizationConfig @@ -34,6 +39,8 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import UserInputError from ax.generators.torch.botorch_modular.generator import BoTorchGenerator +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import compute_lilo_input_hash from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_experiment_with_observations, @@ -555,3 +562,81 @@ def test_extract_objective_weight_matrix(self) -> None: ) result = extract_objective_weight_matrix(multi, outcomes) np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0]]) + + def test_get_fresh_pairwise_trial_indices(self) -> None: + """Verify _get_fresh_pairwise_trial_indices hash-based filtering.""" + search_space = get_search_space_for_range_values() + exp = Experiment(name="test", search_space=search_space) + + # Register a DerivedMetric with pairwise name so the function can + # look up input_metric_names. + pairwise_metric = DerivedMetric( + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, + input_metric_names=["latency"], + ) + exp.add_tracking_metric(pairwise_metric) + + # Helper to create trial data. + def _attach( + trial_index: int, arms: dict[str, float], exp: Experiment = exp + ) -> None: + rows = [ + { + "trial_index": trial_index, + "arm_name": name, + "metric_name": "latency", + "metric_signature": "latency", + "mean": val, + "sem": 0.1, + } + for name, val in arms.items() + ] + exp.attach_data(Data(df=pd.DataFrame(rows))) + + # Create two trials with data. + for i in range(2): + trial = exp.new_batch_trial() + trial.add_arm(Arm(name=f"{i}_0", parameters={"x": float(i)})) + trial.mark_running(no_runner_required=True) + trial.mark_completed() + _attach(i, {f"{i}_0": float(i + 1)}) + + with self.subTest("no_hashes_returns_none"): + # No trials have LILO_INPUT_HASH -- not a LILO experiment. + result = _get_fresh_pairwise_trial_indices(exp) + self.assertIsNone(result) + + # Stamp trial 0 with the current hash. + current_hash = compute_lilo_input_hash(exp, ["latency"]) + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash + + with self.subTest("fresh_hash_included"): + result = _get_fresh_pairwise_trial_indices(exp) + assert result is not None + self.assertIn(0, result) + # Trial 1 has no hash -- always included. + self.assertIn(1, result) + + # Stamp trial 1 with a stale hash. + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash_value" + + with self.subTest("stale_hash_excluded"): + result = _get_fresh_pairwise_trial_indices(exp) + assert result is not None + self.assertIn(0, result) + self.assertNotIn(1, result) + + with self.subTest("all_stale"): + # Make both hashes stale by adding new data. + trial2 = exp.new_batch_trial() + trial2.add_arm(Arm(name="2_0", parameters={"x": 10.0})) + trial2.mark_running(no_runner_required=True) + trial2.mark_completed() + _attach(2, {"2_0": 999.0}) + # Now both trial 0 and trial 1 have stale hashes. + result = _get_fresh_pairwise_trial_indices(exp) + assert result is not None + # Trial 0 and 1 are stale, trial 2 has no hash -- included. + self.assertNotIn(0, result) + self.assertNotIn(1, result) + self.assertIn(2, result) diff --git a/ax/adapter/torch.py b/ax/adapter/torch.py index 31c583061bb..d72295e90d0 100644 --- a/ax/adapter/torch.py +++ b/ax/adapter/torch.py @@ -17,6 +17,7 @@ import numpy.typing as npt import torch from ax.adapter.adapter_utils import ( + _get_fresh_pairwise_trial_indices, arm_to_np_array, array_to_observation_data, extract_objective_thresholds, @@ -468,6 +469,26 @@ def _convert_experiment_data( Yvar = torch.from_numpy(sem).double().square().view(-1, 1) group_indices = torch.from_numpy(trial_indices_np[to_keep]) if outcome == Keys.PAIRWISE_PREFERENCE_QUERY.value: + # Filter out stale LILO trials whose input hash no longer + # matches the current experiment state. + fresh_indices = _get_fresh_pairwise_trial_indices( + experiment=self._experiment, + ) + if fresh_indices is not None: + fresh_mask = torch.tensor( + [int(gi.item()) in fresh_indices for gi in group_indices], + dtype=torch.bool, + ) + X = X[fresh_mask] + Y = Y[fresh_mask] + group_indices = group_indices[fresh_mask] + # Narrow the NaN-filtered to_keep mask further so + # candidate_metadata stays aligned. + to_keep_indices = np.where(to_keep)[0] + fresh_mask_np = fresh_mask.numpy() + to_keep = np.zeros_like(to_keep) + to_keep[to_keep_indices[fresh_mask_np]] = True + dataset = prep_pairwise_data( X=X.to(device=self.device), Y=Y.to(dtype=torch.long, device=self.device), From 4d64f8798e5b4823888a55060cbdbf89b742144f Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Fri, 13 Mar 2026 10:16:03 -0700 Subject: [PATCH 3/3] Add FreshLILOLabelCheck transition criterion (#4994) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/4994 Add a hash-aware transition criterion for LILO GS loops. `FreshLILOLabelCheck` counts only trials whose LILO input hash matches the current experiment state, ensuring transitions are gated on *fresh* labels (produced under current data + LLM messages). The `require_sufficient` flag controls the transition direction: - `require_sufficient=True` (LILO_LABELING -> MBG): is_met when fresh count >= threshold. "Enough fresh labels -- proceed to BO generation." - `require_sufficient=False` (MBG -> LILO_LABELING): is_met when fresh count < threshold. "Labels are stale -- relabel before generating." Non-LILO experiments (no pairwise DerivedMetric) short-circuit: `require_sufficient=True` -> always met, `require_sufficient=False` -> never met. This prevents false relabeling triggers on non-LILO experiments. Reviewed By: saitcakmak Differential Revision: D95284285 --- .../tests/test_transition_criterion.py | 200 ++++++++++++++++++ .../transition_criterion.py | 131 ++++++++++++ ax/storage/json_store/registry.py | 3 + 3 files changed, 334 insertions(+) diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index 33efc1b6a9c..e7fd2b6428d 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -7,11 +7,15 @@ from logging import Logger +from unittest.mock import MagicMock import pandas as pd from ax.adapter.registry import Generators +from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.data import Data +from ax.core.derived_metric import DerivedMetric +from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import MaxParallelismReachedException @@ -24,11 +28,14 @@ from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, AuxiliaryExperimentCheck, + FreshLILOLabelCheck, IsSingleObjective, MaxGenerationParallelism, MaxTrialsAwaitingData, MinTrials, ) +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import compute_lilo_input_hash from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -41,6 +48,13 @@ logger: Logger = get_logger(__name__) +def _mock_node(trials_from_node: set[int]) -> MagicMock: + """Create a mock GenerationNode with a specified trials_from_node set.""" + node = MagicMock() + node.trials_from_node = trials_from_node + return node + + class TestTransitionCriterion(TestCase): def setUp(self) -> None: super().setUp() @@ -614,3 +628,189 @@ def test_max_generation_parallelism_block_error(self) -> None: experiment=self.experiment, trials_from_node={0, 1, 2}, ) + + def test_fresh_lilo_label_check(self) -> None: + """Verify FreshLILOLabelCheck counts only hash-fresh trials.""" + exp = get_branin_experiment() + + # Register a DerivedMetric with pairwise name. + pairwise_metric = DerivedMetric( + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, + input_metric_names=["branin"], + ) + exp.add_tracking_metric(pairwise_metric) + + criterion = FreshLILOLabelCheck( + threshold=2, + transition_to="next_node", + only_in_statuses=[TrialStatus.COMPLETED], + ) + + # Helper to create and complete a trial with data. + def _add_trial(idx: int, exp: Experiment = exp) -> None: + trial = exp.new_batch_trial() + trial.add_arm( + Arm(name=f"{idx}_0", parameters={"x1": float(idx), "x2": 0.0}) + ) + trial.mark_running(no_runner_required=True) + trial.mark_completed() + exp.attach_data( + Data( + df=pd.DataFrame( + [ + { + "trial_index": idx, + "arm_name": f"{idx}_0", + "metric_name": "branin", + "metric_signature": "branin", + "mean": float(idx), + "sem": 0.1, + } + ] + ) + ) + ) + + # Create 3 trials, stamp first 2 with current hash. + for i in range(3): + _add_trial(i) + + current_hash = compute_lilo_input_hash(exp, ["branin"]) + trials_from_node = {0, 1, 2} + + with self.subTest("no_hashes_none_count"): + # No hash stamps → no trials counted (only LILO trials with + # a matching hash contribute). + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + self.assertEqual(count, 0) + + # Stamp trials 0 and 1 with the current hash. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash + + with self.subTest("fresh_hashes_count"): + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # Trials 0, 1 (fresh hash). Trial 2 (no hash → excluded). + self.assertEqual(count, 2) + + # Make trial 1 stale. + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash" + + with self.subTest("stale_hash_excluded"): + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # Trial 0 (fresh). Trial 1 (stale) and trial 2 (no hash) excluded. + self.assertEqual(count, 1) + self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node))) + + # Make trial 0 stale too. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "another_stale" + + with self.subTest("not_enough_fresh"): + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # All stamped trials are stale, trial 2 has no hash → 0. + self.assertEqual(count, 0) + self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("data_change_invalidates"): + # Add new data — changes the current hash, making ALL stamped + # trials stale. + _add_trial(3) + trials_from_node.add(3) + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # Trials 0, 1 stale. Trials 2, 3 have no hash → excluded. + self.assertEqual(count, 0) + + def test_fresh_lilo_label_check_require_sufficient(self) -> None: + """Verify require_sufficient flag controls is_met direction.""" + exp = get_branin_experiment() + + pairwise_metric = DerivedMetric( + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, + input_metric_names=["branin"], + ) + exp.add_tracking_metric(pairwise_metric) + + # Create 2 completed trials with data. + for i in range(2): + trial = exp.new_batch_trial() + trial.add_arm(Arm(name=f"{i}_0", parameters={"x1": float(i), "x2": 0.0})) + trial.mark_running(no_runner_required=True) + trial.mark_completed() + exp.attach_data( + Data( + df=pd.DataFrame( + [ + { + "trial_index": i, + "arm_name": f"{i}_0", + "metric_name": "branin", + "metric_signature": "branin", + "mean": float(i), + "sem": 0.1, + } + ] + ) + ) + ) + + current_hash = compute_lilo_input_hash(exp, ["branin"]) + # Stamp both trials as fresh. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash + trials_from_node = {0, 1} + + sufficient = FreshLILOLabelCheck( + threshold=2, + transition_to="MBG", + require_sufficient=True, + only_in_statuses=[TrialStatus.COMPLETED], + ) + insufficient = FreshLILOLabelCheck( + threshold=2, + transition_to="LILO", + require_sufficient=False, + only_in_statuses=[TrialStatus.COMPLETED], + ) + + with self.subTest("sufficient_met_when_enough_fresh"): + # 2 fresh >= threshold 2 → require_sufficient=True is met. + self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("insufficient_not_met_when_enough_fresh"): + # 2 fresh >= threshold 2 → require_sufficient=False is NOT met. + self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node))) + + # Make trial 0 stale → only 1 fresh trial. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "stale" + + with self.subTest("sufficient_not_met_when_stale"): + # 1 fresh < threshold 2 → require_sufficient=True is NOT met. + self.assertFalse(sufficient.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("insufficient_met_when_stale"): + # 1 fresh < threshold 2 → require_sufficient=False IS met. + self.assertTrue(insufficient.is_met(exp, _mock_node(trials_from_node))) + + def test_fresh_lilo_label_check_non_lilo_fallback(self) -> None: + """Non-LILO experiment: require_sufficient=True always met, + require_sufficient=False never met.""" + exp = get_branin_experiment() + # No pairwise DerivedMetric registered — non-LILO experiment. + trials_from_node: set[int] = set() + + sufficient = FreshLILOLabelCheck( + threshold=32, + transition_to="MBG", + require_sufficient=True, + ) + insufficient = FreshLILOLabelCheck( + threshold=32, + transition_to="LILO", + require_sufficient=False, + ) + + with self.subTest("non_lilo_sufficient_always_met"): + self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("non_lilo_insufficient_never_met"): + self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node))) diff --git a/ax/generation_strategy/transition_criterion.py b/ax/generation_strategy/transition_criterion.py index 887da17d8aa..7f740b529a8 100644 --- a/ax/generation_strategy/transition_criterion.py +++ b/ax/generation_strategy/transition_criterion.py @@ -17,6 +17,8 @@ from ax.core.utils import get_trial_indices_with_required_metrics from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import MaxParallelismReachedException +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import get_current_lilo_hash if TYPE_CHECKING: from ax.generation_strategy.generation_node import GenerationNode @@ -644,6 +646,135 @@ def __init__( ) +class FreshLILOLabelCheck(TrialBasedCriterion): + """Transition criterion based on the freshness of LILO preference labels. + + LILO (Language-in-the-Loop) trials are stamped with a hash of the + experiment state (metric data + LLM messages) at labeling time. + When the experiment state changes (new data arrives, or the user updates + LLM messages), old labels become stale. This criterion gates transitions + based on how many *fresh* labels exist. + + The ``require_sufficient`` flag controls the direction: + + - **``require_sufficient=True``** (LILO_LABELING -> MBG): ``is_met`` + when the number of fresh labels >= ``threshold``. "We have enough + fresh labels -- proceed to BO generation." + - **``require_sufficient=False``** (MBG -> LILO_LABELING): ``is_met`` + when the number of fresh labels < ``threshold``. "Labels are stale + -- relabel before generating." + + **Non-LILO fallback** (no pairwise ``DerivedMetric`` on the experiment): + ``require_sufficient=True`` -> always met (proceed normally). + ``require_sufficient=False`` -> never met (never trigger relabeling). + The fallback short-circuits *before* the count comparison so that a + non-LILO experiment with fewer than ``threshold`` trials does not + falsely trigger relabeling. + + Args: + threshold: Number of fresh trials for the sufficiency check. + transition_to: The GenerationNode to transition to when met. + require_sufficient: If ``True``, ``is_met`` when fresh count >= + threshold. If ``False``, ``is_met`` when fresh count < + threshold. Defaults to ``True``. + only_in_statuses: Only count trials with these statuses. + not_in_statuses: Exclude trials with these statuses. + use_all_trials_in_exp: Count all experiment trials, not just + those from the current node. + continue_trial_generation: Continue generating arms for the + same trial after transition. + count_only_trials_with_data: Only count trials that have data. + """ + + def __init__( + self, + threshold: int, + transition_to: str, + require_sufficient: bool = True, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool | None = False, + continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, + ) -> None: + self.require_sufficient = require_sufficient + super().__init__( + threshold=threshold, + transition_to=transition_to, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + use_all_trials_in_exp=use_all_trials_in_exp, + continue_trial_generation=continue_trial_generation, + count_only_trials_with_data=count_only_trials_with_data, + ) + + def num_contributing_to_threshold( + self, + experiment: Experiment, + trials_from_node: set[int], + ) -> int: + """Count trials toward threshold, excluding those with stale hashes. + + First applies the standard status-based filtering from the base class, + then further filters to only trials whose LILO input hash matches + the current experiment state. + """ + # Get the base count of candidate trial indices (status-filtered). + all_trials = self.all_trials_to_check(experiment) + if self.count_only_trials_with_data: + data_trial_indices = get_trial_indices_with_required_metrics( + experiment=experiment, + df=experiment.lookup_data().df, + require_data_for_all_metrics=False, + ) + all_trials = all_trials.intersection(data_trial_indices) + + if not bool(self.use_all_trials_in_exp): + all_trials = trials_from_node.intersection(all_trials) + + # Further filter by LILO input hash freshness. + current_hash = get_current_lilo_hash(experiment) + if current_hash is None: + # No pairwise DerivedMetric found — fall back to plain count. + return len(all_trials) + + fresh_count = 0 + for idx in all_trials: + trial = experiment.trials[idx] + trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH) + # Only count trials that have a LILO_INPUT_HASH (i.e., actual + # LILO labeling trials) and whose hash matches the current state. + # Trials without a hash (regular Sobol/MBG trials) are excluded + # so they don't inflate the fresh-label count. + if trial_hash is not None and trial_hash == current_hash: + fresh_count += 1 + + return fresh_count + + def is_met( + self, + experiment: Experiment, + curr_node: GenerationNode, + ) -> bool: + """Check whether the freshness condition is satisfied. + + For non-LILO experiments (no pairwise ``DerivedMetric``), this + short-circuits: ``require_sufficient=True`` → always met, + ``require_sufficient=False`` → never met. + """ + # Short-circuit for non-LILO experiments. + if get_current_lilo_hash(experiment) is None: + return self.require_sufficient + + count = self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=curr_node.trials_from_node + ) + if self.require_sufficient: + return count >= self.threshold + else: + return count < self.threshold + + class AuxiliaryExperimentCheck(TransitionCriterion): """A class to transition from one GenerationNode to another by checking if certain types of Auxiliary Experiment purposes exists. diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 3a7f4fcec97..4858dcbd0a7 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -78,6 +78,7 @@ from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, AuxiliaryExperimentCheck, + FreshLILOLabelCheck, IsSingleObjective, MaxGenerationParallelism, MaxTrialsAwaitingData, @@ -222,6 +223,7 @@ MaxTrialsAwaitingData: pausing_criterion_to_dict, Metric: metric_to_dict, MinTrials: transition_criterion_to_dict, + FreshLILOLabelCheck: transition_criterion_to_dict, AuxiliaryExperimentCheck: transition_criterion_to_dict, GeneratorSpec: generator_spec_to_dict, MultiObjective: multi_objective_to_dict, @@ -350,6 +352,7 @@ "MaxTrialsAwaitingData": MaxTrialsAwaitingData, "Metric": Metric, "MinTrials": MinTrials, + "FreshLILOLabelCheck": FreshLILOLabelCheck, # DEPRECATED; backward compatibility for MinimumTrialsInStatus -> MinTrials "MinimumTrialsInStatus": MinTrials, "GeneratorRegistryBase": GeneratorRegistryBase,