Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/api/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,16 @@ These are capable of solving ill-posed linear problems.
members:
- __init__

---

::: lineax.Normal
options:
members:
- __init__

!!! info

In addition to these, `lineax.Diagonal(well_posed=False)` and [`lineax.NormalCG`][] (below) also support ill-posed problems.
In addition to these, `lineax.Diagonal(well_posed=False)` (below) also supports ill-posed problems.

## Structure-exploiting solvers

Expand Down Expand Up @@ -95,13 +102,6 @@ These solvers use only matrix-vector products, and do not require instantiating

---

::: lineax.NormalCG
options:
members:
- __init__

---

::: lineax.BiCGStab
options:
members:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/no_materialisation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"y = jnp.array([1.0, 2.0, 3.0])\n",
"operator = lx.JacobianLinearOperator(f, y, args=None)\n",
"vector = f(y, args=None)\n",
"solver = lx.NormalCG(rtol=1e-6, atol=1e-6)\n",
"solver = lx.Normal(lx.CG(rtol=1e-6, atol=1e-6))\n",
"solution = lx.linear_solve(operator, vector, solver)"
]
},
Expand Down
1 change: 1 addition & 0 deletions lineax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
GMRES as GMRES,
LSMR as LSMR,
LU as LU,
Normal as Normal,
NormalCG as NormalCG,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backwards compatibility let's still provide NormalCG. (Even if it's just a function rather than a class, even if it's removed from the documentation.)

QR as QR,
SVD as SVD,
Expand Down
1 change: 1 addition & 0 deletions lineax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .gmres import GMRES as GMRES
from .lsmr import LSMR as LSMR
from .lu import LU as LU
from .normal import Normal as Normal
from .qr import QR as QR
from .svd import SVD as SVD
from .triangular import Triangular as Triangular
Expand Down
188 changes: 45 additions & 143 deletions lineax/_solver/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections.abc import Callable
from typing import Any, ClassVar, TYPE_CHECKING, TypeAlias
from typing import Any, TypeAlias

import equinox.internal as eqxi
import jax
Expand All @@ -23,12 +23,6 @@
from equinox.internal import ω
from jaxtyping import Array, PyTree, Scalar


if TYPE_CHECKING:
from typing import ClassVar as AbstractClassVar
else:
from equinox.internal import AbstractClassVar

