Skip to content

Commit 8d96e7c

Browse files
authored
Implement Normal, a solver applying an inner solver to the normal equations (#159)
1 parent a6bacf7 commit 8d96e7c

File tree

12 files changed

+291
-184
lines changed

12 files changed

+291
-184
lines changed

docs/api/solvers.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,16 @@ These are capable of solving ill-posed linear problems.
4242
members:
4343
- __init__
4444

45+
---
46+
47+
::: lineax.Normal
48+
options:
49+
members:
50+
- __init__
51+
4552
!!! info
4653

47-
In addition to these, `lineax.Diagonal(well_posed=False)` and [`lineax.NormalCG`][] (below) also support ill-posed problems.
54+
In addition to these, `lineax.Diagonal(well_posed=False)` (below) also supports ill-posed problems.
4855

4956
## Structure-exploiting solvers
5057

@@ -95,13 +102,6 @@ These solvers use only matrix-vector products, and do not require instantiating
95102

96103
---
97104

98-
::: lineax.NormalCG
99-
options:
100-
members:
101-
- __init__
102-
103-
---
104-
105105
::: lineax.BiCGStab
106106
options:
107107
members:

docs/examples/no_materialisation.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"y = jnp.array([1.0, 2.0, 3.0])\n",
5555
"operator = lx.JacobianLinearOperator(f, y, args=None)\n",
5656
"vector = f(y, args=None)\n",
57-
"solver = lx.NormalCG(rtol=1e-6, atol=1e-6)\n",
57+
"solver = lx.Normal(lx.CG(rtol=1e-6, atol=1e-6))\n",
5858
"solution = lx.linear_solve(operator, vector, solver)"
5959
]
6060
},

lineax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
GMRES as GMRES,
6161
LSMR as LSMR,
6262
LU as LU,
63+
Normal as Normal,
6364
NormalCG as NormalCG,
6465
QR as QR,
6566
SVD as SVD,

lineax/_solver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .gmres import GMRES as GMRES
2020
from .lsmr import LSMR as LSMR
2121
from .lu import LU as LU
22+
from .normal import Normal as Normal
2223
from .qr import QR as QR
2324
from .svd import SVD as SVD
2425
from .triangular import Triangular as Triangular

lineax/_solver/cg.py

Lines changed: 45 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from collections.abc import Callable
16-
from typing import Any, ClassVar, TYPE_CHECKING, TypeAlias
16+
from typing import Any, TypeAlias
1717

1818
import equinox.internal as eqxi
1919
import jax
@@ -23,12 +23,6 @@
2323
from equinox.internal import ω
2424
from jaxtyping import Array, PyTree, Scalar
2525

