Skip to content

Commit 7ab24b3

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Move add_trial_type, update_runner, supports_trial_type to base Experiment (#5003)
Summary: Pull Request resolved: #5003 Phase 3 of moving MultiTypeExperiment features into base Experiment. Moves `add_trial_type` and `update_runner` from MultiTypeExperiment to the base Experiment class, making them available to all experiments. Updates `supports_trial_type` to unify the logic: for multi-type experiments (where `default_trial_type` is set), only trial types registered in `_trial_type_to_runner` are supported. For single-type experiments, `None` is supported along with `SHORT_RUN` and `LONG_RUN` for backward compatibility with generation strategies that use those trial types. Removes the corresponding overrides from MultiTypeExperiment — all three methods are now inherited from the base class. Reviewed By: saitcakmak Differential Revision: D94988577
1 parent ec492dc commit 7ab24b3

2 files changed

Lines changed: 36 additions & 36 deletions

File tree

ax/core/experiment.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,34 @@ def default_trials(self) -> set[int]:
20332033
if trial.trial_type == self.default_trial_type
20342034
}
20352035

2036+
def add_trial_type(self, trial_type: str, runner: Runner | None = None) -> Self:
2037+
"""Add a new trial type to be supported by this experiment.
2038+
2039+
Args:
2040+
trial_type: The new trial type to be added.
2041+
runner: The default runner for trials of this type.
2042+
"""
2043+
if self.supports_trial_type(trial_type):
2044+
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")
2045+
2046+
if runner is not None:
2047+
self._trial_type_to_runner[trial_type] = runner
2048+
2049+
return self
2050+
2051+
def update_runner(self, trial_type: str, runner: Runner) -> Self:
2052+
"""Update the default runner for an existing trial type.
2053+
2054+
Args:
2055+
trial_type: The trial type whose runner should be updated.
2056+
runner: The new runner for trials of this type.
2057+
"""
2058+
if not self.supports_trial_type(trial_type):
2059+
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
2060+
self._trial_type_to_runner[trial_type] = runner
2061+
self._runner = runner
2062+
return self
2063+
20362064
def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
20372065
"""The default runner to use for a given trial type.
20382066
@@ -2047,14 +2075,20 @@ def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
20472075
def supports_trial_type(self, trial_type: str | None) -> bool:
20482076
"""Whether this experiment allows trials of the given type.
20492077
2050-
The base experiment class only supports None. For experiments
2051-
with multiple trial types, use the MultiTypeExperiment class.
2078+
For experiments with a ``default_trial_type`` (multi-type experiments),
2079+
only trial types registered in ``_trial_type_to_runner`` are supported.
2080+
For single-type experiments, ``None`` is always supported, along with
2081+
``SHORT_RUN`` and ``LONG_RUN`` for backward compatibility with
2082+
generation strategies that use those trial types.
20522083
"""
2084+
if self._default_trial_type is not None:
2085+
return trial_type in self._trial_type_to_runner
20532086
return (
20542087
trial_type is None
20552088
or trial_type == Keys.SHORT_RUN
20562089
or trial_type == Keys.LONG_RUN
20572090
or trial_type == Keys.LILO_LABELING
2091+
or trial_type in self._trial_type_to_runner
20582092
)
20592093

20602094
def attach_trial(

ax/core/multi_type_experiment.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,6 @@ def __init__(
104104
self._default_trial_type
105105
)
106106

107-
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
108-
"""Add a new trial_type to be supported by this experiment.
109-
110-
Args:
111-
trial_type: The new trial_type to be added.
112-
runner: The default runner for trials of this type.
113-
"""
114-
if self.supports_trial_type(trial_type):
115-
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")
116-
117-
self._trial_type_to_runner[trial_type] = runner
118-
return self
119-
120107
# pyre does not support inferring the type of property setter decorators
121108
# or the `.fset` attribute on properties.
122109
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator.
@@ -130,20 +117,6 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
130117
self.default_trial_type
131118
)
132119

133-
def update_runner(self, trial_type: str, runner: Runner) -> Self:
134-
"""Update the default runner for an existing trial_type.
135-
136-
Args:
137-
trial_type: The new trial_type to be added.
138-
runner: The new runner for trials of this type.
139-
"""
140-
if not self.supports_trial_type(trial_type):
141-
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
142-
143-
self._trial_type_to_runner[trial_type] = runner
144-
self._runner = runner
145-
return self
146-
147120
def add_tracking_metric(
148121
self,
149122
metric: Metric,
@@ -232,13 +205,6 @@ def _fetch_trial_data(
232205
# Invoke parent's fetch method using only metrics for this trial_type
233206
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)
234207

235-
def supports_trial_type(self, trial_type: str | None) -> bool:
236-
"""Whether this experiment allows trials of the given type.
237-
238-
Only trial types defined in the trial_type_to_runner are allowed.
239-
"""
240-
return trial_type in self._trial_type_to_runner.keys()
241-
242208

243209
def filter_trials_by_type(
244210
trials: Sequence[BaseTrial], trial_type: str | None

0 commit comments

Comments
 (0)