Skip to content

Commit 2120d23

Browse files
committed
add typing
1 parent 1a5062a commit 2120d23

File tree

1 file changed

+70
-35
lines changed

1 file changed

+70
-35
lines changed

burnman/optimize/nonlinear_solvers.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
import numpy.typing as npt
4+
from typing import Optional, Callable, Tuple, Any
5+
16
import numpy as np
27
from scipy.linalg import lu_factor, lu_solve
38
from types import SimpleNamespace
@@ -120,17 +125,29 @@ def __init__(
120125
self.guess
121126
), "tol must either be a float or an array like guess."
122127

123-
def _constraints(self, x):
128+
def _constraints(self, x: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
124129
return np.dot(self.linear_constraints[0], x) + self.linear_constraints[1]
125130

126-
def _update_lmda(self, x, dx, h, lmda_bounds):
131+
def _update_lmda(
132+
self,
133+
x: npt.NDArray[np.float64],
134+
dx: npt.NDArray[np.float64],
135+
h: npt.NDArray[np.float64],
136+
lmda_bounds: tuple[float, float],
137+
) -> float:
127138
assert lmda_bounds[1] < 1.0 + self.eps
128139
assert lmda_bounds[0] > 1.0e-8 - self.eps
129140

130141
lmda_j = min(1.0 / (h + self.eps), lmda_bounds[1])
131142
return max(lmda_j, lmda_bounds[0])
132143

133-
def _solve_subject_to_constraints(self, x, jac_x, c_x, c_prime):
144+
def _solve_subject_to_constraints(
145+
self,
146+
x: npt.NDArray[np.float64],
147+
jac_x: npt.NDArray[np.float64],
148+
c_x: npt.NDArray[np.float64],
149+
c_prime: npt.NDArray[np.float64],
150+
) -> npt.NDArray[np.float64]:
134151
"""
135152
Solve a constrained Newton correction step using the method of
136153
Lagrange multipliers (KKT system).
@@ -209,7 +226,14 @@ def _solve_subject_to_constraints(self, x, jac_x, c_x, c_prime):
209226
dx = dx_lambda[:n_x]
210227
return x + dx, dx_lambda[n_x:], condition_number
211228

212-
def _constrain_step_to_feasible_region(self, x, dx, n_constraints, lmda, x_j):
229+
def _constrain_step_to_feasible_region(
230+
self,
231+
x: npt.NDArray[np.float64],
232+
dx: npt.NDArray[np.float64],
233+
n_constraints: int,
234+
lmda: float,
235+
x_j: npt.NDArray[np.float64],
236+
) -> tuple[npt.NDArray[np.float64], float]:
213237
"""
214238
Project a trial Newton step back into the feasible region defined
215239
by linear inequality constraints A.x + b <= 0.
@@ -267,8 +291,13 @@ def _constrain_step_to_feasible_region(self, x, dx, n_constraints, lmda, x_j):
267291
return lmda, x_j, violated_constraints
268292

269293
def _lagrangian_walk_along_constraints(
270-
self, sol, dx, luJ, dx_norm, violated_constraints
271-
):
294+
self,
295+
sol: Any,
296+
dx: npt.NDArray[np.float64],
297+
luJ: Any,
298+
dx_norm: float,
299+
violated_constraints: list[int],
300+
) -> tuple[npt.NDArray[np.float64], float]:
272301
"""
273302
Attempt to find a constrained Newton step when a step along the
274303
standard Newton direction would immediately violate active linear
@@ -354,7 +383,13 @@ def _lagrangian_walk_along_constraints(
354383

355384
return lmda, x_j, dx, persistent_bound_violation
356385

357-
def _check_convergence(self, dxbar_j, dx, lmda, lmda_bounds):
386+
def _check_convergence(
387+
self,
388+
dxbar_j: npt.NDArray[np.float64],
389+
dx: npt.NDArray[np.float64],
390+
lmda: float,
391+
lmda_bounds: tuple[float, float],
392+
) -> bool:
358393
if (
359394
all(np.abs(dxbar_j) < self.tol)
360395
and all(np.abs(dx) < np.sqrt(10.0 * self.tol))
@@ -365,21 +400,21 @@ def _check_convergence(self, dxbar_j, dx, lmda, lmda_bounds):
365400

366401
def _posteriori_loop(
367402
self,
368-
x,
369-
F,
370-
dx,
371-
dx_norm,
372-
dxbar_j,
373-
dxbar_j_norm,
374-
x_j,
375-
luJ,
376-
lmda,
377-
lmda_bounds,
378-
converged,
379-
minimum_lmda,
380-
persistent_bound_violation,
381-
require_posteriori_loop,
382-
):
403+
x: npt.NDArray[np.float64],
404+
F: npt.NDArray[np.float64],
405+
dx: npt.NDArray[np.float64],
406+
dx_norm: float,
407+
dxbar_j: npt.NDArray[np.float64],
408+
dxbar_j_norm: float,
409+
x_j: npt.NDArray[np.float64],
410+
luJ: Any,
411+
lmda: float,
412+
lmda_bounds: tuple[float, float],
413+
converged: bool,
414+
minimum_lmda: bool,
415+
persistent_bound_violation: bool,
416+
require_posteriori_loop: bool,
417+
) -> tuple[npt.NDArray[np.float64], float, bool, bool, bool]:
383418
"""
384419
Perform the a posteriori step-size control loop of Deuflhard's
385420
damped Newton method.
@@ -465,14 +500,14 @@ def _posteriori_loop(
465500

466501
def _termination_info(
467502
self,
468-
converged,
469-
minimum_lmda,
470-
persistent_bound_violation,
471-
lmda_bounds,
472-
n_it,
473-
max_iterations,
474-
violated_constraints,
475-
):
503+
converged: bool,
504+
minimum_lmda: bool,
505+
persistent_bound_violation: bool,
506+
lmda_bounds: tuple[float, float],
507+
n_it: int,
508+
max_iterations: int,
509+
violated_constraints: list[int],
510+
) -> tuple[int, str, bool]:
476511
if converged:
477512
return (
478513
True,
@@ -693,10 +728,10 @@ def damped_newton_solve(
693728
F,
694729
J,
695730
guess,
696-
tol=tol,
697-
max_iterations=max_iterations,
698-
lambda_bounds=lambda_bounds,
699-
linear_constraints=linear_constraints,
700-
store_iterates=store_iterates,
731+
tol,
732+
max_iterations,
733+
lambda_bounds,
734+
linear_constraints,
735+
store_iterates,
701736
)
702737
return solver.solve()

0 commit comments

Comments
 (0)