Skip to content

Commit c687ec5

Browse files
committed
neatened nonlinear solver
1 parent 4e5677c commit c687ec5

File tree

1 file changed

+119
-166
lines changed

1 file changed

+119
-166
lines changed

burnman/optimize/nonlinear_solvers.py

Lines changed: 119 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ def __init__(
2929
"""
3030
Initialize the Solution instance.
3131
32-
Arguments are stored as attributes of the instance
33-
with the same names.
34-
3532
:param x: Final solution vector.
3633
:type x: np.ndarray, optional
3734
:param n_it: Number of iterations performed.
@@ -248,7 +245,6 @@ def __init__(
248245
store_iterates: bool = False,
249246
regularization: float = np.finfo(float).eps,
250247
cond_lu_thresh: float = 1e12,
251-
cond_lstsq_thresh: float = 1e15,
252248
constraint_thresh: float = 2 * np.finfo(float).eps,
253249
):
254250
"""
@@ -322,7 +318,6 @@ def __init__(
322318
self.linear_constraints = linear_constraints
323319
self.regularization = regularization
324320
self.cond_lu_thresh = cond_lu_thresh
325-
self.cond_lstsq_thresh = cond_lstsq_thresh
326321
self.eps = 2.0 * np.finfo(float).eps
327322
self.max_condition_number = 1.0 / np.finfo(float).eps
328323

@@ -368,82 +363,66 @@ def _solve_subject_to_constraints(
368363
c_prime: npt.NDArray[np.float64],
369364
) -> npt.NDArray[np.float64]:
370365
"""
371-
Solve a constrained Newton correction step using the method of
372-
Lagrange multipliers (KKT system).
373-
374-
This method computes a step ``dx`` that minimizes the linearized
375-
residual ||J(x)·dx|| subject to linear equality constraints derived
376-
from the currently active inequality constraints.
377-
378-
The system is solved using the KKT (Karush-Kuhn-Tucker) formulation:
379-
380-
.. math::
381-
382-
\\begin{bmatrix}
383-
J^T J + \\alpha I & A^T \\\\
384-
A & 0
385-
\\end{bmatrix}
386-
\\begin{bmatrix}
387-
dx \\\\
388-
\\lambda
389-
\\end{bmatrix}
390-
=
391-
- \\begin{bmatrix}
392-
0 \\\\
393-
c(x)
394-
\\end{bmatrix}
366+
Compute a constrained Newton correction step using the
367+
Karush-Kuhn-Tucker (KKT) formulation.
395368
396-
where:
369+
This method solves for the update ``dx`` that minimizes the linearized
370+
residual ``||J(x)·dx||`` subject to the active linear equality constraints
371+
``A·dx + c(x) = 0``:
397372
398-
* ``J`` is the Jacobian at ``x``
399-
* ``A`` is the constraint Jacobian (``c_prime``)
400-
* ``c(x)`` is the constraint evaluation
401-
* ``\\lambda`` are the Lagrange multipliers
402-
* ``\\alpha`` = ``self.regularization`` is an optional regularization parameter
373+
[ JᵀJ + alpha * I c'ᵀ ] [dx] = -[ 0 ]
374+
[ c' 0 ] [lambda] [c(x)]
403375
404-
The KKT system is solved using one of three strategies depending on
405-
the estimated condition number of the matrix:
376+
where:
377+
* ``J`` is the Jacobian at ``x`` (``jac_x``)
378+
* ``c_prime`` is the active constraint Jacobian
379+
* ``c(x)`` are the constraint values
380+
* ``lambda`` are the Lagrange multipliers
381+
* ``alpha = self.regularization`` is a Tikhonov regularization parameter
406382
407-
1. **LU factorization** if ``cond < cond_lu_thresh``
408-
2. **Least-squares solve** if ``cond < cond_lstsq_thresh``
409-
3. **SVD-based pseudo-inverse** for ill-conditioned cases
383+
The KKT system is solved adaptively based on its condition number:
384+
1. LU factorization for well-conditioned systems
385+
2. SVD-based pseudo-inverse for ill-conditioned systems
410386
411387
:param x: Current solution vector.
412-
:type x: np.ndarray
413-
:param jac_x: Current Jacobian matrix J(x).
414-
:type jac_x: np.ndarray
415-
:param c_x: Values of the active constraints at x.
416-
:type c_x: np.ndarray
417-
:param c_prime: Jacobian of the active constraints (A in Ax + b = 0).
418-
:type c_prime: np.ndarray
419-
420-
:return: A 3-tuple containing:
421-
422-
* **x_new** (np.ndarray) -- Updated solution ``x + dx``.
423-
* **lambdas** (np.ndarray) -- Lagrange multipliers for active constraints.
424-
* **condition_number** (float) -- Estimated condition number of the KKT matrix.
425-
426-
:rtype: tuple[np.ndarray, np.ndarray, float]
388+
:type x: numpy.ndarray
389+
:param jac_x: Jacobian of residuals at ``x``.
390+
:type jac_x: numpy.ndarray
391+
:param c_x: Values of the active constraints at ``x``.
392+
:type c_x: numpy.ndarray
393+
:param c_prime: Jacobian of the active constraints.
394+
:type c_prime: numpy.ndarray
395+
396+
:return: Tuple ``(x_new, lambdas, condition_number)`` where:
397+
- **x_new** (*numpy.ndarray*) – Updated solution vector ``x + dx``.
398+
- **lambdas** (*numpy.ndarray*) – Lagrange multipliers for active constraints.
399+
- **condition_number** (*float*) – Estimated condition number of the KKT matrix.
400+
:rtype: tuple[numpy.ndarray, numpy.ndarray, float]
427401
"""
402+
428403
n_x = x.shape[0]
429404
n_c = c_x.shape[0]
430405
JTJ_reg = jac_x.T @ jac_x + self.regularization * np.eye(n_x)
431-
norm = n_x * n_x / np.linalg.norm(JTJ_reg)
432-
KKT = np.block([[JTJ_reg * norm, c_prime.T], [c_prime, np.zeros((n_c, n_c))]])
406+
scale = np.linalg.norm(JTJ_reg)
407+
if scale == 0:
408+
scale = 1.0
409+
KKT = np.block([[JTJ_reg / scale, c_prime.T], [c_prime, np.zeros((n_c, n_c))]])
433410
rhs = -np.concatenate([np.zeros(n_x), c_x])
434411

435412
condition_number = np.linalg.cond(KKT)
436413
if condition_number < self.cond_lu_thresh:
437-
dx_lambda = lu_solve(lu_factor(KKT), rhs)
438-
elif condition_number < self.cond_lstsq_thresh:
439-
dx_lambda, *_ = np.linalg.lstsq(KKT, rhs, rcond=None)
414+
lu, piv = lu_factor(KKT)
415+
dx_lambda = lu_solve((lu, piv), rhs)
440416
else:
441417
U, s, Vt = np.linalg.svd(KKT, full_matrices=False)
442-
s_inv = np.where(s > 1e-12, 1.0 / s, 0.0)
418+
tol = np.finfo(float).eps * max(KKT.shape) * np.max(s)
419+
s_inv = np.where(s > tol, 1.0 / s, 0.0)
443420
dx_lambda = Vt.T @ (s_inv * (U.T @ rhs))
444421

445422
dx = dx_lambda[:n_x]
446-
return x + dx, dx_lambda[n_x:], condition_number
423+
lambdas = dx_lambda[n_x:]
424+
425+
return x + dx, lambdas, condition_number
447426

448427
def _constrain_step_to_feasible_region(
449428
self,
@@ -454,46 +433,30 @@ def _constrain_step_to_feasible_region(
454433
x_j: npt.NDArray[np.float64],
455434
) -> tuple[npt.NDArray[np.float64], float]:
456435
"""
457-
Project a trial Newton step back into the feasible region defined
458-
by linear inequality constraints A.x + b <= 0.
459-
460-
This method checks whether the trial point x_j = x + lambda.dx violates
461-
any constraints. If so, it computes the maximum allowable step scaling
462-
factor to remain feasible, reduces lambda accordingly, and updates the
463-
trial iterate.
464-
465-
The scaling factor is computed per violated constraint as:
466-
467-
.. math::
436+
Project a trial step back into the feasible region of linear inequality constraints.
468437
469-
\\lambda_i = \\frac{c_x[i]}{c_x[i] - c_{x_j}[i]}
470-
471-
where c_x and c_{x_j} are the constraint function values at x and x_j.
472-
The smallest lambda_i is used to rescale the step to just touch the first
473-
violated constraint.
438+
Given a trial point x_j = x + lambda*dx, this method checks for constraint
439+
violations and rescales the step to remain feasible. The scaling factor is
440+
computed per violated constraint, and the smallest factor is applied to
441+
lambda to ensure the trial point stays within the feasible region.
474442
475443
:param x: Current solution vector.
476-
:type x: np.ndarray
477-
:param dx: Full Newton step direction.
478-
:type dx: np.ndarray
479-
:param n_constraints: Total number of linear inequality constraints.
444+
:type x: numpy.ndarray
445+
:param dx: Newton step direction.
446+
:type dx: numpy.ndarray
447+
:param n_constraints: Number of linear inequality constraints.
480448
:type n_constraints: int
481-
:param lmda: Current damping factor lambda for the trial step.
449+
:param lmda: Current step scaling factor.
482450
:type lmda: float
483-
:param x_j: Current trial iterate x + lambda.dx.
484-
:type x_j: np.ndarray
485-
486-
:return: A 3-tuple containing:
487-
488-
* **lmda** (float)
489-
-- Updated damping factor lambda that ensures feasibility.
490-
* **x_j** (np.ndarray)
491-
-- Adjusted trial point within the feasible region.
492-
* **violated_constraints** (list[tuple[int, float]])
493-
-- List of (index, lambda_i) for each violated constraint,
494-
sorted by lambda_i.
495-
496-
:rtype: tuple[float, np.ndarray, list[tuple[int, float]]]
451+
:param x_j: Trial point x + lambda*dx.
452+
:type x_j: numpy.ndarray
453+
454+
:return: Tuple ``(lmda, x_j, violated_constraints)`` where:
455+
- **lmda** (*float*) - Updated scaling factor ensuring feasibility.
456+
- **x_j** (*numpy.ndarray*) - Adjusted trial point within feasible region.
457+
- **violated_constraints** (*list[tuple[int, float]]*) - List of
458+
(constraint index, scaling factor) for violated constraints, sorted by factor.
459+
:rtype: tuple[float, numpy.ndarray, list[tuple[int, float]]]
497460
"""
498461
c_x_j = self._constraints(x_j)
499462
c_x = self._constraints(x)
@@ -514,98 +477,88 @@ def _lagrangian_walk_along_constraints(
514477
dx: npt.NDArray[np.float64],
515478
luJ: Any,
516479
dx_norm: float,
517-
violated_constraints: list[int],
518-
) -> tuple[npt.NDArray[np.float64], float]:
480+
violated_constraints: list[tuple[int, float]],
481+
) -> tuple[float, npt.NDArray[np.float64], npt.NDArray[np.float64], bool]:
519482
"""
520-
Attempt to find a constrained Newton step when a step along the
521-
standard Newton direction would immediately violate active linear
522-
inequality constraints (A.x + b <= 0).
523-
Uses the method of Lagrange multipliers, attemping to "walk along"
524-
the active constraints to remain in the feasible region while
525-
decreasing the residual norm ||F(x)||.
526-
527-
:param sol: Current solver state with fields x and F.
528-
:type sol: SimpleNamespace
529-
:param dx: Current Newton step direction.
530-
:type dx: np.ndarray
531-
:param luJ: LU factorization of the current Jacobian, as returned by
532-
``scipy.linalg.lu_factor``.
533-
:type luJ: tuple
534-
:param dx_norm: L2 norm of the current Newton step dx.
535-
:type dx_norm: float
536-
:param lmda_bounds: Tuple (min_lambda, max_lambda) for the damping factor.
537-
:type lmda_bounds: tuple[float, float]
483+
Attempt a constrained Newton step along active linear constraints
484+
to remain feasible while decreasing the residual norm.
485+
486+
:param dx: Newton step direction.
487+
:param luJ: LU factorization of current Jacobian (from `lu_factor`).
488+
:param dx_norm: L2 norm of the Newton step.
538489
:param violated_constraints: List of (index, fraction) for constraints
539-
that would be violated by the current Newton step.
540-
:type violated_constraints: list[tuple[int, float]]
490+
that would be violated by the current step.
541491
542-
:return: Updated damping factor, updated values, full Newton step,
543-
and flag indicating whether the solver encountered a persistent
544-
constraint violation or reached the minimum lambda.
545-
:rtype: tuple[float, np.ndarray, np.ndarray, bool]
492+
:return: Tuple of (lambda, adjusted trial point x_j, full Newton step dx,
493+
persistent_bound_violation flag).
546494
"""
547495
sol = self.sol
548496

549-
active_constraint_indices = [
497+
# Split constraints into active and inactive based on proximity to boundary
498+
active_idx = [
550499
i for i, vc in violated_constraints if vc < self.constraint_thresh
551500
]
552-
inactive_constraint_indices = [
501+
inactive_idx = [
553502
i for i, vc in violated_constraints if vc >= self.constraint_thresh
554503
]
555-
c_newton = self._constraints(self.sol.x + dx)[active_constraint_indices]
556-
c_A = self.linear_constraints[0][active_constraint_indices]
557-
x_n = self.sol.x + dx
558-
persistent_bound_violation = False
559504

560-
if len(c_A) > 0 and np.linalg.matrix_rank(c_A) == len(dx):
561-
n_act = len(active_constraint_indices)
505+
# Evaluate active constraints and corresponding Jacobian
506+
c_active = self._constraints(sol.x + dx)[active_idx]
507+
A_active = self.linear_constraints[0][active_idx]
508+
x_n = sol.x + dx
509+
persistent_violation = False
510+
511+
# Solve KKT system along active constraints if well-posed
512+
if len(A_active) > 0 and np.linalg.matrix_rank(A_active) == len(dx):
513+
n_act = len(active_idx)
514+
# Attempt to remove one active constraint at a time if necessary
562515
for i_rm in range(n_act):
563-
potential_active_indices = [
564-
active_constraint_indices[i] for i in range(n_act) if i != i_rm
565-
]
566-
c_newton = self._constraints(sol.x + dx)[potential_active_indices]
567-
c_A = self.linear_constraints[0][potential_active_indices]
568-
x_m = self._solve_subject_to_constraints(x_n, sol.J, c_newton, c_A)[0]
569-
if self._constraints(x_m)[active_constraint_indices[i_rm]] < 0.0:
516+
keep_idx = [active_idx[j] for j in range(n_act) if j != i_rm]
517+
c_subset = self._constraints(sol.x + dx)[keep_idx]
518+
A_subset = self.linear_constraints[0][keep_idx]
519+
x_m = self._solve_subject_to_constraints(
520+
x_n, sol.J, c_subset, A_subset
521+
)[0]
522+
if self._constraints(x_m)[active_idx[i_rm]] < 0:
570523
break
571524
else:
572-
x_m = self._solve_subject_to_constraints(x_n, sol.J, c_newton, c_A)[0]
525+
x_m = self._solve_subject_to_constraints(x_n, sol.J, c_active, A_active)[0]
573526

527+
# Update step and damping factor
574528
dx = x_m - sol.x
575-
lmda_bounds_new = self.lambda_bounds(dx, sol.x)
576-
lmda = lmda_bounds_new[1]
529+
lmda_min, lmda_max = self.lambda_bounds(dx, sol.x)
530+
lmda = lmda_max
577531
x_j = sol.x + lmda * dx
578532

579-
# Check feasibility
533+
# Check feasibility at minimum lambda
580534
try:
581-
x_j_min = sol.x + lmda_bounds_new[0] * dx
535+
x_j_min = sol.x + lmda_min * dx
582536
F_j_min = self.F(x_j_min)
583537
dxbar_j_min = lu_solve(luJ, -F_j_min)
584-
dxbar_j_min_norm = np.linalg.norm(dxbar_j_min, ord=2)
585-
586-
if dxbar_j_min_norm > dx_norm or np.linalg.norm(dx, ord=2) < self.eps:
587-
persistent_bound_violation = True
538+
if np.linalg.norm(dxbar_j_min) > dx_norm or np.linalg.norm(dx) < self.eps:
539+
persistent_violation = True
588540
except Exception:
589-
# For example, if self.F(x_j_min) fails
590-
persistent_bound_violation = True
591-
592-
# Check newly violated inactive constraints
593-
n_inactive = len(inactive_constraint_indices)
594-
c_x_j = self._constraints(x_j)[inactive_constraint_indices]
595-
if not np.all(c_x_j < self.eps):
596-
c_x = self._constraints(sol.x)[inactive_constraint_indices]
597-
violated_constraints = sorted(
598-
[
599-
(i, c_x[i] / (c_x[i] - c_x_j[i]))
600-
for i in range(n_inactive)
601-
if c_x_j[i] >= self.eps
602-
],
603-
key=lambda x: x[1],
604-
)
605-
lmda *= violated_constraints[0][1]
606-
x_j = sol.x + lmda * dx
541+
persistent_violation = True
542+
543+
# Check that inactive constraints are not now violated
544+
# If they are, rescale lambda
545+
if inactive_idx:
546+
c_inactive_new = self._constraints(x_j)[inactive_idx]
547+
if not np.all(c_inactive_new < self.eps):
548+
c_inactive_old = self._constraints(sol.x)[inactive_idx]
549+
violated_new = sorted(
550+
[
551+
(i, c_inactive_old[i] / (c_inactive_old[i] - c_inactive_new[i]))
552+
for i in range(len(inactive_idx))
553+
if c_inactive_new[i] >= self.eps
554+
],
555+
key=lambda t: t[1],
556+
)
557+
# Rescale lambda to maintain feasibility
558+
lmda *= violated_new[0][1]
559+
x_j = sol.x + lmda * dx
607560

608-
return lmda, x_j, dx, persistent_bound_violation
561+
return lmda, x_j, dx, persistent_violation
609562

610563
def _check_convergence(
611564
self,

0 commit comments

Comments
 (0)