66
77# pyre-strict
88
9- from collections .abc import Iterable , Sequence
10- from typing import Any , Self
9+ from collections .abc import Sequence
1110
12- from ax .core .arm import Arm
13- from ax .core .base_trial import BaseTrial , TrialStatus
14- from ax .core .data import Data
11+ from ax .core .base_trial import BaseTrial
1512from ax .core .experiment import Experiment
16- from ax .core .metric import Metric , MetricFetchResult
17- from ax .core .optimization_config import OptimizationConfig
18- from ax .core .runner import Runner
19- from ax .core .search_space import SearchSpace
20- from ax .utils .common .docutils import copy_doc
21- from pyre_extensions import none_throws
13+ from ax .core .trial_status import TrialStatus
2214
2315
2416class MultiTypeExperiment (Experiment ):
2517 """Class for experiment with multiple trial types.
2618
27- A canonical use case for this is tuning a large production system
28- with limited evaluation budget and a simulator which approximates
29- evaluations on the main system. Trial deployment and data fetching
30- is separate for the two systems, but the final data is combined and
31- fed into multi-task models.
19+ .. deprecated::
20+ The `MultiTypeExperiment` class is deprecated. Use `Experiment` with
21+ `default_trial_type` parameter instead. All multi-type experiment
22+ functionality has been moved to the base `Experiment` class.
3223
33- See the Multi-Task Modeling tutorial for more details.
34-
35- Attributes:
36- name: Name of the experiment.
37- description: Description of the experiment.
3824 """
3925
40- def __init__ (
41- self ,
42- name : str ,
43- search_space : SearchSpace ,
44- default_trial_type : str ,
45- default_runner : Runner | None ,
46- optimization_config : OptimizationConfig | None = None ,
47- tracking_metrics : list [Metric ] | None = None ,
48- status_quo : Arm | None = None ,
49- description : str | None = None ,
50- is_test : bool = False ,
51- experiment_type : str | None = None ,
52- properties : dict [str , Any ] | None = None ,
53- default_data_type : Any = None ,
54- ) -> None :
55- """Inits Experiment.
56-
57- Args:
58- name: Name of the experiment.
59- search_space: Search space of the experiment.
60- default_trial_type: Default type for trials on this experiment.
61- default_runner: Default runner for trials of the default type.
62- optimization_config: Optimization config of the experiment.
63- tracking_metrics: Additional tracking metrics not used for optimization.
64- These are associated with the default trial type.
65- runner: Default runner used for trials on this experiment.
66- status_quo: Arm representing existing "control" arm.
67- description: Description of the experiment.
68- is_test: Convenience metadata tracker for the user to mark test experiments.
69- experiment_type: The class of experiments this one belongs to.
70- properties: Dictionary of this experiment's properties.
71- default_data_type: Deprecated and ignored.
72- """
73-
74- # Specifies which trial type each metric belongs to
75- self ._metric_to_trial_type : dict [str , str ] = {}
76-
77- # Maps certain metric names to a canonical name. Useful for ancillary trial
78- # types' metrics, to specify which primary metrics they correspond to
79- # (e.g. 'comment_prediction' => 'comment')
80- self ._metric_to_canonical_name : dict [str , str ] = {}
81-
82- # call super.__init__() after defining fields above, because we need
83- # them to be populated before optimization config is set
84- super ().__init__ (
85- name = name ,
86- search_space = search_space ,
87- optimization_config = optimization_config ,
88- status_quo = status_quo ,
89- description = description ,
90- is_test = is_test ,
91- experiment_type = experiment_type ,
92- properties = properties ,
93- tracking_metrics = tracking_metrics ,
94- runner = default_runner ,
95- default_trial_type = default_trial_type ,
96- default_data_type = default_data_type ,
97- )
98-
99- def add_trial_type (self , trial_type : str , runner : Runner ) -> Self :
100- """Add a new trial_type to be supported by this experiment.
101-
102- Args:
103- trial_type: The new trial_type to be added.
104- runner: The default runner for trials of this type.
105- """
106- if self .supports_trial_type (trial_type ):
107- raise ValueError (f"Experiment already contains trial_type `{ trial_type } `" )
108-
109- self ._trial_type_to_runner [trial_type ] = runner
110- return self
111-
112- # pyre-fixme [56]: Pyre was not able to infer the type of the decorator
113- # `Experiment.optimization_config.setter`.
114- @Experiment .optimization_config .setter
115- def optimization_config (self , optimization_config : OptimizationConfig ) -> None :
116- # pyre-fixme [16]: `Optional` has no attribute `fset`.
117- Experiment .optimization_config .fset (self , optimization_config )
118- for metric_name in optimization_config .metrics .keys ():
119- # Optimization config metrics are required to be the default trial type
120- # currently. TODO: remove that restriction (T202797235)
121- self ._metric_to_trial_type [metric_name ] = none_throws (
122- self .default_trial_type
123- )
124-
125- def update_runner (self , trial_type : str , runner : Runner ) -> Self :
126- """Update the default runner for an existing trial_type.
127-
128- Args:
129- trial_type: The new trial_type to be added.
130- runner: The new runner for trials of this type.
131- """
132- if not self .supports_trial_type (trial_type ):
133- raise ValueError (f"Experiment does not contain trial_type `{ trial_type } `" )
134-
135- self ._trial_type_to_runner [trial_type ] = runner
136- self ._runner = runner
137- return self
138-
139- def add_tracking_metric (
140- self ,
141- metric : Metric ,
142- trial_type : str | None = None ,
143- canonical_name : str | None = None ,
144- ) -> Self :
145- """Add a new metric to the experiment.
146-
147- Args:
148- metric: The metric to add.
149- trial_type: The trial type for which this metric is used.
150- canonical_name: The default metric for which this metric is a proxy.
151- """
152- if trial_type is None :
153- trial_type = self ._default_trial_type
154- if not self .supports_trial_type (trial_type ):
155- raise ValueError (f"`{ trial_type } ` is not a supported trial type." )
156-
157- super ().add_tracking_metric (metric )
158- self ._metric_to_trial_type [metric .name ] = none_throws (trial_type )
159- if canonical_name is not None :
160- self ._metric_to_canonical_name [metric .name ] = canonical_name
161- return self
162-
163- def add_tracking_metrics (
164- self ,
165- metrics : list [Metric ],
166- metrics_to_trial_types : dict [str , str ] | None = None ,
167- canonical_names : dict [str , str ] | None = None ,
168- ) -> Experiment :
169- """Add a list of new metrics to the experiment.
170-
171- If any of the metrics are already defined on the experiment,
172- we raise an error and don't add any of them to the experiment
173-
174- Args:
175- metrics: Metrics to be added.
176- metrics_to_trial_types: The mapping from metric names to corresponding
177- trial types for each metric. If provided, the metrics will be
178- added to their trial types. If not provided, then the default
179- trial type will be used.
180- canonical_names: A mapping of metric names to their
181- canonical names(The default metrics for which the metrics are
182- proxies.)
183-
184- Returns:
185- The experiment with the added metrics.
186- """
187- metrics_to_trial_types = metrics_to_trial_types or {}
188- canonical_name = None
189- for metric in metrics :
190- if canonical_names is not None :
191- canonical_name = none_throws (canonical_names ).get (metric .name , None )
192-
193- self .add_tracking_metric (
194- metric = metric ,
195- trial_type = metrics_to_trial_types .get (
196- metric .name , self ._default_trial_type
197- ),
198- canonical_name = canonical_name ,
199- )
200- return self
201-
202- def update_tracking_metric (
203- self ,
204- metric : Metric ,
205- trial_type : str | None = None ,
206- canonical_name : str | None = None ,
207- ) -> Self :
208- """Update an existing metric on the experiment.
209-
210- Args:
211- metric: The metric to add.
212- trial_type: The trial type for which this metric is used. Defaults to
213- the current trial type of the metric (if set), or the default trial
214- type otherwise.
215- canonical_name: The default metric for which this metric is a proxy.
216- """
217- # Default to the existing trial type if not specified
218- if trial_type is None :
219- trial_type = self ._metric_to_trial_type .get (
220- metric .name , self ._default_trial_type
221- )
222- oc = self .optimization_config
223- oc_metrics = oc .metrics if oc else []
224- if metric .name in oc_metrics and trial_type != self ._default_trial_type :
225- raise ValueError (
226- f"Metric `{ metric .name } ` must remain a "
227- f"`{ self ._default_trial_type } ` metric because it is part of the "
228- "optimization_config."
229- )
230- elif not self .supports_trial_type (trial_type ):
231- raise ValueError (f"`{ trial_type } ` is not a supported trial type." )
232-
233- super ().update_tracking_metric (metric )
234- self ._metric_to_trial_type [metric .name ] = none_throws (trial_type )
235- if canonical_name is not None :
236- self ._metric_to_canonical_name [metric .name ] = canonical_name
237- return self
238-
239- @copy_doc (Experiment .remove_tracking_metric )
240- def remove_tracking_metric (self , metric_name : str ) -> Self :
241- if metric_name not in self ._tracking_metrics :
242- raise ValueError (f"Metric `{ metric_name } ` doesn't exist on experiment." )
243-
244- # Required fields
245- del self ._tracking_metrics [metric_name ]
246- del self ._metric_to_trial_type [metric_name ]
247-
248- # Optional
249- if metric_name in self ._metric_to_canonical_name :
250- del self ._metric_to_canonical_name [metric_name ]
251- return self
252-
253- @copy_doc (Experiment .fetch_data )
254- def fetch_data (
255- self ,
256- trial_indices : Iterable [int ] | None = None ,
257- metrics : list [Metric ] | None = None ,
258- ** kwargs : Any ,
259- ) -> Data :
260- # TODO: make this more efficient for fetching
261- # data for multiple trials of the same type
262- # by overriding Experiment._lookup_or_fetch_trials_results
263- return Data .from_multiple_data (
264- [
265- (
266- trial .fetch_data (** kwargs , metrics = metrics )
267- if trial .status .expecting_data
268- else Data ()
269- )
270- for trial in self .trials .values ()
271- ]
272- )
273-
274- @copy_doc (Experiment ._fetch_trial_data )
275- def _fetch_trial_data (
276- self , trial_index : int , metrics : list [Metric ] | None = None , ** kwargs : Any
277- ) -> dict [str , MetricFetchResult ]:
278- trial = self .trials [trial_index ]
279- metrics = [
280- metric
281- for metric in (metrics or self .metrics .values ())
282- if self .metric_to_trial_type [metric .name ] == trial .trial_type
283- ]
284- # Invoke parent's fetch method using only metrics for this trial_type
285- return super ()._fetch_trial_data (trial .index , metrics = metrics , ** kwargs )
286-
287- @property
288- def default_trials (self ) -> set [int ]:
289- """Return the indicies for trials of the default type."""
290- return {
291- idx
292- for idx , trial in self .trials .items ()
293- if trial .trial_type == self .default_trial_type
294- }
295-
296- @property
297- def metric_to_trial_type (self ) -> dict [str , str ]:
298- """Map metrics to trial types.
299-
300- Adds in default trial type for OC metrics to custom defined trial types..
301- """
302- opt_config_types = {
303- metric_name : self .default_trial_type
304- for metric_name in self .optimization_config .metrics .keys ()
305- }
306- return {** opt_config_types , ** self ._metric_to_trial_type }
307-
308- # -- Overridden functions from Base Experiment Class --
309- @property
310- def default_trial_type (self ) -> str | None :
311- """Default trial type assigned to trials in this experiment."""
312- return self ._default_trial_type
313-
314- def metrics_for_trial_type (self , trial_type : str ) -> list [Metric ]:
315- """The default runner to use for a given trial type.
316-
317- Looks up the appropriate runner for this trial type in the trial_type_to_runner.
318- """
319- if not self .supports_trial_type (trial_type ):
320- raise ValueError (f"Trial type `{ trial_type } ` is not supported." )
321- return [
322- self .metrics [metric_name ]
323- for metric_name , metric_trial_type in self ._metric_to_trial_type .items ()
324- if metric_trial_type == trial_type
325- ]
326-
327- def supports_trial_type (self , trial_type : str | None ) -> bool :
328- """Whether this experiment allows trials of the given type.
329-
330- Only trial types defined in the trial_type_to_runner are allowed.
331- """
332- return trial_type in self ._trial_type_to_runner .keys ()
333-
33426
33527def filter_trials_by_type (
336- trials : Sequence [BaseTrial ], trial_type : str | None
28+ trials : Sequence [BaseTrial ],
29+ trial_type : str | None ,
33730) -> list [BaseTrial ]:
33831 """Filter trials by trial type if provided.
33932
@@ -352,7 +45,9 @@ def filter_trials_by_type(
35245
35346
35447def get_trial_indices_for_statuses (
355- experiment : Experiment , statuses : set [TrialStatus ], trial_type : str | None = None
48+ experiment : Experiment ,
49+ statuses : set [TrialStatus ],
50+ trial_type : str | None = None ,
35651) -> set [int ]:
35752 """Get trial indices for a set of statuses.
35853
0 commit comments