|
| 1 | +import logging |
| 2 | +import numpy as np |
| 3 | +from collections.abc import Sequence |
| 4 | +from typing import Dict, List, Union |
| 5 | + |
| 6 | +# Configure module-level logger; users can configure handlers in their application |
| 7 | +logger = logging.getLogger(__name__) |
| 8 | + |
| 9 | + |
| 10 | +def sample_alternative_solutions( |
| 11 | + problem, |
| 12 | + variable_name: str, |
| 13 | + *, |
| 14 | + percentage: float = 0.10, |
| 15 | + scale: float = 0.03, |
| 16 | + rel_opt_tol: float = 0.05, |
| 17 | + max_samples: int = 30, |
| 18 | + perturbation_name: str = "perturbation", |
| 19 | + solver_kwargs: dict | None = None, |
| 20 | + rng: np.random.Generator | int | None = None, |
| 21 | + collect_vars: Sequence[str] | None = None, |
| 22 | + verbose: int = 1, # 0 = silent, 1 = summary, 2 = full detail |
| 23 | +) -> Dict[str, np.ndarray]: |
| 24 | + """Sample alternative solutions by perturbing a chosen decision variable. |
| 25 | +
|
| 26 | + This routine takes an optimization problem (with attributes .expr, |
| 27 | + .solve(), and .objectives), identifies a target variable within it, |
| 28 | + and generates up to max_samples new feasible solutions by randomly |
| 29 | + perturbing a fraction of that variable's entries. Only those perturbations |
| 30 | + that keep all original objectives within a relative tolerance of the |
| 31 | + baseline are accepted. |
| 32 | +
|
| 33 | + Args: |
| 34 | + problem: An optimization problem instance exposing |
| 35 | + - expr: a mapping of variable names to variable objects, |
| 36 | + - solve(...): method to solve the problem, |
| 37 | + - objectives: list of objective objects (each with .name and .value). |
| 38 | + variable_name (str): Name of the variable in problem.expr to perturb. |
| 39 | + percentage (float, optional): Fraction of the variable's entries to |
| 40 | + perturb in each trial (default 0.10). |
| 41 | + scale (float, optional): Standard deviation of the normal random noise |
| 42 | + (default 0.03). |
| 43 | + rel_opt_tol (float, optional): Maximum allowed relative deviation of |
| 44 | + any original objective from its baseline value (default 0.05). |
| 45 | + max_samples (int, optional): Maximum number of perturbation trials |
| 46 | + to attempt (default 30). |
| 47 | + perturbation_name (str, optional): Name to assign to the added |
| 48 | + perturbation objective (default `"perturbation"). |
| 49 | + solver_kwargs (dict or None, optional): Extra keyword arguments passed |
| 50 | + to problem.solve() (default None). |
| 51 | + rng (np.random.Generator or int or None, optional): Random number |
| 52 | + generator or seed for reproducibility (default None). |
| 53 | + collect_vars (Sequence[str] or None, optional): |
| 54 | + Names of the variables whose values you want back. |
| 55 | +
|
| 56 | + - `None (default) – collect **every** variable in problem.expr |
| 57 | + - `[] – collect **none** (method solves but returns an empty dict) |
| 58 | + - `["x", "y"] – collect only those named variables |
| 59 | + verbose (int, optional): Verbosity level: |
| 60 | + 0 = silent, 1 = summary, 2 = full detail (default 1). |
| 61 | +
|
| 62 | + Returns: |
| 63 | + dict: |
| 64 | + A dictionary that maps each collected variable name to a NumPy |
| 65 | + array with shape `(n_samples, *variable.shape) where |
| 66 | +
|
| 67 | + * `n_samples ≥ 1 – it counts the incumbent plus every accepted |
| 68 | + perturbation; |
| 69 | + * the remaining dimensions match the variable’s own shape. |
| 70 | +
|
| 71 | + Example:: |
| 72 | +
|
| 73 | + out = sample_alternative_solutions(problem, "x", collect_vars=["x", "y"]) |
| 74 | + x_stack = out["x"] # shape (n_samples, *x.shape) |
| 75 | + incumbent_x = x_stack[0] # first slice is always the baseline |
| 76 | +
|
| 77 | + Raises: |
| 78 | + KeyError: |
| 79 | + If `variable_name is not in problem.expr **or** if any name |
| 80 | + inside `collect_vars is missing from problem.expr. |
| 81 | + """ |
| 82 | + # Map verbosity to logging levels |
| 83 | + if verbose >= 2: |
| 84 | + log_level = logging.DEBUG |
| 85 | + elif verbose == 1: |
| 86 | + log_level = logging.INFO |
| 87 | + else: |
| 88 | + log_level = logging.WARNING |
| 89 | + logger.setLevel(log_level) |
| 90 | + |
| 91 | + if solver_kwargs is None: |
| 92 | + solver_kwargs = {} |
| 93 | + rng = rng if isinstance(rng, np.random.Generator) else np.random.default_rng(rng) |
| 94 | + |
| 95 | + # ------------------ sanity checks ------------------ |
| 96 | + if variable_name not in problem.expr: |
| 97 | + raise KeyError(f"Variable '{variable_name}' not found in problem.expr") |
| 98 | + |
| 99 | + if collect_vars is None: |
| 100 | + collect_vars = list(problem.expr.keys()) |
| 101 | + else: |
| 102 | + missing = [v for v in collect_vars if v not in problem.expr] |
| 103 | + if missing: |
| 104 | + raise KeyError(f"Variables not found in problem.expr: {missing}") |
| 105 | + |
| 106 | + collected: Dict[str, List[np.ndarray]] = {v: [] for v in collect_vars} |
| 107 | + target_var = problem.expr[variable_name] |
| 108 | + |
| 109 | + # 1) original solve --------------------------------- |
| 110 | + logger.debug("Solving original model …") |
| 111 | + problem.solve(**solver_kwargs, verbosity=0) |
| 112 | + baseline_obj = {o.name: float(o.value) for o in problem.objectives} |
| 113 | + logger.debug( |
| 114 | + "Baseline objectives: " |
| 115 | + + ", ".join(f"{k}={v:.6g}" for k, v in baseline_obj.items()) |
| 116 | + ) |
| 117 | + |
| 118 | + for v in collect_vars: |
| 119 | + collected[v].append(np.asarray(problem.expr[v].value).copy()) |
| 120 | + |
| 121 | + # 2) build perturbation parameter ------------------- |
| 122 | + var_shape = tuple(int(s) for s in target_var.shape) |
| 123 | + total_elems = int(np.prod(var_shape)) |
| 124 | + n_perturb = max(1, int(total_elems * percentage)) |
| 125 | + |
| 126 | + noise_buf = np.zeros(var_shape, dtype=float) |
| 127 | + pert = problem.backend.Parameter( |
| 128 | + name=f"{perturbation_name}_param", shape=var_shape, value=noise_buf |
| 129 | + ) |
| 130 | + problem.add_objective( |
| 131 | + (target_var.multiply(pert)) |
| 132 | + .sum() |
| 133 | + .reshape( |
| 134 | + 1, |
| 135 | + ), |
| 136 | + name=perturbation_name, |
| 137 | + ) |
| 138 | + |
| 139 | + flat_buf = noise_buf.reshape(-1) |
| 140 | + n_accept = n_reject = 0 |
| 141 | + |
| 142 | + # 3) sampling loop |
| 143 | + for trial in range(1, max_samples + 1): |
| 144 | + # 3a) new perturbation |
| 145 | + flat_buf.fill(0.0) |
| 146 | + idx = rng.choice(total_elems, n_perturb, replace=False) |
| 147 | + flat_buf[idx] = rng.normal(0.0, scale, n_perturb) |
| 148 | + pert.value = noise_buf |
| 149 | + |
| 150 | + # 3b) solve |
| 151 | + problem.solve(warm_start=True, **solver_kwargs, verbosity=0) |
| 152 | + |
| 153 | + # 3c) compute relative errors for each objective |
| 154 | + relerrs = {} |
| 155 | + current_vals = {} |
| 156 | + for o in problem.objectives: |
| 157 | + if o.name == perturbation_name: |
| 158 | + continue |
| 159 | + val = float(o.value) |
| 160 | + current_vals[o.name] = val |
| 161 | + denom = max(abs(baseline_obj[o.name]), 1e-9) |
| 162 | + relerrs[o.name] = abs(val - baseline_obj[o.name]) / denom |
| 163 | + |
| 164 | + # check tolerance |
| 165 | + violated = next( |
| 166 | + ((name, err) for name, err in relerrs.items() if err > rel_opt_tol), None |
| 167 | + ) |
| 168 | + |
| 169 | + # log objective values and errors |
| 170 | + detail_msg = ", ".join( |
| 171 | + f"{name}: val={current_vals[name]:.6g}, rel.err={relerrs[name]:.4f}" |
| 172 | + for name in current_vals |
| 173 | + ) |
| 174 | + |
| 175 | + if violated is None: |
| 176 | + for v in collect_vars: |
| 177 | + collected[v].append(np.asarray(problem.expr[v].value).copy()) |
| 178 | + n_accept += 1 |
| 179 | + logger.info( |
| 180 | + f"[{trial}/{max_samples}] accepted (total accepted={n_accept}) -> {detail_msg}" |
| 181 | + ) |
| 182 | + else: |
| 183 | + n_reject += 1 |
| 184 | + logger.info( |
| 185 | + f"[{trial}/{max_samples}] rejected (tol={rel_opt_tol}) -> {detail_msg}" |
| 186 | + ) |
| 187 | + |
| 188 | + # 4) stack lists into arrays ------------------------ |
| 189 | + out: Dict[str, np.ndarray] = { |
| 190 | + v: np.stack(values, axis=0) for v, values in collected.items() |
| 191 | + } |
| 192 | + |
| 193 | + logger.info( |
| 194 | + f"Done. accepted={n_accept}, rejected={n_reject}, solutions returned=" |
| 195 | + f"{out[next(iter(out))].shape[0]}" |
| 196 | + ) |
| 197 | + return out |
0 commit comments