diff --git a/ax/adapter/adapter_utils.py b/ax/adapter/adapter_utils.py index ba90e972a6a..c42889dca82 100644 --- a/ax/adapter/adapter_utils.py +++ b/ax/adapter/adapter_utils.py @@ -44,6 +44,8 @@ get_weighted_mc_objective_and_objective_thresholds, pareto_frontier_evaluator, ) +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import get_current_lilo_hash from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import ( assert_is_instance_of_tuple, @@ -1270,6 +1272,57 @@ def process_contextual_datasets( return contextual_datasets +def _get_fresh_pairwise_trial_indices( + experiment: Experiment, +) -> set[int] | None: + """Return trial indices whose pairwise labels match current experiment state. + + LILO (Language-in-the-Loop) trials are stamped with a hash of the + experiment state (metric data + LLM messages) at labeling time. When + the experiment state changes (new data, updated LLM messages), old labels + become stale and should be excluded from PairwiseGP model fitting. + + Design note: we intentionally compare each trial's stamped hash against + the *current* experiment state rather than the most-recently-stamped LILO + hash. This is because the LLM prompt includes a full experiment summary + (via ``get_llm_messages_with_experiment_summary``), so any change to + input metric data -- even from non-LILO trials -- alters the context + under which labels would be produced and warrants relabeling. + + Returns: + A set of trial indices whose LILO input hash matches the current + experiment state, or ``None`` if hash-based filtering is not + applicable (e.g., no trials have a LILO input hash -- the experiment + uses BOPE or another non-LILO pairwise workflow). + """ + # Collect trials that have been stamped with a LILO input hash. + stamped_trials = { + idx: trial + for idx, trial in experiment.trials.items() + if Keys.LILO_INPUT_HASH in trial._properties + } + if not stamped_trials: + # Not a LILO experiment -- no filtering needed. + return None + + current_hash = get_current_lilo_hash(experiment) + if current_hash is None: + return None + + fresh_indices: set[int] = set() + for idx, trial in experiment.trials.items(): + trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH) + if trial_hash is None: + # Trial without hash (non-LILO trial) -- always include. + fresh_indices.add(idx) + elif trial_hash == current_hash: + # Hash matches -- labels are fresh. + fresh_indices.add(idx) + # else: stale hash -- excluded. + + return fresh_indices + + def prep_pairwise_data( X: Tensor, Y: Tensor, diff --git a/ax/adapter/tests/test_adapter_utils.py b/ax/adapter/tests/test_adapter_utils.py index c30b2605703..4b618db2981 100644 --- a/ax/adapter/tests/test_adapter_utils.py +++ b/ax/adapter/tests/test_adapter_utils.py @@ -8,9 +8,11 @@ import numpy as np +import pandas as pd import torch from ax.adapter.adapter_utils import ( _get_adapter_training_data, + _get_fresh_pairwise_trial_indices, arm_to_np_array, can_map_to_binary, extract_objective_weight_matrix, @@ -25,6 +27,9 @@ from ax.adapter.torch import TorchAdapter from ax.adapter.transforms.choice_encode import ChoiceToNumericChoice from ax.core.arm import Arm +from ax.core.data import Data +from ax.core.derived_metric import DerivedMetric +from ax.core.experiment import Experiment from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import MultiObjectiveOptimizationConfig @@ -34,6 +39,8 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import UserInputError from ax.generators.torch.botorch_modular.generator import BoTorchGenerator +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import compute_lilo_input_hash from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_experiment_with_observations, @@ -555,3 +562,81 @@ def test_extract_objective_weight_matrix(self) -> None: ) result = extract_objective_weight_matrix(multi, outcomes) np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0]]) + + def test_get_fresh_pairwise_trial_indices(self) -> None: + """Verify _get_fresh_pairwise_trial_indices hash-based filtering.""" + search_space = get_search_space_for_range_values() + exp = Experiment(name="test", search_space=search_space) + + # Register a DerivedMetric with pairwise name so the function can + # look up input_metric_names. + pairwise_metric = DerivedMetric( + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, + input_metric_names=["latency"], + ) + exp.add_tracking_metric(pairwise_metric) + + # Helper to create trial data. + def _attach( + trial_index: int, arms: dict[str, float], exp: Experiment = exp + ) -> None: + rows = [ + { + "trial_index": trial_index, + "arm_name": name, + "metric_name": "latency", + "metric_signature": "latency", + "mean": val, + "sem": 0.1, + } + for name, val in arms.items() + ] + exp.attach_data(Data(df=pd.DataFrame(rows))) + + # Create two trials with data. + for i in range(2): + trial = exp.new_batch_trial() + trial.add_arm(Arm(name=f"{i}_0", parameters={"x": float(i)})) + trial.mark_running(no_runner_required=True) + trial.mark_completed() + _attach(i, {f"{i}_0": float(i + 1)}) + + with self.subTest("no_hashes_returns_none"): + # No trials have LILO_INPUT_HASH -- not a LILO experiment. + result = _get_fresh_pairwise_trial_indices(exp) + self.assertIsNone(result) + + # Stamp trial 0 with the current hash. + current_hash = compute_lilo_input_hash(exp, ["latency"]) + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash + + with self.subTest("fresh_hash_included"): + result = _get_fresh_pairwise_trial_indices(exp) + assert result is not None + self.assertIn(0, result) + # Trial 1 has no hash -- always included. + self.assertIn(1, result) + + # Stamp trial 1 with a stale hash. + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash_value" + + with self.subTest("stale_hash_excluded"): + result = _get_fresh_pairwise_trial_indices(exp) + assert result is not None + self.assertIn(0, result) + self.assertNotIn(1, result) + + with self.subTest("all_stale"): + # Make both hashes stale by adding new data. + trial2 = exp.new_batch_trial() + trial2.add_arm(Arm(name="2_0", parameters={"x": 10.0})) + trial2.mark_running(no_runner_required=True) + trial2.mark_completed() + _attach(2, {"2_0": 999.0}) + # Now both trial 0 and trial 1 have stale hashes. + result = _get_fresh_pairwise_trial_indices(exp) + assert result is not None + # Trial 0 and 1 are stale, trial 2 has no hash -- included. + self.assertNotIn(0, result) + self.assertNotIn(1, result) + self.assertIn(2, result) diff --git a/ax/adapter/torch.py b/ax/adapter/torch.py index 31c583061bb..d72295e90d0 100644 --- a/ax/adapter/torch.py +++ b/ax/adapter/torch.py @@ -17,6 +17,7 @@ import numpy.typing as npt import torch from ax.adapter.adapter_utils import ( + _get_fresh_pairwise_trial_indices, arm_to_np_array, array_to_observation_data, extract_objective_thresholds, @@ -468,6 +469,26 @@ def _convert_experiment_data( Yvar = torch.from_numpy(sem).double().square().view(-1, 1) group_indices = torch.from_numpy(trial_indices_np[to_keep]) if outcome == Keys.PAIRWISE_PREFERENCE_QUERY.value: + # Filter out stale LILO trials whose input hash no longer + # matches the current experiment state. + fresh_indices = _get_fresh_pairwise_trial_indices( + experiment=self._experiment, + ) + if fresh_indices is not None: + fresh_mask = torch.tensor( + [int(gi.item()) in fresh_indices for gi in group_indices], + dtype=torch.bool, + ) + X = X[fresh_mask] + Y = Y[fresh_mask] + group_indices = group_indices[fresh_mask] + # Narrow the NaN-filtered to_keep mask further so + # candidate_metadata stays aligned. + to_keep_indices = np.where(to_keep)[0] + fresh_mask_np = fresh_mask.numpy() + to_keep = np.zeros_like(to_keep) + to_keep[to_keep_indices[fresh_mask_np]] = True + dataset = prep_pairwise_data( X=X.to(device=self.device), Y=Y.to(dtype=torch.long, device=self.device), diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index 33efc1b6a9c..e7fd2b6428d 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -7,11 +7,15 @@ from logging import Logger +from unittest.mock import MagicMock import pandas as pd from ax.adapter.registry import Generators +from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.data import Data +from ax.core.derived_metric import DerivedMetric +from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import MaxParallelismReachedException @@ -24,11 +28,14 @@ from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, AuxiliaryExperimentCheck, + FreshLILOLabelCheck, IsSingleObjective, MaxGenerationParallelism, MaxTrialsAwaitingData, MinTrials, ) +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import compute_lilo_input_hash from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -41,6 +48,13 @@ logger: Logger = get_logger(__name__) +def _mock_node(trials_from_node: set[int]) -> MagicMock: + """Create a mock GenerationNode with a specified trials_from_node set.""" + node = MagicMock() + node.trials_from_node = trials_from_node + return node + + class TestTransitionCriterion(TestCase): def setUp(self) -> None: super().setUp() @@ -614,3 +628,189 @@ def test_max_generation_parallelism_block_error(self) -> None: experiment=self.experiment, trials_from_node={0, 1, 2}, ) + + def test_fresh_lilo_label_check(self) -> None: + """Verify FreshLILOLabelCheck counts only hash-fresh trials.""" + exp = get_branin_experiment() + + # Register a DerivedMetric with pairwise name. + pairwise_metric = DerivedMetric( + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, + input_metric_names=["branin"], + ) + exp.add_tracking_metric(pairwise_metric) + + criterion = FreshLILOLabelCheck( + threshold=2, + transition_to="next_node", + only_in_statuses=[TrialStatus.COMPLETED], + ) + + # Helper to create and complete a trial with data. + def _add_trial(idx: int, exp: Experiment = exp) -> None: + trial = exp.new_batch_trial() + trial.add_arm( + Arm(name=f"{idx}_0", parameters={"x1": float(idx), "x2": 0.0}) + ) + trial.mark_running(no_runner_required=True) + trial.mark_completed() + exp.attach_data( + Data( + df=pd.DataFrame( + [ + { + "trial_index": idx, + "arm_name": f"{idx}_0", + "metric_name": "branin", + "metric_signature": "branin", + "mean": float(idx), + "sem": 0.1, + } + ] + ) + ) + ) + + # Create 3 trials, stamp first 2 with current hash. + for i in range(3): + _add_trial(i) + + current_hash = compute_lilo_input_hash(exp, ["branin"]) + trials_from_node = {0, 1, 2} + + with self.subTest("no_hashes_none_count"): + # No hash stamps → no trials counted (only LILO trials with + # a matching hash contribute). + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + self.assertEqual(count, 0) + + # Stamp trials 0 and 1 with the current hash. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash + + with self.subTest("fresh_hashes_count"): + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # Trials 0, 1 (fresh hash). Trial 2 (no hash → excluded). + self.assertEqual(count, 2) + + # Make trial 1 stale. + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash" + + with self.subTest("stale_hash_excluded"): + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # Trial 0 (fresh). Trial 1 (stale) and trial 2 (no hash) excluded. + self.assertEqual(count, 1) + self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node))) + + # Make trial 0 stale too. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "another_stale" + + with self.subTest("not_enough_fresh"): + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # All stamped trials are stale, trial 2 has no hash → 0. + self.assertEqual(count, 0) + self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("data_change_invalidates"): + # Add new data — changes the current hash, making ALL stamped + # trials stale. + _add_trial(3) + trials_from_node.add(3) + count = criterion.num_contributing_to_threshold(exp, trials_from_node) + # Trials 0, 1 stale. Trials 2, 3 have no hash → excluded. + self.assertEqual(count, 0) + + def test_fresh_lilo_label_check_require_sufficient(self) -> None: + """Verify require_sufficient flag controls is_met direction.""" + exp = get_branin_experiment() + + pairwise_metric = DerivedMetric( + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, + input_metric_names=["branin"], + ) + exp.add_tracking_metric(pairwise_metric) + + # Create 2 completed trials with data. + for i in range(2): + trial = exp.new_batch_trial() + trial.add_arm(Arm(name=f"{i}_0", parameters={"x1": float(i), "x2": 0.0})) + trial.mark_running(no_runner_required=True) + trial.mark_completed() + exp.attach_data( + Data( + df=pd.DataFrame( + [ + { + "trial_index": i, + "arm_name": f"{i}_0", + "metric_name": "branin", + "metric_signature": "branin", + "mean": float(i), + "sem": 0.1, + } + ] + ) + ) + ) + + current_hash = compute_lilo_input_hash(exp, ["branin"]) + # Stamp both trials as fresh. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash + trials_from_node = {0, 1} + + sufficient = FreshLILOLabelCheck( + threshold=2, + transition_to="MBG", + require_sufficient=True, + only_in_statuses=[TrialStatus.COMPLETED], + ) + insufficient = FreshLILOLabelCheck( + threshold=2, + transition_to="LILO", + require_sufficient=False, + only_in_statuses=[TrialStatus.COMPLETED], + ) + + with self.subTest("sufficient_met_when_enough_fresh"): + # 2 fresh >= threshold 2 → require_sufficient=True is met. + self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("insufficient_not_met_when_enough_fresh"): + # 2 fresh >= threshold 2 → require_sufficient=False is NOT met. + self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node))) + + # Make trial 0 stale → only 1 fresh trial. + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "stale" + + with self.subTest("sufficient_not_met_when_stale"): + # 1 fresh < threshold 2 → require_sufficient=True is NOT met. + self.assertFalse(sufficient.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("insufficient_met_when_stale"): + # 1 fresh < threshold 2 → require_sufficient=False IS met. + self.assertTrue(insufficient.is_met(exp, _mock_node(trials_from_node))) + + def test_fresh_lilo_label_check_non_lilo_fallback(self) -> None: + """Non-LILO experiment: require_sufficient=True always met, + require_sufficient=False never met.""" + exp = get_branin_experiment() + # No pairwise DerivedMetric registered — non-LILO experiment. + trials_from_node: set[int] = set() + + sufficient = FreshLILOLabelCheck( + threshold=32, + transition_to="MBG", + require_sufficient=True, + ) + insufficient = FreshLILOLabelCheck( + threshold=32, + transition_to="LILO", + require_sufficient=False, + ) + + with self.subTest("non_lilo_sufficient_always_met"): + self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node))) + + with self.subTest("non_lilo_insufficient_never_met"): + self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node))) diff --git a/ax/generation_strategy/transition_criterion.py b/ax/generation_strategy/transition_criterion.py index 887da17d8aa..7f740b529a8 100644 --- a/ax/generation_strategy/transition_criterion.py +++ b/ax/generation_strategy/transition_criterion.py @@ -17,6 +17,8 @@ from ax.core.utils import get_trial_indices_with_required_metrics from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import MaxParallelismReachedException +from ax.utils.common.constants import Keys +from ax.utils.common.hash_utils import get_current_lilo_hash if TYPE_CHECKING: from ax.generation_strategy.generation_node import GenerationNode @@ -644,6 +646,135 @@ def __init__( ) +class FreshLILOLabelCheck(TrialBasedCriterion): + """Transition criterion based on the freshness of LILO preference labels. + + LILO (Language-in-the-Loop) trials are stamped with a hash of the + experiment state (metric data + LLM messages) at labeling time. + When the experiment state changes (new data arrives, or the user updates + LLM messages), old labels become stale. This criterion gates transitions + based on how many *fresh* labels exist. + + The ``require_sufficient`` flag controls the direction: + + - **``require_sufficient=True``** (LILO_LABELING -> MBG): ``is_met`` + when the number of fresh labels >= ``threshold``. "We have enough + fresh labels -- proceed to BO generation." + - **``require_sufficient=False``** (MBG -> LILO_LABELING): ``is_met`` + when the number of fresh labels < ``threshold``. "Labels are stale + -- relabel before generating." + + **Non-LILO fallback** (no pairwise ``DerivedMetric`` on the experiment): + ``require_sufficient=True`` -> always met (proceed normally). + ``require_sufficient=False`` -> never met (never trigger relabeling). + The fallback short-circuits *before* the count comparison so that a + non-LILO experiment with fewer than ``threshold`` trials does not + falsely trigger relabeling. + + Args: + threshold: Number of fresh trials for the sufficiency check. + transition_to: The GenerationNode to transition to when met. + require_sufficient: If ``True``, ``is_met`` when fresh count >= + threshold. If ``False``, ``is_met`` when fresh count < + threshold. Defaults to ``True``. + only_in_statuses: Only count trials with these statuses. + not_in_statuses: Exclude trials with these statuses. + use_all_trials_in_exp: Count all experiment trials, not just + those from the current node. + continue_trial_generation: Continue generating arms for the + same trial after transition. + count_only_trials_with_data: Only count trials that have data. + """ + + def __init__( + self, + threshold: int, + transition_to: str, + require_sufficient: bool = True, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool | None = False, + continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, + ) -> None: + self.require_sufficient = require_sufficient + super().__init__( + threshold=threshold, + transition_to=transition_to, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + use_all_trials_in_exp=use_all_trials_in_exp, + continue_trial_generation=continue_trial_generation, + count_only_trials_with_data=count_only_trials_with_data, + ) + + def num_contributing_to_threshold( + self, + experiment: Experiment, + trials_from_node: set[int], + ) -> int: + """Count trials toward threshold, excluding those with stale hashes. + + First applies the standard status-based filtering from the base class, + then further filters to only trials whose LILO input hash matches + the current experiment state. + """ + # Get the base count of candidate trial indices (status-filtered). + all_trials = self.all_trials_to_check(experiment) + if self.count_only_trials_with_data: + data_trial_indices = get_trial_indices_with_required_metrics( + experiment=experiment, + df=experiment.lookup_data().df, + require_data_for_all_metrics=False, + ) + all_trials = all_trials.intersection(data_trial_indices) + + if not bool(self.use_all_trials_in_exp): + all_trials = trials_from_node.intersection(all_trials) + + # Further filter by LILO input hash freshness. + current_hash = get_current_lilo_hash(experiment) + if current_hash is None: + # No pairwise DerivedMetric found — fall back to plain count. + return len(all_trials) + + fresh_count = 0 + for idx in all_trials: + trial = experiment.trials[idx] + trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH) + # Only count trials that have a LILO_INPUT_HASH (i.e., actual + # LILO labeling trials) and whose hash matches the current state. + # Trials without a hash (regular Sobol/MBG trials) are excluded + # so they don't inflate the fresh-label count. + if trial_hash is not None and trial_hash == current_hash: + fresh_count += 1 + + return fresh_count + + def is_met( + self, + experiment: Experiment, + curr_node: GenerationNode, + ) -> bool: + """Check whether the freshness condition is satisfied. + + For non-LILO experiments (no pairwise ``DerivedMetric``), this + short-circuits: ``require_sufficient=True`` → always met, + ``require_sufficient=False`` → never met. + """ + # Short-circuit for non-LILO experiments. + if get_current_lilo_hash(experiment) is None: + return self.require_sufficient + + count = self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=curr_node.trials_from_node + ) + if self.require_sufficient: + return count >= self.threshold + else: + return count < self.threshold + + class AuxiliaryExperimentCheck(TransitionCriterion): """A class to transition from one GenerationNode to another by checking if certain types of Auxiliary Experiment purposes exists. diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 3a7f4fcec97..4858dcbd0a7 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -78,6 +78,7 @@ from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, AuxiliaryExperimentCheck, + FreshLILOLabelCheck, IsSingleObjective, MaxGenerationParallelism, MaxTrialsAwaitingData, @@ -222,6 +223,7 @@ MaxTrialsAwaitingData: pausing_criterion_to_dict, Metric: metric_to_dict, MinTrials: transition_criterion_to_dict, + FreshLILOLabelCheck: transition_criterion_to_dict, AuxiliaryExperimentCheck: transition_criterion_to_dict, GeneratorSpec: generator_spec_to_dict, MultiObjective: multi_objective_to_dict, @@ -350,6 +352,7 @@ "MaxTrialsAwaitingData": MaxTrialsAwaitingData, "Metric": Metric, "MinTrials": MinTrials, + "FreshLILOLabelCheck": FreshLILOLabelCheck, # DEPRECATED; backward compatibility for MinimumTrialsInStatus -> MinTrials "MinimumTrialsInStatus": MinTrials, "GeneratorRegistryBase": GeneratorRegistryBase, diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index b0612ebf4fa..34ec22ba068 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -64,6 +64,7 @@ class Keys(StrEnum): FRAC_RANDOM = "frac_random" FULL_PARAMETERIZATION = "full_parameterization" IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = "immutable_search_space_and_opt_config" + LILO_INPUT_HASH = "lilo_input_hash" LILO_LABELING = "lilo_labeling" LLM_MESSAGES = "llm_messages" LONG_RUN = "long_run" diff --git a/ax/utils/common/hash_utils.py b/ax/utils/common/hash_utils.py new file mode 100644 index 00000000000..e4f51bea682 --- /dev/null +++ b/ax/utils/common/hash_utils.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""Hash utilities for LILO (Language-in-the-Loop) data freshness tracking.""" + +from __future__ import annotations + +import hashlib +from typing import TYPE_CHECKING + +from ax.core.derived_metric import DerivedMetric +from ax.utils.common.constants import Keys + +if TYPE_CHECKING: + from ax.core.experiment import Experiment + + +def compute_lilo_input_hash( + experiment: Experiment, + input_metric_names: list[str], +) -> str: + """Compute a hash of the experiment state relevant to LILO labeling. + + The hash captures two components: + 1. The experiment's LLM messages (user preferences that guide labeling). + 2. The observed metric data for ``input_metric_names`` across all trials. + + If any of these inputs change, the hash changes, indicating that existing + LILO labels are stale and should be excluded from model fitting. + + Args: + experiment: The experiment whose state to hash. + input_metric_names: Names of the base metrics whose observed values + are shown to the LLM for pairwise comparison. + + Returns: + An SHA-256 hex digest string representing the current LILO input state. + """ + parts: list[str] = [] + + # Component 1: LLM messages (canonical serialization). + for msg in experiment.llm_messages: + parts.append(f"{msg.role}:{msg.content}") + + parts.append("---") # Separator between components. + + # Component 2: Metric data for input_metric_names. + data = experiment.data + if not data.empty: + df = data.df + metric_df = df[df["metric_name"].isin(input_metric_names)] + if not metric_df.empty: + # Sort deterministically and serialize key columns. + sorted_df = metric_df.sort_values( + ["trial_index", "arm_name", "metric_name"] + ) + for _, row in sorted_df.iterrows(): + parts.append( + f"{row['trial_index']}|{row['arm_name']}|" + f"{row['metric_name']}|{row['mean']}|{row['sem']}" + ) + + content = "\n".join(parts) + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + +def get_current_lilo_hash(experiment: Experiment) -> str | None: + """Compute the current LILO input hash, or ``None`` if not applicable. + + Looks up the pairwise preference metric on the experiment by name + (``Keys.PAIRWISE_PREFERENCE_QUERY``), checks that it is a + ``DerivedMetric`` (which provides ``input_metric_names``), and computes + the hash. In practice only ``LILOPairwiseMetric`` satisfies both + conditions; we check ``DerivedMetric`` rather than ``LILOPairwiseMetric`` + directly because the latter lives in ``ax.fb`` and cannot be imported + from this OSS module without creating a circular dependency. + + Returns: + The SHA-256 hex digest of the current LILO input state, or ``None`` + if no suitable pairwise ``DerivedMetric`` is registered. + """ + pairwise_metric_name = Keys.PAIRWISE_PREFERENCE_QUERY.value + metric = experiment.metrics.get(pairwise_metric_name) + # TODO: Replace `DerivedMetric` with `LILOPairwiseMetric` here. + if metric is None or not isinstance(metric, DerivedMetric): + return None + return compute_lilo_input_hash( + experiment=experiment, + input_metric_names=metric.input_metric_names, + )