27
27
from ax .core .generator_run import GeneratorRun
28
28
from ax .core .map_data import MapData
29
29
from ax .core .map_metric import MapMetric
30
+ from ax .core .multi_type_experiment import MultiTypeExperiment
30
31
from ax .core .objective import MultiObjective , Objective
31
32
from ax .core .observation import ObservationFeatures
32
33
from ax .core .optimization_config import (
33
34
MultiObjectiveOptimizationConfig ,
34
35
OptimizationConfig ,
35
36
)
37
+ from ax .core .runner import Runner
36
38
from ax .core .trial import Trial
37
39
from ax .core .types import (
38
40
TEvaluationOutcome ,
39
41
TModelPredictArm ,
40
42
TParameterization ,
41
43
TParamValue ,
42
44
)
45
+
43
46
from ax .core .utils import get_pending_observation_features_based_on_trial_status
44
47
from ax .early_stopping .strategies import BaseEarlyStoppingStrategy
45
48
from ax .early_stopping .utils import estimate_early_stopping_savings
90
93
from ax .utils .common .typeutils import checked_cast
91
94
from pyre_extensions import assert_is_instance , none_throws
92
95
96
+
93
97
logger : Logger = get_logger (__name__ )
94
98
95
99
@@ -251,6 +255,8 @@ def create_experiment(
251
255
immutable_search_space_and_opt_config : bool = True ,
252
256
is_test : bool = False ,
253
257
metric_definitions : dict [str , dict [str , Any ]] | None = None ,
258
+ default_trial_type : str | None = None ,
259
+ default_runner : Runner | None = None ,
254
260
) -> None :
255
261
"""Create a new experiment and save it if DBSettings available.
256
262
@@ -316,6 +322,15 @@ def create_experiment(
316
322
to that metric. Note these are modified in-place. Each
317
323
Metric must have its own dictionary (metrics cannot share a
318
324
single dictionary object).
325
+ default_trial_type: The default trial type if multiple
326
+ trial types are intended to be used in the experiment. If specified,
327
+ a MultiTypeExperiment will be created. Otherwise, a single-type
328
+ Experiment will be created.
329
+ default_runner: The default runner in this experiment.
330
+ This applies to MultiTypeExperiment (when default_trial_type
331
+ is specified) and needs to be specified together with
332
+ default_trial_type. This will be ignored for single-type Experiment
333
+ (when default_trial_type is not specified).
319
334
"""
320
335
self ._validate_early_stopping_strategy (support_intermediate_data )
321
336
@@ -344,6 +359,8 @@ def create_experiment(
344
359
support_intermediate_data = support_intermediate_data ,
345
360
immutable_search_space_and_opt_config = immutable_search_space_and_opt_config ,
346
361
is_test = is_test ,
362
+ default_trial_type = default_trial_type ,
363
+ default_runner = default_runner ,
347
364
** objective_kwargs ,
348
365
)
349
366
self ._set_runner (experiment = experiment )
@@ -416,6 +433,8 @@ def add_tracking_metrics(
416
433
self ,
417
434
metric_names : list [str ],
418
435
metric_definitions : dict [str , dict [str , Any ]] | None = None ,
436
+ metrics_to_trial_types : dict [str , str ] | None = None ,
437
+ canonical_names : dict [str , str ] | None = None ,
419
438
) -> None :
420
439
"""Add a list of new metrics to the experiment.
421
440
@@ -428,20 +447,34 @@ def add_tracking_metrics(
428
447
to that metric. Note these are modified in-place. Each
429
448
Metric must have its is own dictionary (metrics cannot share a
430
449
single dictionary object).
450
+ metrics_to_trial_types: Only applicable to MultiTypeExperiment.
451
+ The mapping from metric names to corresponding
452
+ trial types for each metric. If provided, the metrics will be
453
+ added with their respective trial types. If not provided, then the
454
+ default trial type will be used.
455
+ canonical_names: A mapping from metric name (of a particular trial type)
456
+ to the metric name of the default trial type. Only applicable to
457
+ MultiTypeExperiment.
431
458
"""
432
459
metric_definitions = (
433
460
self .metric_definitions
434
461
if metric_definitions is None
435
462
else metric_definitions
436
463
)
437
- self .experiment .add_tracking_metrics (
438
- metrics = [
439
- self ._make_metric (
440
- name = metric_name , metric_definitions = metric_definitions
441
- )
442
- for metric_name in metric_names
443
- ]
444
- )
464
+ metric_objects = [
465
+ self ._make_metric (name = metric_name , metric_definitions = metric_definitions )
466
+ for metric_name in metric_names
467
+ ]
468
+
469
+ if isinstance (self .experiment , MultiTypeExperiment ):
470
+ experiment = assert_is_instance (self .experiment , MultiTypeExperiment )
471
+ experiment .add_tracking_metrics (
472
+ metrics = metric_objects ,
473
+ metrics_to_trial_types = metrics_to_trial_types ,
474
+ canonical_names = canonical_names ,
475
+ )
476
+ else :
477
+ self .experiment .add_tracking_metrics (metrics = metric_objects )
445
478
446
479
@copy_doc (Experiment .remove_tracking_metric )
447
480
def remove_tracking_metric (self , metric_name : str ) -> None :
0 commit comments