Skip to content

Commit 74abf89

Browse files
mpolson64facebook-github-bot
authored andcommitted
Replace isinstance(MultiTypeExperiment) checks with feature checks (#5001)
Summary: Move `filter_trials_by_type` and `get_trial_indices_for_statuses` from `multi_type_experiment.py` to `experiment.py`, with backward-compatible re-exports from the old module. Replace all `isinstance(experiment, MultiTypeExperiment)` checks with `experiment.default_trial_type is not None` feature checks in: - `ax/orchestration/orchestrator.py` - `ax/fb/realtime/utils.py` - `ax/fb/axolotl/utils_chronos.py` Replace `assert_is_instance(experiment, MultiTypeExperiment).method()` calls with direct calls on base `Experiment` (which now has `runner_for_trial_type` and `metrics_for_trial_type`). This is Phase 6 of the MultiTypeExperiment pull-up plan. Differential Revision: D94999020
1 parent ccf3dea commit 74abf89

4 files changed

Lines changed: 87 additions & 70 deletions

File tree

ax/core/experiment.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,3 +2795,47 @@ def add_arm_and_prevent_naming_collision(
27952795
stacklevel=2,
27962796
)
27972797
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True))
2798+
2799+
2800+
def filter_trials_by_type(
2801+
trials: Sequence[BaseTrial], trial_type: str | None
2802+
) -> list[BaseTrial]:
2803+
"""Filter trials by trial type if provided.
2804+
2805+
This filters trials by trial type if the experiment has multiple
2806+
trial types.
2807+
2808+
Args:
2809+
trials: Trials to filter.
2810+
trial_type: The trial type to filter by. If None, all trials are returned.
2811+
2812+
Returns:
2813+
Filtered trials.
2814+
"""
2815+
if trial_type is not None:
2816+
return [t for t in trials if t.trial_type == trial_type]
2817+
return list(trials)
2818+
2819+
2820+
def get_trial_indices_for_statuses(
2821+
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
2822+
) -> set[int]:
2823+
"""Get trial indices for a set of statuses.
2824+
2825+
Args:
2826+
experiment: The experiment to get trial indices from.
2827+
statuses: Set of statuses to get trial indices for.
2828+
trial_type: If provided, only return indices for trials of this type.
2829+
2830+
Returns:
2831+
Set of trial indices for the given statuses.
2832+
"""
2833+
return {
2834+
i
2835+
for i, t in experiment.trials.items()
2836+
if (t.status in statuses)
2837+
and (
2838+
(trial_type is None)
2839+
or ((trial_type is not None) and (t.trial_type == trial_type))
2840+
)
2841+
}

ax/core/multi_type_experiment.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
# pyre-strict
88

9-
from collections.abc import Sequence
109
from typing import Any, Self
1110

1211
from ax.core.arm import Arm
13-
from ax.core.base_trial import BaseTrial, TrialStatus
14-
from ax.core.experiment import Experiment
12+
from ax.core.experiment import (
13+
Experiment,
14+
filter_trials_by_type,
15+
get_trial_indices_for_statuses,
16+
)
1517
from ax.core.metric import Metric
1618
from ax.core.optimization_config import OptimizationConfig
1719
from ax.core.runner import Runner
@@ -171,42 +173,9 @@ def remove_metric(self, metric_name: str) -> Self:
171173
return self
172174

173175

174-
def filter_trials_by_type(
175-
trials: Sequence[BaseTrial], trial_type: str | None
176-
) -> list[BaseTrial]:
177-
"""Filter trials by trial type if provided.
178-
179-
This filters trials by trial type if the experiment is a
180-
MultiTypeExperiment.
181-
182-
Args:
183-
trials: Trials to filter.
184-
185-
Returns:
186-
Filtered trials.
187-
"""
188-
if trial_type is not None:
189-
return [t for t in trials if t.trial_type == trial_type]
190-
return list(trials)
191-
192-
193-
def get_trial_indices_for_statuses(
194-
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
195-
) -> set[int]:
196-
"""Get trial indices for a set of statuses.
197-
198-
Args:
199-
statuses: Set of statuses to get trial indices for.
200-
201-
Returns:
202-
Set of trial indices for the given statuses.
203-
"""
204-
return {
205-
i
206-
for i, t in experiment.trials.items()
207-
if (t.status in statuses)
208-
and (
209-
(trial_type is None)
210-
or ((trial_type is not None) and (t.trial_type == trial_type))
211-
)
212-
}
176+
# Re-exported from ax.core.experiment for backward compatibility.
177+
__all__ = [
178+
"MultiTypeExperiment",
179+
"filter_trials_by_type",
180+
"get_trial_indices_for_statuses",
181+
]

