Skip to content

Commit 8b6da20

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Hash-based filtering of stale LILO data in adapter (facebook#4993)
Summary: When building the RankingDataset for PairwiseGP model fitting, exclude LILO trial data whose input hash doesn't match the current experiment state. This ensures PairwiseGP is only fitted on labels that are consistent with the current metric data and LLM messages. Changes: - Add `_get_fresh_pairwise_trial_indices` helper to `adapter_utils.py`: uses `get_current_lilo_hash` from `hash_utils` to compute the current hash and returns trial indices whose stamped hash matches, or `None` if not a LILO experiment (preserving BOPE compatibility) - Filter pairwise data in `TorchAdapter._convert_experiment_data` before calling `prep_pairwise_data`, ensuring stale rows are excluded - Add tests for hash-based filtering logic Reviewed By: saitcakmak Differential Revision: D95284286
1 parent 31b2085 commit 8b6da20

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

ax/adapter/adapter_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
get_weighted_mc_objective_and_objective_thresholds,
4545
pareto_frontier_evaluator,
4646
)
47+
from ax.utils.common.constants import Keys
48+
from ax.utils.common.hash_utils import get_current_lilo_hash
4749
from ax.utils.common.logger import get_logger
4850
from ax.utils.common.typeutils import (
4951
assert_is_instance_of_tuple,
@@ -1270,6 +1272,50 @@ def process_contextual_datasets(
12701272
return contextual_datasets
12711273

12721274

1275+
def _get_fresh_pairwise_trial_indices(
1276+
experiment: Experiment,
1277+
) -> set[int] | None:
1278+
"""Return trial indices whose pairwise labels match current experiment state.
1279+
1280+
LILO (Language-in-the-Loop) trials are stamped with a hash of the
1281+
experiment state (metric data + LLM messages) at labeling time. When
1282+
the experiment state changes (new data, updated LLM messages), old labels
1283+
become stale and should be excluded from PairwiseGP model fitting.
1284+
1285+
Returns:
1286+
A set of trial indices whose LILO input hash matches the current
1287+
experiment state, or ``None`` if hash-based filtering is not
1288+
applicable (e.g., no trials have a LILO input hash — the experiment
1289+
uses BOPE or another non-LILO pairwise workflow).
1290+
"""
1291+
# Collect trials that have been stamped with a LILO input hash.
1292+
stamped_trials = {
1293+
idx: trial
1294+
for idx, trial in experiment.trials.items()
1295+
if Keys.LILO_INPUT_HASH in trial._properties
1296+
}
1297+
if not stamped_trials:
1298+
# Not a LILO experiment — no filtering needed.
1299+
return None
1300+
1301+
current_hash = get_current_lilo_hash(experiment)
1302+
if current_hash is None:
1303+
return None
1304+
1305+
fresh_indices: set[int] = set()
1306+
for idx, trial in experiment.trials.items():
1307+
trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH)
1308+
if trial_hash is None:
1309+
# Trial without hash (non-LILO trial) — always include.
1310+
fresh_indices.add(idx)
1311+
elif trial_hash == current_hash:
1312+
# Hash matches — labels are fresh.
1313+
fresh_indices.add(idx)
1314+
# else: stale hash — excluded.
1315+
1316+
return fresh_indices
1317+
1318+
12731319
def prep_pairwise_data(
12741320
X: Tensor,
12751321
Y: Tensor,

ax/adapter/tests/test_adapter_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99

1010
import numpy as np
11+
import pandas as pd
1112
import torch
1213
from ax.adapter.adapter_utils import (
1314
_get_adapter_training_data,
15+
_get_fresh_pairwise_trial_indices,
1416
arm_to_np_array,
1517
can_map_to_binary,
1618
extract_objective_weight_matrix,
@@ -25,6 +27,9 @@
2527
from ax.adapter.torch import TorchAdapter
2628
from ax.adapter.transforms.choice_encode import ChoiceToNumericChoice
2729
from ax.core.arm import Arm
30+
from ax.core.data import Data
31+
from ax.core.derived_metric import DerivedMetric
32+
from ax.core.experiment import Experiment
2833
from ax.core.metric import Metric
2934
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
3035
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
@@ -34,6 +39,8 @@
3439
from ax.core.types import ComparisonOp
3540
from ax.exceptions.core import UserInputError
3641
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
42+
from ax.utils.common.constants import Keys
43+
from ax.utils.common.hash_utils import compute_lilo_input_hash
3744
from ax.utils.common.testutils import TestCase
3845
from ax.utils.testing.core_stubs import (
3946
get_experiment_with_observations,
@@ -555,3 +562,81 @@ def test_extract_objective_weight_matrix(self) -> None:
555562
)
556563
result = extract_objective_weight_matrix(multi, outcomes)
557564
np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0]])
565+
566+
def test_get_fresh_pairwise_trial_indices(self) -> None:
567+
"""Verify _get_fresh_pairwise_trial_indices hash-based filtering."""
568+
search_space = get_search_space_for_range_values()
569+
exp = Experiment(name="test", search_space=search_space)
570+
571+
# Register a DerivedMetric with pairwise name so the function can
572+
# look up input_metric_names.
573+
pairwise_metric = DerivedMetric(
574+
name=Keys.PAIRWISE_PREFERENCE_QUERY.value,
575+
input_metric_names=["latency"],
576+
)
577+
exp.add_tracking_metric(pairwise_metric)
578+
579+
# Helper to create trial data.
580+
def _attach(
581+
trial_index: int, arms: dict[str, float], exp: Experiment = exp
582+
) -> None:
583+
rows = [
584+
{
585+
"trial_index": trial_index,
586+
"arm_name": name,
587+
"metric_name": "latency",
588+
"metric_signature": "latency",
589+
"mean": val,
590+
"sem": 0.1,
591+
}
592+
for name, val in arms.items()
593+
]
594+
exp.attach_data(Data(df=pd.DataFrame(rows)))
595+
596+
# Create two trials with data.
597+
for i in range(2):
598+
trial = exp.new_batch_trial()
599+
trial.add_arm(Arm(name=f"{i}_0", parameters={"x": float(i)}))
600+
trial.mark_running(no_runner_required=True)
601+
trial.mark_completed()
602+
_attach(i, {f"{i}_0": float(i + 1)})
603+
604+
with self.subTest("no_hashes_returns_none"):
605+
# No trials have LILO_INPUT_HASH → not a LILO experiment.
606+
result = _get_fresh_pairwise_trial_indices(exp)
607+
self.assertIsNone(result)
608+
609+
# Stamp trial 0 with the current hash.
610+
current_hash = compute_lilo_input_hash(exp, ["latency"])
611+
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash
612+
613+
with self.subTest("fresh_hash_included"):
614+
result = _get_fresh_pairwise_trial_indices(exp)
615+
assert result is not None
616+
self.assertIn(0, result)
617+
# Trial 1 has no hash → always included.
618+
self.assertIn(1, result)
619+
620+
# Stamp trial 1 with a stale hash.
621+
exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash_value"
622+
623+
with self.subTest("stale_hash_excluded"):
624+
result = _get_fresh_pairwise_trial_indices(exp)
625+
assert result is not None
626+
self.assertIn(0, result)
627+
self.assertNotIn(1, result)
628+
629+
with self.subTest("all_stale"):
630+
# Make both hashes stale by adding new data.
631+
trial2 = exp.new_batch_trial()
632+
trial2.add_arm(Arm(name="2_0", parameters={"x": 10.0}))
633+
trial2.mark_running(no_runner_required=True)
634+
trial2.mark_completed()
635+
_attach(2, {"2_0": 999.0})
636+
# Now both trial 0 and trial 1 have stale hashes.
637+
result = _get_fresh_pairwise_trial_indices(exp)
638+
assert result is not None
639+
# Trial 0 and 1 are stale, trial 2 has no hash → included.
640+
self.assertNotIn(0, result)
641+
self.assertNotIn(1, result)
642+
self.assertIn(2, result)