26-
27-
if TYPE_CHECKING:
28-
from typing import ClassVar as AbstractClassVar
29-
else:
30-
from equinox.internal import AbstractClassVar
31-
3226
from .._misc import resolve_rcond, structure_equal, tree_where
3327
from .._norm import max_norm, tree_dot
3428
from .._operator import (
@@ -41,24 +35,42 @@
4135
from .._solution import RESULTS
4236
from .._solve import AbstractLinearSolver
4337
from .misc import preconditioner_and_y0
38+
from .normal import Normal
4439

4540

4641
_CGState: TypeAlias = tuple[AbstractLinearOperator, bool]
4742

4843

4944
# TODO(kidger): this is pretty slow to compile.
5045
# - CG evaluates `operator.mv` three times.
51-
# - Normal CG evaluates `operator.mv` seven (!) times.
5246
# Possibly this can be cheapened a bit somehow?
53-
class _AbstractCG(AbstractLinearSolver[_CGState]):
47+
class CG(AbstractLinearSolver[_CGState]):
48+
"""Conjugate gradient solver for linear systems.
49+
50+
The operator should be positive or negative definite.
51+
52+
Equivalent to `scipy.sparse.linalg.cg`.
53+
54+
This supports the following `options` (as passed to
55+
`lx.linear_solve(..., options=...)`).
56+
57+
- `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
58+
to be used as preconditioner. Defaults to
59+
[`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
60+
so it is the preconditioned residual that is minimized, though the actual
61+
termination criteria uses the un-preconditioned residual.
62+
63+
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
64+
zeros.
65+
66+
"""
67+
5468
rtol: float
5569
atol: float
5670
norm: Callable[[PyTree], Scalar] = max_norm
5771
stabilise_every: int | None = 10
5872
max_steps: int | None = None
5973

60-
_normal: AbstractClassVar[bool]
61-
6274
def __check_init__(self):
6375
if isinstance(self.rtol, (int, float)) and self.rtol < 0:
6476
raise ValueError("Tolerances must be non-negative.")
@@ -75,18 +87,18 @@ def __check_init__(self):
7587
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
7688
del options
7789
is_nsd = is_negative_semidefinite(operator)
78-
if not self._normal:
79-
if not structure_equal(operator.in_structure(), operator.out_structure()):
80-
raise ValueError(
81-
"`CG()` may only be used for linear solves with square matrices."
82-
)
83-
if not (is_positive_semidefinite(operator) | is_nsd):
84-
raise ValueError(
85-
"`CG()` may only be used for positive "
86-
"or negative definite linear operators"
87-
)
88-
if is_nsd:
89-
operator = -operator
90+
if not structure_equal(operator.in_structure(), operator.out_structure()):
91+
raise ValueError(
92+
"`CG()` may only be used for linear solves with square matrices."
93+
)
94+
if not (is_positive_semidefinite(operator) | is_nsd):
95+
raise ValueError(
96+
"`CG()` may only be used for positive "
97+
"or negative definite linear operators"
98+
)
99+
if is_nsd:
100+
operator = -operator
101+
operator = linearise(operator)
90102
return operator, is_nsd
91103

92104
# This differs from jax.scipy.sparse.linalg.cg in:
@@ -103,46 +115,16 @@ def compute(
103115
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
104116
operator, is_nsd = state
105117
preconditioner, y0 = preconditioner_and_y0(operator, vector, options)
106-
if self._normal:
107-
# Linearise if JacobianLinearOperator, to avoid computing the forward
108-
# pass separately for mv and transpose_mv.
109-
# This choice is "fast by default", even at the expense of memory.
110-
# If a downstream user wants to avoid this then they can call
111-
# ```
112-
# linear_solve(
113-
# conj(operator.T) @ operator, operator.mv(b), solver=CG()
114-
# )
115-
# ```
116-
# directly.
117-
operator = linearise(operator)
118-
preconditioner = linearise(preconditioner)
119-
120-
_mv = operator.mv
121-
_transpose_mv = conj(operator.transpose()).mv
122-
_pmv = preconditioner.mv
123-
_transpose_pmv = conj(preconditioner.transpose()).mv
124-
125-
def mv(vector: PyTree) -> PyTree:
126-
return _transpose_mv(_mv(vector))
127-
128-
def psolve(vector: PyTree) -> PyTree:
129-
return _pmv(_transpose_pmv(vector))
130-
131-
vector = _transpose_mv(vector)
132-
else:
133-
if not is_positive_semidefinite(preconditioner):
134-
raise ValueError("The preconditioner must be positive definite.")
135-
mv = operator.mv
136-
psolve = preconditioner.mv
137-
118+
if not is_positive_semidefinite(preconditioner):
119+
raise ValueError("The preconditioner must be positive definite.")
138120
leaves, _ = jtu.tree_flatten(vector)
139121
size = sum(leaf.size for leaf in leaves)
140122
if self.max_steps is None:
141123
max_steps = 10 * size # Copied from SciPy!
142124
else:
143125
max_steps = self.max_steps
144-
r0 = (vector**ω - mv(y0) ** ω).ω
145-
p0 = psolve(r0)
126+
r0 = (vector**ω - operator.mv(y0) ** ω).ω
127+
p0 = preconditioner.mv(r0)
146128
gamma0 = tree_dot(p0, r0)
147129
rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves))
148130
initial_value = (
@@ -184,7 +166,7 @@ def cond_fun(value):
184166

185167
def body_fun(value):
186168
_, y, r, p, gamma, step = value
187-
mat_p = mv(p)
169+
mat_p = operator.mv(p)
188170
inner_prod = tree_dot(mat_p, p)
189171
alpha = gamma / inner_prod
190172
alpha = tree_where(
@@ -201,7 +183,7 @@ def body_fun(value):
201183
# We compute the residual the "expensive" way every now and again, so as to
202184
# correct numerical rounding errors.
203185
def stable_r():
204-
return (vector**ω - mv(y) ** ω).ω
186+
return (vector**ω - operator.mv(y) ** ω).ω
205187

206188
def cheap_r():
207189
return (r**ω - alpha * mat_p**ω).ω
@@ -215,7 +197,7 @@ def cheap_r():
215197
stable_step = eqxi.nonbatchable(stable_step)
216198
r = lax.cond(stable_step, stable_r, cheap_r)
217199

218-
z = psolve(r)
200+
z = preconditioner.mv(r)
219201
gamma_prev = gamma
220202
gamma = tree_dot(z, r)
221203
beta = gamma / gamma_prev
@@ -239,7 +221,7 @@ def cheap_r():
239221
RESULTS.successful,
240222
)
241223

242-
if is_nsd and not self._normal:
224+
if is_nsd:
243225
solution = -(solution**ω).ω
244226
stats = {"num_steps": num_steps, "max_steps": self.max_steps}
245227
return solution, result, stats
@@ -260,66 +242,6 @@ def conj(self, state: _CGState, options: dict[str, Any]):
260242
conj_state = conj(psd_op), is_nsd
261243
return conj_state, conj_options
262244

263-
264-
class CG(_AbstractCG):
265-
"""Conjugate gradient solver for linear systems.
266-
267-
The operator should be positive or negative definite.
268-
269-
Equivalent to `scipy.sparse.linalg.cg`.
270-
271-
This supports the following `options` (as passed to
272-
`lx.linear_solve(..., options=...)`).
273-
274-
- `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
275-
to be used as preconditioner. Defaults to
276-
[`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
277-
so it is the preconditioned residual that is minimized, though the actual
278-
termination criteria uses the un-preconditioned residual.
279-
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
280-
zeros.
281-
282-
!!! info
283-
284-
285-
"""
286-
287-
_normal: ClassVar[bool] = False
288-
289-
def assume_full_rank(self):
290-
return True
291-
292-
293-
class NormalCG(_AbstractCG):
294-
"""Conjugate gradient applied to the normal equations:
295-
296-
`A^T A = A^T b`
297-
298-
of a system of linear equations. Note that this squares the condition
299-
number, so it is not recommended. This is a fast but potentially inaccurate
300-
method, especially in 32 bit floating point precision.
301-
302-
This can handle nonsquare operators provided they are full-rank.
303-
304-
This supports the following `options` (as passed to
305-
`lx.linear_solve(..., options=...)`).
306-
307-
- `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as
308-
preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. Note that
309-
the preconditioner should approximate the inverse of `A`, not the inverse of
310-
`A^T A`. This method uses left preconditioning, so it is the preconditioned
311-
residual that is minimized, though the actual termination criteria uses
312-
the un-preconditioned residual.
313-
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
314-
zeros.
315-
316-
!!! info
317-
318-
319-
"""
320-
321-
_normal: ClassVar[bool] = True
322-
323245
def assume_full_rank(self):
324246
return True
325247

@@ -341,19 +263,6 @@ def assume_full_rank(self):
341263
than this are required, then the solve is halted with a failure.
342264
"""
343265

344-
NormalCG.__init__.__doc__ = r"""**Arguments:**
345266

346-
- `rtol`: Relative tolerance for terminating solve.
347-
- `atol`: Absolute tolerance for terminating solve.
348-
- `norm`: The norm to use when computing whether the error falls within the tolerance.
349-
Defaults to the max norm.
350-
- `stabilise_every`: The conjugate gradient is an iterative method that produces
351-
candidate solutions $x_1, x_2, \ldots$, and terminates once $r_i = \| Ax_i - b \|$
352-
is small enough. For computational efficiency, the values $r_i$ are computed using
353-
other internal quantities, and not by directly evaluating the formula above.
354-
However, this computation of $r_i$ is susceptible to drift due to limited
355-
floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed
356-
directly using the formula above, in order to stabilise the computation.
357-
- `max_steps`: The maximum number of iterations to run the solver for. If more steps
358-
than this are required, then the solve is halted with a failure.
359-
"""
267+
def NormalCG(*args):
268+
return Normal(CG(*args))

lineax/_solver/cholesky.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
5757
if is_nsd:
5858
matrix = -matrix
5959
factor, lower = jsp.linalg.cho_factor(matrix)
60-
# Fix lower triangular for simplicity.
60+
# Fix upper triangular for simplicity.
6161
assert lower is False
6262
return factor, is_nsd
6363

0 commit comments

Comments
 (0)