Skip to content

Commit dddc513

Browse files
Merge pull request #176 from johannahaffner/lsmr-init-overflow
Decrease default value to prevent overflow in 32-bit.
2 parents 9d83182 + d297209 commit dddc513

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

lineax/_solver/lsmr.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from equinox.internal import ω
4343
from jaxtyping import Array, PyTree
4444

45+
from .._misc import complex_to_real_dtype
4546
from .._norm import two_norm
4647
from .._operator import AbstractLinearOperator, conj
4748
from .._solution import RESULTS
@@ -108,23 +109,30 @@ def compute(
108109
and self.rtol == 0
109110
)
110111

112+
dtype = jnp.result_type(
113+
*jtu.tree_leaves(vector),
114+
*jtu.tree_leaves(x),
115+
*jtu.tree_leaves(operator.in_structure()),
116+
)
117+
111118
m, n = operator.out_size(), operator.in_size()
112119
# number of singular values
113120
min_dim = min([m, n])
114121
if self.max_steps is None:
115-
max_steps = min_dim * 10 # for consistency with other iterative solvers
122+
# Set max_steps based on the minimum dimension + avoid numerical overflows
123+
# https://github.com/patrick-kidger/lineax/issues/175
124+
# https://github.com/patrick-kidger/lineax/issues/177
125+
int_dtype = jnp.dtype(f"int{complex_to_real_dtype(dtype).itemsize * 8}")
126+
if min_dim > (jnp.iinfo(int_dtype).max / 10):
127+
max_steps = jnp.iinfo(int_dtype).max
128+
else:
129+
max_steps = min_dim * 10 # for consistency with other iterative solvers
116130
else:
117131
max_steps = self.max_steps
118132

119133
if x is None:
120134
x = jtu.tree_map(jnp.zeros_like, operator.in_structure())
121135

122-
dtype = jnp.result_type(
123-
*jtu.tree_leaves(vector),
124-
*jtu.tree_leaves(x),
125-
*jtu.tree_leaves(operator.in_structure()),
126-
)
127-
128136
b = vector
129137
u = (ω(b) - ω(operator.mv(x))).ω
130138
normb = self.norm(b)
@@ -178,7 +186,7 @@ def beta_zero(beta, u):
178186
# variables for estimation of ||A|| and cond(A)
179187
normA2=alpha**2,
180188
maxrbar=0.0,
181-
minrbar=1e100,
189+
minrbar=jnp.finfo(dtype).max,
182190
condA=1.0,
183191
# variables for use in stopping rules
184192
istop=0,

0 commit comments

Comments
 (0)