from .._misc import resolve_rcond, structure_equal, tree_where
from .._norm import max_norm, tree_dot
from .._operator import (
Expand All @@ -41,24 +35,42 @@
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import preconditioner_and_y0
from .normal import Normal


_CGState: TypeAlias = tuple[AbstractLinearOperator, bool]


# TODO(kidger): this is pretty slow to compile.
# - CG evaluates `operator.mv` three times.
# - Normal CG evaluates `operator.mv` seven (!) times.
# Possibly this can be cheapened a bit somehow?
class _AbstractCG(AbstractLinearSolver[_CGState]):
class CG(AbstractLinearSolver[_CGState]):
"""Conjugate gradient solver for linear systems.

The operator should be positive or negative definite.

Equivalent to `scipy.sparse.linalg.cg`.

This supports the following `options` (as passed to
`lx.linear_solve(..., options=...)`).

- `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
to be used as preconditioner. Defaults to
[`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
so it is the preconditioned residual that is minimized, though the actual
termination criteria uses the un-preconditioned residual.

- `y0`: The initial estimate of the solution to the linear system. Defaults to all
zeros.

"""

rtol: float
atol: float
norm: Callable[[PyTree], Scalar] = max_norm
stabilise_every: int | None = 10
max_steps: int | None = None

_normal: AbstractClassVar[bool]

def __check_init__(self):
if isinstance(self.rtol, (int, float)) and self.rtol < 0:
raise ValueError("Tolerances must be non-negative.")
Expand All @@ -75,18 +87,18 @@ def __check_init__(self):
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
is_nsd = is_negative_semidefinite(operator)
if not self._normal:
if not structure_equal(operator.in_structure(), operator.out_structure()):
raise ValueError(
"`CG()` may only be used for linear solves with square matrices."
)
if not (is_positive_semidefinite(operator) | is_nsd):
raise ValueError(
"`CG()` may only be used for positive "
"or negative definite linear operators"
)
if is_nsd:
operator = -operator
if not structure_equal(operator.in_structure(), operator.out_structure()):
raise ValueError(
"`CG()` may only be used for linear solves with square matrices."
)
if not (is_positive_semidefinite(operator) | is_nsd):
raise ValueError(
"`CG()` may only be used for positive "
"or negative definite linear operators"
)
if is_nsd:
operator = -operator
operator = linearise(operator)
return operator, is_nsd

# This differs from jax.scipy.sparse.linalg.cg in:
Expand All @@ -103,46 +115,16 @@ def compute(
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
operator, is_nsd = state
preconditioner, y0 = preconditioner_and_y0(operator, vector, options)
if self._normal:
# Linearise if JacobianLinearOperator, to avoid computing the forward
# pass separately for mv and transpose_mv.
# This choice is "fast by default", even at the expense of memory.
# If a downstream user wants to avoid this then they can call
# ```
# linear_solve(
# conj(operator.T) @ operator, operator.mv(b), solver=CG()
# )
# ```
# directly.
operator = linearise(operator)
preconditioner = linearise(preconditioner)

_mv = operator.mv
_transpose_mv = conj(operator.transpose()).mv
_pmv = preconditioner.mv
_transpose_pmv = conj(preconditioner.transpose()).mv

def mv(vector: PyTree) -> PyTree:
return _transpose_mv(_mv(vector))

def psolve(vector: PyTree) -> PyTree:
return _pmv(_transpose_pmv(vector))

vector = _transpose_mv(vector)
else:
if not is_positive_semidefinite(preconditioner):
raise ValueError("The preconditioner must be positive definite.")
mv = operator.mv
psolve = preconditioner.mv

if not is_positive_semidefinite(preconditioner):
raise ValueError("The preconditioner must be positive definite.")
leaves, _ = jtu.tree_flatten(vector)
size = sum(leaf.size for leaf in leaves)
if self.max_steps is None:
max_steps = 10 * size # Copied from SciPy!
else:
max_steps = self.max_steps
r0 = (vector**ω - mv(y0) ** ω).ω
p0 = psolve(r0)
r0 = (vector**ω - operator.mv(y0) ** ω).ω
p0 = preconditioner.mv(r0)
gamma0 = tree_dot(p0, r0)
rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves))
initial_value = (
Expand Down Expand Up @@ -184,7 +166,7 @@ def cond_fun(value):

def body_fun(value):
_, y, r, p, gamma, step = value
mat_p = mv(p)
mat_p = operator.mv(p)
inner_prod = tree_dot(mat_p, p)
alpha = gamma / inner_prod
alpha = tree_where(
Expand All @@ -199,7 +181,7 @@ def body_fun(value):
# We compute the residual the "expensive" way every now and again, so as to
# correct numerical rounding errors.
def stable_r():
return (vector**ω - mv(y) ** ω).ω
return (vector**ω - operator.mv(y) ** ω).ω

def cheap_r():
return (r**ω - alpha * mat_p**ω).ω
Expand All @@ -213,7 +195,7 @@ def cheap_r():
stable_step = eqxi.nonbatchable(stable_step)
r = lax.cond(stable_step, stable_r, cheap_r)

z = psolve(r)
z = preconditioner.mv(r)
gamma_prev = gamma
gamma = tree_dot(z, r)
beta = gamma / gamma_prev
Expand All @@ -237,7 +219,7 @@ def cheap_r():
RESULTS.successful,
)

if is_nsd and not self._normal:
if is_nsd:
solution = -(solution**ω).ω
stats = {"num_steps": num_steps, "max_steps": self.max_steps}
return solution, result, stats
Expand All @@ -258,80 +240,13 @@ def conj(self, state: _CGState, options: dict[str, Any]):
conj_state = conj(psd_op), is_nsd
return conj_state, conj_options


class CG(_AbstractCG):
"""Conjugate gradient solver for linear systems.

The operator should be positive or negative definite.

Equivalent to `scipy.sparse.linalg.cg`.

This supports the following `options` (as passed to
`lx.linear_solve(..., options=...)`).

- `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
to be used as preconditioner. Defaults to
[`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
so it is the preconditioned residual that is minimized, though the actual
termination criteria uses the un-preconditioned residual.
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
zeros.

!!! info


"""

_normal: ClassVar[bool] = False

def allow_dependent_columns(self, operator):
return False

def allow_dependent_rows(self, operator):
return False


class NormalCG(_AbstractCG):
"""Conjugate gradient applied to the normal equations:

`A^T A = A^T b`

of a system of linear equations. Note that this squares the condition
number, so it is not recommended. This is a fast but potentially inaccurate
method, especially in 32 bit floating point precision.

This can handle nonsquare operators provided they are full-rank.

This supports the following `options` (as passed to
`lx.linear_solve(..., options=...)`).

- `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as
preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. Note that
the preconditioner should approximate the inverse of `A`, not the inverse of
`A^T A`. This method uses left preconditioning, so it is the preconditioned
residual that is minimized, though the actual termination criteria uses
the un-preconditioned residual.
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
zeros.

!!! info


"""

_normal: ClassVar[bool] = True

def allow_dependent_columns(self, operator):
rows = operator.out_size()
columns = operator.in_size()
return columns > rows

def allow_dependent_rows(self, operator):
rows = operator.out_size()
columns = operator.in_size()
return rows > columns


CG.__init__.__doc__ = r"""**Arguments:**

- `rtol`: Relative tolerance for terminating solve.
Expand All @@ -349,19 +264,6 @@ def allow_dependent_rows(self, operator):
than this are required, then the solve is halted with a failure.
"""

NormalCG.__init__.__doc__ = r"""**Arguments:**

- `rtol`: Relative tolerance for terminating solve.
- `atol`: Absolute tolerance for terminating solve.
- `norm`: The norm to use when computing whether the error falls within the tolerance.
Defaults to the max norm.
- `stabilise_every`: The conjugate gradient is an iterative method that produces
candidate solutions $x_1, x_2, \ldots$, and terminates once $r_i = \| Ax_i - b \|$
is small enough. For computational efficiency, the values $r_i$ are computed using
other internal quantities, and not by directly evaluating the formula above.
However, this computation of $r_i$ is susceptible to drift due to limited
floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed
directly using the formula above, in order to stabilise the computation.
- `max_steps`: The maximum number of iterations to run the solver for. If more steps
than this are required, then the solve is halted with a failure.
"""
def NormalCG(*args):
return Normal(CG(*args))
2 changes: 1 addition & 1 deletion lineax/_solver/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
if is_nsd:
matrix = -matrix
factor, lower = jsp.linalg.cho_factor(matrix)
# Fix lower triangular for simplicity.
# Fix upper triangular for simplicity.
assert lower is False
return factor, is_nsd

Expand Down
Loading
Loading