Skip to content

Commit 27de1d6

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Replace isinstance(MultiTypeExperiment) checks with feature checks (facebook#5001)
Summary: Pull Request resolved: facebook#5001 Move `filter_trials_by_type` and `get_trial_indices_for_statuses` from `multi_type_experiment.py` to `experiment.py`, with backward-compatible re-exports from the old module. Replace all `isinstance(experiment, MultiTypeExperiment)` checks with `experiment.default_trial_type is not None` feature checks in: - `ax/orchestration/orchestrator.py` - `ax/fb/realtime/utils.py` - `ax/fb/axolotl/utils_chronos.py` Replace `assert_is_instance(experiment, MultiTypeExperiment).method()` calls with direct calls on base `Experiment` (which now has `runner_for_trial_type` and `metrics_for_trial_type`). This is Phase 6 of the MultiTypeExperiment pull-up plan. Differential Revision: D94999020
1 parent 34b9dc5 commit 27de1d6

4 files changed

Lines changed: 84 additions & 70 deletions

File tree

ax/core/experiment.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,3 +2739,47 @@ def add_arm_and_prevent_naming_collision(
27392739
stacklevel=2,
27402740
)
27412741
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True))
2742+
2743+
2744+
def filter_trials_by_type(
2745+
trials: Sequence[BaseTrial], trial_type: str | None
2746+
) -> list[BaseTrial]:
2747+
"""Filter trials by trial type if provided.
2748+
2749+
This filters trials by trial type if the experiment has multiple
2750+
trial types.
2751+
2752+
Args:
2753+
trials: Trials to filter.
2754+
trial_type: The trial type to filter by. If None, all trials are returned.
2755+
2756+
Returns:
2757+
Filtered trials.
2758+
"""
2759+
if trial_type is not None:
2760+
return [t for t in trials if t.trial_type == trial_type]
2761+
return list(trials)
2762+
2763+
2764+
def get_trial_indices_for_statuses(
2765+
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
2766+
) -> set[int]:
2767+
"""Get trial indices for a set of statuses.
2768+
2769+
Args:
2770+
experiment: The experiment to get trial indices from.
2771+
statuses: Set of statuses to get trial indices for.
2772+
trial_type: If provided, only return indices for trials of this type.
2773+
2774+
Returns:
2775+
Set of trial indices for the given statuses.
2776+
"""
2777+
return {
2778+
i
2779+
for i, t in experiment.trials.items()
2780+
if (t.status in statuses)
2781+
and (
2782+
(trial_type is None)
2783+
or ((trial_type is not None) and (t.trial_type == trial_type))
2784+
)
2785+
}

ax/core/multi_type_experiment.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
# pyre-strict
88

9-
from collections.abc import Sequence
109
from typing import Any, Self
1110

1211
from ax.core.arm import Arm
13-
from ax.core.base_trial import BaseTrial, TrialStatus
14-
from ax.core.experiment import Experiment
12+
from ax.core.experiment import (
13+
Experiment,
14+
filter_trials_by_type,
15+
get_trial_indices_for_statuses,
16+
)
1517
from ax.core.metric import Metric
1618
from ax.core.optimization_config import OptimizationConfig
1719
from ax.core.runner import Runner
@@ -171,42 +173,9 @@ def remove_metric(self, metric_name: str) -> Self:
171173
return self
172174

173175

174-
def filter_trials_by_type(
175-
trials: Sequence[BaseTrial], trial_type: str | None
176-
) -> list[BaseTrial]:
177-
"""Filter trials by trial type if provided.
178-
179-
This filters trials by trial type if the experiment is a
180-
MultiTypeExperiment.
181-
182-
Args:
183-
trials: Trials to filter.
184-
185-
Returns:
186-
Filtered trials.
187-
"""
188-
if trial_type is not None:
189-
return [t for t in trials if t.trial_type == trial_type]
190-
return list(trials)
191-
192-
193-
def get_trial_indices_for_statuses(
194-
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
195-
) -> set[int]:
196-
"""Get trial indices for a set of statuses.
197-
198-
Args:
199-
statuses: Set of statuses to get trial indices for.
200-
201-
Returns:
202-
Set of trial indices for the given statuses.
203-
"""
204-
return {
205-
i
206-
for i, t in experiment.trials.items()
207-
if (t.status in statuses)
208-
and (
209-
(trial_type is None)
210-
or ((trial_type is not None) and (t.trial_type == trial_type))
211-
)
212-
}
176+
# Re-exported from ax.core.experiment for backward compatibility.
177+
__all__ = [
178+
"MultiTypeExperiment",
179+
"filter_trials_by_type",
180+
"get_trial_indices_for_statuses",
181+
]

