1818import equinox .internal as eqxi
1919from jaxtyping import Array , PyTree
2020
21- from .._operator import (
22- conj ,
23- TaggedLinearOperator ,
24- )
21+ from .._operator import conj , linearise , TaggedLinearOperator
2522from .._solution import RESULTS
2623from .._solve import AbstractLinearOperator , AbstractLinearSolver
2724from .._tags import positive_semidefinite_tag
@@ -36,6 +33,7 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
3633 inner_options = copy (options )
3734 del options
3835 if preconditioner is not None :
36+ preconditioner = linearise (preconditioner )
3937 if tall :
4038 inner_options ["preconditioner" ] = TaggedLinearOperator (
4139 preconditioner @ conj (preconditioner .transpose ()),
@@ -46,8 +44,8 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
4644 conj (preconditioner .transpose ()) @ preconditioner ,
4745 positive_semidefinite_tag ,
4846 )
49- if preconditioner is not None and y0 is not None and not tall :
50- inner_options ["y0" ] = conj (preconditioner .transpose ()).mv (y0 )
47+ if y0 is not None :
48+ inner_options ["y0" ] = conj (preconditioner .transpose ()).mv (y0 )
5149 return inner_options
5250
5351
@@ -105,14 +103,17 @@ class Normal(
105103
106104 def init (self , operator , options ):
107105 tall = operator .out_size () >= operator .in_size ()
106+ # we are apply repeated mv's when constructing normal matrix
107+ # these cannot be parallelised so more efficient to linearise first
108+ lin_op = linearise (operator )
108109 if tall :
109- inner_operator = conj (operator .transpose ()) @ operator
110+ inner_operator = conj (lin_op .transpose ()) @ lin_op
110111 else :
111- inner_operator = operator @ conj (operator .transpose ())
112+ inner_operator = lin_op @ conj (lin_op .transpose ())
112113 inner_operator = TaggedLinearOperator (inner_operator , positive_semidefinite_tag )
113114 inner_options = normal_preconditioner_and_y0 (options , tall )
114115 inner_state = self .inner_solver .init (inner_operator , inner_options )
115- operator_conj_transpose = conj (operator .transpose ())
116+ operator_conj_transpose = conj (lin_op .transpose ())
116117 return inner_state , eqxi .Static (tall ), operator_conj_transpose , inner_options
117118
118119 def compute (
0 commit comments