Skip to content

Commit c56f900

Browse files
mpolson64facebook-github-bot
authored andcommitted
Remove MultiTypeExperiment (#4875)
Summary: Guts MultiTypeExperiment class down to a deprecation warning, replaces all places which initialize a MultiTypeExperiment with a base Experiment, and updated type annotations. Updated storage accordingly; previously stored MultiTypeExperiments will be correctly decoded as Experiments. As previously discussed, also deprecated metric_to_cannonical_name mapping as we intend to reexamine this design in the context of the metric_signature field in Data we added in H2 2025 Differential Revision: D92089176
1 parent 67ff859 commit c56f900

File tree

13 files changed

+201
-533
lines changed

13 files changed

+201
-533
lines changed

ax/core/experiment.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,15 @@ def trial_indices_with_data(
12251225

12261226
return trials_with_data
12271227

1228+
@property
1229+
def default_trials(self) -> set[int]:
1230+
"""Return the indicies for trials of the default type."""
1231+
return {
1232+
idx
1233+
for idx, trial in self.trials.items()
1234+
if trial.trial_type == self.default_trial_type
1235+
}
1236+
12281237
def new_trial(
12291238
self,
12301239
generator_run: GeneratorRun | None = None,

ax/core/multi_type_experiment.py

Lines changed: 12 additions & 317 deletions
Original file line numberDiff line numberDiff line change
@@ -6,334 +6,27 @@
66

77
# pyre-strict
88

9-
from collections.abc import Iterable, Sequence
10-
from typing import Any, Self
9+
from collections.abc import Sequence
1110

12-
from ax.core.arm import Arm
13-
from ax.core.base_trial import BaseTrial, TrialStatus
14-
from ax.core.data import Data
11+
from ax.core.base_trial import BaseTrial
1512
from ax.core.experiment import Experiment
16-
from ax.core.metric import Metric, MetricFetchResult
17-
from ax.core.optimization_config import OptimizationConfig
18-
from ax.core.runner import Runner
19-
from ax.core.search_space import SearchSpace
20-
from ax.utils.common.docutils import copy_doc
21-
from pyre_extensions import none_throws
13+
from ax.core.trial_status import TrialStatus
2214

2315

2416
class MultiTypeExperiment(Experiment):
2517
"""Class for experiment with multiple trial types.
2618
27-
A canonical use case for this is tuning a large production system
28-
with limited evaluation budget and a simulator which approximates
29-
evaluations on the main system. Trial deployment and data fetching
30-
is separate for the two systems, but the final data is combined and
31-
fed into multi-task models.
19+
.. deprecated::
20+
The `MultiTypeExperiment` class is deprecated. Use `Experiment` with
21+
`default_trial_type` parameter instead. All multi-type experiment
22+
functionality has been moved to the base `Experiment` class.
3223
33-
See the Multi-Task Modeling tutorial for more details.
34-
35-
Attributes:
36-
name: Name of the experiment.
37-
description: Description of the experiment.
3824
"""
3925

40-
def __init__(
41-
self,
42-
name: str,
43-
search_space: SearchSpace,
44-
default_trial_type: str,
45-
default_runner: Runner | None,
46-
optimization_config: OptimizationConfig | None = None,
47-
tracking_metrics: list[Metric] | None = None,
48-
status_quo: Arm | None = None,
49-
description: str | None = None,
50-
is_test: bool = False,
51-
experiment_type: str | None = None,
52-
properties: dict[str, Any] | None = None,
53-
default_data_type: Any = None,
54-
) -> None:
55-
"""Inits Experiment.
56-
57-
Args:
58-
name: Name of the experiment.
59-
search_space: Search space of the experiment.
60-
default_trial_type: Default type for trials on this experiment.
61-
default_runner: Default runner for trials of the default type.
62-
optimization_config: Optimization config of the experiment.
63-
tracking_metrics: Additional tracking metrics not used for optimization.
64-
These are associated with the default trial type.
65-
runner: Default runner used for trials on this experiment.
66-
status_quo: Arm representing existing "control" arm.
67-
description: Description of the experiment.
68-
is_test: Convenience metadata tracker for the user to mark test experiments.
69-
experiment_type: The class of experiments this one belongs to.
70-
properties: Dictionary of this experiment's properties.
71-
default_data_type: Deprecated and ignored.
72-
"""
73-
74-
# Specifies which trial type each metric belongs to
75-
self._metric_to_trial_type: dict[str, str] = {}
76-
77-
# Maps certain metric names to a canonical name. Useful for ancillary trial
78-
# types' metrics, to specify which primary metrics they correspond to
79-
# (e.g. 'comment_prediction' => 'comment')
80-
self._metric_to_canonical_name: dict[str, str] = {}
81-
82-
# call super.__init__() after defining fields above, because we need
83-
# them to be populated before optimization config is set
84-
super().__init__(
85-
name=name,
86-
search_space=search_space,
87-
optimization_config=optimization_config,
88-
status_quo=status_quo,
89-
description=description,
90-
is_test=is_test,
91-
experiment_type=experiment_type,
92-
properties=properties,
93-
tracking_metrics=tracking_metrics,
94-
runner=default_runner,
95-
default_trial_type=default_trial_type,
96-
default_data_type=default_data_type,
97-
)
98-
99-
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
100-
"""Add a new trial_type to be supported by this experiment.
101-
102-
Args:
103-
trial_type: The new trial_type to be added.
104-
runner: The default runner for trials of this type.
105-
"""
106-
if self.supports_trial_type(trial_type):
107-
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")
108-
109-
self._trial_type_to_runner[trial_type] = runner
110-
return self
111-
112-
# pyre-fixme [56]: Pyre was not able to infer the type of the decorator
113-
# `Experiment.optimization_config.setter`.
114-
@Experiment.optimization_config.setter
115-
def optimization_config(self, optimization_config: OptimizationConfig) -> None:
116-
# pyre-fixme [16]: `Optional` has no attribute `fset`.
117-
Experiment.optimization_config.fset(self, optimization_config)
118-
for metric_name in optimization_config.metrics.keys():
119-
# Optimization config metrics are required to be the default trial type
120-
# currently. TODO: remove that restriction (T202797235)
121-
self._metric_to_trial_type[metric_name] = none_throws(
122-
self.default_trial_type
123-
)
124-
125-
def update_runner(self, trial_type: str, runner: Runner) -> Self:
126-
"""Update the default runner for an existing trial_type.
127-
128-
Args:
129-
trial_type: The new trial_type to be added.
130-
runner: The new runner for trials of this type.
131-
"""
132-
if not self.supports_trial_type(trial_type):
133-
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
134-
135-
self._trial_type_to_runner[trial_type] = runner
136-
self._runner = runner
137-
return self
138-
139-
def add_tracking_metric(
140-
self,
141-
metric: Metric,
142-
trial_type: str | None = None,
143-
canonical_name: str | None = None,
144-
) -> Self:
145-
"""Add a new metric to the experiment.
146-
147-
Args:
148-
metric: The metric to add.
149-
trial_type: The trial type for which this metric is used.
150-
canonical_name: The default metric for which this metric is a proxy.
151-
"""
152-
if trial_type is None:
153-
trial_type = self._default_trial_type
154-
if not self.supports_trial_type(trial_type):
155-
raise ValueError(f"`{trial_type}` is not a supported trial type.")
156-
157-
super().add_tracking_metric(metric)
158-
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
159-
if canonical_name is not None:
160-
self._metric_to_canonical_name[metric.name] = canonical_name
161-
return self
162-
163-
def add_tracking_metrics(
164-
self,
165-
metrics: list[Metric],
166-
metrics_to_trial_types: dict[str, str] | None = None,
167-
canonical_names: dict[str, str] | None = None,
168-
) -> Experiment:
169-
"""Add a list of new metrics to the experiment.
170-
171-
If any of the metrics are already defined on the experiment,
172-
we raise an error and don't add any of them to the experiment
173-
174-
Args:
175-
metrics: Metrics to be added.
176-
metrics_to_trial_types: The mapping from metric names to corresponding
177-
trial types for each metric. If provided, the metrics will be
178-
added to their trial types. If not provided, then the default
179-
trial type will be used.
180-
canonical_names: A mapping of metric names to their
181-
canonical names(The default metrics for which the metrics are
182-
proxies.)
183-
184-
Returns:
185-
The experiment with the added metrics.
186-
"""
187-
metrics_to_trial_types = metrics_to_trial_types or {}
188-
canonical_name = None
189-
for metric in metrics:
190-
if canonical_names is not None:
191-
canonical_name = none_throws(canonical_names).get(metric.name, None)
192-
193-
self.add_tracking_metric(
194-
metric=metric,
195-
trial_type=metrics_to_trial_types.get(
196-
metric.name, self._default_trial_type
197-
),
198-
canonical_name=canonical_name,
199-
)
200-
return self
201-
202-
def update_tracking_metric(
203-
self,
204-
metric: Metric,
205-
trial_type: str | None = None,
206-
canonical_name: str | None = None,
207-
) -> Self:
208-
"""Update an existing metric on the experiment.
209-
210-
Args:
211-
metric: The metric to add.
212-
trial_type: The trial type for which this metric is used. Defaults to
213-
the current trial type of the metric (if set), or the default trial
214-
type otherwise.
215-
canonical_name: The default metric for which this metric is a proxy.
216-
"""
217-
# Default to the existing trial type if not specified
218-
if trial_type is None:
219-
trial_type = self._metric_to_trial_type.get(
220-
metric.name, self._default_trial_type
221-
)
222-
oc = self.optimization_config
223-
oc_metrics = oc.metrics if oc else []
224-
if metric.name in oc_metrics and trial_type != self._default_trial_type:
225-
raise ValueError(
226-
f"Metric `{metric.name}` must remain a "
227-
f"`{self._default_trial_type}` metric because it is part of the "
228-
"optimization_config."
229-
)
230-
elif not self.supports_trial_type(trial_type):
231-
raise ValueError(f"`{trial_type}` is not a supported trial type.")
232-
233-
super().update_tracking_metric(metric)
234-
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
235-
if canonical_name is not None:
236-
self._metric_to_canonical_name[metric.name] = canonical_name
237-
return self
238-
239-
@copy_doc(Experiment.remove_tracking_metric)
240-
def remove_tracking_metric(self, metric_name: str) -> Self:
241-
if metric_name not in self._tracking_metrics:
242-
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")
243-
244-
# Required fields
245-
del self._tracking_metrics[metric_name]
246-
del self._metric_to_trial_type[metric_name]
247-
248-
# Optional
249-
if metric_name in self._metric_to_canonical_name:
250-
del self._metric_to_canonical_name[metric_name]
251-
return self
252-
253-
@copy_doc(Experiment.fetch_data)
254-
def fetch_data(
255-
self,
256-
trial_indices: Iterable[int] | None = None,
257-
metrics: list[Metric] | None = None,
258-
**kwargs: Any,
259-
) -> Data:
260-
# TODO: make this more efficient for fetching
261-
# data for multiple trials of the same type
262-
# by overriding Experiment._lookup_or_fetch_trials_results
263-
return Data.from_multiple_data(
264-
[
265-
(
266-
trial.fetch_data(**kwargs, metrics=metrics)
267-
if trial.status.expecting_data
268-
else Data()
269-
)
270-
for trial in self.trials.values()
271-
]
272-
)
273-
274-
@copy_doc(Experiment._fetch_trial_data)
275-
def _fetch_trial_data(
276-
self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any
277-
) -> dict[str, MetricFetchResult]:
278-
trial = self.trials[trial_index]
279-
metrics = [
280-
metric
281-
for metric in (metrics or self.metrics.values())
282-
if self.metric_to_trial_type[metric.name] == trial.trial_type
283-
]
284-
# Invoke parent's fetch method using only metrics for this trial_type
285-
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)
286-
287-
@property
288-
def default_trials(self) -> set[int]:
289-
"""Return the indicies for trials of the default type."""
290-
return {
291-
idx
292-
for idx, trial in self.trials.items()
293-
if trial.trial_type == self.default_trial_type
294-
}
295-
296-
@property
297-
def metric_to_trial_type(self) -> dict[str, str]:
298-
"""Map metrics to trial types.
299-
300-
Adds in default trial type for OC metrics to custom defined trial types..
301-
"""
302-
opt_config_types = {
303-
metric_name: self.default_trial_type
304-
for metric_name in self.optimization_config.metrics.keys()
305-
}
306-
return {**opt_config_types, **self._metric_to_trial_type}
307-
308-
# -- Overridden functions from Base Experiment Class --
309-
@property
310-
def default_trial_type(self) -> str | None:
311-
"""Default trial type assigned to trials in this experiment."""
312-
return self._default_trial_type
313-
314-
def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
315-
"""The default runner to use for a given trial type.
316-
317-
Looks up the appropriate runner for this trial type in the trial_type_to_runner.
318-
"""
319-
if not self.supports_trial_type(trial_type):
320-
raise ValueError(f"Trial type `{trial_type}` is not supported.")
321-
return [
322-
self.metrics[metric_name]
323-
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
324-
if metric_trial_type == trial_type
325-
]
326-
327-
def supports_trial_type(self, trial_type: str | None) -> bool:
328-
"""Whether this experiment allows trials of the given type.
329-
330-
Only trial types defined in the trial_type_to_runner are allowed.
331-
"""
332-
return trial_type in self._trial_type_to_runner.keys()
333-
33426

33527
def filter_trials_by_type(
336-
trials: Sequence[BaseTrial], trial_type: str | None
28+
trials: Sequence[BaseTrial],
29+
trial_type: str | None,
33730
) -> list[BaseTrial]:
33831
"""Filter trials by trial type if provided.
33932
@@ -352,7 +45,9 @@ def filter_trials_by_type(
35245

35346

35447
def get_trial_indices_for_statuses(
355-
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
48+
experiment: Experiment,
49+
statuses: set[TrialStatus],
50+
trial_type: str | None = None,
35651
) -> set[int]:
35752
"""Get trial indices for a set of statuses.
35853

0 commit comments

Comments
 (0)