Skip to content

Commit f72b4d0

Browse files
mpolson64facebook-github-bot
authored andcommitted
Filter fetch_data by trial type in base Experiment (facebook#5005)
Summary: Phase 4 of moving MultiTypeExperiment features into base Experiment. Updates the base Experiment's `_fetch_trial_data` to filter metrics by trial type when `_trial_type_to_metric_names` is populated. Only metrics whose names appear in the set for the trial's type are fetched, preventing metrics from being evaluated against trials of the wrong type. Updates `fetch_data` to iterate trial-by-trial when `_trial_type_to_metric_names` is populated (so each trial fetches only its relevant metrics), while preserving the existing bulk fetch path for single-type experiments. Removes the `fetch_data` and `_fetch_trial_data` overrides from MultiTypeExperiment, along with now-unused imports (Iterable, Data, MetricFetchResult). Differential Revision: D94990429
1 parent 2873662 commit f72b4d0

3 files changed

Lines changed: 142 additions & 37 deletions

File tree

ax/core/experiment.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,41 @@ def fetch_data(
11801180
Returns:
11811181
Data for the experiment.
11821182
"""
1183+
# Only use trial-type-aware fetching for multi-type experiments;
1184+
# single-type experiments have an empty mapping.
1185+
if self._trial_type_to_metric_names:
1186+
# When metrics are mapped to trial types, group trials by type
1187+
# and bulk-fetch per group so each group only fetches its
1188+
# relevant metrics.
1189+
all_trials = (
1190+
list(self.trials.values())
1191+
if trial_indices is None
1192+
else self.get_trials_by_indices(trial_indices)
1193+
)
1194+
trials_by_type: dict[str | None, list[BaseTrial]] = defaultdict(list)
1195+
for trial in all_trials:
1196+
if trial.status.expecting_data:
1197+
trials_by_type[trial.trial_type].append(trial)
1198+
all_results: dict[int, dict[str, MetricFetchResult]] = {}
1199+
for trial_type, type_trials in trials_by_type.items():
1200+
if metrics is not None and trial_type is not None:
1201+
valid_names = self._trial_type_to_metric_names.get(
1202+
trial_type, set()
1203+
)
1204+
type_metrics = [m for m in metrics if m.name in valid_names]
1205+
elif metrics is not None:
1206+
type_metrics = metrics
1207+
elif trial_type is not None:
1208+
type_metrics = self.metrics_for_trial_type(trial_type)
1209+
else:
1210+
type_metrics = list(self.metrics.values())
1211+
results = self._lookup_or_fetch_trials_results(
1212+
trials=type_trials,
1213+
metrics=type_metrics,
1214+
**kwargs,
1215+
)
1216+
all_results.update(results)
1217+
return Metric._unwrap_experiment_data_multi(results=all_results)
11831218
results = self._lookup_or_fetch_trials_results(
11841219
trials=list(self.trials.values())
11851220
if trial_indices is None
@@ -1320,6 +1355,14 @@ def _fetch_trial_data(
13201355
) -> dict[str, MetricFetchResult]:
13211356
trial = self.trials[trial_index]
13221357

1358+
# When metrics are mapped to trial types, filter to only the
1359+
# metrics relevant to this trial's type.
1360+
trial_type = trial.trial_type
1361+
if self._trial_type_to_metric_names and trial_type is not None:
1362+
valid_names = self._trial_type_to_metric_names.get(trial_type, set())
1363+
all_metrics = list(metrics or self.metrics.values())
1364+
metrics = [m for m in all_metrics if m.name in valid_names]
1365+
13231366
trial_data = self._lookup_or_fetch_trials_results(
13241367
trials=[trial], metrics=metrics, **kwargs
13251368
)

ax/core/multi_type_experiment.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66

77
# pyre-strict
88

9-
from collections.abc import Iterable, Sequence
9+
from collections.abc import Sequence
1010
from typing import Any, Self
1111

1212
from ax.core.arm import Arm
1313
from ax.core.base_trial import BaseTrial, TrialStatus
14-
from ax.core.data import Data
1514
from ax.core.experiment import Experiment
16-
from ax.core.metric import Metric, MetricFetchResult
15+
from ax.core.metric import Metric
1716
from ax.core.optimization_config import OptimizationConfig
1817
from ax.core.runner import Runner
1918
from ax.core.search_space import SearchSpace
@@ -171,40 +170,6 @@ def remove_metric(self, metric_name: str) -> Self:
171170
self._metric_to_canonical_name.pop(metric_name, None)
172171
return self
173172

174-
@copy_doc(Experiment.fetch_data)
175-
def fetch_data(
176-
self,
177-
trial_indices: Iterable[int] | None = None,
178-
metrics: list[Metric] | None = None,
179-
**kwargs: Any,
180-
) -> Data:
181-
# TODO: make this more efficient for fetching
182-
# data for multiple trials of the same type
183-
# by overriding Experiment._lookup_or_fetch_trials_results
184-
return Data.from_multiple_data(
185-
[
186-
(
187-
trial.fetch_data(**kwargs, metrics=metrics)
188-
if trial.status.expecting_data
189-
else Data()
190-
)
191-
for trial in self.trials.values()
192-
]
193-
)
194-
195-
@copy_doc(Experiment._fetch_trial_data)
196-
def _fetch_trial_data(
197-
self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any
198-
) -> dict[str, MetricFetchResult]:
199-
trial = self.trials[trial_index]
200-
metrics = [
201-
metric
202-
for metric in (metrics or self.metrics.values())
203-
if self.metric_to_trial_type[metric.name] == trial.trial_type
204-
]
205-
# Invoke parent's fetch method using only metrics for this trial_type
206-
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)
207-
208173

209174
def filter_trials_by_type(
210175
trials: Sequence[BaseTrial], trial_type: str | None

ax/core/tests/test_experiment.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2816,3 +2816,100 @@ def test_extract_relevant_trials(self) -> None:
28162816
)
28172817
self.assertEqual(len(trials), 1)
28182818
self.assertEqual(trials[0].index, 0)
2819+
2820+
def _setup_multi_type_branin_experiment(self, n: int = 10) -> Experiment:
2821+
"""Create a base Experiment with two trial types and metrics mapped
2822+
to each, mimicking a multi-type setup without using
2823+
MultiTypeExperiment.
2824+
"""
2825+
exp = Experiment(
2826+
name="multi_type_test",
2827+
search_space=get_branin_search_space(),
2828+
default_trial_type="type1",
2829+
tracking_metrics=[
2830+
BraninMetric(name="m1", param_names=["x1", "x2"]),
2831+
],
2832+
runner=SyntheticRunner(),
2833+
)
2834+
# Register a second trial type with its own runner and metric.
2835+
exp._trial_type_to_runner["type2"] = SyntheticRunner()
2836+
exp.add_tracking_metric(
2837+
BraninMetric(name="m2", param_names=["x2", "x1"]),
2838+
trial_type="type2",
2839+
)
2840+
2841+
# Create one batch per trial type and run them.
2842+
b1 = exp.new_batch_trial(trial_type="type1")
2843+
b1.add_arms_and_weights(arms=get_branin_arms(n=n, seed=0))
2844+
b1.run()
2845+
2846+
b2 = exp.new_batch_trial(trial_type="type2")
2847+
b2.add_arms_and_weights(arms=get_branin_arms(n=n, seed=0))
2848+
b2.run()
2849+
2850+
return exp
2851+
2852+
def test_fetch_data_with_trial_types(self) -> None:
2853+
"""fetch_data should correctly filter metrics by trial type."""
2854+
n = 10
2855+
exp = self._setup_multi_type_branin_experiment(n=n)
2856+
2857+
with self.subTest("filters_by_trial_type"):
2858+
df = exp.fetch_data().df
2859+
# Each trial should have n rows (one per arm), for a total of 2*n.
2860+
self.assertEqual(len(df), 2 * n)
2861+
2862+
# Trial 0 (type1) should only have metric "m1".
2863+
trial_0_df = df[df["trial_index"] == 0]
2864+
self.assertEqual(set(trial_0_df["metric_name"]), {"m1"})
2865+
self.assertEqual(len(trial_0_df), n)
2866+
2867+
# Trial 1 (type2) should only have metric "m2".
2868+
trial_1_df = df[df["trial_index"] == 1]
2869+
self.assertEqual(set(trial_1_df["metric_name"]), {"m2"})
2870+
self.assertEqual(len(trial_1_df), n)
2871+
2872+
with self.subTest("with_trial_indices"):
2873+
# Fetch only trial 1 (type2).
2874+
df = exp.fetch_data(trial_indices=[1]).df
2875+
self.assertEqual(len(df), n)
2876+
self.assertEqual(set(df["metric_name"]), {"m2"})
2877+
self.assertTrue((df["trial_index"] == 1).all())
2878+
2879+
with self.subTest("skips_non_expecting_trials"):
2880+
# Mark trial 0 as abandoned so it doesn't expect data.
2881+
exp.trials[0].mark_abandoned()
2882+
2883+
df = exp.fetch_data().df
2884+
# Only trial 1 should have data.
2885+
self.assertEqual(len(df), n)
2886+
self.assertTrue((df["trial_index"] == 1).all())
2887+
self.assertEqual(set(df["metric_name"]), {"m2"})
2888+
2889+
def test_fetch_trial_data_with_trial_types(self) -> None:
2890+
"""_fetch_trial_data should filter metrics by trial type."""
2891+
n = 10
2892+
exp = self._setup_multi_type_branin_experiment(n=n)
2893+
2894+
with self.subTest("filters_metrics_by_trial_type"):
2895+
# Fetch data for trial 0 (type1) -- should only contain "m1".
2896+
results_0 = exp._fetch_trial_data(trial_index=0)
2897+
self.assertIn("m1", results_0)
2898+
self.assertNotIn("m2", results_0)
2899+
2900+
# Fetch data for trial 1 (type2) -- should only contain "m2".
2901+
results_1 = exp._fetch_trial_data(trial_index=1)
2902+
self.assertIn("m2", results_1)
2903+
self.assertNotIn("m1", results_1)
2904+
2905+
with self.subTest("filters_explicit_metrics_by_trial_type"):
2906+
both_metrics = list(exp.metrics.values())
2907+
# Passing both metrics to a type1 trial should still only return m1.
2908+
results_0 = exp._fetch_trial_data(trial_index=0, metrics=both_metrics)
2909+
self.assertIn("m1", results_0)
2910+
self.assertNotIn("m2", results_0)
2911+
2912+
# Passing both metrics to a type2 trial should still only return m2.
2913+
results_1 = exp._fetch_trial_data(trial_index=1, metrics=both_metrics)
2914+
self.assertIn("m2", results_1)
2915+
self.assertNotIn("m1", results_1)

0 commit comments

Comments
 (0)