Skip to content

Commit 3d9e90f

Browse files
authored
Merge pull request #20 from ihmeuw-msca/bugfix/linesearch-revert
Bugfix/linesearch-revert
2 parents 6a506eb + 3e72253 commit 3d9e90f

File tree

11 files changed

+131
-63
lines changed

11 files changed

+131
-63
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "msca"
7-
version = "0.3.2"
7+
version = "0.3.3"
88
description = "Mathematical sciences and computational algorithms"
99
readme = "README.md"
1010
requires-python = ">=3.11,<3.13"

src/msca/c2fun/__init__.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,31 @@
1-
from .main import *
1+
from .main import (
2+
C2Fun,
3+
Identity,
4+
Exp,
5+
Log,
6+
Expit,
7+
Logit,
8+
Logerfc,
9+
identity,
10+
exp,
11+
log,
12+
expit,
13+
logit,
14+
logerfc,
15+
)
16+
17+
__all__ = [
18+
"C2Fun",
19+
"Identity",
20+
"Exp",
21+
"Log",
22+
"Expit",
23+
"Logit",
24+
"Logerfc",
25+
"identity",
26+
"exp",
27+
"log",
28+
"expit",
29+
"logit",
30+
"logerfc",
31+
]

src/msca/c2fun/main.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from __future__ import annotations
3535

36-
from abc import ABC, abstractproperty, abstractstaticmethod
36+
from abc import ABC, abstractmethod
3737
from typing import Dict
3838

3939
import numpy as np
@@ -53,12 +53,14 @@ class C2Fun(ABC):
5353
5454
"""
5555

56-
@abstractproperty
56+
@property
57+
@abstractmethod
5758
def inv(self) -> C2Fun:
5859
"""The inverse of the function such that :code:`x = fun.inv(fun(x))`."""
5960
pass
6061

61-
@abstractstaticmethod
62+
@staticmethod
63+
@abstractmethod
6264
def fun(x: NDArray) -> NDArray:
6365
"""Implementation of the function.
6466
@@ -70,7 +72,8 @@ def fun(x: NDArray) -> NDArray:
7072
"""
7173
pass
7274

73-
@abstractstaticmethod
75+
@staticmethod
76+
@abstractmethod
7477
def dfun(x: NDArray) -> NDArray:
7578
"""Implementation of the derivative of the function.
7679
@@ -82,7 +85,8 @@ def dfun(x: NDArray) -> NDArray:
8285
"""
8386
pass
8487

85-
@abstractstaticmethod
88+
@staticmethod
89+
@abstractmethod
8690
def d2fun(x: NDArray) -> NDArray:
8791
"""Implementation of the second order derivative of the function.
8892
@@ -520,7 +524,9 @@ def dfun(x: NDArray) -> NDArray:
520524

521525
l_indices = x < 25
522526
y[l_indices] = (
523-
-2 * np.exp(-(x[l_indices] ** 2)) / (erfc(x[l_indices]) * np.sqrt(np.pi))
527+
-2
528+
* np.exp(-(x[l_indices] ** 2))
529+
/ (erfc(x[l_indices]) * np.sqrt(np.pi))
524530
)
525531

526532
r_indices = ~l_indices

src/msca/linalg/matrix.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def solve(self, x: ArrayLike, method: str = "", **kwargs) -> NDArray:
131131
elif method == "cg":
132132
result, info = sp.linalg.cg(self, x, **kwargs)
133133
if info > 0:
134-
raise RuntimeError(f"CG convergence not achieved. with {info=:}")
134+
raise RuntimeError(
135+
f"CG convergence not achieved. with {info=:}"
136+
)
135137
else:
136138
raise ValueError(f"{method=:} is not supported.")
137139
return result
@@ -193,7 +195,9 @@ def solve(self, x: NDArray, method: str = "", **kwargs) -> NDArray:
193195
elif method == "cg":
194196
result, info = sp.sparse.linalg.cg(self, x, **kwargs)
195197
if info > 0:
196-
raise RuntimeError(f"CG convergence not achieved. with {info=:}")
198+
raise RuntimeError(
199+
f"CG convergence not achieved. with {info=:}"
200+
)
197201
else:
198202
raise ValueError(f"{method=:} is not supported.")
199203
return result
@@ -255,7 +259,9 @@ def solve(self, x: NDArray, method: str = "", **kwargs) -> NDArray:
255259
elif method == "cg":
256260
result, info = sp.sparse.linalg.cg(self, x, **kwargs)
257261
if info > 0:
258-
raise RuntimeError(f"CG convergence not achieved. with {info=:}")
262+
raise RuntimeError(
263+
f"CG convergence not achieved. with {info=:}"
264+
)
259265
else:
260266
raise ValueError(f"{method=:} is not supported.")
261267
return result

