Skip to content

Commit f638e7f

Browse files
authored
Allow initial guesses to be passed to LSMR (#251)
* Make x0 explicit in LSMR * Make damp and x0 non-differentiable to postpone gradient derivations * Assert the LSMR starting vector is used correctly * Avoid linalg.lstsq-equivalence test for wide matrices * Improve formatting * Run equivalence test in double precision because scipy uses it
1 parent 1ace202 commit f638e7f

3 files changed

Lines changed: 111 additions & 54 deletions

File tree

matfree/backend/testing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,9 @@ def warns(warning, /):
2424
return pytest.warns(warning)
2525

2626

27+
def filterwarnings(warning, /):
28+
return pytest.mark.filterwarnings(warning)
29+
30+
2731
def case():
2832
return pytest_cases.case()

matfree/lstsq.py

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def lsmr(
2323
maxiter: int = 1_000_000,
2424
while_loop: Callable = control_flow.while_loop,
2525
custom_vjp: bool = True,
26-
damp: float = 0.0,
2726
):
2827
"""Construct an experimental implementation of LSMR.
2928
@@ -78,28 +77,48 @@ class State:
7877
# more often than not, the matvec is defined after the LSMR
7978
# solver has been constructed. So it's part of the run()
8079
# function, not the LSMR constructor.
81-
def run(vecmat, b, *vecmat_args):
80+
def run(vecmat, b, *vecmat_args, x0=None, damp=0.0):
81+
x_like = func.eval_shape(vecmat, b, *vecmat_args)
82+
(ncols,) = x_like.shape
83+
x = x0 if x0 is not None else np.zeros(ncols, dtype=b.dtype)
84+
85+
# Combine the lstsq_fun wiht a closure convert, because
86+
# typically, vecmat is a lambda function and if we want to
87+
# have explicit parameter-VJPs, all parameters need to be explicit.
88+
# This means that in this function here, we always use lstsq_public
89+
# (and return lstsq_public!), but provide lstsq_fun with the custom VJP.
90+
# Thereby, the function that gets the custom VJP is, from now on, only
91+
# called after a previous call to closure convert which 'fixes' all namespaces.
92+
vecmat_closure, args = func.closure_convert(
93+
lambda s: vecmat(s, *vecmat_args), b
94+
)
95+
return _run(vecmat_closure, b, args, x, damp)
96+
97+
def _run(vecmat, b, vecmat_args, x0, damp):
8298
def vecmat_noargs(v):
8399
return vecmat(v, *vecmat_args)
84100

85-
(ncols,) = func.eval_shape(vecmat, b, *vecmat_args).shape
101+
def matvec_noargs(w):
102+
matvec = func.linear_transpose(vecmat_noargs, b)
103+
(Aw,) = matvec(w)
104+
return Aw
86105

87-
state, normb, matvec_noargs = init(vecmat_noargs, b, ncols=ncols)
88-
step_fun = make_step(matvec_noargs, normb=normb)
106+
state, normb = init(matvec_noargs, b, x0)
107+
step_fun = make_step(matvec_noargs, normb=normb, damp=damp)
89108
cond_fun = make_cond_fun()
90109
state = while_loop(cond_fun, step_fun, state)
91110
stats_ = stats(state)
92111
return state.x, stats_
93112

94-
def init(vecmat, b, ncols: int):
113+
def init(matvec_noargs, b, x):
95114
normb = linalg.vector_norm(b)
96-
x = np.zeros(ncols, dtype=b.dtype)
97-
beta = normb
98115

99-
u = b
116+
Ax, vecmat_noargs = func.vjp(matvec_noargs, x)
117+
u = b - Ax
118+
beta = linalg.vector_norm(u)
100119
u = u / np.where(beta > 0, beta, 1.0)
101120

102-
v, matvec = func.vjp(vecmat, u)
121+
(v,) = vecmat_noargs(u)
103122
alpha = linalg.vector_norm(v)
104123
v = v / np.where(alpha > 0, alpha, 1)
105124
v = np.where(beta == 0, np.zeros_like(v), v)
@@ -115,7 +134,7 @@ def init(vecmat, b, ncols: int):
115134
sbar = 0.0
116135

117136
h = v
118-
hbar = np.zeros(ncols, dtype=b.dtype)
137+
hbar = np.zeros_like(x)
119138

120139
# Initialize variables for estimation of ||r||.
121140

@@ -176,9 +195,9 @@ def init(vecmat, b, ncols: int):
176195
istop=0,
177196
)
178197
state = tree.tree_map(np.asarray, state)
179-
return state, normb, lambda *a: matvec(*a)[0]
198+
return state, normb
180199

181-
def make_step(matvec, normb: float) -> Callable:
200+
def make_step(matvec, normb: float, damp: float) -> Callable:
182201
def step(state: State) -> State:
183202
# Perform the next step of the bidiagonalization
184203

@@ -338,7 +357,7 @@ def stats(state: State) -> dict:
338357
}
339358

340359
if custom_vjp:
341-
return _lstsq_custom_vjp(run)
360+
_run = _lstsq_custom_vjp(_run)
342361
return run
343362

344363

@@ -380,32 +399,23 @@ def _sym_ortho_3(a, b):
380399

381400

382401
def _lstsq_custom_vjp(lstsq_fun: Callable) -> Callable:
383-
# Combine the lstsq_fun wiht a closure convert, because
384-
# typically, vecmat is a lambda function and if we want to
385-
# have explicit parameter-VJPs, all parameters need to be explicit.
386-
# This means that in this function here, we always use lstsq_public
387-
# (and return lstsq_public!), but provide lstsq_fun with the custom VJP.
388-
# Thereby, the function that gets the custom VJP is, from now on, only
389-
# called after a previous call to closure convert which 'fixes' all namespaces.
390-
def lstsq_public(vecmat, rhs, *vecmat_args):
391-
vecmat_, args = func.closure_convert(lambda s: vecmat(s, *vecmat_args), rhs)
392-
return lstsq_fun(vecmat_, rhs, *args)
393-
394-
def lstsq_fwd(vecmat, rhs, *vecmat_args):
395-
x, stats = lstsq_public(vecmat, rhs, *vecmat_args)
396-
cache = {"x": x, "rhs": rhs, "vecmat_args": vecmat_args}
402+
def lstsq_fwd(vecmat, rhs, vecmat_args, x0, damp):
403+
x, stats = lstsq_fun(vecmat, rhs, vecmat_args, x0, damp)
404+
cache = {"x": x, "rhs": rhs, "vecmat_args": vecmat_args, "x0": x0, "damp": damp}
397405
return (x, stats), cache
398406

399-
def lstsq_rev(vecmat, cache, dmu_dx):
407+
def lstsq_rev(vecmat, x0, damp, cache, dmu_dx):
400408
dmu_dx, _ = dmu_dx
401409
x_like = func.eval_shape(vecmat, cache["rhs"], *cache["vecmat_args"])
402410
if cache["rhs"].size <= x_like.size:
403-
return lstsq_rev_wide(vecmat, cache, dmu_dx)
404-
return lstsq_rev_tall(vecmat, cache, dmu_dx)
411+
return lstsq_rev_wide(vecmat, x0, damp, cache, dmu_dx)
412+
return lstsq_rev_tall(vecmat, x0, damp, cache, dmu_dx)
405413

406-
def lstsq_rev_tall(vecmat, cache, dmu_dx):
414+
def lstsq_rev_tall(vecmat, x0, damp, cache, dmu_dx):
407415
x = cache["x"]
408416
rhs = cache["rhs"]
417+
x0 = cache["x0"]
418+
damp = cache["damp"]
409419
vecmat_args = cache["vecmat_args"]
410420

411421
def vecmat_noargs(z):
@@ -414,11 +424,12 @@ def vecmat_noargs(z):
414424
def matvec_noargs(z):
415425
return func.vjp(vecmat_noargs, rhs)[1](z)[0]
416426

417-
dmu_db = lstsq_public(matvec_noargs, dmu_dx)[0]
418-
p = lstsq_public(vecmat_noargs, -dmu_db)[0]
427+
x0_rev = np.zeros_like(rhs)
428+
dmu_db = lstsq_fun(matvec_noargs, dmu_dx, (), x0_rev, damp)[0]
429+
p = lstsq_fun(vecmat_noargs, -dmu_db, (), x0, damp)[0]
419430

420-
Ax_minus_b = matvec_noargs(x) - rhs
421431
Ap = matvec_noargs(p)
432+
Ax_minus_b = matvec_noargs(x) - rhs
422433

423434
@func.grad
424435
def grad_theta(theta):
@@ -427,9 +438,9 @@ def grad_theta(theta):
427438
return linalg.inner(rA, p) + linalg.inner(pAA, x)
428439

429440
dmu_dparams = grad_theta(vecmat_args)
430-
return dmu_db, *dmu_dparams
441+
return dmu_db, dmu_dparams
431442

432-
def lstsq_rev_wide(vecmat, cache, dmu_dx):
443+
def lstsq_rev_wide(vecmat, x0, damp, cache, dmu_dx):
433444
x = cache["x"]
434445
rhs = cache["rhs"]
435446
vecmat_args = cache["vecmat_args"]
@@ -441,11 +452,12 @@ def matvec_noargs(z):
441452
return func.linear_transpose(vecmat_noargs, rhs)(z)[0]
442453

443454
# Compute the Lagrange multiplier from the forward pass
444-
y = lstsq_public(matvec_noargs, x)[0]
455+
x0_rev = np.zeros_like(rhs)
456+
y = lstsq_fun(matvec_noargs, x, (), x0_rev, damp)[0]
445457

446458
# Compute the two solutions of the backward pass
447-
p = dmu_dx - lstsq_public(vecmat_noargs, matvec_noargs(dmu_dx))[0]
448-
q = lstsq_public(matvec_noargs, p - dmu_dx)[0]
459+
p = dmu_dx - lstsq_fun(vecmat_noargs, matvec_noargs(dmu_dx), (), x0, damp)[0]
460+
q = lstsq_fun(matvec_noargs, p - dmu_dx, (), x0_rev, damp)[0]
449461

450462
@func.grad
451463
def grad_theta(theta):
@@ -455,8 +467,8 @@ def grad_theta(theta):
455467

456468
grad_vecmat_args = grad_theta(vecmat_args)
457469
grad_rhs = -q
458-
return grad_rhs, *grad_vecmat_args
470+
return grad_rhs, grad_vecmat_args
459471

460-
lstsq_fun = func.custom_vjp(lstsq_fun, nondiff_argnums=(0,))
472+
lstsq_fun = func.custom_vjp(lstsq_fun, nondiff_argnums=(0, 3, 4))
461473
lstsq_fun.defvjp(lstsq_fwd, lstsq_rev) # type: ignore
462-
return lstsq_public
474+
return lstsq_fun

tests/test_lstsq.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
"""Tests for least-squares functionality."""
22

33
from matfree import lstsq, test_util
4-
from matfree.backend import func, linalg, prng, testing
5-
from matfree.backend.typing import Callable
6-
7-
8-
@testing.case()
9-
def case_lstsq_lsmr() -> Callable:
10-
return lstsq.lsmr(atol=1e-5, btol=1e-5, ctol=1e-5)
4+
from matfree.backend import config, func, linalg, np, prng, testing
115

126

137
def case_A_shape_wide() -> tuple:
@@ -22,9 +16,9 @@ def case_A_shape_square() -> tuple:
2216
return 3, 3
2317

2418

25-
@testing.parametrize_with_cases("lstsq_fun", cases=".", prefix="case_lstsq_")
2619
@testing.parametrize_with_cases("A_shape", cases=".", prefix="case_A_shape_")
27-
def test_value_and_grad_matches_numpy_lstsq(lstsq_fun: Callable, A_shape: tuple):
20+
@testing.parametrize("provide_x0", [True, False])
21+
def test_value_and_grad_matches_numpy_lstsq(A_shape: tuple, provide_x0: bool):
2822
key = prng.prng_key(1)
2923

3024
key, subkey = prng.split(key, 2)
@@ -34,6 +28,13 @@ def test_value_and_grad_matches_numpy_lstsq(lstsq_fun: Callable, A_shape: tuple)
3428
key, subkey = prng.split(key, num=2)
3529
dsol = prng.normal(subkey, shape=(A_shape[1],))
3630

31+
# If the matrix is wide, any nonzero initial guess affects the optimal solution
32+
# so the comparison to np.linalg.lstsq() is no longer valid. Thus, the caveat below.
33+
key, subkey = prng.split(key, num=2)
34+
is_wide = A_shape[1] > A_shape[0]
35+
x0_suggestion = prng.normal(subkey, shape=(A_shape[1],))
36+
x0 = x0_suggestion if provide_x0 and not is_wide else None
37+
3738
def lstsq_jnp(a, b):
3839
sol, *_ = linalg.lstsq(a, b)
3940
return sol
@@ -46,12 +47,52 @@ def vecmat(vector, p_as_list):
4647
return p.T @ vector
4748

4849
def lstsq_matfree(a, b):
49-
sol, _ = lstsq_fun(vecmat, a, b)
50+
lsmr = lstsq.lsmr(atol=1e-5, btol=1e-5, ctol=1e-5)
51+
sol, _ = lsmr(vecmat, a, b, x0=x0)
5052
return sol
5153

5254
received, received_vjp = func.vjp(lstsq_matfree, rhs, [matrix])
5355
drhs2, [dmatrix2] = received_vjp(dsol) # mind the order of rhs & matrix
5456

5557
test_util.assert_allclose(received, expected)
56-
test_util.assert_allclose(dmatrix1, dmatrix2)
5758
test_util.assert_allclose(drhs1, drhs2)
59+
test_util.assert_allclose(dmatrix1, dmatrix2)
60+
61+
62+
@testing.parametrize_with_cases("A_shape", cases=".", prefix="case_A_shape_")
63+
@testing.filterwarnings("ignore: overflow encountered in") # SciPy LSMR warns...
64+
def test_output_matches_original_scipy_lsmr(A_shape: tuple):
65+
"""Assert that the implementation of scipy's LSMR is matched exactly."""
66+
import numpy as onp # noqa: ICN001
67+
import scipy.sparse.linalg
68+
69+
# Scipy uses double precision, so we emulate this behaviour
70+
config.update("jax_enable_x64", True)
71+
72+
key = prng.prng_key(1)
73+
key, subkey = prng.split(key, 2)
74+
matrix = prng.normal(subkey, shape=A_shape)
75+
key, subkey = prng.split(key, 2)
76+
rhs = prng.normal(subkey, shape=(A_shape[0],))
77+
key, subkey = prng.split(key, num=2)
78+
x0 = prng.normal(subkey, shape=(A_shape[1],))
79+
key, subkey = prng.split(key, num=2)
80+
damp = (prng.uniform(subkey, shape=())) ** 2
81+
82+
# Our code
83+
lsmr = lstsq.lsmr(atol=1e-5, btol=1e-5, ctol=1e-5)
84+
sol, _ = lsmr(lambda v: matrix.T @ v, rhs, damp=damp, x0=x0)
85+
86+
# Original NumPy
87+
matrix = onp.asarray(matrix)
88+
rhs = onp.asarray(rhs)
89+
x0 = onp.asarray(x0)
90+
damp = onp.asarray(damp)
91+
sol2, *_ = scipy.sparse.linalg.lsmr(
92+
matrix, rhs, atol=1e-5, btol=1e-5, conlim=1e5, damp=damp, x0=x0
93+
)
94+
95+
assert np.allclose(sol, np.asarray(sol2))
96+
97+
# Scipy uses double precision, so we emulate this behaviour
98+
config.update("jax_enable_x64", False)

0 commit comments

Comments
 (0)