Skip to content

Commit f56e855

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Replace isinstance(MultiTypeExperiment) checks with feature checks (#5001)
Summary: Pull Request resolved: #5001 Move `filter_trials_by_type` and `get_trial_indices_for_statuses` from `multi_type_experiment.py` to `experiment.py`, with backward-compatible re-exports from the old module. Replace all `isinstance(experiment, MultiTypeExperiment)` checks with `experiment.default_trial_type is not None` feature checks in: - `ax/orchestration/orchestrator.py` - `ax/fb/realtime/utils.py` - `ax/fb/axolotl/utils_chronos.py` Replace `assert_is_instance(experiment, MultiTypeExperiment).method()` calls with direct calls on base `Experiment` (which now has `runner_for_trial_type` and `metrics_for_trial_type`). This is Phase 6 of the MultiTypeExperiment pull-up plan. Differential Revision: D94999020
1 parent 8baa422 commit f56e855

4 files changed

Lines changed: 84 additions & 70 deletions

File tree

ax/core/experiment.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2736,3 +2736,47 @@ def add_arm_and_prevent_naming_collision(
27362736
stacklevel=2,
27372737
)
27382738
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True))
2739+
2740+
2741+
def filter_trials_by_type(
2742+
trials: Sequence[BaseTrial], trial_type: str | None
2743+
) -> list[BaseTrial]:
2744+
"""Filter trials by trial type if provided.
2745+
2746+
This filters trials by trial type if the experiment has multiple
2747+
trial types.
2748+
2749+
Args:
2750+
trials: Trials to filter.
2751+
trial_type: The trial type to filter by. If None, all trials are returned.
2752+
2753+
Returns:
2754+
Filtered trials.
2755+
"""
2756+
if trial_type is not None:
2757+
return [t for t in trials if t.trial_type == trial_type]
2758+
return list(trials)
2759+
2760+
2761+
def get_trial_indices_for_statuses(
2762+
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
2763+
) -> set[int]:
2764+
"""Get trial indices for a set of statuses.
2765+
2766+
Args:
2767+
experiment: The experiment to get trial indices from.
2768+
statuses: Set of statuses to get trial indices for.
2769+
trial_type: If provided, only return indices for trials of this type.
2770+
2771+
Returns:
2772+
Set of trial indices for the given statuses.
2773+
"""
2774+
return {
2775+
i
2776+
for i, t in experiment.trials.items()
2777+
if (t.status in statuses)
2778+
and (
2779+
(trial_type is None)
2780+
or ((trial_type is not None) and (t.trial_type == trial_type))
2781+
)
2782+
}

ax/core/multi_type_experiment.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
# pyre-strict
88

9-
from collections.abc import Sequence
109
from typing import Any, Self
1110

1211
from ax.core.arm import Arm
13-
from ax.core.base_trial import BaseTrial, TrialStatus
14-
from ax.core.experiment import Experiment
12+
from ax.core.experiment import (
13+
Experiment,
14+
filter_trials_by_type,
15+
get_trial_indices_for_statuses,
16+
)
1517
from ax.core.metric import Metric
1618
from ax.core.optimization_config import OptimizationConfig
1719
from ax.core.runner import Runner
@@ -171,42 +173,9 @@ def remove_metric(self, metric_name: str) -> Self:
171173
return self
172174

173175

174-
def filter_trials_by_type(
175-
trials: Sequence[BaseTrial], trial_type: str | None
176-
) -> list[BaseTrial]:
177-
"""Filter trials by trial type if provided.
178-
179-
This filters trials by trial type if the experiment is a
180-
MultiTypeExperiment.
181-
182-
Args:
183-
trials: Trials to filter.
184-
185-
Returns:
186-
Filtered trials.
187-
"""
188-
if trial_type is not None:
189-
return [t for t in trials if t.trial_type == trial_type]
190-
return list(trials)
191-
192-
193-
def get_trial_indices_for_statuses(
194-
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
195-
) -> set[int]:
196-
"""Get trial indices for a set of statuses.
197-
198-
Args:
199-
statuses: Set of statuses to get trial indices for.
200-
201-
Returns:
202-
Set of trial indices for the given statuses.
203-
"""
204-
return {
205-
i
206-
for i, t in experiment.trials.items()
207-
if (t.status in statuses)
208-
and (
209-
(trial_type is None)
210-
or ((trial_type is not None) and (t.trial_type == trial_type))
211-
)
212-
}
176+
# Re-exported from ax.core.experiment for backward compatibility.
177+
__all__ = [
178+
"MultiTypeExperiment",
179+
"filter_trials_by_type",
180+
"get_trial_indices_for_statuses",
181+
]

