Skip to content

Commit 29e6da9

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
(2/6) Port helpers to OSS for the new Complexity Rating Healthcheck - summarize_ax_optimization_complexity (facebook#4649)
Summary: Pull Request resolved: facebook#4649 Reviewed By: bernardbeckerman Differential Revision: D88894871
1 parent 89b35c7 commit 29e6da9

2 files changed

Lines changed: 251 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from typing import Any
9+
10+
from ax.adapter.adapter_utils import can_map_to_binary, is_unordered_choice
11+
from ax.core.experiment import Experiment
12+
from ax.core.objective import MultiObjective
13+
from ax.exceptions.core import OptimizationNotConfiguredError
14+
from ax.service.orchestrator import OrchestratorOptions
15+
16+
17+
def summarize_ax_optimization_complexity(
18+
experiment: Experiment,
19+
options: OrchestratorOptions,
20+
tier_metadata: dict[str, Any],
21+
) -> dict[str, Any]:
22+
"""Summarize the experiment's optimization complexity.
23+
24+
This function analyzes an experiment's configuration and returns metrics and key
25+
characteristics that help assess the difficulty of the optimization problem.
26+
27+
Args:
28+
experiment: The Ax Experiment.
29+
options: The orchestrator options.
30+
tier_metadata: tier-related meta-data from the orchestrator.
31+
32+
Returns:
33+
A dictionary summarizing the experiment.
34+
"""
35+
search_space = experiment.search_space
36+
optimization_config = experiment.optimization_config
37+
if optimization_config is None:
38+
raise OptimizationNotConfiguredError(
39+
"Experiment must have an optimization_config."
40+
)
41+
params = search_space.tunable_parameters.values()
42+
43+
max_trials = tier_metadata.get("user_supplied_max_trials", None)
44+
num_params = len(search_space.tunable_parameters)
45+
num_binary = sum(can_map_to_binary(p) for p in params)
46+
num_categorical_3_5 = sum(
47+
is_unordered_choice(p, min_choices=3, max_choices=5) for p in params
48+
)
49+
num_categorical_6_inf = sum(is_unordered_choice(p, min_choices=6) for p in params)
50+
num_parameter_constraints = len(search_space.parameter_constraints)
51+
num_objectives = (
52+
len(optimization_config.objective.objectives)
53+
if isinstance(optimization_config.objective, MultiObjective)
54+
else 1
55+
)
56+
num_outcome_constraints = len(optimization_config.outcome_constraints)
57+
uses_early_stopping = options.early_stopping_strategy is not None
58+
uses_global_stopping = options.global_stopping_strategy is not None
59+
60+
# Check if any metrics use merge_multiple_curves
61+
uses_merge_multiple_curves = False
62+
all_metrics = list(optimization_config.metrics.values())
63+
if hasattr(experiment, "tracking_metrics"):
64+
all_metrics.extend(experiment.tracking_metrics)
65+
66+
for metric in all_metrics:
67+
if getattr(metric, "merge_multiple_curves", False):
68+
uses_merge_multiple_curves = True
69+
break
70+
71+
return {
72+
"max_trials": max_trials,
73+
"num_params": num_params,
74+
"num_binary": num_binary,
75+
"num_categorical_3_5": num_categorical_3_5,
76+
"num_categorical_6_inf": num_categorical_6_inf,
77+
"num_parameter_constraints": num_parameter_constraints,
78+
"num_objectives": num_objectives,
79+
"num_outcome_constraints": num_outcome_constraints,
80+
"uses_early_stopping": uses_early_stopping,
81+
"uses_global_stopping": uses_global_stopping,
82+
"uses_merge_multiple_curves": uses_merge_multiple_curves,
83+
"all_inputs_are_configs": tier_metadata.get("all_inputs_are_configs", False),
84+
"tolerated_trial_failure_rate": options.tolerated_trial_failure_rate,
85+
"max_pending_trials": options.max_pending_trials,
86+
"min_failed_trials_for_failure_rate_check": (
87+
options.min_failed_trials_for_failure_rate_check
88+
),
89+
}
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from ax.exceptions.core import OptimizationNotConfiguredError
10+
from ax.service.orchestrator import OrchestratorOptions
11+
from ax.utils.common.complexity_utils import summarize_ax_optimization_complexity
12+
from ax.utils.common.testutils import TestCase
13+
from ax.utils.testing.core_stubs import (
14+
get_experiment,
15+
get_experiment_with_multi_objective,
16+
)
17+
18+
19+
class TestSummarizeAxOptimizationComplexity(TestCase):
20+
"""Tests for the summarize_ax_optimization_complexity function."""
21+
22+
def setUp(self) -> None:
23+
super().setUp()
24+
self.experiment = get_experiment()
25+
self.options = OrchestratorOptions()
26+
self.tier_metadata: dict[str, object] = {}
27+
28+
def test_basic_experiment_summary(self) -> None:
29+
# GIVEN a basic experiment with single objective (from setUp)
30+
31+
# WHEN we summarize the experiment
32+
summary = summarize_ax_optimization_complexity(
33+
experiment=self.experiment,
34+
options=self.options,
35+
tier_metadata=self.tier_metadata,
36+
)
37+
38+
# THEN the summary should contain all expected keys with correct values
39+
expected_keys = [
40+
"max_trials",
41+
"num_params",
42+
"num_binary",
43+
"num_categorical_3_5",
44+
"num_categorical_6_inf",
45+
"num_parameter_constraints",
46+
"num_objectives",
47+
"num_outcome_constraints",
48+
"uses_early_stopping",
49+
"uses_global_stopping",
50+
"uses_merge_multiple_curves",
51+
"all_inputs_are_configs",
52+
"tolerated_trial_failure_rate",
53+
"max_pending_trials",
54+
"min_failed_trials_for_failure_rate_check",
55+
]
56+
for key in expected_keys:
57+
self.assertIn(key, summary)
58+
59+
# Validate specific values for single-objective experiment
60+
self.assertEqual(summary["num_objectives"], 1)
61+
self.assertFalse(summary["uses_early_stopping"])
62+
self.assertFalse(summary["uses_global_stopping"])
63+
64+
def test_multi_objective_experiment(self) -> None:
65+
# GIVEN a multi-objective experiment
66+
experiment = get_experiment_with_multi_objective()
67+
68+
# WHEN we summarize the experiment
69+
summary = summarize_ax_optimization_complexity(
70+
experiment=experiment,
71+
options=self.options,
72+
tier_metadata=self.tier_metadata,
73+
)
74+
75+
# THEN num_objectives should be greater than 1
76+
self.assertGreater(summary["num_objectives"], 1)
77+
78+
def test_experiment_without_optimization_config_raises(self) -> None:
79+
# GIVEN an experiment without optimization config
80+
self.experiment._optimization_config = None
81+
82+
# WHEN/THEN summarizing should raise OptimizationNotConfiguredError
83+
with self.assertRaisesRegex(
84+
OptimizationNotConfiguredError,
85+
"Experiment must have an optimization_config",
86+
):
87+
summarize_ax_optimization_complexity(
88+
experiment=self.experiment,
89+
options=self.options,
90+
tier_metadata=self.tier_metadata,
91+
)
92+
93+
def test_tier_metadata_extraction(self) -> None:
94+
# Test that tier_metadata values are correctly extracted
95+
test_cases = [
96+
(
97+
"with_values",
98+
{"user_supplied_max_trials": 50, "all_inputs_are_configs": True},
99+
50,
100+
True,
101+
),
102+
(
103+
"empty_defaults",
104+
{},
105+
None,
106+
False,
107+
),
108+
]
109+
110+
for (
111+
name,
112+
tier_metadata,
113+
expected_max_trials,
114+
expected_all_configs,
115+
) in test_cases:
116+
with self.subTest(name=name):
117+
# WHEN we summarize the experiment
118+
summary = summarize_ax_optimization_complexity(
119+
experiment=self.experiment,
120+
options=self.options,
121+
tier_metadata=tier_metadata,
122+
)
123+
124+
# THEN the summary should reflect tier metadata values
125+
self.assertEqual(summary["max_trials"], expected_max_trials)
126+
self.assertEqual(
127+
summary["all_inputs_are_configs"], expected_all_configs
128+
)
129+
130+
def test_orchestrator_options_extraction(self) -> None:
131+
# GIVEN custom orchestrator options
132+
options = OrchestratorOptions(
133+
tolerated_trial_failure_rate=0.25,
134+
max_pending_trials=5,
135+
min_failed_trials_for_failure_rate_check=10,
136+
)
137+
138+
# WHEN we summarize the experiment
139+
summary = summarize_ax_optimization_complexity(
140+
experiment=self.experiment,
141+
options=options,
142+
tier_metadata=self.tier_metadata,
143+
)
144+
145+
# THEN the summary should reflect orchestrator options
146+
self.assertEqual(summary["tolerated_trial_failure_rate"], 0.25)
147+
self.assertEqual(summary["max_pending_trials"], 5)
148+
self.assertEqual(summary["min_failed_trials_for_failure_rate_check"], 10)
149+
150+
def test_parameter_constraints_counted(self) -> None:
151+
# GIVEN an experiment with parameter constraints
152+
experiment = get_experiment(constrain_search_space=True)
153+
154+
# WHEN we summarize the experiment
155+
summary = summarize_ax_optimization_complexity(
156+
experiment=experiment,
157+
options=self.options,
158+
tier_metadata=self.tier_metadata,
159+
)
160+
161+
# THEN num_parameter_constraints should be greater than 0
162+
self.assertGreater(summary["num_parameter_constraints"], 0)

0 commit comments

Comments
 (0)