Skip to content

Commit f81332c

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Filter fetch_data by trial type in base Experiment (facebook#5005)
Summary: Pull Request resolved: facebook#5005 Phase 4 of moving MultiTypeExperiment features into base Experiment. Updates the base Experiment's `_fetch_trial_data` to filter metrics by trial type when `_trial_type_to_metric_names` is populated. Only metrics whose names appear in the set for the trial's type are fetched, preventing metrics from being evaluated against trials of the wrong type. Updates `fetch_data` to iterate trial-by-trial when `_trial_type_to_metric_names` is populated (so each trial fetches only its relevant metrics), while preserving the existing bulk fetch path for single-type experiments. Removes the `fetch_data` and `_fetch_trial_data` overrides from MultiTypeExperiment, along with now-unused imports (Iterable, Data, MetricFetchResult). Differential Revision: D94990429
1 parent 8eafffe commit f81332c

3 files changed

Lines changed: 152 additions & 37 deletions

File tree

ax/core/experiment.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,37 @@ def fetch_data(
11381138
Returns:
11391139
Data for the experiment.
11401140
"""
1141+
if self._trial_type_to_metric_names:
1142+
# When metrics are mapped to trial types, group trials by type
1143+
# and bulk-fetch per group so each group only fetches its
1144+
# relevant metrics.
1145+
all_trials = (
1146+
list(self.trials.values())
1147+
if trial_indices is None
1148+
else self.get_trials_by_indices(trial_indices)
1149+
)
1150+
trials_by_type: dict[str | None, list[BaseTrial]] = defaultdict(list)
1151+
for trial in all_trials:
1152+
if trial.status.expecting_data:
1153+
trials_by_type[trial.trial_type].append(trial)
1154+
all_results: dict[int, dict[str, MetricFetchResult]] = {}
1155+
for trial_type, type_trials in trials_by_type.items():
1156+
type_metrics = (
1157+
metrics
1158+
if metrics is not None
1159+
else (
1160+
self.metrics_for_trial_type(trial_type)
1161+
if trial_type is not None
1162+
else list(self.metrics.values())
1163+
)
1164+
)
1165+
results = self._lookup_or_fetch_trials_results(
1166+
trials=type_trials,
1167+
metrics=type_metrics,
1168+
**kwargs,
1169+
)
1170+
all_results.update(results)
1171+
return Metric._unwrap_experiment_data_multi(results=all_results)
11411172
results = self._lookup_or_fetch_trials_results(
11421173
trials=list(self.trials.values())
11431174
if trial_indices is None
@@ -1278,6 +1309,14 @@ def _fetch_trial_data(
12781309
) -> dict[str, MetricFetchResult]:
12791310
trial = self.trials[trial_index]
12801311

1312+
# When metrics are mapped to trial types, filter to only the
1313+
# metrics relevant to this trial's type.
1314+
trial_type = trial.trial_type
1315+
if self._trial_type_to_metric_names and trial_type is not None:
1316+
valid_names = self._trial_type_to_metric_names.get(trial_type, set())
1317+
all_metrics = list(metrics or self.metrics.values())
1318+
metrics = [m for m in all_metrics if m.name in valid_names]
1319+
12811320
trial_data = self._lookup_or_fetch_trials_results(
12821321
trials=[trial], metrics=metrics, **kwargs
12831322
)

ax/core/multi_type_experiment.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66

77
# pyre-strict
88

9-
from collections.abc import Iterable, Sequence
9+
from collections.abc import Sequence
1010
from typing import Any, Self
1111

1212
from ax.core.arm import Arm
1313
from ax.core.base_trial import BaseTrial, TrialStatus
14-
from ax.core.data import Data
1514
from ax.core.experiment import Experiment
16-
from ax.core.metric import Metric, MetricFetchResult
15+
from ax.core.metric import Metric
1716
from ax.core.optimization_config import OptimizationConfig
1817
from ax.core.runner import Runner
1918
from ax.core.search_space import SearchSpace
@@ -171,40 +170,6 @@ def remove_metric(self, metric_name: str) -> Self:
171170
self._metric_to_canonical_name.pop(metric_name, None)
172171
return self
173172

174-
@copy_doc(Experiment.fetch_data)
175-
def fetch_data(
176-
self,
177-
trial_indices: Iterable[int] | None = None,
178-
metrics: list[Metric] | None = None,
179-
**kwargs: Any,
180-
) -> Data:
181-
# TODO: make this more efficient for fetching
182-
# data for multiple trials of the same type
183-
# by overriding Experiment._lookup_or_fetch_trials_results
184-
return Data.from_multiple_data(
185-
[
186-
(
187-
trial.fetch_data(**kwargs, metrics=metrics)
188-
if trial.status.expecting_data
189-
else Data()
190-
)
191-
for trial in self.trials.values()
192-
]
193-
)
194-
195-
@copy_doc(Experiment._fetch_trial_data)
196-
def _fetch_trial_data(
197-
self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any
198-
) -> dict[str, MetricFetchResult]:
199-
trial = self.trials[trial_index]
200-
metrics = [
201-
metric
202-
for metric in (metrics or self.metrics.values())
203-
if self.metric_to_trial_type[metric.name] == trial.trial_type
204-
]
205-
# Invoke parent's fetch method using only metrics for this trial_type
206-
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)
207-
208173

209174
def filter_trials_by_type(
210175
trials: Sequence[BaseTrial], trial_type: str | None

ax/core/tests/test_experiment.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,3 +2772,114 @@ def test_extract_relevant_trials(self) -> None:
27722772
)
27732773
self.assertEqual(len(trials), 1)
27742774
self.assertEqual(trials[0].index, 0)
2775+
2776+
def _setup_multi_type_branin_experiment(self, n: int = 10) -> Experiment:
2777+
"""Create a base Experiment with two trial types and metrics mapped
2778+
to each, mimicking a multi-type setup without using
2779+
MultiTypeExperiment.
2780+
"""
2781+
exp = Experiment(
2782+
name="multi_type_test",
2783+
search_space=get_branin_search_space(),
2784+
default_trial_type="type1",
2785+
tracking_metrics=[
2786+
BraninMetric(name="m1", param_names=["x1", "x2"]),
2787+
],
2788+
runner=SyntheticRunner(),
2789+
)
2790+
# Register a second trial type with its own runner and metric.
2791+
exp._trial_type_to_runner["type2"] = SyntheticRunner()
2792+
exp.add_tracking_metric(
2793+
BraninMetric(name="m2", param_names=["x2", "x1"]),
2794+
trial_type="type2",
2795+
)
2796+
2797+
# Create one batch per trial type and run them.
2798+
b1 = exp.new_batch_trial(trial_type="type1")
2799+
b1.add_arms_and_weights(arms=get_branin_arms(n=n, seed=0))
2800+
b1.run()
2801+
2802+
b2 = exp.new_batch_trial(trial_type="type2")
2803+
b2.add_arms_and_weights(arms=get_branin_arms(n=n, seed=0))
2804+
b2.run()
2805+
2806+
return exp
2807+
2808+
def test_fetch_data_filters_by_trial_type(self) -> None:
2809+
"""fetch_data should return only the metrics mapped to each trial's
2810+
type when _trial_type_to_metric_names is populated."""
2811+
n = 10
2812+
exp = self._setup_multi_type_branin_experiment(n=n)
2813+
2814+
df = exp.fetch_data().df
2815+
# Each trial should have n rows (one per arm), for a total of 2*n.
2816+
self.assertEqual(len(df), 2 * n)
2817+
2818+
# Trial 0 (type1) should only have metric "m1".
2819+
trial_0_df = df[df["trial_index"] == 0]
2820+
self.assertEqual(set(trial_0_df["metric_name"]), {"m1"})
2821+
self.assertEqual(len(trial_0_df), n)
2822+
2823+
# Trial 1 (type2) should only have metric "m2".
2824+
trial_1_df = df[df["trial_index"] == 1]
2825+
self.assertEqual(set(trial_1_df["metric_name"]), {"m2"})
2826+
self.assertEqual(len(trial_1_df), n)
2827+
2828+
def test_fetch_data_with_trial_indices_and_trial_types(self) -> None:
2829+
"""fetch_data with trial_indices should respect trial type filtering."""
2830+
n = 10
2831+
exp = self._setup_multi_type_branin_experiment(n=n)
2832+
2833+
# Fetch only trial 1 (type2).
2834+
df = exp.fetch_data(trial_indices=[1]).df
2835+
self.assertEqual(len(df), n)
2836+
self.assertEqual(set(df["metric_name"]), {"m2"})
2837+
self.assertTrue((df["trial_index"] == 1).all())
2838+
2839+
def test_fetch_data_skips_non_expecting_trials_with_trial_types(self) -> None:
2840+
"""fetch_data should skip trials not expecting data when
2841+
_trial_type_to_metric_names is populated."""
2842+
n = 10
2843+
exp = self._setup_multi_type_branin_experiment(n=n)
2844+
2845+
# Mark trial 0 as abandoned so it doesn't expect data.
2846+
exp.trials[0].mark_abandoned()
2847+
2848+
df = exp.fetch_data().df
2849+
# Only trial 1 should have data.
2850+
self.assertEqual(len(df), n)
2851+
self.assertTrue((df["trial_index"] == 1).all())
2852+
self.assertEqual(set(df["metric_name"]), {"m2"})
2853+
2854+
def test_fetch_trial_data_filters_metrics_by_trial_type(self) -> None:
2855+
"""_fetch_trial_data should filter to only metrics relevant to the
2856+
trial's type when _trial_type_to_metric_names is populated."""
2857+
n = 10
2858+
exp = self._setup_multi_type_branin_experiment(n=n)
2859+
2860+
# Fetch data for trial 0 (type1) — should only contain "m1".
2861+
results_0 = exp._fetch_trial_data(trial_index=0)
2862+
self.assertIn("m1", results_0)
2863+
self.assertNotIn("m2", results_0)
2864+
2865+
# Fetch data for trial 1 (type2) — should only contain "m2".
2866+
results_1 = exp._fetch_trial_data(trial_index=1)
2867+
self.assertIn("m2", results_1)
2868+
self.assertNotIn("m1", results_1)
2869+
2870+
def test_fetch_trial_data_filters_explicit_metrics_by_trial_type(self) -> None:
2871+
"""_fetch_trial_data should filter even an explicit metrics list to
2872+
only those relevant to the trial's type."""
2873+
n = 10
2874+
exp = self._setup_multi_type_branin_experiment(n=n)
2875+
2876+
both_metrics = list(exp.metrics.values())
2877+
# Passing both metrics to a type1 trial should still only return m1.
2878+
results_0 = exp._fetch_trial_data(trial_index=0, metrics=both_metrics)
2879+
self.assertIn("m1", results_0)
2880+
self.assertNotIn("m2", results_0)
2881+
2882+
# Passing both metrics to a type2 trial should still only return m2.
2883+
results_1 = exp._fetch_trial_data(trial_index=1, metrics=both_metrics)
2884+
self.assertIn("m2", results_1)
2885+
self.assertNotIn("m1", results_1)

0 commit comments

Comments
 (0)