ax/orchestration/orchestrator.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
from ax.adapter.adapter_utils import get_fixed_features_from_experiment
2222
from ax.adapter.base import Adapter
2323
from ax.core.base_trial import BaseTrial
24-
from ax.core.experiment import Experiment
25-
from ax.core.generator_run import GeneratorRun
26-
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
27-
from ax.core.multi_type_experiment import (
24+
from ax.core.experiment import (
25+
Experiment,
2826
filter_trials_by_type,
2927
get_trial_indices_for_statuses,
30-
MultiTypeExperiment,
3128
)
29+
from ax.core.generator_run import GeneratorRun
30+
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
3231
from ax.core.runner import Runner
3332
from ax.core.trial import Trial
3433
from ax.core.trial_status import TrialStatus
@@ -58,7 +57,7 @@
5857
set_ax_logger_levels,
5958
)
6059
from ax.utils.common.timeutils import current_timestamp_in_millis
61-
from pyre_extensions import assert_is_instance, none_throws
60+
from pyre_extensions import none_throws
6261

6362

6463
NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \
@@ -367,21 +366,21 @@ def options(self, options: OrchestratorOptions) -> None:
367366
def trial_type(self) -> str | None:
368367
"""Trial type for the experiment this Orchestrator is running.
369368
370-
This returns None if the experiment is not a MultitypeExperiment
369+
This returns None if the experiment does not have multiple trial types.
371370
372371
Returns:
373372
Trial type for the experiment this Orchestrator is running if the
374-
experiment is a MultiTypeExperiment and None otherwise.
373+
experiment has multiple trial types and None otherwise.
375374
"""
376-
if isinstance(self.experiment, MultiTypeExperiment):
375+
if self.experiment.default_trial_type is not None:
377376
return self.options.mt_experiment_trial_type
378377
return None
379378

