|
42 | 42 | from equinox.internal import ω |
43 | 43 | from jaxtyping import Array, PyTree |
44 | 44 |
|
| 45 | +from .._misc import complex_to_real_dtype |
45 | 46 | from .._norm import two_norm |
46 | 47 | from .._operator import AbstractLinearOperator, conj |
47 | 48 | from .._solution import RESULTS |
@@ -108,23 +109,30 @@ def compute( |
108 | 109 | and self.rtol == 0 |
109 | 110 | ) |
110 | 111 |
|
| 112 | + dtype = jnp.result_type( |
| 113 | + *jtu.tree_leaves(vector), |
| 114 | + *jtu.tree_leaves(x), |
| 115 | + *jtu.tree_leaves(operator.in_structure()), |
| 116 | + ) |
| 117 | + |
111 | 118 | m, n = operator.out_size(), operator.in_size() |
112 | 119 | # number of singular values |
113 | 120 | min_dim = min([m, n]) |
114 | 121 | if self.max_steps is None: |
115 | | - max_steps = min_dim * 10 # for consistency with other iterative solvers |
| 122 | + # Set max_steps based on the minimum dimension + avoid numerical overflows |
| 123 | + # https://github.com/patrick-kidger/lineax/issues/175 |
| 124 | + # https://github.com/patrick-kidger/lineax/issues/177 |
| 125 | + int_dtype = jnp.dtype(f"int{complex_to_real_dtype(dtype).itemsize * 8}") |
| 126 | + if min_dim > (jnp.iinfo(int_dtype).max / 10): |
| 127 | + max_steps = jnp.iinfo(int_dtype).max |
| 128 | + else: |
| 129 | + max_steps = min_dim * 10 # for consistency with other iterative solvers |
116 | 130 | else: |
117 | 131 | max_steps = self.max_steps |
118 | 132 |
|
119 | 133 | if x is None: |
120 | 134 | x = jtu.tree_map(jnp.zeros_like, operator.in_structure()) |
121 | 135 |
|
122 | | - dtype = jnp.result_type( |
123 | | - *jtu.tree_leaves(vector), |
124 | | - *jtu.tree_leaves(x), |
125 | | - *jtu.tree_leaves(operator.in_structure()), |
126 | | - ) |
127 | | - |
128 | 136 | b = vector |
129 | 137 | u = (ω(b) - ω(operator.mv(x))).ω |
130 | 138 | normb = self.norm(b) |
@@ -178,7 +186,7 @@ def beta_zero(beta, u): |
178 | 186 | # variables for estimation of ||A|| and cond(A) |
179 | 187 | normA2=alpha**2, |
180 | 188 | maxrbar=0.0, |
181 | | - minrbar=1e100, |
| 189 | + minrbar=jnp.finfo(dtype).max, |
182 | 190 | condA=1.0, |
183 | 191 | # variables for use in stopping rules |
184 | 192 | istop=0, |
|
0 commit comments