Skip to content

Commit 8dde1b5

Browse files
mpolson64facebook-github-bot
authored andcommitted
Upstream MultiTypeExperiment features into Experiment
Summary: These changes will enable us to deprecate multitypeexperment, simplifying the Ax data model ahead of storage changes. 1. In Experiment make the default_trial_type a new Key.DEFAULT_TRIAL_TYPE value instead of None 2. Move over logic for bookkeeping metric -> trial_type and runner -> trial_type mappings 3. Treat LONG_ and SHORT_RUN trial types as special cases which map to DEFAULT_TRIAL_TYPE (i.e. if a Trial has trial_type=LONG_RUN then use whichever metrics and runners are mapped to DEFAULT_TRIAL_TYPE 4. Fix tests which expect the default_trial_type of an Experiment to be None This diff allows us to remove all isinstance(foo, MultiTypeExperiment) checks in Ax in the next diff, then to deprecate MultiTypeExperiment entirely. Differential Revision: D91618283
1 parent 8a706b8 commit 8dde1b5

File tree

12 files changed

+291
-67
lines changed

12 files changed

+291
-67
lines changed

ax/core/base_trial.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,15 @@ def __init__(
9898
self._ttl_seconds: int | None = ttl_seconds
9999
self._index: int = self._experiment._attach_trial(self, index=index)
100100

101-
trial_type = (
101+
self._trial_type: str = (
102102
trial_type
103103
if trial_type is not None
104104
else self._experiment.default_trial_type
105105
)
106-
if not self._experiment.supports_trial_type(trial_type):
106+
if not self._experiment.supports_trial_type(self._trial_type):
107107
raise ValueError(
108-
f"Trial type {trial_type} is not supported by the experiment."
108+
f"Trial type {self._trial_type} is not supported by the experiment."
109109
)
110-
self._trial_type = trial_type
111110

112111
self.__status: TrialStatus | None = None
113112
# Uses `_status` setter, which updates trial statuses to trial indices
@@ -285,7 +284,7 @@ def stop_metadata(self) -> dict[str, Any]:
285284
return self._stop_metadata
286285

287286
@property
288-
def trial_type(self) -> str | None:
287+
def trial_type(self) -> str:
289288
"""The type of the trial.
290289
291290
Relevant for experiments containing different kinds of trials

ax/core/experiment.py

Lines changed: 202 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
default_data_type: Any = None,
101101
auxiliary_experiments_by_purpose: None
102102
| (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None,
103-
default_trial_type: str | None = None,
103+
default_trial_type: str = Keys.DEFAULT_TRIAL_TYPE.value,
104104
) -> None:
105105
"""Inits Experiment.
106106
@@ -123,6 +123,8 @@ def __init__(
123123
default_data_type: Deprecated and ignored.
124124
auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments
125125
for different purposes (e.g., transfer learning).
126+
default_trial_type: Default trial type for trials on this experiment.
127+
Defaults to Keys.DEFAULT_TRIAL_TYPE.
126128
"""
127129
if default_data_type is not None:
128130
warnings.warn(
@@ -150,10 +152,16 @@ def __init__(
150152
self._properties: dict[str, Any] = properties or {}
151153

152154
# Initialize trial type to runner mapping
153-
self._default_trial_type = default_trial_type
154-
self._trial_type_to_runner: dict[str | None, Runner | None] = {
155-
default_trial_type: runner
155+
self._default_trial_type: str = (
156+
default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value
157+
)
158+
self._trial_type_to_runner: dict[str, Runner | None] = {
159+
self._default_trial_type: runner
156160
}
161+
162+
# Maps metric names to their trial types. Every metric must have an entry.
163+
self._metric_to_trial_type: dict[str, str] = {}
164+
157165
# Used to keep track of whether any trials on the experiment
158166
# specify a TTL. Since trials need to be checked for their TTL's
159167
# expiration often, having this attribute helps avoid unnecessary
@@ -413,16 +421,46 @@ def runner(self) -> Runner | None:
413421
def runner(self, runner: Runner | None) -> None:
414422
"""Set the default runner and update trial type mapping."""
415423
self._runner = runner
416-
if runner is not None:
417-
self._trial_type_to_runner[self._default_trial_type] = runner
418-
else:
419-
self._trial_type_to_runner = {None: None}
424+
self._trial_type_to_runner[self._default_trial_type] = runner
420425

421426
@runner.deleter
422427
def runner(self) -> None:
423428
"""Delete the runner."""
424429
self._runner = None
425-
self._trial_type_to_runner = {None: None}
430+
self._trial_type_to_runner[self._default_trial_type] = None
431+
432+
def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment":
433+
"""Add a new trial_type to be supported by this experiment.
434+
435+
Args:
436+
trial_type: The new trial_type to be added.
437+
runner: The default runner for trials of this type.
438+
439+
Returns:
440+
The experiment with the new trial type added.
441+
"""
442+
if self.supports_trial_type(trial_type):
443+
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")
444+
445+
self._trial_type_to_runner[trial_type] = runner
446+
return self
447+
448+
def update_runner(self, trial_type: str, runner: Runner) -> "Experiment":
449+
"""Update the default runner for an existing trial_type.
450+
451+
Args:
452+
trial_type: The trial_type to update.
453+
runner: The new runner for trials of this type.
454+
455+
Returns:
456+
The experiment with the updated runner.
457+
"""
458+
if not self.supports_trial_type(trial_type):
459+
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
460+
461+
self._trial_type_to_runner[trial_type] = runner
462+
self._runner = runner
463+
return self
426464

427465
@property
428466
def parameters(self) -> dict[str, Parameter]:
@@ -489,13 +527,25 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
489527
f"`{Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF.value}` "
490528
"property that is set to `True` on this experiment."
491529
)
530+
531+
# Remove old OC metrics from trial type mapping
532+
prev_optimization_config = self._optimization_config
533+
if prev_optimization_config is not None:
534+
for metric_name in prev_optimization_config.metrics.keys():
535+
self._metric_to_trial_type.pop(metric_name, None)
536+
492537
for metric_name in optimization_config.metrics.keys():
493538
if metric_name in self._tracking_metrics:
494539
self.remove_tracking_metric(metric_name)
540+
495541
# add metrics from the previous optimization config that are not in the new
496542
# optimization config as tracking metrics
497-
prev_optimization_config = self._optimization_config
498543
self._optimization_config = optimization_config
544+
545+
# Map new OC metrics to default trial type
546+
for metric_name in optimization_config.metrics.keys():
547+
self._metric_to_trial_type[metric_name] = self._default_trial_type
548+
499549
if prev_optimization_config is not None:
500550
metrics_to_track = (
501551
set(prev_optimization_config.metrics.keys())
@@ -505,6 +555,16 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
505555
for metric_name in metrics_to_track:
506556
self.add_tracking_metric(prev_optimization_config.metrics[metric_name])
507557

558+
# Clean up any stale entries in _metric_to_trial_type that don't correspond
559+
# to actual metrics (can happen when same optimization_config object is
560+
# mutated and reassigned).
561+
current_metric_names = set(self.metrics.keys())
562+
stale_metric_names = (
563+
set(self._metric_to_trial_type.keys()) - current_metric_names
564+
)
565+
for metric_name in stale_metric_names:
566+
self._metric_to_trial_type.pop(metric_name, None)
567+
508568
@property
509569
def is_moo_problem(self) -> bool:
510570
"""Whether the experiment's optimization config contains multiple objectives."""
@@ -553,12 +613,25 @@ def immutable_search_space_and_opt_config(self) -> bool:
553613
def tracking_metrics(self) -> list[Metric]:
554614
return list(self._tracking_metrics.values())
555615

556-
def add_tracking_metric(self, metric: Metric) -> Self:
616+
def add_tracking_metric(
617+
self,
618+
metric: Metric,
619+
trial_type: str | None = None,
620+
) -> Self:
557621
"""Add a new metric to the experiment.
558622
559623
Args:
560624
metric: Metric to be added.
625+
trial_type: The trial type for which this metric is used. If not
626+
provided, defaults to the experiment's default trial type.
561627
"""
628+
effective_trial_type = (
629+
trial_type if trial_type is not None else self._default_trial_type
630+
)
631+
632+
if not self.supports_trial_type(effective_trial_type):
633+
raise ValueError(f"`{effective_trial_type}` is not a supported trial type.")
634+
562635
if metric.name in self._tracking_metrics:
563636
raise ValueError(
564637
f"Metric `{metric.name}` already defined on experiment. "
@@ -574,33 +647,73 @@ def add_tracking_metric(self, metric: Metric) -> Self:
574647
)
575648

576649
self._tracking_metrics[metric.name] = metric
650+
self._metric_to_trial_type[metric.name] = effective_trial_type
577651
return self
578652

579-
def add_tracking_metrics(self, metrics: list[Metric]) -> Self:
653+
def add_tracking_metrics(
654+
self,
655+
metrics: list[Metric],
656+
metrics_to_trial_types: dict[str, str] | None = None,
657+
) -> Self:
580658
"""Add a list of new metrics to the experiment.
581659
582660
If any of the metrics are already defined on the experiment,
583661
we raise an error and don't add any of them to the experiment
584662
585663
Args:
586664
metrics: Metrics to be added.
665+
metrics_to_trial_types: Optional mapping from metric names to
666+
corresponding trial types. If not provided for a metric,
667+
the experiment's default trial type is used.
587668
"""
588-
# Before setting any metrics, we validate none are already on
589-
# the experiment
669+
metrics_to_trial_types = metrics_to_trial_types or {}
590670
for metric in metrics:
591-
self.add_tracking_metric(metric)
671+
self.add_tracking_metric(
672+
metric=metric,
673+
trial_type=metrics_to_trial_types.get(metric.name),
674+
)
592675
return self
593676

594-
def update_tracking_metric(self, metric: Metric) -> Self:
677+
def update_tracking_metric(
678+
self,
679+
metric: Metric,
680+
trial_type: str | None = None,
681+
) -> Self:
595682
"""Redefine a metric that already exists on the experiment.
596683
597684
Args:
598685
metric: New metric definition.
686+
trial_type: The trial type for which this metric is used. If not
687+
provided, keeps the existing trial type mapping.
599688
"""
600689
if metric.name not in self._tracking_metrics:
601690
raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.")
602691

692+
# Validate trial type if provided
693+
effective_trial_type = (
694+
trial_type
695+
if trial_type is not None
696+
else self._metric_to_trial_type.get(metric.name, self._default_trial_type)
697+
)
698+
699+
# Check that optimization config metrics stay on default trial type
700+
oc = self.optimization_config
701+
oc_metrics = oc.metrics if oc else {}
702+
if (
703+
metric.name in oc_metrics
704+
and effective_trial_type != self._default_trial_type
705+
):
706+
raise ValueError(
707+
f"Metric `{metric.name}` must remain a "
708+
f"`{self._default_trial_type}` metric because it is part of the "
709+
"optimization_config."
710+
)
711+
712+
if not self.supports_trial_type(effective_trial_type):
713+
raise ValueError(f"`{effective_trial_type}` is not a supported trial type.")
714+
603715
self._tracking_metrics[metric.name] = metric
716+
self._metric_to_trial_type[metric.name] = effective_trial_type
604717
return self
605718

606719
def remove_tracking_metric(self, metric_name: str) -> Self:
@@ -613,6 +726,7 @@ def remove_tracking_metric(self, metric_name: str) -> Self:
613726
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")
614727

615728
del self._tracking_metrics[metric_name]
729+
self._metric_to_trial_type.pop(metric_name, None)
616730
return self
617731

618732
@property
@@ -852,8 +966,21 @@ def _fetch_trial_data(
852966
) -> dict[str, MetricFetchResult]:
853967
trial = self.trials[trial_index]
854968

969+
# If metrics are not provided, fetch all metrics on the experiment for the
970+
# relevant trial type, or the default trial type as a fallback. Otherwise,
971+
# fetch provided metrics.
972+
if metrics is None:
973+
resolved_metrics = [
974+
metric
975+
for metric in list(self.metrics.values())
976+
if self._metric_to_trial_type.get(metric.name, self._default_trial_type)
977+
== trial.trial_type
978+
]
979+
else:
980+
resolved_metrics = metrics
981+
855982
trial_data = self._lookup_or_fetch_trials_results(
856-
trials=[trial], metrics=metrics, **kwargs
983+
trials=[trial], metrics=resolved_metrics, **kwargs
857984
)
858985

859986
if trial_index in trial_data:
@@ -1548,39 +1675,79 @@ def __repr__(self) -> str:
15481675
# overridden in the MultiTypeExperiment class.
15491676

15501677
@property
1551-
def default_trial_type(self) -> str | None:
1552-
"""Default trial type assigned to trials in this experiment.
1553-
1554-
In the base experiment class this is always None. For experiments
1555-
with multiple trial types, use the MultiTypeExperiment class.
1556-
"""
1678+
def default_trial_type(self) -> str:
1679+
"""Default trial type assigned to trials in this experiment."""
15571680
return self._default_trial_type
15581681

1559-
def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
1682+
def runner_for_trial_type(self, trial_type: str) -> Runner | None:
15601683
"""The default runner to use for a given trial type.
15611684
15621685
Looks up the appropriate runner for this trial type in the trial_type_to_runner.
15631686
"""
1687+
# Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
1688+
# trial types for deployment.
1689+
if (
1690+
trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN
1691+
) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE):
1692+
return self._trial_type_to_runner[Keys.DEFAULT_TRIAL_TYPE]
1693+
15641694
if not self.supports_trial_type(trial_type):
15651695
raise ValueError(f"Trial type `{trial_type}` is not supported.")
15661696
if (runner := self._trial_type_to_runner.get(trial_type)) is None:
15671697
return self.runner # return the default runner
15681698
return runner
15691699

1570-
def supports_trial_type(self, trial_type: str | None) -> bool:
1700+
def supports_trial_type(self, trial_type: str) -> bool:
15711701
"""Whether this experiment allows trials of the given type.
15721702
1573-
The base experiment class only supports None. For experiments
1574-
with multiple trial types, use the MultiTypeExperiment class.
1703+
Checks if the trial type is registered in the trial_type_to_runner mapping.
15751704
"""
1576-
return (
1577-
trial_type is None
1578-
# We temporarily allow "short run" and "long run" trial
1579-
# types in single-type experiments during development of
1580-
# a new ``GenerationStrategy`` that needs them.
1581-
or trial_type == Keys.SHORT_RUN
1582-
or trial_type == Keys.LONG_RUN
1583-
)
1705+
# Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
1706+
# trial types for deployment.
1707+
if (
1708+
trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN
1709+
) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE):
1710+
return True
1711+
1712+
return trial_type in self._trial_type_to_runner
1713+
1714+
@property
1715+
def is_multi_type(self) -> bool:
1716+
"""Returns True if this experiment has multiple trial types registered."""
1717+
return len(self._trial_type_to_runner) > 1
1718+
1719+
@property
1720+
def metric_to_trial_type(self) -> dict[str, str]:
1721+
"""Read-only mapping of metric names to trial types."""
1722+
return self._metric_to_trial_type.copy()
1723+
1724+
def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
1725+
"""Returns metrics associated with a specific trial type.
1726+
1727+
Args:
1728+
trial_type: The trial type to get metrics for.
1729+
1730+
Returns:
1731+
List of metrics associated with the given trial type.
1732+
"""
1733+
# Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
1734+
# trial types for deployment.
1735+
if (
1736+
trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN
1737+
) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE):
1738+
return [
1739+
self.metrics[metric_name]
1740+
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
1741+
if metric_trial_type == Keys.DEFAULT_TRIAL_TYPE
1742+
]
1743+
1744+
if not self.supports_trial_type(trial_type):
1745+
raise ValueError(f"Trial type `{trial_type}` is not supported.")
1746+
return [
1747+
self.metrics[metric_name]
1748+
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
1749+
if metric_trial_type == trial_type
1750+
]
15841751

15851752
def attach_trial(
15861753
self,

0 commit comments

Comments
 (0)