Skip to content

Commit cf5ccde

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Remove bifrucation around MultiTypeExperiment
Summary: With recent changes to experiment we no longer need this bifructation. Next diff will remove places where we construct MultiTypeExperiment, and the one after will deprecate the class entirely Differential Revision: D91920991
1 parent 73184e0 commit cf5ccde

File tree

5 files changed

+24
-44
lines changed

5 files changed

+24
-44
lines changed

ax/orchestration/orchestrator.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ax.core.multi_type_experiment import (
2828
filter_trials_by_type,
2929
get_trial_indices_for_statuses,
30-
MultiTypeExperiment,
3130
)
3231
from ax.core.runner import Runner
3332
from ax.core.trial import Trial
@@ -58,7 +57,7 @@
5857
set_ax_logger_levels,
5958
)
6059
from ax.utils.common.timeutils import current_timestamp_in_millis
61-
from pyre_extensions import assert_is_instance, none_throws
60+
from pyre_extensions import none_throws
6261

6362

6463
NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \
@@ -367,17 +366,11 @@ def options(self, options: OrchestratorOptions) -> None:
367366
def trial_type(self) -> str:
368367
"""Trial type for the experiment this Orchestrator is running.
369368
370-
This returns None if the experiment is not a MultitypeExperiment
371-
372369
Returns:
373-
Trial type for the experiment this Orchestrator is running if the
374-
experiment is a MultiTypeExperiment and None otherwise.
370+
Trial type for the experiment this Orchestrator is running.
371+
Defaults to Keys.DEFAULT_TRIAL_TYPE if not specified.
375372
"""
376-
if isinstance(self.experiment, MultiTypeExperiment):
377-
return (
378-
self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value
379-
)
380-
return Keys.DEFAULT_TRIAL_TYPE.value
373+
return self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value
381374

382375
@property
383376
def running_trials(self) -> list[BaseTrial]:
@@ -1619,23 +1612,14 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16191612
"will be unable to fetch intermediate results with which to "
16201613
"evaluate early stopping criteria."
16211614
)
1622-
if isinstance(self.experiment, MultiTypeExperiment):
1623-
if options.mt_experiment_trial_type is None:
1624-
raise UserInputError(
1625-
"Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
1626-
)
1615+
if options.mt_experiment_trial_type is not None:
16271616
if not self.experiment.supports_trial_type(
16281617
options.mt_experiment_trial_type
16291618
):
16301619
raise ValueError(
16311620
"Experiment does not support trial type "
16321621
f"{options.mt_experiment_trial_type}."
16331622
)
1634-
elif options.mt_experiment_trial_type is not None:
1635-
raise UserInputError(
1636-
"`mt_experiment_trial_type` must be None unless the experiment is a "
1637-
"MultiTypeExperiment."
1638-
)
16391623

16401624
def _get_max_pending_trials(self) -> int:
16411625
"""Returns the maximum number of pending trials specified in the options, or

ax/orchestration/tests/test_orchestrator.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2736,12 +2736,9 @@ def test_validate_options_not_none_mt_trial_type(
27362736
self, msg: str | None = None
27372737
) -> None:
27382738
# test that error is raised if `mt_experiment_trial_type` is not
2739-
# compatible with the type of experiment (single or multi-type)
2739+
# a supported trial type for this experiment
27402740
if msg is None:
2741-
msg = (
2742-
"`mt_experiment_trial_type` must be None unless the experiment is a "
2743-
"MultiTypeExperiment."
2744-
)
2741+
msg = "Experiment does not support trial type type1."
27452742
options = OrchestratorOptions(
27462743
init_seconds_between_polls=0, # No wait bw polls so test is fast.
27472744
batch_size=10,
@@ -2752,7 +2749,7 @@ def test_validate_options_not_none_mt_trial_type(
27522749
),
27532750
)
27542751
gs = self.two_sobol_steps_GS
2755-
with self.assertRaisesRegex(UserInputError, msg):
2752+
with self.assertRaisesRegex(ValueError, msg):
27562753
Orchestrator(
27572754
experiment=self.branin_experiment,
27582755
generation_strategy=gs,
@@ -3010,10 +3007,11 @@ def test_fetch_and_process_trials_data_results_failed_non_objective(
30103007
def test_validate_options_not_none_mt_trial_type(
30113008
self, msg: str | None = None
30123009
) -> None:
3013-
# test if a MultiTypeExperiment with `mt_experiment_trial_type=None`
3014-
self.orchestrator_options_kwargs["mt_experiment_trial_type"] = None
3010+
# test that error is raised if `mt_experiment_trial_type` is not
3011+
# a supported trial type for this experiment (using an invalid type)
3012+
self.orchestrator_options_kwargs["mt_experiment_trial_type"] = "invalid_type"
30153013
super().test_validate_options_not_none_mt_trial_type(
3016-
msg="Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
3014+
msg="Experiment does not support trial type invalid_type."
30173015
)
30183016

30193017
def test_run_n_trials_single_step_existing_experiment(

ax/service/ax_client.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ax.core.evaluations_to_data import raw_evaluations_to_data
2525
from ax.core.experiment import Experiment
2626
from ax.core.generator_run import GeneratorRun
27-
from ax.core.multi_type_experiment import MultiTypeExperiment
2827
from ax.core.objective import MultiObjective, Objective
2928
from ax.core.observation import ObservationFeatures
3029
from ax.core.runner import Runner
@@ -458,15 +457,11 @@ def add_tracking_metrics(
458457
for metric_name in metric_names
459458
]
460459

461-
if isinstance(self.experiment, MultiTypeExperiment):
462-
experiment = assert_is_instance(self.experiment, MultiTypeExperiment)
463-
experiment.add_tracking_metrics(
464-
metrics=metric_objects,
465-
metrics_to_trial_types=metrics_to_trial_types,
466-
canonical_names=canonical_names,
467-
)
468-
else:
469-
self.experiment.add_tracking_metrics(metrics=metric_objects)
460+
self.experiment.add_tracking_metrics(
461+
metrics=metric_objects,
462+
metrics_to_trial_types=metrics_to_trial_types,
463+
**({"canonical_names": canonical_names} if canonical_names else {}),
464+
)
470465

471466
@copy_doc(Experiment.remove_tracking_metric)
472467
def remove_tracking_metric(self, metric_name: str) -> None:

ax/service/tests/test_report_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ def test_exp_to_df_with_failure(self) -> None:
199199
self.assertEqual(f"{fail_reason}...", df["reason"].iloc[0])
200200

201201
def test_exp_to_df(self) -> None:
202-
# MultiTypeExperiment should fail
202+
# Experiments with multiple trial types should fail
203203
exp = get_multi_type_experiment()
204-
with self.assertRaisesRegex(ValueError, "MultiTypeExperiment"):
204+
with self.assertRaisesRegex(ValueError, "multiple trial types"):
205205
exp_to_df(exp=exp)
206206

207207
# exp with no trials should return empty results

ax/service/utils/report_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,11 @@ def exp_to_df(
787787
)
788788

789789
# Accept Experiment and SimpleExperiment
790-
if isinstance(exp, MultiTypeExperiment):
791-
raise ValueError("Cannot transform MultiTypeExperiments to DataFrames.")
790+
# Reject experiments with multiple trial types as they need special handling
791+
if len(exp._trial_type_to_runner) > 1:
792+
raise ValueError(
793+
"Cannot transform experiments with multiple trial types to DataFrames."
794+
)
792795

793796
key_components = ["trial_index", "arm_name"]
794797

0 commit comments

Comments
 (0)