Skip to content

Commit c14a3cf

Browse files
committed
move regularization
1 parent c45108a commit c14a3cf

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

burnman/optimize/nonlinear_solvers.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125
lambda_bounds=lambda dx, x: (1.0e-8, 1.0),
126126
linear_constraints=(0.0, np.array([-1.0])),
127127
store_iterates: bool = False,
128-
regularization: float = 0.0,
128+
regularization: float = np.finfo(float).eps,
129129
cond_lu_thresh: float = 1e12,
130130
cond_lstsq_thresh: float = 1e15,
131131
constraint_thresh: float = 2 * np.finfo(float).eps,
@@ -170,7 +170,7 @@ def __init__(
170170
:type store_iterates: bool, optional
171171
172172
:param regularization: Regularization parameter for the KKT system
173-
in Lagrangian solves, defaults to 0.0.
173+
in Lagrangian solves, defaults to numpy float epsilon.
174174
:type regularization: float, optional
175175
176176
:param cond_lu_thresh: Condition number threshold below which LU decomposition
@@ -709,6 +709,11 @@ def solve(self) -> Solution:
709709
):
710710
sol.J = self.J(sol.x)
711711
condition_number = np.linalg.cond(sol.J)
712+
713+
# Regularize ill-conditioned Jacobian
714+
if condition_number > self.cond_lu_thresh:
715+
sol.J = sol.J + np.eye(sol.J.shape[0]) * self.regularization
716+
712717
luJ = lu_factor(sol.J)
713718
dx = lu_solve(luJ, -sol.F)
714719
dx_norm = np.linalg.norm(dx, ord=2)
@@ -746,19 +751,11 @@ def solve(self) -> Solution:
746751

747752
# Evaluate simplified Newton step
748753
F_j = self.F(x_j)
749-
750-
# Regularise ill-conditioned Jacobian
751-
if condition_number < self.cond_lu_thresh:
752-
dxbar_j = lu_solve(luJ, -F_j)
753-
else:
754-
J_reg = sol.J + np.eye(sol.J.shape[0]) * self.eps
755-
dxbar_j = lu_solve(lu_factor(J_reg), -F_j)
756-
754+
dxbar_j = lu_solve(luJ, -F_j)
757755
dxbar_j_norm = np.linalg.norm(dxbar_j, ord=2)
758756

759-
converged = self._check_convergence(dxbar_j, dx, lmda, lmda_bounds)
760-
761757
# Additional convergence check on F(x)
758+
converged = self._check_convergence(dxbar_j, dx, lmda, lmda_bounds)
762759
if converged and not all(np.abs(F_j) < self.F_tol):
763760
converged = False
764761

@@ -788,6 +785,7 @@ def solve(self) -> Solution:
788785
sol.iterates.append(sol.x, sol.F, lmda)
789786

790787
# Final adjustment for constraints
788+
# and recompute F, J, condition number without regularization
791789
if condition_number < self.max_condition_number:
792790
if converged and not persistent_bound_violation:
793791
sol.x = x_j + dxbar_j

0 commit comments

Comments
 (0)