2121from ax .adapter .adapter_utils import get_fixed_features_from_experiment
2222from ax .adapter .base import Adapter
2323from ax .core .base_trial import BaseTrial
24- from ax .core .experiment import Experiment
25- from ax .core .generator_run import GeneratorRun
26- from ax .core .metric import Metric , MetricFetchE , MetricFetchResult
27- from ax .core .multi_type_experiment import (
24+ from ax .core .experiment import (
25+ Experiment ,
2826 filter_trials_by_type ,
2927 get_trial_indices_for_statuses ,
30- MultiTypeExperiment ,
3128)
29+ from ax .core .generator_run import GeneratorRun
30+ from ax .core .metric import Metric , MetricFetchE , MetricFetchResult
3231from ax .core .runner import Runner
3332from ax .core .trial import Trial
3433from ax .core .trial_status import TrialStatus
5857 set_ax_logger_levels ,
5958)
6059from ax .utils .common .timeutils import current_timestamp_in_millis
61- from pyre_extensions import assert_is_instance , none_throws
60+ from pyre_extensions import none_throws
6261
6362
6463NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \
@@ -367,21 +366,21 @@ def options(self, options: OrchestratorOptions) -> None:
367366 def trial_type (self ) -> str | None :
368367 """Trial type for the experiment this Orchestrator is running.
369368
370- This returns None if the experiment is not a MultitypeExperiment
369+ This returns None if the experiment does not have multiple trial types.
371370
372371 Returns:
373372 Trial type for the experiment this Orchestrator is running if the
374- experiment is a MultiTypeExperiment and None otherwise.
373+ experiment has multiple trial types and None otherwise.
375374 """
376- if isinstance ( self .experiment , MultiTypeExperiment ) :
375+ if self .experiment . default_trial_type is not None :
377376 return self .options .mt_experiment_trial_type
378377 return None
379378
380379 @property
381380 def running_trials (self ) -> list [BaseTrial ]:
382381 """Currently running trials.
383382
384- Note: if the experiment is a MultiTypeExperiment , then this will
383+ Note: if the experiment has multiple trial types , then this will
385384 only fetch trials of type `orchestrator.trial_type`.
386385
387386
@@ -397,7 +396,7 @@ def running_trials(self) -> list[BaseTrial]:
397396 def trials (self ) -> list [BaseTrial ]:
398397 """All trials.
399398
400- Note: if the experiment is a MultiTypeExperiment , then this will
399+ Note: if the experiment has multiple trial types , then this will
401400 only fetch trials of type `orchestrator.trial_type`.
402401
403402 Returns:
@@ -424,7 +423,7 @@ def running_trial_indices(self) -> set[int]:
424423 def failed_abandoned_trial_indices (self ) -> set [int ]:
425424 """Failed or abandoned trials.
426425
427- Note: if the experiment is a MultiTypeExperiment , then this will
426+ Note: if the experiment has multiple trial types , then this will
428427 only fetch trials of type `orchestrator.trial_type`.
429428
430429 Returns:
@@ -441,7 +440,7 @@ def pending_trials(self) -> list[BaseTrial]:
441440 """Running or staged trials on the experiment this Orchestrator is
442441 running.
443442
444- Note: if the experiment is a MultiTypeExperiment , then this will
443+ Note: if the experiment has multiple trial types , then this will
445444 only fetch trials of type `orchestrator.trial_type`.
446445
447446 Returns:
@@ -457,7 +456,7 @@ def pending_trials(self) -> list[BaseTrial]:
457456 def candidate_trials (self ) -> list [BaseTrial ]:
458457 """Candidate trials on the experiment this Orchestrator is running.
459458
460- Note: if the experiment is a MultiTypeExperiment , then this will
459+ Note: if the experiment has multiple trial types , then this will
461460 only fetch trials of type `orchestrator.trial_type`.
462461
463462 Returns:
@@ -472,7 +471,7 @@ def candidate_trials(self) -> list[BaseTrial]:
472471 def trials_expecting_data (self ) -> list [BaseTrial ]:
473472 """Trials expecting data.
474473
475- Note: if the experiment is a MultiTypeExperiment , then this will
474+ Note: if the experiment has multiple trial types , then this will
476475 only fetch trials of type `orchestrator.trial_type`.
477476 """
478477 trials = []
@@ -488,9 +487,9 @@ def runner(self) -> Runner:
488487 instance.
489488 """
490489 if self .trial_type is not None :
491- runner = assert_is_instance (
492- self .experiment , MultiTypeExperiment
493- ). runner_for_trial_type ( trial_type = none_throws ( self . trial_type ))
490+ runner = self . experiment . runner_for_trial_type (
491+ trial_type = none_throws ( self .trial_type )
492+ )
494493 else :
495494 runner = self .experiment .runner
496495 if runner is None :
@@ -1626,10 +1625,11 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16261625 "will be unable to fetch intermediate results with which to "
16271626 "evaluate early stopping criteria."
16281627 )
1629- if isinstance ( self .experiment , MultiTypeExperiment ) :
1628+ if self .experiment . default_trial_type is not None :
16301629 if options .mt_experiment_trial_type is None :
16311630 raise UserInputError (
1632- "Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
1631+ "Must specify `mt_experiment_trial_type` for experiments "
1632+ "with multiple trial types."
16331633 )
16341634 if not self .experiment .supports_trial_type (
16351635 options .mt_experiment_trial_type
@@ -1640,8 +1640,8 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16401640 )
16411641 elif options .mt_experiment_trial_type is not None :
16421642 raise UserInputError (
1643- "`mt_experiment_trial_type` must be None unless the experiment is a "
1644- "MultiTypeExperiment ."
1643+ "`mt_experiment_trial_type` must be None unless the experiment "
1644+ "has multiple trial types ."
16451645 )
16461646
16471647 def _get_max_pending_trials (self ) -> int :
@@ -2040,9 +2040,9 @@ def _fetch_and_process_trials_data_results(
20402040 try :
20412041 kwargs = deepcopy (self .options .fetch_kwargs )
20422042 if self .trial_type is not None :
2043- metrics = assert_is_instance (
2044- self .experiment , MultiTypeExperiment
2045- ). metrics_for_trial_type ( trial_type = none_throws ( self . trial_type ))
2043+ metrics = self . experiment . metrics_for_trial_type (
2044+ trial_type = none_throws ( self .trial_type )
2045+ )
20462046 kwargs ["metrics" ] = metrics
20472047 results = self .experiment .fetch_trials_data_results (
20482048 trial_indices = trial_indices ,
0 commit comments