|
15 | 15 | _get_fresh_pairwise_trial_indices, |
16 | 16 | arm_to_np_array, |
17 | 17 | can_map_to_binary, |
| 18 | + extract_equality_constraints, |
| 19 | + extract_inequality_constraints, |
18 | 20 | extract_objective_weight_matrix, |
19 | 21 | extract_search_space_digest, |
20 | 22 | feasible_hypervolume, |
|
35 | 37 | from ax.core.optimization_config import MultiObjectiveOptimizationConfig |
36 | 38 | from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint |
37 | 39 | from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter |
| 40 | +from ax.core.parameter_constraint import ParameterConstraint |
38 | 41 | from ax.core.search_space import SearchSpace |
39 | 42 | from ax.core.types import ComparisonOp |
40 | 43 | from ax.exceptions.core import UserInputError |
@@ -377,6 +380,7 @@ def test_validate_and_apply_final_transform_with_target_point(self) -> None: |
377 | 380 | _, |
378 | 381 | _, |
379 | 382 | target_p, |
| 383 | + _, |
380 | 384 | ) = validate_and_apply_final_transform( |
381 | 385 | objective_weights=objective_weights, |
382 | 386 | outcome_constraints=outcome_constraints, |
@@ -412,6 +416,7 @@ def test_validate_and_apply_final_transform_none_target_point(self) -> None: |
412 | 416 | _, |
413 | 417 | _, |
414 | 418 | target_p, |
| 419 | + _, |
415 | 420 | ) = validate_and_apply_final_transform( |
416 | 421 | objective_weights=objective_weights, |
417 | 422 | outcome_constraints=outcome_constraints, |
@@ -652,3 +657,90 @@ def _attach( |
652 | 657 | self.assertNotIn(0, result) |
653 | 658 | self.assertNotIn(1, result) |
654 | 659 | self.assertIn(2, result) |
| 660 | + |
| 661 | + def test_extract_inequality_constraints(self) -> None: |
| 662 | + param_names = ["x", "y"] |
| 663 | + ineq = ParameterConstraint(inequality="x + y <= 1") |
| 664 | + eq = ParameterConstraint(equality="x + y == 1") |
| 665 | + |
| 666 | + # Only inequality constraints are extracted |
| 667 | + result = extract_inequality_constraints([ineq, eq], param_names) |
| 668 | + self.assertIsNotNone(result) |
| 669 | + assert result is not None |
| 670 | + A, b = result |
| 671 | + self.assertEqual(A.shape, (1, 2)) |
| 672 | + self.assertEqual(b.shape, (1, 1)) |
| 673 | + np.testing.assert_array_equal(A[0], [1.0, 1.0]) |
| 674 | + np.testing.assert_array_equal(b[0], [1.0]) |
| 675 | + |
| 676 | + # Returns None when no inequality constraints |
| 677 | + result = extract_inequality_constraints([eq], param_names) |
| 678 | + self.assertIsNone(result) |
| 679 | + |
| 680 | + # Returns None for empty list |
| 681 | + result = extract_inequality_constraints([], param_names) |
| 682 | + self.assertIsNone(result) |
| 683 | + |
| 684 | + def test_extract_equality_constraints(self) -> None: |
| 685 | + param_names = ["x", "y"] |
| 686 | + ineq = ParameterConstraint(inequality="x + y <= 1") |
| 687 | + eq = ParameterConstraint(equality="x + y == 1") |
| 688 | + |
| 689 | + # Only equality constraints are extracted |
| 690 | + result = extract_equality_constraints([ineq, eq], param_names) |
| 691 | + self.assertIsNotNone(result) |
| 692 | + assert result is not None |
| 693 | + A, b = result |
| 694 | + self.assertEqual(A.shape, (1, 2)) |
| 695 | + self.assertEqual(b.shape, (1, 1)) |
| 696 | + np.testing.assert_array_equal(A[0], [1.0, 1.0]) |
| 697 | + np.testing.assert_array_equal(b[0], [1.0]) |
| 698 | + |
| 699 | + # Returns None when no equality constraints |
| 700 | + result = extract_equality_constraints([ineq], param_names) |
| 701 | + self.assertIsNone(result) |
| 702 | + |
| 703 | + def test_extract_constraints_mixed(self) -> None: |
| 704 | + """Both functions correctly partition a mixed list.""" |
| 705 | + param_names = ["x", "y"] |
| 706 | + ineq1 = ParameterConstraint(inequality="x <= 0.5") |
| 707 | + ineq2 = ParameterConstraint(inequality="y <= 0.8") |
| 708 | + eq1 = ParameterConstraint(equality="x + y == 1") |
| 709 | + |
| 710 | + ineq_result = extract_inequality_constraints([ineq1, eq1, ineq2], param_names) |
| 711 | + eq_result = extract_equality_constraints([ineq1, eq1, ineq2], param_names) |
| 712 | + |
| 713 | + assert ineq_result is not None |
| 714 | + assert eq_result is not None |
| 715 | + self.assertEqual(ineq_result[0].shape, (2, 2)) # 2 inequalities |
| 716 | + self.assertEqual(eq_result[0].shape, (1, 2)) # 1 equality |
| 717 | + |
| 718 | + def test_validate_and_apply_final_transform_equality_constraints(self) -> None: |
| 719 | + """equality_constraints are converted to tensors.""" |
| 720 | + objective_weights = np.array([1.0, 0.0]) |
| 721 | + A_eq = np.array([[1.0, 1.0]]) |
| 722 | + b_eq = np.array([[1.0]]) |
| 723 | + |
| 724 | + _, _, _, _, _, _, eq_c = validate_and_apply_final_transform( |
| 725 | + objective_weights=objective_weights, |
| 726 | + outcome_constraints=None, |
| 727 | + linear_constraints=None, |
| 728 | + pending_observations=None, |
| 729 | + equality_constraints=(A_eq, b_eq), |
| 730 | + ) |
| 731 | + self.assertIsNotNone(eq_c) |
| 732 | + assert eq_c is not None |
| 733 | + self.assertTrue(torch.equal(eq_c[0], torch.tensor(A_eq))) |
| 734 | + self.assertTrue(torch.equal(eq_c[1], torch.tensor(b_eq))) |
| 735 | + |
| 736 | + def test_validate_and_apply_final_transform_no_equality_constraints(self) -> None: |
| 737 | + """equality_constraints defaults to None.""" |
| 738 | + objective_weights = np.array([1.0]) |
| 739 | + |
| 740 | + _, _, _, _, _, _, eq_c = validate_and_apply_final_transform( |
| 741 | + objective_weights=objective_weights, |
| 742 | + outcome_constraints=None, |
| 743 | + linear_constraints=None, |
| 744 | + pending_observations=None, |
| 745 | + ) |
| 746 | + self.assertIsNone(eq_c) |
0 commit comments