380379
@property
381380
def running_trials(self) -> list[BaseTrial]:
382381
"""Currently running trials.
383382
384-
Note: if the experiment is a MultiTypeExperiment, then this will
383+
Note: if the experiment has multiple trial types, then this will
385384
only fetch trials of type `orchestrator.trial_type`.
386385
387386
@@ -397,7 +396,7 @@ def running_trials(self) -> list[BaseTrial]:
397396
def trials(self) -> list[BaseTrial]:
398397
"""All trials.
399398
400-
Note: if the experiment is a MultiTypeExperiment, then this will
399+
Note: if the experiment has multiple trial types, then this will
401400
only fetch trials of type `orchestrator.trial_type`.
402401
403402
Returns:
@@ -424,7 +423,7 @@ def running_trial_indices(self) -> set[int]:
424423
def failed_abandoned_trial_indices(self) -> set[int]:
425424
"""Failed or abandoned trials.
426425
427-
Note: if the experiment is a MultiTypeExperiment, then this will
426+
Note: if the experiment has multiple trial types, then this will
428427
only fetch trials of type `orchestrator.trial_type`.
429428
430429
Returns:
@@ -441,7 +440,7 @@ def pending_trials(self) -> list[BaseTrial]:
441440
"""Running or staged trials on the experiment this Orchestrator is
442441
running.
443442
444-
Note: if the experiment is a MultiTypeExperiment, then this will
443+
Note: if the experiment has multiple trial types, then this will
445444
only fetch trials of type `orchestrator.trial_type`.
446445
447446
Returns:
@@ -457,7 +456,7 @@ def pending_trials(self) -> list[BaseTrial]:
457456
def candidate_trials(self) -> list[BaseTrial]:
458457
"""Candidate trials on the experiment this Orchestrator is running.
459458
460-
Note: if the experiment is a MultiTypeExperiment, then this will
459+
Note: if the experiment has multiple trial types, then this will
461460
only fetch trials of type `orchestrator.trial_type`.
462461
463462
Returns:
@@ -472,7 +471,7 @@ def candidate_trials(self) -> list[BaseTrial]:
472471
def trials_expecting_data(self) -> list[BaseTrial]:
473472
"""Trials expecting data.
474473
475-
Note: if the experiment is a MultiTypeExperiment, then this will
474+
Note: if the experiment has multiple trial types, then this will
476475
only fetch trials of type `orchestrator.trial_type`.
477476
"""
478477
trials = []
@@ -488,9 +487,9 @@ def runner(self) -> Runner:
488487
instance.
489488
"""
490489
if self.trial_type is not None:
491-
runner = assert_is_instance(
492-
self.experiment, MultiTypeExperiment
493-
).runner_for_trial_type(trial_type=none_throws(self.trial_type))
490+
runner = self.experiment.runner_for_trial_type(
491+
trial_type=none_throws(self.trial_type)
492+
)
494493
else:
495494
runner = self.experiment.runner
496495
if runner is None:
@@ -1626,10 +1625,11 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16261625
"will be unable to fetch intermediate results with which to "
16271626
"evaluate early stopping criteria."
16281627
)
1629-
if isinstance(self.experiment, MultiTypeExperiment):
1628+
if self.experiment.default_trial_type is not None:
16301629
if options.mt_experiment_trial_type is None:
16311630
raise UserInputError(
1632-
"Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
1631+
"Must specify `mt_experiment_trial_type` for experiments "
1632+
"with multiple trial types."
16331633
)
16341634
if not self.experiment.supports_trial_type(
16351635
options.mt_experiment_trial_type
@@ -1640,8 +1640,8 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
16401640
)
16411641
elif options.mt_experiment_trial_type is not None:
16421642
raise UserInputError(
1643-
"`mt_experiment_trial_type` must be None unless the experiment is a "
1644-
"MultiTypeExperiment."
1643+
"`mt_experiment_trial_type` must be None unless the experiment "
1644+
"has multiple trial types."
16451645
)
16461646

16471647
def _get_max_pending_trials(self) -> int:
@@ -2040,9 +2040,9 @@ def _fetch_and_process_trials_data_results(
20402040
try:
20412041
kwargs = deepcopy(self.options.fetch_kwargs)
20422042
if self.trial_type is not None:
2043-
metrics = assert_is_instance(
2044-
self.experiment, MultiTypeExperiment
2045-
).metrics_for_trial_type(trial_type=none_throws(self.trial_type))
2043+
metrics = self.experiment.metrics_for_trial_type(
2044+
trial_type=none_throws(self.trial_type)
2045+
)
20462046
kwargs["metrics"] = metrics
20472047
results = self.experiment.fetch_trials_data_results(
20482048
trial_indices=trial_indices,

ax/orchestration/tests/test_orchestrator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2796,8 +2796,8 @@ def test_validate_options_not_none_mt_trial_type(
27962796
# compatible with the type of experiment (single or multi-type)
27972797
if msg is None:
27982798
msg = (
2799-
"`mt_experiment_trial_type` must be None unless the experiment is a "
2800-
"MultiTypeExperiment."
2799+
"`mt_experiment_trial_type` must be None unless the experiment "
2800+
"has multiple trial types."
28012801
)
28022802
options = OrchestratorOptions(
28032803
init_seconds_between_polls=0, # No wait bw polls so test is fast.
@@ -3080,7 +3080,8 @@ def test_validate_options_not_none_mt_trial_type(
30803080
# test if a MultiTypeExperiment with `mt_experiment_trial_type=None`
30813081
self.orchestrator_options_kwargs["mt_experiment_trial_type"] = None
30823082
super().test_validate_options_not_none_mt_trial_type(
3083-
msg="Must specify `mt_experiment_trial_type` for MultiTypeExperiment."
3083+
msg="Must specify `mt_experiment_trial_type` for experiments "
3084+
"with multiple trial types."
30843085
)
30853086

30863087
def test_run_n_trials_single_step_existing_experiment(

0 commit comments

Comments
 (0)