2121from ax .adapter .adapter_utils import get_fixed_features_from_experiment
2222from ax .adapter .base import Adapter
2323from ax .core .base_trial import BaseTrial
24- from ax .core .experiment import Experiment
25- from ax .core .generator_run import GeneratorRun
26- from ax .core .metric import Metric , MetricFetchE , MetricFetchResult
27- from ax .core .multi_type_experiment import (
24+ from ax .core .experiment import (
25+ Experiment ,
2826 filter_trials_by_type ,
2927 get_trial_indices_for_statuses ,
30- MultiTypeExperiment ,
3128)
29+ from ax .core .generator_run import GeneratorRun
30+ from ax .core .metric import Metric , MetricFetchE , MetricFetchResult
3231from ax .core .runner import Runner
3332from ax .core .trial import Trial
3433from ax .core .trial_status import TrialStatus
5857 set_ax_logger_levels ,
5958)
6059from ax .utils .common .timeutils import current_timestamp_in_millis
61- from pyre_extensions import assert_is_instance , none_throws
60+ from pyre_extensions import none_throws
6261
6362
6463NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \
@@ -367,21 +366,24 @@ def options(self, options: OrchestratorOptions) -> None:
367366 def trial_type (self ) -> str | None :
368367 """Trial type for the experiment this Orchestrator is running.
369368
370- This returns None if the experiment is not a MultitypeExperiment
369+ Returns ``None`` for single-type experiments (where
370+ ``default_trial_type`` is ``None``). For multi-type experiments,
371+ ``_validate_options`` guarantees that ``mt_experiment_trial_type`` is
372+ set, so the returned value is always a valid ``str`` in that case.
371373
372374 Returns:
373375 Trial type for the experiment this Orchestrator is running if the
374- experiment is a MultiTypeExperiment and None otherwise.
376+ experiment has multiple trial types and None otherwise.
375377 """
376- if isinstance ( self .experiment , MultiTypeExperiment ) :
378+ if self .experiment . default_trial_type is not None :
377379 return self .options .mt_experiment_trial_type
378380 return None
379381
380382 @property
381383 def running_trials (self ) -> list [BaseTrial ]:
382384 """Currently running trials.
383385
384- Note: if the experiment is a MultiTypeExperiment , then this will
386+ Note: if the experiment has multiple trial types , then this will
385387 only fetch trials of type `orchestrator.trial_type`.
386388
387389
@@ -397,7 +399,7 @@ def running_trials(self) -> list[BaseTrial]:
397399 def trials (self ) -> list [BaseTrial ]:
398400 """All trials.
399401
400- Note: if the experiment is a MultiTypeExperiment , then this will
402+ Note: if the experiment has multiple trial types , then this will
401403 only fetch trials of type `orchestrator.trial_type`.
402404
403405 Returns:
@@ -424,7 +426,7 @@ def running_trial_indices(self) -> set[int]:
424426 def failed_abandoned_trial_indices (self ) -> set [int ]:
425427 """Failed or abandoned trials.
426428
427- Note: if the experiment is a MultiTypeExperiment , then this will
429+ Note: if the experiment has multiple trial types , then this will
428430 only fetch trials of type `orchestrator.trial_type`.
429431
430432 Returns:
@@ -441,7 +443,7 @@ def pending_trials(self) -> list[BaseTrial]:
441443 """Running or staged trials on the experiment this Orchestrator is
442444 running.
443445
444- Note: if the experiment is a MultiTypeExperiment , then this will
446+ Note: if the experiment has multiple trial types , then this will
445447 only fetch trials of type `orchestrator.trial_type`.
446448
447449 Returns:
@@ -457,7 +459,7 @@ def pending_trials(self) -> list[BaseTrial]:
457459 def candidate_trials (self ) -> list [BaseTrial ]:
458460 """Candidate trials on the experiment this Orchestrator is running.
459461
460- Note: if the experiment is a MultiTypeExperiment , then this will
462+ Note: if the experiment has multiple trial types , then this will
461463 only fetch trials of type `orchestrator.trial_type`.
462464
463465 Returns:
@@ -472,7 +474,7 @@ def candidate_trials(self) -> list[BaseTrial]:
472474 def trials_expecting_data (self ) -> list [BaseTrial ]:
473475 """Trials expecting data.
474476
475- Note: if the experiment is a MultiTypeExperiment , then this will
477+ Note: if the experiment has multiple trial types , then this will
476478 only fetch trials of type `orchestrator.trial_type`.
477479 """
478480 trials = []
@@ -488,9 +490,9 @@ def runner(self) -> Runner:
488490 instance.
489491 """
490492 if self .trial_type is not None :
491- runner = assert_is_instance (
492- self .experiment , MultiTypeExperiment
493- ). runner_for_trial_type ( trial_type = none_throws ( self . trial_type ))
493+ runner = self . experiment . runner_for_trial_type (
494+ trial_type = none_throws ( self .trial_type )
495+ )
494496 else :
495497 runner = self .experiment .runner
496498 if runner is None :
@@ -1626,10 +1628,11 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16261628 "will be unable to fetch intermediate results with which to "
16271629 "evaluate early stopping criteria."
16281630 )
1629- if isinstance ( self .experiment , MultiTypeExperiment ) :
1631+ if self .experiment . default_trial_type is not None :
16301632 if options .mt_experiment_trial_type is None :
16311633 raise UserInputError (
1632- "Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
1634+ "Must specify `mt_experiment_trial_type` for experiments "
1635+ "with multiple trial types."
16331636 )
16341637 if not self .experiment .supports_trial_type (
16351638 options .mt_experiment_trial_type
@@ -1640,8 +1643,8 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16401643 )
16411644 elif options .mt_experiment_trial_type is not None :
16421645 raise UserInputError (
1643- "`mt_experiment_trial_type` must be None unless the experiment is a "
1644- "MultiTypeExperiment ."
1646+ "`mt_experiment_trial_type` must be None unless the experiment "
1647+ "has multiple trial types ."
16451648 )
16461649
16471650 def _get_max_pending_trials (self ) -> int :
@@ -2040,9 +2043,9 @@ def _fetch_and_process_trials_data_results(
20402043 try :
20412044 kwargs = deepcopy (self .options .fetch_kwargs )
20422045 if self .trial_type is not None :
2043- metrics = assert_is_instance (
2044- self .experiment , MultiTypeExperiment
2045- ). metrics_for_trial_type ( trial_type = none_throws ( self . trial_type ))
2046+ metrics = self . experiment . metrics_for_trial_type (
2047+ trial_type = none_throws ( self .trial_type )
2048+ )
20462049 kwargs ["metrics" ] = metrics
20472050 results = self .experiment .fetch_trials_data_results (
20482051 trial_indices = trial_indices ,
0 commit comments