|
20 | 20 | from ax.core.arm import Arm |
21 | 21 | from ax.core.metric import Metric |
22 | 22 | from ax.core.objective import MultiObjective, Objective |
23 | | -from ax.core.optimization_config import MultiObjectiveOptimizationConfig |
| 23 | +from ax.core.optimization_config import ( |
| 24 | + MultiObjectiveOptimizationConfig, |
| 25 | + OptimizationConfig, |
| 26 | +) |
24 | 27 | from ax.core.outcome_constraint import ObjectiveThreshold |
25 | 28 | from ax.core.types import ComparisonOp |
| 29 | +from ax.exceptions.core import UnsupportedError |
26 | 30 | from ax.generation_strategy.generation_node import GenerationStep |
27 | 31 | from ax.generation_strategy.generation_strategy import GenerationStrategy |
28 | 32 | from ax.orchestration.orchestrator import Orchestrator |
|
49 | 53 | get_branin_experiment, |
50 | 54 | get_branin_experiment_with_multi_objective, |
51 | 55 | get_branin_experiment_with_timestamp_map_metric, |
| 56 | + get_branin_metric, |
52 | 57 | get_experiment_with_observations, |
53 | 58 | get_high_dimensional_branin_experiment, |
54 | 59 | get_multi_type_experiment, |
@@ -743,3 +748,31 @@ def test_maybe_extract_baseline_comparison_values_metric_missing_moo(self) -> No |
743 | 748 | baseline_arm_name=arm_names[0], |
744 | 749 | ) |
745 | 750 | self.assertIsNone(result) |
| 751 | + |
| 752 | + def test_get_objective_trace_plot_scalarized(self) -> None: |
| 753 | + """_get_objective_trace_plot raises UnsupportedError for scalarized.""" |
| 754 | + exp = get_branin_experiment(with_completed_trial=True) |
| 755 | + exp.add_tracking_metric(get_branin_metric(name="branin2")) |
| 756 | + exp._optimization_config = OptimizationConfig( |
| 757 | + objective=Objective(expression="2*branin + -1*branin2"), |
| 758 | + ) |
| 759 | + with self.assertRaisesRegex(UnsupportedError, "not supported for scalarized"): |
| 760 | + _get_objective_trace_plot(experiment=exp) |
| 761 | + |
| 762 | + def test_maybe_extract_baseline_comparison_values_scalarized(self) -> None: |
| 763 | + """maybe_extract_baseline_comparison_values raises UnsupportedError |
| 764 | + for scalarized.""" |
| 765 | + exp = get_branin_experiment_with_multi_objective(with_batch=True) |
| 766 | + exp.trials[0].run() |
| 767 | + exp.fetch_data() |
| 768 | + arm_names = list(exp.arms_by_name.keys()) |
| 769 | + exp._optimization_config = OptimizationConfig( |
| 770 | + objective=Objective(expression="2*branin_a + -1*branin_b"), |
| 771 | + ) |
| 772 | + with self.assertRaisesRegex(UnsupportedError, "not supported for scalarized"): |
| 773 | + maybe_extract_baseline_comparison_values( |
| 774 | + experiment=exp, |
| 775 | + optimization_config=exp.optimization_config, |
| 776 | + comparison_arm_names=[arm_names[1]], |
| 777 | + baseline_arm_name=arm_names[0], |
| 778 | + ) |
0 commit comments