ax/orchestration/orchestrator.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
from ax.adapter.adapter_utils import get_fixed_features_from_experiment
2222
from ax.adapter.base import Adapter
2323
from ax.core.base_trial import BaseTrial
24-
from ax.core.experiment import Experiment
25-
from ax.core.generator_run import GeneratorRun
26-
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
27-
from ax.core.multi_type_experiment import (
24+
from ax.core.experiment import (
25+
Experiment,
2826
filter_trials_by_type,
2927
get_trial_indices_for_statuses,
30-
MultiTypeExperiment,
3128
)
29+
from ax.core.generator_run import GeneratorRun
30+
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
3231
from ax.core.runner import Runner
3332
from ax.core.trial import Trial
3433
from ax.core.trial_status import TrialStatus
@@ -58,7 +57,7 @@
5857
set_ax_logger_levels,
5958
)
6059
from ax.utils.common.timeutils import current_timestamp_in_millis
61-
from pyre_extensions import assert_is_instance, none_throws
60+
from pyre_extensions import none_throws
6261

6362

6463
NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \
@@ -367,21 +366,21 @@ def options(self, options: OrchestratorOptions) -> None:
367366
def trial_type(self) -> str | None:
368367
"""Trial type for the experiment this Orchestrator is running.
369368
370-
This returns None if the experiment is not a MultitypeExperiment
369+
This returns None if the experiment does not have multiple trial types.
371370
372371
Returns:
373372
Trial type for the experiment this Orchestrator is running if the
374-
experiment is a MultiTypeExperiment and None otherwise.
373+
experiment has multiple trial types and None otherwise.
375374
"""
376-
if isinstance(self.experiment, MultiTypeExperiment):
375+
if self.experiment.default_trial_type is not None:
377376
return self.options.mt_experiment_trial_type
378377
return None
379378