ax/adapter/torch.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy.typing as npt
1818
import torch
1919
from ax.adapter.adapter_utils import (
20+
_get_fresh_pairwise_trial_indices,
2021
arm_to_np_array,
2122
array_to_observation_data,
2223
extract_objective_thresholds,
@@ -468,6 +469,26 @@ def _convert_experiment_data(
468469
Yvar = torch.from_numpy(sem).double().square().view(-1, 1)
469470
group_indices = torch.from_numpy(trial_indices_np[to_keep])
470471
if outcome == Keys.PAIRWISE_PREFERENCE_QUERY.value:
472+
# Filter out stale LILO trials whose input hash no longer
473+
# matches the current experiment state.
474+
fresh_indices = _get_fresh_pairwise_trial_indices(
475+
experiment=self._experiment,
476+
)
477+
if fresh_indices is not None:
478+
fresh_mask = torch.tensor(
479+
[int(gi.item()) in fresh_indices for gi in group_indices],
480+
dtype=torch.bool,
481+
)
482+
X = X[fresh_mask]
483+
Y = Y[fresh_mask]
484+
group_indices = group_indices[fresh_mask]
485+
# Narrow the NaN-filtered to_keep mask further so
486+
# candidate_metadata stays aligned.
487+
to_keep_indices = np.where(to_keep)[0]
488+
fresh_mask_np = fresh_mask.numpy()
489+
to_keep = np.zeros_like(to_keep)
490+
to_keep[to_keep_indices[fresh_mask_np]] = True
491+
471492
dataset = prep_pairwise_data(
472493
X=X.to(device=self.device),
473494
Y=Y.to(dtype=torch.long, device=self.device),

0 commit comments

Comments
 (0)