Skip to content

Commit 6a506eb

Browse files
authored
Merge pull request #19 from ihmeuw-msca/feat/quit-early-when-nans
Quit early when nans encountered
2 parents 94976de + a2c5b03 commit 6a506eb

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "msca"
7-
version = "0.3.1"
7+
version = "0.3.2"
88
description = "Mathematical sciences and computational algorithms"
99
readme = "README.md"
1010
requires-python = ">=3.11,<3.13"

src/msca/optim/solver/ntcgsolver.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,19 @@ def get_cg_maxiter(niter: int) -> int | None:
144144
step = 1.0
145145
niter = 0
146146
success = False
147+
failure = False
147148

148149
x_pair = deque([x], maxlen=2)
149150
g_pair = deque([g], maxlen=2)
150151

151152
if verbose:
152153
fun = self.fun(x)
153154
print(f"{type(self).__name__}:")
154-
print(f"{niter=:3d}, {fun=:.2e}, {gnorm=:.2e}, {xdiff=:.2e}, {step=:.2e}")
155+
print(
156+
f"{niter=:3d}, {fun=:.2e}, {gnorm=:.2e}, {xdiff=:.2e}, {step=:.2e}"
157+
)
155158

156-
while (not success) and (niter < maxiter):
159+
while (not success) and (not failure) and (niter < maxiter):
157160
niter += 1
158161

159162
# compute all directions
@@ -185,13 +188,19 @@ def cg_iter_counter(xk, cg_info):
185188
x_pair.append(x)
186189
g_pair.append(g)
187190

191+
fun = self.fun(x)
188192
if verbose:
189-
fun = self.fun(x)
190193
print(
191194
f"{niter=:3d}, {fun=:.2e}, {gnorm=:.2e}, {xdiff=:.2e}, "
192195
f"{step=:.2e}, cg_iter={cg_info['iter']}"
193196
)
194197
success = gnorm <= gtol or xdiff <= xtol
198+
failure = not (
199+
np.isfinite(fun)
200+
and np.isfinite(gnorm)
201+
and np.isfinite(xdiff)
202+
and np.isfinite(step)
203+
)
195204

196205
result = NTCGResult(
197206
x=x,

0 commit comments

Comments
 (0)