src/msca/optim/line_search/armijo.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,54 @@
33
import numpy as np
44
from numpy.typing import NDArray
55

6+
67
def armijo_line_search(
7-
x,
8-
p,
9-
g,
10-
objective: Callable,
8+
gradient: Callable,
9+
x: NDArray,
10+
dx: NDArray,
1111
step_init: float = 1.0,
12-
alpha: float = 0.01,
13-
shrinkage: float = 0.5,
14-
):
15-
"""
16-
Performs an Armijo line search to select an appropriate step size along a given search direction.
17-
This function iteratively reduces the step size until the decrease in the objective function, along the direction of descent,
18-
satisfies the Armijo (sufficient decrease) condition. In each iteration, it checks whether the new point yields a value that is
19-
lower than the current value by a margin proportional to the step and directional derivative. If no satisfactory step size is found
20-
and the step size becomes exceedingly small (<= 1e-15), a RuntimeError is raised.
21-
Parameters:
22-
x (array_like): The current point or position in the parameter space.
23-
p (array_like): The descent direction along which the line search is performed.
24-
g (array_like): The gradient of the objective function evaluated at x.
25-
objective (Callable): A callable that computes the objective function value given a point.
26-
step_init (float, optional): The initial step size to start the line search. Default is 1.0.
27-
alpha (float, optional): The Armijo condition control parameter defining the sufficient decrease criterion. Default is 0.01.
28-
shrinkage (float, optional): The factor by which the step is multiplied to reduce the step size in each iteration. Default is 0.5.
29-
Returns:
30-
float: The step size that satisfies the Armijo sufficient decrease condition.
31-
Raises:
32-
RuntimeError: If the step size becomes too small (<= 1e-15) without satisfying the Armijo condition,
33-
indicating failure in finding a suitable step size.
34-
"""
35-
def sufficiently_improved(new_val, step):
36-
return (new_val - val <= -1 * alpha * step * np.dot(g, p)) and (
37-
not np.isnan(new_val)
38-
)
12+
step_const: float = 0.01,
13+
step_scale: float = 0.9,
14+
step_lb: float = 1e-3,
15+
) -> float:
16+
"""Armijo line search.
17+
18+
Parameters
19+
----------
20+
x
21+
A list a parameters, including x, s, and v, where s is the slackness
22+
variable and v is the dual variable for the constraints.
23+
dx
24+
A list of direction for the parameters.
25+
step_init
26+
Initial step size, by default 1.0.
27+
step_const
28+
Constant for the line search condition, the larger the harder, by
29+
default 0.01.
30+
step_scale
31+
Shrinkage factor for step size, by default 0.9.
32+
step_lb
33+
Lower bound of the step size when the step size is below this bound
34+
the line search will be terminated.
3935
36+
Returns
37+
-------
38+
float
39+
The step size in the given direction.
40+
41+
"""
4042
step = step_init
41-
new_x = x - step * p
42-
val, new_val = objective(x), objective(new_x)
43-
while (not sufficiently_improved(new_val, step)):
44-
if step <= 1e-15:
45-
raise RuntimeError(
46-
f"Line Search Failed, new_val = {new_val}, prev_val = {val}"
47-
)
48-
step *= shrinkage
49-
new_x = x - step * p
50-
new_val = objective(new_x)
51-
return step
43+
x_next = x + step * dx
44+
g_next = gradient(x_next)
45+
gnorm_curr = np.max(np.abs(gradient(x)))
46+
gnorm_next = np.max(np.abs(g_next))
47+
48+
while gnorm_next > (1 - step_const * step) * gnorm_curr:
49+
if step * step_scale < step_lb:
50+
break
51+
step *= step_scale
52+
x_next = x + step * dx
53+
g_next = gradient(x_next)
54+
gnorm_next = np.max(np.abs(g_next))
55+
56+
return step

