Skip to content

Commit eb69964

Browse files
committed
apply linearise in init for Normal and iterative solvers
1 parent c718ff8 commit eb69964

File tree

5 files changed

+18
-24
lines changed

5 files changed

+18
-24
lines changed

lineax/_solver/bicgstab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from jaxtyping import Array, PyTree
2424

2525
from .._norm import max_norm, tree_dot
26-
from .._operator import AbstractLinearOperator, conj
26+
from .._operator import AbstractLinearOperator, conj, linearise
2727
from .._solution import RESULTS
2828
from .._solve import AbstractLinearSolver
2929
from .misc import preconditioner_and_y0
@@ -73,7 +73,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
7373
"`BiCGstab(..., normal=False)` may only be used for linear solves with "
7474
"square matrices."
7575
)
76-
return operator
76+
return linearise(operator)
7777

7878
def compute(
7979
self, state: _BiCGStabState, vector: PyTree[Array], options: dict[str, Any]

lineax/_solver/gmres.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@
2626

2727
from .._misc import structure_equal
2828
from .._norm import max_norm, two_norm
29-
from .._operator import (
30-
AbstractLinearOperator,
31-
conj,
32-
MatrixLinearOperator,
33-
)
29+
from .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator
3430
from .._solution import RESULTS
3531
from .._solve import AbstractLinearSolver, linear_solve
3632
from .misc import preconditioner_and_y0
@@ -86,7 +82,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
8682
"`GMRES(..., normal=False)` may only be used for linear solves with "
8783
"square matrices."
8884
)
89-
return operator
85+
return linearise(operator)
9086

9187
#
9288
# This differs from `jax.scipy.sparse.linalg.gmres` in a few ways:

lineax/_solver/lsmr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
from .._misc import complex_to_real_dtype
4646
from .._norm import two_norm
47-
from .._operator import AbstractLinearOperator, conj
47+
from .._operator import AbstractLinearOperator, conj, linearise
4848
from .._solution import RESULTS
4949
from .._solve import AbstractLinearSolver
5050

@@ -89,7 +89,7 @@ def __check_init__(self):
8989
)
9090

9191
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
92-
return operator
92+
return linearise(operator)
9393

9494
def compute(
9595
self,

lineax/_solver/misc.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,15 @@
2424
from jaxtyping import Array, PyTree, Shaped
2525

2626
from .._misc import strip_weak_dtype, structure_equal
27-
from .._operator import (
28-
AbstractLinearOperator,
29-
IdentityLinearOperator,
30-
)
27+
from .._operator import AbstractLinearOperator, IdentityLinearOperator, linearise
3128

3229

3330
def preconditioner_and_y0(
3431
operator: AbstractLinearOperator, vector: PyTree[Array], options: dict[str, Any]
3532
):
3633
structure = operator.in_structure()
3734
try:
38-
preconditioner = options["preconditioner"]
35+
preconditioner = linearise(options["preconditioner"])
3936
except KeyError:
4037
preconditioner = IdentityLinearOperator(structure)
4138
else:

lineax/_solver/normal.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
import equinox.internal as eqxi
1919
from jaxtyping import Array, PyTree
2020

21-
from .._operator import (
22-
conj,
23-
TaggedLinearOperator,
24-
)
21+
from .._operator import conj, linearise, TaggedLinearOperator
2522
from .._solution import RESULTS
2623
from .._solve import AbstractLinearOperator, AbstractLinearSolver
2724
from .._tags import positive_semidefinite_tag
@@ -36,6 +33,7 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
3633
inner_options = copy(options)
3734
del options
3835
if preconditioner is not None:
36+
preconditioner = linearise(preconditioner)
3937
if tall:
4038
inner_options["preconditioner"] = TaggedLinearOperator(
4139
preconditioner @ conj(preconditioner.transpose()),
@@ -46,8 +44,8 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
4644
conj(preconditioner.transpose()) @ preconditioner,
4745
positive_semidefinite_tag,
4846
)
49-
if preconditioner is not None and y0 is not None and not tall:
50-
inner_options["y0"] = conj(preconditioner.transpose()).mv(y0)
47+
if y0 is not None:
48+
inner_options["y0"] = conj(preconditioner.transpose()).mv(y0)
5149
return inner_options
5250

5351

@@ -105,14 +103,17 @@ class Normal(
105103

106104
def init(self, operator, options):
107105
tall = operator.out_size() >= operator.in_size()
106+
# we are apply repeated mv's when constructing normal matrix
107+
# these cannot be parallelised so more efficient to linearise first
108+
lin_op = linearise(operator)
108109
if tall:
109-
inner_operator = conj(operator.transpose()) @ operator
110+
inner_operator = conj(lin_op.transpose()) @ lin_op
110111
else:
111-
inner_operator = operator @ conj(operator.transpose())
112+
inner_operator = lin_op @ conj(lin_op.transpose())
112113
inner_operator = TaggedLinearOperator(inner_operator, positive_semidefinite_tag)
113114
inner_options = normal_preconditioner_and_y0(options, tall)
114115
inner_state = self.inner_solver.init(inner_operator, inner_options)
115-
operator_conj_transpose = conj(operator.transpose())
116+
operator_conj_transpose = conj(lin_op.transpose())
116117
return inner_state, eqxi.Static(tall), operator_conj_transpose, inner_options
117118

118119
def compute(

0 commit comments

Comments
 (0)