Skip to content

Commit d0983e7

Browse files
Tidy up results.
1 parent a6bacf7 commit d0983e7

File tree

4 files changed

+24
-25
lines changed

4 files changed

+24
-25
lines changed

lineax/_solver/bicgstab.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,14 @@ def body_fun(carry):
187187

188188
if self.max_steps is None:
189189
result = RESULTS.where(
190-
(num_steps == max_steps), RESULTS.singular, RESULTS.successful
190+
num_steps == max_steps, RESULTS.singular, RESULTS.successful
191191
)
192-
else:
192+
elif has_scale:
193193
result = RESULTS.where(
194-
(num_steps == self.max_steps),
195-
RESULTS.max_steps_reached if has_scale else RESULTS.successful,
196-
RESULTS.successful,
194+
num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful
197195
)
196+
else:
197+
result = RESULTS.successful
198198
# breakdown is only an issue if we did not converge
199199
breakdown = breakdown_occurred(omega, alpha, rho) & not_converged(
200200
residual, diff, solution

lineax/_solver/cg.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,18 +226,16 @@ def cheap_r():
226226
cond_fun, body_fun, initial_value
227227
)
228228

229-
if (self.max_steps is None) or (max_steps < self.max_steps):
229+
if self.max_steps is None:
230230
result = RESULTS.where(
231-
num_steps == max_steps,
232-
RESULTS.singular,
233-
RESULTS.successful,
231+
num_steps == max_steps, RESULTS.singular, RESULTS.successful
234232
)
235-
else:
233+
elif has_scale:
236234
result = RESULTS.where(
237-
num_steps == max_steps,
238-
RESULTS.max_steps_reached if has_scale else RESULTS.successful,
239-
RESULTS.successful,
235+
num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful
240236
)
237+
else:
238+
result = RESULTS.successful
241239

242240
if is_nsd and not self._normal:
243241
solution = -(solution**ω).ω

lineax/_solver/gmres.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,15 @@ def body_fun(carry):
220220

221221
if self.max_steps is None:
222222
result = RESULTS.where(
223-
(num_steps == max_steps), RESULTS.singular, RESULTS.successful
223+
num_steps == max_steps, RESULTS.singular, RESULTS.successful
224224
)
225-
else:
225+
elif has_scale:
226226
result = RESULTS.where(
227-
(num_steps == self.max_steps),
228-
RESULTS.max_steps_reached if has_scale else RESULTS.successful,
229-
RESULTS.successful,
227+
num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful
230228
)
229+
else:
230+
result = RESULTS.successful
231+
231232
result = RESULTS.where(
232233
stagnation_counter >= self.stagnation_iters, RESULTS.stagnation, result
233234
)

lineax/_solver/lsmr.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,19 +327,19 @@ def beta_zero(alpha, beta, u, v):
327327
"cond_A": loop_state["condA"],
328328
"norm_x": self.norm(loop_state["x"]),
329329
}
330-
if (self.max_steps is None) or (max_steps < self.max_steps):
330+
331+
if self.max_steps is None:
331332
result = RESULTS.where(
332-
loop_state["itn"] == max_steps,
333-
RESULTS.singular,
334-
RESULTS.successful,
333+
loop_state["itn"] == max_steps, RESULTS.singular, RESULTS.successful
335334
)
336-
else:
335+
elif has_scale:
337336
result = RESULTS.where(
338337
loop_state["itn"] == max_steps,
339-
RESULTS.max_steps_reached if has_scale else RESULTS.successful,
338+
RESULTS.max_steps_reached,
340339
RESULTS.successful,
341340
)
342-
341+
else:
342+
result = RESULTS.successful
343343
result = RESULTS.where(loop_state["istop"] < 3, RESULTS.successful, result)
344344
result = RESULTS.where(loop_state["istop"] == 3, RESULTS.conlim, result)
345345

0 commit comments

Comments
 (0)