@@ -23,7 +23,6 @@ def lsmr(
2323 maxiter : int = 1_000_000 ,
2424 while_loop : Callable = control_flow .while_loop ,
2525 custom_vjp : bool = True ,
26- damp : float = 0.0 ,
2726):
2827 """Construct an experimental implementation of LSMR.
2928
@@ -78,28 +77,48 @@ class State:
7877 # more often than not, the matvec is defined after the LSMR
7978 # solver has been constructed. So it's part of the run()
8079 # function, not the LSMR constructor.
81- def run (vecmat , b , * vecmat_args ):
80+ def run (vecmat , b , * vecmat_args , x0 = None , damp = 0.0 ):
81+ x_like = func .eval_shape (vecmat , b , * vecmat_args )
82+ (ncols ,) = x_like .shape
83+ x = x0 if x0 is not None else np .zeros (ncols , dtype = b .dtype )
84+
85+ # Combine the lstsq_fun wiht a closure convert, because
86+ # typically, vecmat is a lambda function and if we want to
87+ # have explicit parameter-VJPs, all parameters need to be explicit.
88+ # This means that in this function here, we always use lstsq_public
89+ # (and return lstsq_public!), but provide lstsq_fun with the custom VJP.
90+ # Thereby, the function that gets the custom VJP is, from now on, only
91+ # called after a previous call to closure convert which 'fixes' all namespaces.
92+ vecmat_closure , args = func .closure_convert (
93+ lambda s : vecmat (s , * vecmat_args ), b
94+ )
95+ return _run (vecmat_closure , b , args , x , damp )
96+
97+ def _run (vecmat , b , vecmat_args , x0 , damp ):
8298 def vecmat_noargs (v ):
8399 return vecmat (v , * vecmat_args )
84100
85- (ncols ,) = func .eval_shape (vecmat , b , * vecmat_args ).shape
101+ def matvec_noargs (w ):
102+ matvec = func .linear_transpose (vecmat_noargs , b )
103+ (Aw ,) = matvec (w )
104+ return Aw
86105
87- state , normb , matvec_noargs = init (vecmat_noargs , b , ncols = ncols )
88- step_fun = make_step (matvec_noargs , normb = normb )
106+ state , normb = init (matvec_noargs , b , x0 )
107+ step_fun = make_step (matvec_noargs , normb = normb , damp = damp )
89108 cond_fun = make_cond_fun ()
90109 state = while_loop (cond_fun , step_fun , state )
91110 stats_ = stats (state )
92111 return state .x , stats_
93112
94- def init (vecmat , b , ncols : int ):
113+ def init (matvec_noargs , b , x ):
95114 normb = linalg .vector_norm (b )
96- x = np .zeros (ncols , dtype = b .dtype )
97- beta = normb
98115
99- u = b
116+ Ax , vecmat_noargs = func .vjp (matvec_noargs , x )
117+ u = b - Ax
118+ beta = linalg .vector_norm (u )
100119 u = u / np .where (beta > 0 , beta , 1.0 )
101120
102- v , matvec = func . vjp ( vecmat , u )
121+ ( v ,) = vecmat_noargs ( u )
103122 alpha = linalg .vector_norm (v )
104123 v = v / np .where (alpha > 0 , alpha , 1 )
105124 v = np .where (beta == 0 , np .zeros_like (v ), v )
@@ -115,7 +134,7 @@ def init(vecmat, b, ncols: int):
115134 sbar = 0.0
116135
117136 h = v
118- hbar = np .zeros ( ncols , dtype = b . dtype )
137+ hbar = np .zeros_like ( x )
119138
120139 # Initialize variables for estimation of ||r||.
121140
@@ -176,9 +195,9 @@ def init(vecmat, b, ncols: int):
176195 istop = 0 ,
177196 )
178197 state = tree .tree_map (np .asarray , state )
179- return state , normb , lambda * a : matvec ( * a )[ 0 ]
198+ return state , normb
180199
181- def make_step (matvec , normb : float ) -> Callable :
200+ def make_step (matvec , normb : float , damp : float ) -> Callable :
182201 def step (state : State ) -> State :
183202 # Perform the next step of the bidiagonalization
184203
@@ -338,7 +357,7 @@ def stats(state: State) -> dict:
338357 }
339358
340359 if custom_vjp :
341- return _lstsq_custom_vjp (run )
360+ _run = _lstsq_custom_vjp (_run )
342361 return run
343362
344363
@@ -380,32 +399,23 @@ def _sym_ortho_3(a, b):
380399
381400
382401def _lstsq_custom_vjp (lstsq_fun : Callable ) -> Callable :
383- # Combine the lstsq_fun wiht a closure convert, because
384- # typically, vecmat is a lambda function and if we want to
385- # have explicit parameter-VJPs, all parameters need to be explicit.
386- # This means that in this function here, we always use lstsq_public
387- # (and return lstsq_public!), but provide lstsq_fun with the custom VJP.
388- # Thereby, the function that gets the custom VJP is, from now on, only
389- # called after a previous call to closure convert which 'fixes' all namespaces.
390- def lstsq_public (vecmat , rhs , * vecmat_args ):
391- vecmat_ , args = func .closure_convert (lambda s : vecmat (s , * vecmat_args ), rhs )
392- return lstsq_fun (vecmat_ , rhs , * args )
393-
394- def lstsq_fwd (vecmat , rhs , * vecmat_args ):
395- x , stats = lstsq_public (vecmat , rhs , * vecmat_args )
396- cache = {"x" : x , "rhs" : rhs , "vecmat_args" : vecmat_args }
402+ def lstsq_fwd (vecmat , rhs , vecmat_args , x0 , damp ):
403+ x , stats = lstsq_fun (vecmat , rhs , vecmat_args , x0 , damp )
404+ cache = {"x" : x , "rhs" : rhs , "vecmat_args" : vecmat_args , "x0" : x0 , "damp" : damp }
397405 return (x , stats ), cache
398406
399- def lstsq_rev (vecmat , cache , dmu_dx ):
407+ def lstsq_rev (vecmat , x0 , damp , cache , dmu_dx ):
400408 dmu_dx , _ = dmu_dx
401409 x_like = func .eval_shape (vecmat , cache ["rhs" ], * cache ["vecmat_args" ])
402410 if cache ["rhs" ].size <= x_like .size :
403- return lstsq_rev_wide (vecmat , cache , dmu_dx )
404- return lstsq_rev_tall (vecmat , cache , dmu_dx )
411+ return lstsq_rev_wide (vecmat , x0 , damp , cache , dmu_dx )
412+ return lstsq_rev_tall (vecmat , x0 , damp , cache , dmu_dx )
405413
406- def lstsq_rev_tall (vecmat , cache , dmu_dx ):
414+ def lstsq_rev_tall (vecmat , x0 , damp , cache , dmu_dx ):
407415 x = cache ["x" ]
408416 rhs = cache ["rhs" ]
417+ x0 = cache ["x0" ]
418+ damp = cache ["damp" ]
409419 vecmat_args = cache ["vecmat_args" ]
410420
411421 def vecmat_noargs (z ):
@@ -414,11 +424,12 @@ def vecmat_noargs(z):
414424 def matvec_noargs (z ):
415425 return func .vjp (vecmat_noargs , rhs )[1 ](z )[0 ]
416426
417- dmu_db = lstsq_public (matvec_noargs , dmu_dx )[0 ]
418- p = lstsq_public (vecmat_noargs , - dmu_db )[0 ]
427+ x0_rev = np .zeros_like (rhs )
428+ dmu_db = lstsq_fun (matvec_noargs , dmu_dx , (), x0_rev , damp )[0 ]
429+ p = lstsq_fun (vecmat_noargs , - dmu_db , (), x0 , damp )[0 ]
419430
420- Ax_minus_b = matvec_noargs (x ) - rhs
421431 Ap = matvec_noargs (p )
432+ Ax_minus_b = matvec_noargs (x ) - rhs
422433
423434 @func .grad
424435 def grad_theta (theta ):
@@ -427,9 +438,9 @@ def grad_theta(theta):
427438 return linalg .inner (rA , p ) + linalg .inner (pAA , x )
428439
429440 dmu_dparams = grad_theta (vecmat_args )
430- return dmu_db , * dmu_dparams
441+ return dmu_db , dmu_dparams
431442
432- def lstsq_rev_wide (vecmat , cache , dmu_dx ):
443+ def lstsq_rev_wide (vecmat , x0 , damp , cache , dmu_dx ):
433444 x = cache ["x" ]
434445 rhs = cache ["rhs" ]
435446 vecmat_args = cache ["vecmat_args" ]
@@ -441,11 +452,12 @@ def matvec_noargs(z):
441452 return func .linear_transpose (vecmat_noargs , rhs )(z )[0 ]
442453
443454 # Compute the Lagrange multiplier from the forward pass
444- y = lstsq_public (matvec_noargs , x )[0 ]
455+ x0_rev = np .zeros_like (rhs )
456+ y = lstsq_fun (matvec_noargs , x , (), x0_rev , damp )[0 ]
445457
446458 # Compute the two solutions of the backward pass
447- p = dmu_dx - lstsq_public (vecmat_noargs , matvec_noargs (dmu_dx ))[0 ]
448- q = lstsq_public (matvec_noargs , p - dmu_dx )[0 ]
459+ p = dmu_dx - lstsq_fun (vecmat_noargs , matvec_noargs (dmu_dx ), (), x0 , damp )[0 ]
460+ q = lstsq_fun (matvec_noargs , p - dmu_dx , (), x0_rev , damp )[0 ]
449461
450462 @func .grad
451463 def grad_theta (theta ):
@@ -455,8 +467,8 @@ def grad_theta(theta):
455467
456468 grad_vecmat_args = grad_theta (vecmat_args )
457469 grad_rhs = - q
458- return grad_rhs , * grad_vecmat_args
470+ return grad_rhs , grad_vecmat_args
459471
460- lstsq_fun = func .custom_vjp (lstsq_fun , nondiff_argnums = (0 ,))
472+ lstsq_fun = func .custom_vjp (lstsq_fun , nondiff_argnums = (0 , 3 , 4 ))
461473 lstsq_fun .defvjp (lstsq_fwd , lstsq_rev ) # type: ignore
462- return lstsq_public
474+ return lstsq_fun
0 commit comments