Skip to content

Commit 9a38f37

Browse files
johannahaffnerJohanna Haffner
authored andcommitted
Check finite input (#130)
* check for non-finite inputs if solution has NaN or inf * remove private import --------- Co-authored-by: Johanna Haffner <johanna.haffner@bsse.ethz.ch>
1 parent 5bf1627 commit 9a38f37

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

lineax/_solution.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
""".strip()
4141

4242

43+
_nonfinite_msg = """
44+
The linear solver received non-finite (NaN or inf) input and cannot determine a
45+
solution.
46+
47+
This means that you have a bug upstream of lineax and should check the inputs to
48+
`lineax.linear_solve` for non-finite values.
49+
""".strip()
50+
51+
4352
class RESULTS(eqxi.Enumeration):
4453
successful = ""
4554
max_steps_reached = (
@@ -55,6 +64,7 @@ class RESULTS(eqxi.Enumeration):
5564
"A stagnation in an iterative linear solve has occurred. Try increasing "
5665
"`stagnation_iters` or `restart`."
5766
)
67+
nonfinite_input = _nonfinite_msg
5868

5969

6070
class Solution(eqx.Module, strict=True):

lineax/_solve.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,26 @@ def _linear_solve_impl(_, state, vector, options, solver, throw, *, check_closur
9090
out, name="lineax.linear_solve with respect to a closed-over value"
9191
)
9292
solution, result, stats = out
93-
has_nonfinites = jnp.any(
93+
has_nonfinite_output = jnp.any(
9494
jnp.stack(
9595
[jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(solution)]
9696
)
9797
)
9898
result = RESULTS.where(
99-
(result == RESULTS.successful) & has_nonfinites,
99+
(result == RESULTS.successful) & has_nonfinite_output,
100100
RESULTS.singular,
101101
result,
102102
)
103+
has_nonfinite_input = jnp.any(
104+
jnp.stack(
105+
[jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(vector)]
106+
)
107+
)
108+
result = RESULTS.where(
109+
(result == RESULTS.singular) & has_nonfinite_input,
110+
RESULTS.nonfinite_input,
111+
result,
112+
)
103113
if throw:
104114
solution, result, stats = result.error_if(
105115
(solution, result, stats),

tests/test_solve.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,18 @@ def test_iterative_solver_max_steps_only(solver):
191191
rhs = jax.random.normal(jax.random.key(0), (SIZE,))
192192

193193
lx.linear_solve(poisson_operator, rhs, solver)
194+
195+
196+
def test_nonfinite_input():
197+
operator = lx.DiagonalLinearOperator((1.0, 1.0))
198+
vector = (1.0, jnp.inf)
199+
sol = lx.linear_solve(operator, vector, throw=False)
200+
assert sol.result == lx.RESULTS.nonfinite_input
201+
202+
vector = (1.0, jnp.nan)
203+
sol = lx.linear_solve(operator, vector, throw=False)
204+
assert sol.result == lx.RESULTS.nonfinite_input
205+
206+
vector = (jnp.nan, jnp.inf)
207+
sol = lx.linear_solve(operator, vector, throw=False)
208+
assert sol.result == lx.RESULTS.nonfinite_input

0 commit comments

Comments
 (0)