src/msca/optim/prox/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .capped_simplex import proj_capped_simplex
2+
3+
__all__ = ["proj_capped_simplex"]

src/msca/optim/solver/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .ipsolver import IPSolver
22
from .ntcgsolver import NTCGSolver
33
from .ntsolver import NTSolver
4+
5+
__all__ = ["IPSolver", "NTCGSolver", "NTSolver"]

src/msca/optim/solver/ipsolver.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ class IPSolver:
5959
"""
6060

6161
def __init__(
62-
self, fun: Callable, grad: Callable, hess: Callable, cmat: Matrix, cvec: NDArray
62+
self,
63+
fun: Callable,
64+
grad: Callable,
65+
hess: Callable,
66+
cmat: Matrix,
67+
cvec: NDArray,
6368
):
6469
self.fun = fun
6570
self.grad = grad
@@ -257,7 +262,9 @@ def minimize(
257262
dp = [dx, ds, dv]
258263

259264
# get step size
260-
step, p = self._update_params(p, dp, m, a_init, a_const, a_scale, a_lb)
265+
step, p = self._update_params(
266+
p, dp, m, a_init, a_const, a_scale, a_lb
267+
)
261268

262269
# update m
263270
if niter % m_freq == 0:

src/msca/optim/solver/ntcgsolver.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def minimize(
126126
precon_builder = precon_builder_map[precon_builder](
127127
**(precon_builder_options or {})
128128
)
129-
cg_options = cg_options or {"rtol":1e-2}
129+
cg_options = cg_options or {"rtol": 1e-2}
130130

131131
def get_cg_maxiter(niter: int) -> int | None:
132132
if cg_maxiter_init is None and cg_maxiter is None:
@@ -171,13 +171,17 @@ def cg_iter_counter(xk, cg_info):
171171
if precon_builder is not None:
172172
cg_options["M"] = precon_builder(x_pair, g_pair)
173173
cg_options["maxiter"] = get_cg_maxiter(niter)
174-
dx = cg(hess, -g,**cg_options)[0]
174+
dx = cg(hess, -g, **cg_options)[0]
175175
try:
176176
# get step size
177-
step = line_search(x, -dx,g,self.fun, **line_search_options)
178-
except:
177+
step = line_search(
178+
gradient=self.grad, x=x, dx=-dx, **line_search_options
179+
)
180+
except RuntimeError:
179181
dx = -g
180-
step = line_search(x, -dx,g,self.fun, **line_search_options)
182+
step = line_search(
183+
gradient=self.grad, x=x, dx=-dx, **line_search_options
184+
)
181185
x = x + step * dx
182186

183187
# update f and gnorm

src/msca/optim/solver/ntsolver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,17 @@ def minimize(
168168
if verbose:
169169
fun = self.fun(x)
170170
print(f"{type(self).__name__}:")
171-
print(f"{niter=:3d}, {fun=:.2e}, {gnorm=:.2e}, {xdiff=:.2e}, {step=:.2e}")
171+
print(
172+
f"{niter=:3d}, {fun=:.2e}, {gnorm=:.2e}, {xdiff=:.2e}, {step=:.2e}"
173+
)
172174

173175
while (not success) and (niter < max_iter):
174176
niter += 1
175177

176178
# compute all directions
177-
dx = -self.hess(x).solve(g, method=mat_solve_method, **mat_solve_options)
179+
dx = -self.hess(x).solve(
180+
g, method=mat_solve_method, **mat_solve_options
181+
)
178182

179183
# get step size
180184
step, x = self._update_params(x, dx, a_init, a_const, a_scale, a_lb)

0 commit comments

Comments
 (0)