|
6 | 6 | # pyre-strict |
7 | 7 |
|
8 | 8 | from dataclasses import dataclass |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +from ax.adapter.adapter_utils import can_map_to_binary, is_unordered_choice |
| 12 | +from ax.core.experiment import Experiment |
| 13 | +from ax.core.objective import MultiObjective |
| 14 | +from ax.exceptions.core import OptimizationNotConfiguredError |
| 15 | +from ax.service.orchestrator import OrchestratorOptions |
9 | 16 |
|
10 | 17 |
|
11 | 18 | @dataclass(frozen=True) |
@@ -56,3 +63,75 @@ class OptimizationSummary: |
56 | 63 | min_failed_trials_for_failure_rate_check: int | None = None |
57 | 64 | non_default_advanced_options: bool | None = None |
58 | 65 | uses_merge_multiple_curves: bool | None = None |
| 66 | + |
| 67 | + |
| 68 | +def summarize_experiment_wheelhouse( |
| 69 | + experiment: Experiment, |
| 70 | + options: OrchestratorOptions, |
| 71 | + tier_metadata: dict[str, Any], |
| 72 | +) -> dict[str, Any]: |
| 73 | + """Summarize the Ax experiment. |
| 74 | +
|
| 75 | + Args: |
| 76 | + experiment: The Ax Experiment. |
| 77 | + options: The orchestrator options. |
| 78 | + tier_metadata: tier-related meta-data from the orchestrator. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + A dictionary summarizing the experiment. |
| 82 | + """ |
| 83 | + search_space = experiment.search_space |
| 84 | + optimization_config = experiment.optimization_config |
| 85 | + if optimization_config is None: |
| 86 | + raise OptimizationNotConfiguredError( |
| 87 | + "Experiment must have an optimization_config." |
| 88 | + ) |
| 89 | + params = search_space.tunable_parameters.values() |
| 90 | + |
| 91 | + max_trials = tier_metadata.get("user_supplied_max_trials", None) |
| 92 | + num_params = len(search_space.tunable_parameters) |
| 93 | + num_binary = sum(can_map_to_binary(p) for p in params) |
| 94 | + num_categorical_3_5 = sum( |
| 95 | + is_unordered_choice(p, min_choices=3, max_choices=5) for p in params |
| 96 | + ) |
| 97 | + num_categorical_6_inf = sum(is_unordered_choice(p, min_choices=6) for p in params) |
| 98 | + num_parameter_constraints = len(search_space.parameter_constraints) |
| 99 | + num_objectives = ( |
| 100 | + len(optimization_config.objective.objectives) |
| 101 | + if isinstance(optimization_config.objective, MultiObjective) |
| 102 | + else 1 |
| 103 | + ) |
| 104 | + num_outcome_constraints = len(optimization_config.outcome_constraints) |
| 105 | + uses_early_stopping = options.early_stopping_strategy is not None |
| 106 | + uses_global_stopping = options.global_stopping_strategy is not None |
| 107 | + |
| 108 | + # Check if any metrics use merge_multiple_curves |
| 109 | + uses_merge_multiple_curves = False |
| 110 | + all_metrics = list(optimization_config.metrics.values()) |
| 111 | + if hasattr(experiment, "tracking_metrics"): |
| 112 | + all_metrics.extend(experiment.tracking_metrics) |
| 113 | + |
| 114 | + for metric in all_metrics: |
| 115 | + if getattr(metric, "merge_multiple_curves", False): |
| 116 | + uses_merge_multiple_curves = True |
| 117 | + break |
| 118 | + |
| 119 | + return { |
| 120 | + "max_trials": max_trials, |
| 121 | + "num_params": num_params, |
| 122 | + "num_binary": num_binary, |
| 123 | + "num_categorical_3_5": num_categorical_3_5, |
| 124 | + "num_categorical_6_inf": num_categorical_6_inf, |
| 125 | + "num_parameter_constraints": num_parameter_constraints, |
| 126 | + "num_objectives": num_objectives, |
| 127 | + "num_outcome_constraints": num_outcome_constraints, |
| 128 | + "uses_early_stopping": uses_early_stopping, |
| 129 | + "uses_global_stopping": uses_global_stopping, |
| 130 | + "uses_merge_multiple_curves": uses_merge_multiple_curves, |
| 131 | + "all_inputs_are_configs": tier_metadata.get("all_inputs_are_configs", False), |
| 132 | + "tolerated_trial_failure_rate": options.tolerated_trial_failure_rate, |
| 133 | + "max_pending_trials": options.max_pending_trials, |
| 134 | + "min_failed_trials_for_failure_rate_check": ( |
| 135 | + options.min_failed_trials_for_failure_rate_check |
| 136 | + ), |
| 137 | + } |
0 commit comments