380379
@property
381380
def running_trials(self) -> list[BaseTrial]:
382381
"""Currently running trials.
383382
384-
Note: if the experiment is a MultiTypeExperiment, then this will
383+
Note: if the experiment has multiple trial types, then this will
385384
only fetch trials of type `orchestrator.trial_type`.
386385
387386
@@ -397,7 +396,7 @@ def running_trials(self) -> list[BaseTrial]:
397396
def trials(self) -> list[BaseTrial]:
398397
"""All trials.
399398
400-
Note: if the experiment is a MultiTypeExperiment, then this will
399+
Note: if the experiment has multiple trial types, then this will
401400
only fetch trials of type `orchestrator.trial_type`.
402401
403402
Returns:
@@ -424,7 +423,7 @@ def running_trial_indices(self) -> set[int]:
424423
def failed_abandoned_trial_indices(self) -> set[int]:
425424
"""Failed or abandoned trials.
426425
427-
Note: if the experiment is a MultiTypeExperiment, then this will
426+
Note: if the experiment has multiple trial types, then this will
428427
only fetch trials of type `orchestrator.trial_type`.
429428
430429
Returns:
@@ -441,7 +440,7 @@ def pending_trials(self) -> list[BaseTrial]:
441440
"""Running or staged trials on the experiment this Orchestrator is
442441
running.
443442
444-
Note: if the experiment is a MultiTypeExperiment, then this will
443+
Note: if the experiment has multiple trial types, then this will
445444
only fetch trials of type `orchestrator.trial_type`.
446445
447446
Returns:
@@ -457,7 +456,7 @@ def pending_trials(self) -> list[BaseTrial]:
457456
def candidate_trials(self) -> list[BaseTrial]:
458457
"""Candidate trials on the experiment this Orchestrator is running.
459458
460-
Note: if the experiment is a MultiTypeExperiment, then this will
459+
Note: if the experiment has multiple trial types, then this will
461460
only fetch trials of type `orchestrator.trial_type`.
462461
463462
Returns:
@@ -472,7 +471,7 @@ def candidate_trials(self) -> list[BaseTrial]:
472471
def trials_expecting_data(self) -> list[BaseTrial]:
473472
"""Trials expecting data.
474473
475-
Note: if the experiment is a MultiTypeExperiment, then this will
474+
Note: if the experiment has multiple trial types, then this will
476475
only fetch trials of type `orchestrator.trial_type`.
477476
"""
478477
trials = []
@@ -488,9 +487,9 @@ def runner(self) -> Runner:
488487
instance.
489488
"""
490489
if self.trial_type is not None:
491-
runner = assert_is_instance(
492-
self.experiment, MultiTypeExperiment
493-
).runner_for_trial_type(trial_type=none_throws(self.trial_type))
490+
runner = self.experiment.runner_for_trial_type(
491+
trial_type=none_throws(self.trial_type)
492+
)
494493
else:
495494
runner = self.experiment.runner
496495
if runner is None:
@@ -1626,10 +1625,11 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16261625
"will be unable to fetch intermediate results with which to "
16271626
"evaluate early stopping criteria."
16281627
)
1629-
if isinstance(self.experiment, MultiTypeExperiment):
1628+
if self.experiment.default_trial_type is not None:
16301629
if options.mt_experiment_trial_type is None:
16311630
raise UserInputError(
1632-
"Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
1631+
"Must specify `mt_experiment_trial_type` for experiments "
1632+
"with multiple trial types."
16331633
)
16341634
if not self.experiment.supports_trial_type(
16351635
options.mt_experiment_trial_type
@@ -1640,8 +1640,8 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16401640
)
16411641
elif options.mt_experiment_trial_type is not None:
16421642
raise UserInputError(
1643-
"`mt_experiment_trial_type` must be None unless the experiment is a "
1644-
"MultiTypeExperiment."
1643+
"`mt_experiment_trial_type` must be None unless the experiment "
1644+
"has multiple trial types."
16451645
)
16461646

16471647
def _get_max_pending_trials(self) -> int:
@@ -2040,9 +2040,9 @@ def _fetch_and_process_trials_data_results(
20402040
try:
20412041
kwargs = deepcopy(self.options.fetch_kwargs)
20422042
if self.trial_type is not None:
2043-
metrics = assert_is_instance(
2044-
self.experiment, MultiTypeExperiment
2045-
).metrics_for_trial_type(trial_type=none_throws(self.trial_type))
2043+
metrics = self.experiment.metrics_for_trial_type(
2044+
trial_type=none_throws(self.trial_type)
2045+
)
20462046
kwargs["metrics"] = metrics
20472047
results = self.experiment.fetch_trials_data_results(
20482048
trial_indices=trial_indices,

ax/orchestration/tests/test_orchestrator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,8 +2797,8 @@ def test_validate_options_not_none_mt_trial_type(
27972797
# compatible with the type of experiment (single or multi-type)
27982798
if msg is None:
27992799
msg = (
2800-
"`mt_experiment_trial_type` must be None unless the experiment is a "
2801-
"MultiTypeExperiment."
2800+
"`mt_experiment_trial_type` must be None unless the experiment "
2801+
"has multiple trial types."
28022802
)
28032803
options = OrchestratorOptions(
28042804
init_seconds_between_polls=0, # No wait bw polls so test is fast.
@@ -3081,7 +3081,8 @@ def test_validate_options_not_none_mt_trial_type(
30813081
# test if a MultiTypeExperiment with `mt_experiment_trial_type=None`
30823082
self.orchestrator_options_kwargs["mt_experiment_trial_type"] = None
30833083
super().test_validate_options_not_none_mt_trial_type(
3084-
msg="Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
3084+
msg="Must specify `mt_experiment_trial_type` for experiments "
3085+
"with multiple trial types."
30853086
)
30863087

30873088
def test_run_n_trials_single_step_existing_experiment(

0 commit comments

Comments
 (0)