1313# limitations under the License.
1414
1515from collections .abc import Callable
16- from typing import Any , ClassVar , TYPE_CHECKING , TypeAlias
16+ from typing import Any , TypeAlias
1717
1818import equinox .internal as eqxi
1919import jax
2323from equinox .internal import ω
2424from jaxtyping import Array , PyTree , Scalar
2525
26-
27- if TYPE_CHECKING :
28- from typing import ClassVar as AbstractClassVar
29- else :
30- from equinox .internal import AbstractClassVar
31-
3226from .._misc import resolve_rcond , structure_equal , tree_where
3327from .._norm import max_norm , tree_dot
3428from .._operator import (
4135from .._solution import RESULTS
4236from .._solve import AbstractLinearSolver
4337from .misc import preconditioner_and_y0
38+ from .normal import Normal
4439
4540
4641_CGState : TypeAlias = tuple [AbstractLinearOperator , bool ]
4742
4843
4944# TODO(kidger): this is pretty slow to compile.
5045# - CG evaluates `operator.mv` three times.
51- # - Normal CG evaluates `operator.mv` seven (!) times.
5246# Possibly this can be cheapened a bit somehow?
53- class _AbstractCG (AbstractLinearSolver [_CGState ]):
47+ class CG (AbstractLinearSolver [_CGState ]):
48+ """Conjugate gradient solver for linear systems.
49+
50+ The operator should be positive or negative definite.
51+
52+ Equivalent to `scipy.sparse.linalg.cg`.
53+
54+ This supports the following `options` (as passed to
55+ `lx.linear_solve(..., options=...)`).
56+
57+ - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
58+ to be used as preconditioner. Defaults to
59+ [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
60+ so it is the preconditioned residual that is minimized, though the actual
61+ termination criteria uses the un-preconditioned residual.
62+
63+ - `y0`: The initial estimate of the solution to the linear system. Defaults to all
64+ zeros.
65+
66+ """
67+
5468 rtol : float
5569 atol : float
5670 norm : Callable [[PyTree ], Scalar ] = max_norm
5771 stabilise_every : int | None = 10
5872 max_steps : int | None = None
5973
60- _normal : AbstractClassVar [bool ]
61-
6274 def __check_init__ (self ):
6375 if isinstance (self .rtol , (int , float )) and self .rtol < 0 :
6476 raise ValueError ("Tolerances must be non-negative." )
@@ -75,18 +87,18 @@ def __check_init__(self):
7587 def init (self , operator : AbstractLinearOperator , options : dict [str , Any ]):
7688 del options
7789 is_nsd = is_negative_semidefinite (operator )
78- if not self . _normal :
79- if not structure_equal ( operator . in_structure (), operator . out_structure ()):
80- raise ValueError (
81- "`CG()` may only be used for linear solves with square matrices."
82- )
83- if not ( is_positive_semidefinite ( operator ) | is_nsd ):
84- raise ValueError (
85- "`CG()` may only be used for positive "
86- "or negative definite linear operators"
87- )
88- if is_nsd :
89- operator = - operator
90+ if not structure_equal ( operator . in_structure (), operator . out_structure ()) :
91+ raise ValueError (
92+ "`CG()` may only be used for linear solves with square matrices."
93+ )
94+ if not ( is_positive_semidefinite ( operator ) | is_nsd ):
95+ raise ValueError (
96+ "`CG()` may only be used for positive "
97+ "or negative definite linear operators "
98+ )
99+ if is_nsd :
100+ operator = - operator
101+ operator = linearise ( operator )
90102 return operator , is_nsd
91103
92104 # This differs from jax.scipy.sparse.linalg.cg in:
@@ -103,46 +115,16 @@ def compute(
103115 ) -> tuple [PyTree [Array ], RESULTS , dict [str , Any ]]:
104116 operator , is_nsd = state
105117 preconditioner , y0 = preconditioner_and_y0 (operator , vector , options )
106- if self ._normal :
107- # Linearise if JacobianLinearOperator, to avoid computing the forward
108- # pass separately for mv and transpose_mv.
109- # This choice is "fast by default", even at the expense of memory.
110- # If a downstream user wants to avoid this then they can call
111- # ```
112- # linear_solve(
113- # conj(operator.T) @ operator, operator.mv(b), solver=CG()
114- # )
115- # ```
116- # directly.
117- operator = linearise (operator )
118- preconditioner = linearise (preconditioner )
119-
120- _mv = operator .mv
121- _transpose_mv = conj (operator .transpose ()).mv
122- _pmv = preconditioner .mv
123- _transpose_pmv = conj (preconditioner .transpose ()).mv
124-
125- def mv (vector : PyTree ) -> PyTree :
126- return _transpose_mv (_mv (vector ))
127-
128- def psolve (vector : PyTree ) -> PyTree :
129- return _pmv (_transpose_pmv (vector ))
130-
131- vector = _transpose_mv (vector )
132- else :
133- if not is_positive_semidefinite (preconditioner ):
134- raise ValueError ("The preconditioner must be positive definite." )
135- mv = operator .mv
136- psolve = preconditioner .mv
137-
118+ if not is_positive_semidefinite (preconditioner ):
119+ raise ValueError ("The preconditioner must be positive definite." )
138120 leaves , _ = jtu .tree_flatten (vector )
139121 size = sum (leaf .size for leaf in leaves )
140122 if self .max_steps is None :
141123 max_steps = 10 * size # Copied from SciPy!
142124 else :
143125 max_steps = self .max_steps
144- r0 = (vector ** ω - mv (y0 ) ** ω ).ω
145- p0 = psolve (r0 )
126+ r0 = (vector ** ω - operator . mv (y0 ) ** ω ).ω
127+ p0 = preconditioner . mv (r0 )
146128 gamma0 = tree_dot (p0 , r0 )
147129 rcond = resolve_rcond (None , size , size , jnp .result_type (* leaves ))
148130 initial_value = (
@@ -184,7 +166,7 @@ def cond_fun(value):
184166
185167 def body_fun (value ):
186168 _ , y , r , p , gamma , step = value
187- mat_p = mv (p )
169+ mat_p = operator . mv (p )
188170 inner_prod = tree_dot (mat_p , p )
189171 alpha = gamma / inner_prod
190172 alpha = tree_where (
@@ -199,7 +181,7 @@ def body_fun(value):
199181 # We compute the residual the "expensive" way every now and again, so as to
200182 # correct numerical rounding errors.
201183 def stable_r ():
202- return (vector ** ω - mv (y ) ** ω ).ω
184+ return (vector ** ω - operator . mv (y ) ** ω ).ω
203185
204186 def cheap_r ():
205187 return (r ** ω - alpha * mat_p ** ω ).ω
@@ -213,7 +195,7 @@ def cheap_r():
213195 stable_step = eqxi .nonbatchable (stable_step )
214196 r = lax .cond (stable_step , stable_r , cheap_r )
215197
216- z = psolve (r )
198+ z = preconditioner . mv (r )
217199 gamma_prev = gamma
218200 gamma = tree_dot (z , r )
219201 beta = gamma / gamma_prev
@@ -237,7 +219,7 @@ def cheap_r():
237219 RESULTS .successful ,
238220 )
239221
240- if is_nsd and not self . _normal :
222+ if is_nsd :
241223 solution = - (solution ** ω ).ω
242224 stats = {"num_steps" : num_steps , "max_steps" : self .max_steps }
243225 return solution , result , stats
@@ -258,66 +240,6 @@ def conj(self, state: _CGState, options: dict[str, Any]):
258240 conj_state = conj (psd_op ), is_nsd
259241 return conj_state , conj_options
260242
261-
262- class CG (_AbstractCG ):
263- """Conjugate gradient solver for linear systems.
264-
265- The operator should be positive or negative definite.
266-
267- Equivalent to `scipy.sparse.linalg.cg`.
268-
269- This supports the following `options` (as passed to
270- `lx.linear_solve(..., options=...)`).
271-
272- - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
273- to be used as preconditioner. Defaults to
274- [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
275- so it is the preconditioned residual that is minimized, though the actual
276- termination criteria uses the un-preconditioned residual.
277- - `y0`: The initial estimate of the solution to the linear system. Defaults to all
278- zeros.
279-
280- !!! info
281-
282-
283- """
284-
285- _normal : ClassVar [bool ] = False
286-
287- def assume_full_rank (self ):
288- return True
289-
290-
291- class NormalCG (_AbstractCG ):
292- """Conjugate gradient applied to the normal equations:
293-
294- `A^T A = A^T b`
295-
296- of a system of linear equations. Note that this squares the condition
297- number, so it is not recommended. This is a fast but potentially inaccurate
298- method, especially in 32 bit floating point precision.
299-
300- This can handle nonsquare operators provided they are full-rank.
301-
302- This supports the following `options` (as passed to
303- `lx.linear_solve(..., options=...)`).
304-
305- - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as
306- preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. Note that
307- the preconditioner should approximate the inverse of `A`, not the inverse of
308- `A^T A`. This method uses left preconditioning, so it is the preconditioned
309- residual that is minimized, though the actual termination criteria uses
310- the un-preconditioned residual.
311- - `y0`: The initial estimate of the solution to the linear system. Defaults to all
312- zeros.
313-
314- !!! info
315-
316-
317- """
318-
319- _normal : ClassVar [bool ] = True
320-
321243 def assume_full_rank (self ):
322244 return True
323245
@@ -339,19 +261,6 @@ def assume_full_rank(self):
339261 than this are required, then the solve is halted with a failure.
340262"""
341263
342- NormalCG .__init__ .__doc__ = r"""**Arguments:**
343264
344- - `rtol`: Relative tolerance for terminating solve.
345- - `atol`: Absolute tolerance for terminating solve.
346- - `norm`: The norm to use when computing whether the error falls within the tolerance.
347- Defaults to the max norm.
348- - `stabilise_every`: The conjugate gradient is an iterative method that produces
349- candidate solutions $x_1, x_2, \ldots$, and terminates once $r_i = \| Ax_i - b \|$
350- is small enough. For computational efficiency, the values $r_i$ are computed using
351- other internal quantities, and not by directly evaluating the formula above.
352- However, this computation of $r_i$ is susceptible to drift due to limited
353- floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed
354- directly using the formula above, in order to stabilise the computation.
355- - `max_steps`: The maximum number of iterations to run the solver for. If more steps
356- than this are required, then the solve is halted with a failure.
357- """
265+ def NormalCG (* args ):
266+ return Normal (CG (* args ))
0 commit comments