ax/orchestration/orchestrator.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
from ax.adapter.adapter_utils import get_fixed_features_from_experiment
2222
from ax.adapter.base import Adapter
2323
from ax.core.base_trial import BaseTrial
24-
from ax.core.experiment import Experiment
25-
from ax.core.generator_run import GeneratorRun
26-
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
27-
from ax.core.multi_type_experiment import (
24+
from ax.core.experiment import (
25+
Experiment,
2826
filter_trials_by_type,
2927
get_trial_indices_for_statuses,
30-
MultiTypeExperiment,
3128
)
29+
from ax.core.generator_run import GeneratorRun
30+
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
3231
from ax.core.runner import Runner
3332
from ax.core.trial import Trial
3433
from ax.core.trial_status import TrialStatus
@@ -58,7 +57,7 @@
5857
set_ax_logger_levels,
5958
)
6059
from ax.utils.common.timeutils import current_timestamp_in_millis
61-
from pyre_extensions import assert_is_instance, none_throws
60+
from pyre_extensions import none_throws
6261

6362

6463
NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \
@@ -367,21 +366,24 @@ def options(self, options: OrchestratorOptions) -> None:
367366
def trial_type(self) -> str | None:
368367
"""Trial type for the experiment this Orchestrator is running.
369368
370-
This returns None if the experiment is not a MultitypeExperiment
369+
Returns ``None`` for single-type experiments (where
370+
``default_trial_type`` is ``None``). For multi-type experiments,
371+
``_validate_options`` guarantees that ``mt_experiment_trial_type`` is
372+
set, so the returned value is always a valid ``str`` in that case.
371373
372374
Returns:
373375
Trial type for the experiment this Orchestrator is running if the
374-
experiment is a MultiTypeExperiment and None otherwise.
376+
experiment has multiple trial types and None otherwise.
375377
"""
376-
if isinstance(self.experiment, MultiTypeExperiment):
378+
if self.experiment.default_trial_type is not None:
377379
return self.options.mt_experiment_trial_type
378380
return None
379381

