|
8 | 8 |
|
9 | 9 | from __future__ import annotations |
10 | 10 |
|
| 11 | +import math |
| 12 | + |
11 | 13 | import warnings |
12 | 14 | from collections.abc import Mapping, Sequence |
13 | 15 | from dataclasses import dataclass, field |
14 | 16 | from logging import Logger |
15 | 17 |
|
| 18 | +import numpy as np |
| 19 | + |
16 | 20 | import pandas as pd |
17 | 21 | from ax import core |
18 | 22 | from ax.core.arm import Arm |
|
36 | 40 | from ax.utils.common.constants import Keys |
37 | 41 | from ax.utils.common.logger import get_logger |
38 | 42 | from pyre_extensions import none_throws |
| 43 | +from scipy.optimize import linprog |
| 44 | + |
| 45 | +from scipy.special import expit, logit |
39 | 46 |
|
40 | 47 |
|
41 | 48 | logger: Logger = get_logger(__name__) |
@@ -572,6 +579,134 @@ def clone(self) -> SearchSpace: |
572 | 579 | parameter_constraints=[pc.clone() for pc in self._parameter_constraints], |
573 | 580 | ) |
574 | 581 |
|
| 582 | + def compute_naive_center(self) -> TParameterization: |
| 583 | + """Compute the naive center of the search space. |
| 584 | +
|
| 585 | + For range parameters, the center is the midpoint of the range. If the |
| 586 | + parameter is log-scale, then the center point will correspond to the |
| 587 | + mid-point in log-scale. If the parameter is logit-scale, then the center |
| 588 | + point will correspond to the mid-point in logit-scale. |
| 589 | + For choice parameters, the center point is determined as the value |
| 590 | + that is at the middle of the values list. |
| 591 | + For both choice and integer range parameters, ties are broken in |
| 592 | + favor of the larger value / index. For example, a binary parameter with |
| 593 | + values [0, 1] will be sampled as 1. |
| 594 | + Fixed parameters are returned at their only allowed value. |
| 595 | +
|
| 596 | + Returns: |
| 597 | + A parameterization with the center values for each parameter. |
| 598 | + """ |
| 599 | + parameters = {} |
| 600 | + derived_params = [] |
| 601 | + for name, p in self.parameters.items(): |
| 602 | + if isinstance(p, RangeParameter): |
| 603 | + if p.logit_scale: |
| 604 | + center = expit((logit(p.lower) + logit(p.upper)) / 2.0) |
| 605 | + elif p.log_scale: |
| 606 | + center = 10 ** ((math.log10(p.lower) + math.log10(p.upper)) / 2.0) |
| 607 | + else: |
| 608 | + center = (float(p.lower) + float(p.upper)) / 2.0 |
| 609 | + parameters[name] = p.cast(center) |
| 610 | + elif isinstance(p, ChoiceParameter): |
| 611 | + parameters[name] = p.values[int(len(p.values) / 2)] |
| 612 | + elif isinstance(p, FixedParameter): |
| 613 | + parameters[name] = p.value |
| 614 | + elif isinstance(p, DerivedParameter): |
| 615 | + derived_params.append(p) |
| 616 | + else: |
| 617 | + raise NotImplementedError(f"Parameter type {type(p)} is not supported.") |
| 618 | + for p in derived_params: |
| 619 | + parameters[p.name] = p.compute(parameters=parameters) |
| 620 | + if self.is_hierarchical: |
| 621 | + parameters = self._cast_parameterization(parameters=parameters) |
| 622 | + return parameters |
| 623 | + |
| 624 | + def compute_chebyshev_center(self) -> dict[str, float] | None: |
| 625 | + """Compute the Chebyshev center of the constraint polytope. |
| 626 | +
|
| 627 | + The Chebyshev center is the center of the largest inscribed ball in the |
| 628 | + feasible region defined by the parameter constraints. This is computed |
| 629 | + by solving a linear program. It is most limited by the tightest constraint. |
| 630 | +
|
| 631 | + For a polytope defined by a @ x <= b, the Chebyshev center (x_c, r) is |
| 632 | + the solution to: |
| 633 | + maximize r, where r is the radius of the inscribed ball |
| 634 | + subject to: a_i^T x + r ||a_i||_2 <= b_i for all i |
| 635 | +
|
| 636 | + Note: this only considers natural (non-log, non-logit) range parameters. |
| 637 | + Other parameter types are handled naively via compute_naive_center. |
| 638 | +
|
| 639 | + Returns: |
| 640 | + A dictionary mapping parameter names to values at the Chebyshev center, |
| 641 | + or None if the problem is infeasible. |
| 642 | + """ |
| 643 | + # Only consider non-log, non-logit range parameters |
| 644 | + natural_range_params = { |
| 645 | + name: param |
| 646 | + for name, param in self.range_parameters.items() |
| 647 | + if not param.log_scale and not param.logit_scale |
| 648 | + } |
| 649 | + |
| 650 | + if not natural_range_params: |
| 651 | + return {} |
| 652 | + |
| 653 | + constraint_matrix = [] |
| 654 | + bound_vector = [] |
| 655 | + param_names = list(natural_range_params.keys()) |
| 656 | + num_params = len(natural_range_params) |
| 657 | + param_name_to_idx = {name: idx for idx, name in enumerate(param_names)} |
| 658 | + |
| 659 | + # Add parameter constraints |
| 660 | + for constraint in self.parameter_constraints: |
| 661 | + row = np.zeros(num_params) |
| 662 | + for param_name, weight in constraint.constraint_dict.items(): |
| 663 | + if param_name in param_name_to_idx: |
| 664 | + row[param_name_to_idx[param_name]] = weight |
| 665 | + |
| 666 | + constraint_matrix.append(row) |
| 667 | + bound_vector.append(constraint.bound) |
| 668 | + |
| 669 | + # Add parameter bounds |
| 670 | + for name, idx in param_name_to_idx.items(): |
| 671 | + param = natural_range_params[name] |
| 672 | + # lower bound: -x_i <= -lower_i |
| 673 | + row_lower = np.zeros(num_params) |
| 674 | + row_lower[idx] = -1.0 |
| 675 | + constraint_matrix.append(row_lower) |
| 676 | + bound_vector.append(-float(param.lower)) |
| 677 | + |
| 678 | + # upper bound: x_i <= upper_i |
| 679 | + row_upper = np.zeros(num_params) |
| 680 | + row_upper[idx] = 1.0 |
| 681 | + constraint_matrix.append(row_upper) |
| 682 | + bound_vector.append(float(param.upper)) |
| 683 | + |
| 684 | + constraint_matrix = np.array(constraint_matrix) |
| 685 | + bound_vector = np.array(bound_vector) |
| 686 | + |
| 687 | + # Compute norm for each vector in constraint matrix |
| 688 | + row_norms = np.linalg.norm(constraint_matrix, axis=1) |
| 689 | + augmented_constraint_matrix = np.column_stack([constraint_matrix, row_norms]) |
| 690 | + |
| 691 | + # Set objective vector which maximizes r (minimize -r == maximize r) |
| 692 | + radius_objective_vector = np.zeros(num_params + 1) |
| 693 | + radius_objective_vector[-1] = -1.0 |
| 694 | + result = linprog( |
| 695 | + c=radius_objective_vector, |
| 696 | + A_ub=augmented_constraint_matrix, |
| 697 | + b_ub=bound_vector, |
| 698 | + bounds=[(None, None)] * num_params + [(0, None)], # no bounds except r >= 0 |
| 699 | + ) |
| 700 | + |
| 701 | + if not result.success or result.x is None: |
| 702 | + return None |
| 703 | + |
| 704 | + center_values = result.x[:num_params] # remove r |
| 705 | + center_dict = { |
| 706 | + name: float(center_values[param_name_to_idx[name]]) for name in param_names |
| 707 | + } |
| 708 | + return center_dict |
| 709 | + |
575 | 710 | def _validate_parameter_constraints( |
576 | 711 | self, parameter_constraints: list[ParameterConstraint] |
577 | 712 | ) -> None: |
|
0 commit comments