Implement Normal, a solver applying an inner solver to the normal equations#168
Implement Normal, a solver applying an inner solver to the normal equations#168adconner wants to merge 1 commit intopatrick-kidger:mainfrom
Conversation
1d6f641 to
317577c
Compare
|
One question about the implementation of So is it correct to say that the result of allow_dependent_{rows,columns} can assume the eventually used solver state is fresh, ie unmodified by transpose or conj? In this case, this assumption would allow a more precise and simpler implementation, removing the special case for square operators. If you agree I can make the simplification. |
patrick-kidger
left a comment
There was a problem hiding this comment.
So is it correct to say that the result of allow_dependent_{rows,columns} can assume the eventually used solver state is fresh, ie unmodified by transpose or conj? In this case, this assumption would allow a more precise and simpler implementation, removing the special case for square operators. If you agree I can make the simplification.
Yup, that's correct. To be precise: I think you can assume that the only callsite for allow_... is the one in the JVP rule.
| GMRES as GMRES, | ||
| LSMR as LSMR, | ||
| LU as LU, | ||
| NormalCG as NormalCG, |
There was a problem hiding this comment.
For backwards compatibility let's still provide NormalCG. (Even if it's just a function rather than a class, even if it's removed from the documentation.)
lineax/_solve.py
Outdated
| If `True` then a more expensive backward pass is needed, to account for the | ||
| extra generality. | ||
|
|
||
| The value `True` does not guarantee that the solver will produce correct |
There was a problem hiding this comment.
Not quite: it must at least be correct.
The desired implications here are that:
allow_dependent_*=Trueimplies that non-NaN output will be the correct pseudoinverse solution.allow_dependent_*=Falseimplies that singular problems will have NaN outputs.
There was a problem hiding this comment.
I think it cant be as you say exactly. Consider unpivoted qr for a tall matrix. We have allow_dependent_rows = True and allow_dependent_cols = False and the method supports full rank matrices only. If the input matrix is not full rank, it may give incorrect (not NaN) pseudoinverse results.
But I hear you that there is intended to be some positive assertion for the True value. Is it that if they are both true then pseudoinverses must be fully supported?
There was a problem hiding this comment.
Do you have an example of such a solve?
As a reference point, here's a quick test script on my part that I think indicates QR will solve or not solve precisely according to whether the matrix is full rank. In particular, with non-full-rank implying a failed solved (as expected):
import lineax as lx
import jax.numpy as jnp
import jax.random as jr
want_full_rank = True
if want_full_rank:
cutoff = 4
else:
cutoff = 2
key1, key2 = jr.split(jr.key(567), 2)
x = jr.normal(key1, (5, 3))
y = jr.normal(key2, (5,))
x = x.at[cutoff:, :].set(0)
def run(x, y):
sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.QR(), throw=False)
return sol.value, sol.result
def run2(x, y):
sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.SVD(), throw=False)
return sol.value, sol.result
out_qr, result_qr = run(x, y)
out_svd, result_svd, = run2(x, y)
assert result_svd == lx.RESULTS.successful
rank = jnp.linalg.matrix_rank(x)
is_full_rank = rank == min(x.shape)
assert is_full_rank == want_full_rank
print(f"Got matrix {x} with rank {rank} which is{'' if is_full_rank else ' not'} full rank.")
if result_qr == lx.RESULTS.successful:
assert jnp.allclose(out_qr, out_svd)
print("QR solved succeeded.")
else:
print("QR solve failed.")There was a problem hiding this comment.
Its enough for instance to use this structurally non-full-rank example (I've modified the lineax tests in this change to use this example as well). For a fully generic nonstructural rank 2, uncomment the alternative lines.
import lineax as lx
import jax.numpy as jnp
import jax.random as jr
want_full_rank = False
key1, key2 = jr.split(jr.key(567), 2)
x = jr.normal(key1, (5, 3))
y = jr.normal(key2, (5,))
if not want_full_rank:
x = x.at[1:,1:].set(0)
# # alternative nonstructural rank 2 x
# key3, key4 = jr.split(key1)
# x = jr.normal(key3, (5,2)) @ jr.normal(key4, (2,3))
def run(x, y):
sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.QR(), throw=False)
return sol.value, sol.result
def run2(x, y):
sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.SVD(), throw=False)
return sol.value, sol.result
out_qr, result_qr = run(x, y)
out_svd, result_svd, = run2(x, y)
assert result_svd == lx.RESULTS.successful
rank = jnp.linalg.matrix_rank(x)
is_full_rank = rank == min(x.shape)
assert is_full_rank == want_full_rank
print(f"Got matrix {x} with rank {rank} which is{'' if is_full_rank else ' not'} full rank.")
if result_qr == lx.RESULTS.successful:
assert jnp.allclose(out_qr, out_svd)
print("QR solved succeeded.")
else:
print("QR solve failed.")There was a problem hiding this comment.
Ah bother, that is unfortunate. My belief was the contract offered by JAX's QR decomposition was to error out if given a non-full-rank matrix.
This is thorny enough that I think I'm coming around to your idea of having a single simple method for tackling this: whether the solver may return pseudoinverse solutions or not. It won't always be as efficient but it is at least guaranteed to be both simple and correct.
WDYT? Sorry for the long back-and-forth across these points!
There was a problem hiding this comment.
The contract by jax/numpy I imagine mirrors the same contract as in the underlying LAPACK routines, eg here
!> It is assumed that A has full rank, and only a rudimentary protection
!> against rank-deficient matrices is provided. This subroutine only detects
!> exact rank-deficiency, where a diagonal element of the triangular factor
!> of A is exactly zero.
Of course, even in the exact rank deficient case like the structural rank 2 above, floating point roundoff error during the calculation can make these entries small nonzero. So the impression one gets is there are two kinds of solvers, ones assuming full rank input which are allowed to divide by values very close to zero assuming they really are meant to be nonzero, and ones robust to non-full rank inputs (rank revealing algorithms).
My suggestions are documented in our discussion in #158: I would be in favor of a assume_full_rank tag for operators which the first kind of solver would filter for at compile time and the second kind of solver would verify at runtime and otherwise error out. As I mention in our discussion the cases where generality are lost are both necessarily quite trivial and also easily recovered by factoring the matrix appropriately. If you wish, you can also introduce operator tags assume_constant_row_space and assume_constant_column_space which control the corresponding jvp terms (and are appropriately implied by assume_full_rank and shape). You could even make a method to guess these tags for a given parameterized operator. The general logical contract is then that every solver does a pseudoinverse solve, but is allowed to exhibit solver specific behavior (wrong result, NaN, guaranteed error) if given an operator not satisfying its tagged assumptions.
I'm also open to assume_full_rank being a property of the solvers like in #158, but this structure seems slightly logically less pleasing.
317577c to
3cae981
Compare
This updates #159 to be independent of #158, as requested there.