380382
@property
381383
def running_trials(self) -> list[BaseTrial]:
382384
"""Currently running trials.
383385
384-
Note: if the experiment is a MultiTypeExperiment, then this will
386+
Note: if the experiment has multiple trial types, then this will
385387
only fetch trials of type `orchestrator.trial_type`.
386388
387389
@@ -397,7 +399,7 @@ def running_trials(self) -> list[BaseTrial]:
397399
def trials(self) -> list[BaseTrial]:
398400
"""All trials.
399401
400-
Note: if the experiment is a MultiTypeExperiment, then this will
402+
Note: if the experiment has multiple trial types, then this will
401403
only fetch trials of type `orchestrator.trial_type`.
402404
403405
Returns:
@@ -424,7 +426,7 @@ def running_trial_indices(self) -> set[int]:
424426
def failed_abandoned_trial_indices(self) -> set[int]:
425427
"""Failed or abandoned trials.
426428
427-
Note: if the experiment is a MultiTypeExperiment, then this will
429+
Note: if the experiment has multiple trial types, then this will
428430
only fetch trials of type `orchestrator.trial_type`.
429431
430432
Returns:
@@ -441,7 +443,7 @@ def pending_trials(self) -> list[BaseTrial]:
441443
"""Running or staged trials on the experiment this Orchestrator is
442444
running.
443445
444-
Note: if the experiment is a MultiTypeExperiment, then this will
446+
Note: if the experiment has multiple trial types, then this will
445447
only fetch trials of type `orchestrator.trial_type`.
446448
447449
Returns:
@@ -457,7 +459,7 @@ def pending_trials(self) -> list[BaseTrial]:
457459
def candidate_trials(self) -> list[BaseTrial]:
458460
"""Candidate trials on the experiment this Orchestrator is running.
459461
460-
Note: if the experiment is a MultiTypeExperiment, then this will
462+
Note: if the experiment has multiple trial types, then this will
461463
only fetch trials of type `orchestrator.trial_type`.
462464
463465
Returns:
@@ -472,7 +474,7 @@ def candidate_trials(self) -> list[BaseTrial]:
472474
def trials_expecting_data(self) -> list[BaseTrial]:
473475
"""Trials expecting data.
474476
475-
Note: if the experiment is a MultiTypeExperiment, then this will
477+
Note: if the experiment has multiple trial types, then this will
476478
only fetch trials of type `orchestrator.trial_type`.
477479
"""
478480
trials = []
@@ -488,9 +490,9 @@ def runner(self) -> Runner:
488490
instance.
489491
"""
490492
if self.trial_type is not None:
491-
runner = assert_is_instance(
492-
self.experiment, MultiTypeExperiment
493-
).runner_for_trial_type(trial_type=none_throws(self.trial_type))
493+
runner = self.experiment.runner_for_trial_type(
494+
trial_type=none_throws(self.trial_type)
495+
)
494496
else:
495497
runner = self.experiment.runner
496498
if runner is None:
@@ -1626,10 +1628,11 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16261628
"will be unable to fetch intermediate results with which to "
16271629
"evaluate early stopping criteria."
16281630
)
1629-
if isinstance(self.experiment, MultiTypeExperiment):
1631+
if self.experiment.default_trial_type is not None:
16301632
if options.mt_experiment_trial_type is None:
16311633
raise UserInputError(
1632-
"Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
1634+
"Must specify `mt_experiment_trial_type` for experiments "
1635+
"with multiple trial types."
16331636
)
16341637
if not self.experiment.supports_trial_type(
16351638
options.mt_experiment_trial_type
@@ -1640,8 +1643,8 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16401643
)
16411644
elif options.mt_experiment_trial_type is not None:
16421645
raise UserInputError(
1643-
"`mt_experiment_trial_type` must be None unless the experiment is a "
1644-
"MultiTypeExperiment."
1646+
"`mt_experiment_trial_type` must be None unless the experiment "
1647+
"has multiple trial types."
16451648
)
16461649

16471650
def _get_max_pending_trials(self) -> int:
@@ -2040,9 +2043,9 @@ def _fetch_and_process_trials_data_results(
20402043
try:
20412044
kwargs = deepcopy(self.options.fetch_kwargs)
20422045
if self.trial_type is not None:
2043-
metrics = assert_is_instance(
2044-
self.experiment, MultiTypeExperiment
2045-
).metrics_for_trial_type(trial_type=none_throws(self.trial_type))
2046+
metrics = self.experiment.metrics_for_trial_type(
2047+
trial_type=none_throws(self.trial_type)
2048+
)
20462049
kwargs["metrics"] = metrics
20472050
results = self.experiment.fetch_trials_data_results(
20482051
trial_indices=trial_indices,

ax/orchestration/tests/test_orchestrator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2798,8 +2798,8 @@ def test_validate_options_not_none_mt_trial_type(
27982798
# compatible with the type of experiment (single or multi-type)
27992799
if msg is None:
28002800
msg = (
2801-
"`mt_experiment_trial_type` must be None unless the experiment is a "
2802-
"MultiTypeExperiment."
2801+
"`mt_experiment_trial_type` must be None unless the experiment "
2802+
"has multiple trial types."
28032803
)
28042804
options = OrchestratorOptions(
28052805
init_seconds_between_polls=0, # No wait bw polls so test is fast.
@@ -3084,7 +3084,8 @@ def test_validate_options_not_none_mt_trial_type(
30843084
# test if a MultiTypeExperiment with `mt_experiment_trial_type=None`
30853085
self.orchestrator_options_kwargs["mt_experiment_trial_type"] = None
30863086
super().test_validate_options_not_none_mt_trial_type(
3087-
msg="Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
3087+
msg="Must specify `mt_experiment_trial_type` for experiments "
3088+
"with multiple trial types."
30883089
)
30893090

30903091
def test_run_n_trials_single_step_existing_experiment(

0 commit comments

Comments
 (0)