|
9 | 9 | from unittest import mock |
10 | 10 |
|
11 | 11 | import numpy as np |
| 12 | +import torch |
12 | 13 | from ax.core.search_space import SearchSpaceDigest |
13 | 14 | from ax.exceptions.core import SearchSpaceExhausted |
14 | 15 | from ax.generators.random.sobol import SobolGenerator |
@@ -244,6 +245,81 @@ def rounding_function(x: npt.NDArray) -> npt.NDArray: |
244 | 245 |
|
245 | 246 | rounding_func.assert_called() |
246 | 247 |
|
| 248 | + def test_SobolGeneratorFallbackToPolytopeSamplerWithEqualityConstraints( |
| 249 | + self, |
| 250 | + ) -> None: |
| 251 | + # Ten parameters with sum <= 1 (forces polytope fallback) plus |
| 252 | + # an equality constraint x0 + x1 = 0.4 to verify equality constraints |
| 253 | + # are threaded through to the polytope sampler. |
| 254 | + generator = SobolGenerator(seed=0, fallback_to_sample_polytope=True) |
| 255 | + ssd = self._create_ssd(n_tunable=10, n_fixed=0) |
| 256 | + A_ineq = np.ones((1, 10)) |
| 257 | + b_ineq = np.array([1]).reshape((1, 1)) |
| 258 | + # Equality constraint: x0 + x1 = 0.4 |
| 259 | + A_eq = np.zeros((1, 10)) |
| 260 | + A_eq[0, 0] = 1.0 |
| 261 | + A_eq[0, 1] = 1.0 |
| 262 | + b_eq = np.array([[0.4]]) |
| 263 | + generated_points, _ = generator.gen( |
| 264 | + n=3, |
| 265 | + search_space_digest=ssd, |
| 266 | + linear_constraints=(A_ineq, b_ineq), |
| 267 | + equality_constraints=(A_eq, b_eq), |
| 268 | + rounding_func=lambda x: x, |
| 269 | + ) |
| 270 | + self.assertEqual(np.shape(generated_points), (3, 10)) |
| 271 | + # Inequality constraint satisfied. |
| 272 | + self.assertTrue(np.all(generated_points @ A_ineq.T <= b_ineq + 1e-6)) |
| 273 | + # Equality constraint satisfied: x0 + x1 ≈ 0.4. |
| 274 | + eq_vals = generated_points @ A_eq.T |
| 275 | + np.testing.assert_allclose(eq_vals, 0.4, atol=1e-6) |
| 276 | + |
| 277 | + def test_SobolGeneratorFallbackToPolytopeSamplerWithEqAndFixedFeatures( |
| 278 | + self, |
| 279 | + ) -> None: |
| 280 | + # Verify that equality constraints and fixed features are combined |
| 281 | + # correctly and passed to the polytope sampler. |
| 282 | + generator = SobolGenerator(seed=0, fallback_to_sample_polytope=True) |
| 283 | + ssd = self._create_ssd(n_tunable=10, n_fixed=1) |
| 284 | + A_ineq = np.insert(np.ones((1, 10)), 10, 0, axis=1) |
| 285 | + b_ineq = np.array([1]).reshape((1, 1)) |
| 286 | + # Equality constraint: x0 + x1 = 0.3. |
| 287 | + A_eq = np.zeros((1, 11)) |
| 288 | + A_eq[0, 0] = 1.0 |
| 289 | + A_eq[0, 1] = 1.0 |
| 290 | + b_eq = np.array([[0.3]]) |
| 291 | + |
| 292 | + with mock.patch( |
| 293 | + "ax.generators.random.base.HitAndRunPolytopeSampler" |
| 294 | + ) as MockSampler: |
| 295 | + mock_instance = MockSampler.return_value |
| 296 | + mock_instance.draw.return_value = torch.rand(1, 11, dtype=torch.double) |
| 297 | + try: |
| 298 | + generator.gen( |
| 299 | + n=1, |
| 300 | + search_space_digest=ssd, |
| 301 | + linear_constraints=(A_ineq, b_ineq), |
| 302 | + equality_constraints=(A_eq, b_eq), |
| 303 | + fixed_features={10: 1}, |
| 304 | + rounding_func=lambda x: x, |
| 305 | + ) |
| 306 | + except Exception: |
| 307 | + pass |
| 308 | + if MockSampler.called: |
| 309 | + eq_arg = MockSampler.call_args.kwargs["equality_constraints"] |
| 310 | + self.assertIsNotNone(eq_arg) |
| 311 | + C, c = eq_arg |
| 312 | + # 1 from fixed features + 1 from parameter equality = 2 rows. |
| 313 | + self.assertEqual(C.shape[0], 2) |
| 314 | + self.assertEqual(C.shape[1], 11) |
| 315 | + # Fixed feature: x[10] = 1 (last row from fixed, first in sorted). |
| 316 | + self.assertEqual(C[0, 10].item(), 1.0) |
| 317 | + self.assertAlmostEqual(c[0].item(), 1.0) |
| 318 | + # Parameter equality: x[0] + x[1] = 0.3. |
| 319 | + self.assertEqual(C[1, 0].item(), 1.0) |
| 320 | + self.assertEqual(C[1, 1].item(), 1.0) |
| 321 | + self.assertAlmostEqual(c[1].item(), 0.3) |
| 322 | + |
247 | 323 | def test_SobolGeneratorFallbackToPolytopeSamplerWithFixedParam(self) -> None: |
248 | 324 | # Ten parameters with sum less than 1. In this example, the rejection |
249 | 325 | # sampler gives a search space exhausted error. Testing fallback to |
|
0 commit comments