1313# limitations under the License.
1414
1515from collections .abc import Callable
16- from typing import Any , ClassVar , TYPE_CHECKING , TypeAlias
16+ from typing import Any , TypeAlias
1717
1818import equinox .internal as eqxi
1919import jax
2323from equinox .internal import ω
2424from jaxtyping import Array , PyTree , Scalar
2525
26-
27- if TYPE_CHECKING :
28- from typing import ClassVar as AbstractClassVar
29- else :
30- from equinox .internal import AbstractClassVar
31-
3226from .._misc import resolve_rcond , structure_equal , tree_where
3327from .._norm import max_norm , tree_dot
3428from .._operator import (
4135from .._solution import RESULTS
4236from .._solve import AbstractLinearSolver
4337from .misc import preconditioner_and_y0
38+ from .normal import Normal
4439
4540
4641_CGState : TypeAlias = tuple [AbstractLinearOperator , bool ]
4742
4843
4944# TODO(kidger): this is pretty slow to compile.
5045# - CG evaluates `operator.mv` three times.
51- # - Normal CG evaluates `operator.mv` seven (!) times.
5246# Possibly this can be cheapened a bit somehow?
53- class _AbstractCG (AbstractLinearSolver [_CGState ]):
47+ class CG (AbstractLinearSolver [_CGState ]):
48+ """Conjugate gradient solver for linear systems.
49+
50+ The operator should be positive or negative definite.
51+
52+ Equivalent to `scipy.sparse.linalg.cg`.
53+
54+ This supports the following `options` (as passed to
55+ `lx.linear_solve(..., options=...)`).
56+
57+ - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
58+ to be used as preconditioner. Defaults to
59+ [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
60+ so it is the preconditioned residual that is minimized, though the actual
61+ termination criteria uses the un-preconditioned residual.
62+
63+ - `y0`: The initial estimate of the solution to the linear system. Defaults to all
64+ zeros.
65+
66+ """
67+
5468 rtol : float
5569 atol : float
5670 norm : Callable [[PyTree ], Scalar ] = max_norm
5771 stabilise_every : int | None = 10
5872 max_steps : int | None = None
5973
60- _normal : AbstractClassVar [bool ]
61-
6274 def __check_init__ (self ):
6375 if isinstance (self .rtol , (int , float )) and self .rtol < 0 :
6476 raise ValueError ("Tolerances must be non-negative." )
@@ -75,18 +87,18 @@ def __check_init__(self):
7587 def init (self , operator : AbstractLinearOperator , options : dict [str , Any ]):
7688 del options
7789 is_nsd = is_negative_semidefinite (operator )
78- if not self . _normal :
79- if not structure_equal ( operator . in_structure (), operator . out_structure ()):
80- raise ValueError (
81- "`CG()` may only be used for linear solves with square matrices."
82- )
83- if not ( is_positive_semidefinite ( operator ) | is_nsd ):
84- raise ValueError (
85- "`CG()` may only be used for positive "
86- "or negative definite linear operators"
87- )
88- if is_nsd :
89- operator = - operator
90+ if not structure_equal ( operator . in_structure (), operator . out_structure ()) :
91+ raise ValueError (
92+ "`CG()` may only be used for linear solves with square matrices."
93+ )
94+ if not ( is_positive_semidefinite ( operator ) | is_nsd ):
95+ raise ValueError (
96+ "`CG()` may only be used for positive "
97+ "or negative definite linear operators "
98+ )
99+ if is_nsd :
100+ operator = - operator
101+ operator = linearise ( operator )
90102 return operator , is_nsd
91103
92104 # This differs from jax.scipy.sparse.linalg.cg in:
@@ -103,46 +115,16 @@ def compute(
103115 ) -> tuple [PyTree [Array ], RESULTS , dict [str , Any ]]:
104116 operator , is_nsd = state
105117 preconditioner , y0 = preconditioner_and_y0 (operator , vector , options )
106- if self ._normal :
107- # Linearise if JacobianLinearOperator, to avoid computing the forward
108- # pass separately for mv and transpose_mv.
109- # This choice is "fast by default", even at the expense of memory.
110- # If a downstream user wants to avoid this then they can call
111- # ```
112- # linear_solve(
113- # conj(operator.T) @ operator, operator.mv(b), solver=CG()
114- # )
115- # ```
116- # directly.
117- operator = linearise (operator )
118- preconditioner = linearise (preconditioner )
119-
120- _mv = operator .mv
121- _transpose_mv = conj (operator .transpose ()).mv
122- _pmv = preconditioner .mv
123- _transpose_pmv = conj (preconditioner .transpose ()).mv
124-
125- def mv (vector : PyTree ) -> PyTree :
126- return _transpose_mv (_mv (vector ))
127-
128- def psolve (vector : PyTree ) -> PyTree :
129- return _pmv (_transpose_pmv (vector ))
130-
131- vector = _transpose_mv (vector )
132- else :
133- if not is_positive_semidefinite (preconditioner ):
134- raise ValueError ("The preconditioner must be positive definite." )
135- mv = operator .mv
136- psolve = preconditioner .mv
137-
118+ if not is_positive_semidefinite (preconditioner ):
119+ raise ValueError ("The preconditioner must be positive definite." )
138120 leaves , _ = jtu .tree_flatten (vector )
139121 size = sum (leaf .size for leaf in leaves )
140122 if self .max_steps is None :
141123 max_steps = 10 * size # Copied from SciPy!
142124 else :
143125 max_steps = self .max_steps
144- r0 = (vector ** ω - mv (y0 ) ** ω ).ω
145- p0 = psolve (r0 )
126+ r0 = (vector ** ω - operator . mv (y0 ) ** ω ).ω
127+ p0 = preconditioner . mv (r0 )
146128 gamma0 = tree_dot (p0 , r0 )
147129 rcond = resolve_rcond (None , size , size , jnp .result_type (* leaves ))
148130 initial_value = (
@@ -184,7 +166,7 @@ def cond_fun(value):
184166
185167 def body_fun (value ):
186168 _ , y , r , p , gamma , step = value
187- mat_p = mv (p )
169+ mat_p = operator . mv (p )
188170 inner_prod = tree_dot (mat_p , p )
189171 alpha = gamma / inner_prod
190172 alpha = tree_where (
@@ -201,7 +183,7 @@ def body_fun(value):
201183 # We compute the residual the "expensive" way every now and again, so as to
202184 # correct numerical rounding errors.
203185 def stable_r ():
204- return (vector ** ω - mv (y ) ** ω ).ω
186+ return (vector ** ω - operator . mv (y ) ** ω ).ω
205187
206188 def cheap_r ():
207189 return (r ** ω - alpha * mat_p ** ω ).ω
@@ -215,7 +197,7 @@ def cheap_r():
215197 stable_step = eqxi .nonbatchable (stable_step )
216198 r = lax .cond (stable_step , stable_r , cheap_r )
217199
218- z = psolve (r )
200+ z = preconditioner . mv (r )
219201 gamma_prev = gamma
220202 gamma = tree_dot (z , r )
221203 beta = gamma / gamma_prev
@@ -239,7 +221,7 @@ def cheap_r():
239221 RESULTS .successful ,
240222 )
241223
242- if is_nsd and not self . _normal :
224+ if is_nsd :
243225 solution = - (solution ** ω ).ω
244226 stats = {"num_steps" : num_steps , "max_steps" : self .max_steps }
245227 return solution , result , stats
@@ -260,66 +242,6 @@ def conj(self, state: _CGState, options: dict[str, Any]):
260242 conj_state = conj (psd_op ), is_nsd
261243 return conj_state , conj_options
262244
263-
264- class CG (_AbstractCG ):
265- """Conjugate gradient solver for linear systems.
266-
267- The operator should be positive or negative definite.
268-
269- Equivalent to `scipy.sparse.linalg.cg`.
270-
271- This supports the following `options` (as passed to
272- `lx.linear_solve(..., options=...)`).
273-
274- - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
275- to be used as preconditioner. Defaults to
276- [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
277- so it is the preconditioned residual that is minimized, though the actual
278- termination criteria uses the un-preconditioned residual.
279- - `y0`: The initial estimate of the solution to the linear system. Defaults to all
280- zeros.
281-
282- !!! info
283-
284-
285- """
286-
287- _normal : ClassVar [bool ] = False
288-
289- def assume_full_rank (self ):
290- return True
291-
292-
293- class NormalCG (_AbstractCG ):
294- """Conjugate gradient applied to the normal equations:
295-
296- `A^T A = A^T b`
297-
298- of a system of linear equations. Note that this squares the condition
299- number, so it is not recommended. This is a fast but potentially inaccurate
300- method, especially in 32 bit floating point precision.
301-
302- This can handle nonsquare operators provided they are full-rank.
303-
304- This supports the following `options` (as passed to
305- `lx.linear_solve(..., options=...)`).
306-
307- - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as
308- preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. Note that
309- the preconditioner should approximate the inverse of `A`, not the inverse of
310- `A^T A`. This method uses left preconditioning, so it is the preconditioned
311- residual that is minimized, though the actual termination criteria uses
312- the un-preconditioned residual.
313- - `y0`: The initial estimate of the solution to the linear system. Defaults to all
314- zeros.
315-
316- !!! info
317-
318-
319- """
320-
321- _normal : ClassVar [bool ] = True
322-
323245 def assume_full_rank (self ):
324246 return True
325247
@@ -341,19 +263,6 @@ def assume_full_rank(self):
341263 than this are required, then the solve is halted with a failure.
342264"""
343265
344- NormalCG .__init__ .__doc__ = r"""**Arguments:**
345266
346- - `rtol`: Relative tolerance for terminating solve.
347- - `atol`: Absolute tolerance for terminating solve.
348- - `norm`: The norm to use when computing whether the error falls within the tolerance.
349- Defaults to the max norm.
350- - `stabilise_every`: The conjugate gradient is an iterative method that produces
351- candidate solutions $x_1, x_2, \ldots$, and terminates once $r_i = \| Ax_i - b \|$
352- is small enough. For computational efficiency, the values $r_i$ are computed using
353- other internal quantities, and not by directly evaluating the formula above.
354- However, this computation of $r_i$ is susceptible to drift due to limited
355- floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed
356- directly using the formula above, in order to stabilise the computation.
357- - `max_steps`: The maximum number of iterations to run the solver for. If more steps
358- than this are required, then the solve is halted with a failure.
359- """
267+ def NormalCG (* args ):
268+ return Normal (CG (* args ))
0 commit comments