Skip to content

Implement Normal, a solver applying an inner solver to the normal equations#168

Closed
adconner wants to merge 1 commit intopatrick-kidger:mainfrom
adconner:push-vnvqwxsmlnlm
Closed

Implement Normal, a solver applying an inner solver to the normal equations#168
adconner wants to merge 1 commit intopatrick-kidger:mainfrom
adconner:push-vnvqwxsmlnlm

Conversation

@adconner
Copy link
Contributor

This updates #159 to be independent of #158, as requested there.

@adconner adconner force-pushed the push-vnvqwxsmlnlm branch from 1d6f641 to 317577c Compare July 24, 2025 03:48
@adconner
Copy link
Contributor Author

One question about the implementation of allow_dependent_{rows,columns}: Currently I implement them defensively with regard to the possibility that the solver state might be transposed, taking into account the possibility that the user manually constructs the solver state and transposes it (for instance, maybe they want to efficiently apply a linear solve to many systems, some involving A and some involving A^t). However, thinking more closely about the way the library works, I think they only way they can do this is to directly call compute on the solver state object, and in this usage, they do not get the custom jvp (any gradients computed would be directly differentiating the compute call), so the result of allow_dependent_{rows,columns} is irrelevant. Any calls to the allow_dependent_{rows,columns} in the library itself follow from the jvp of a linear_solve call, which seems to mean in these functions the solver state can be assumed freshly constructed from the given operator.

So is it correct to say that the result of allow_dependent_{rows,columns} can assume the eventually used solver state is fresh, ie unmodified by transpose or conj? In this case, this assumption would allow a more precise and simpler implementation, removing the special case for square operators. If you agree I can make the simplification.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is it correct to say that the result of allow_dependent_{rows,columns} can assume the eventually used solver state is fresh, ie unmodified by transpose or conj? In this case, this assumption would allow a more precise and simpler implementation, removing the special case for square operators. If you agree I can make the simplification.

Yup, that's correct. To be precise: I think you can assume that the only callsite for allow_... is the one in the JVP rule.

GMRES as GMRES,
LSMR as LSMR,
LU as LU,
NormalCG as NormalCG,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backwards compatibility let's still provide NormalCG. (Even if it's just a function rather than a class, even if it's removed from the documentation.)

lineax/_solve.py Outdated
If `True` then a more expensive backward pass is needed, to account for the
extra generality.

The value `True` does not guarantee that the solver will produce correct
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite: it must at least be correct.

The desired implications here are that:

  • allow_dependent_*=True implies that non-NaN output will be the correct pseudoinverse solution.
  • allow_dependent_*=False implies that singular problems will have NaN outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it cant be as you say exactly. Consider unpivoted qr for a tall matrix. We have allow_dependent_rows = True and allow_dependent_cols = False and the method supports full rank matrices only. If the input matrix is not full rank, it may give incorrect (not NaN) pseudoinverse results.

But I hear you that there is intended to be some positive assertion for the True value. Is it that if they are both true then pseudoinverses must be fully supported?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example of such a solve?

As a reference point, here's a quick test script on my part that I think indicates QR will solve or not solve precisely according to whether the matrix is full rank. In particular, with non-full-rank implying a failed solved (as expected):

import lineax as lx
import jax.numpy as jnp
import jax.random as jr

want_full_rank = True

if want_full_rank:
    cutoff = 4
else:
    cutoff = 2

key1, key2 = jr.split(jr.key(567), 2)
x = jr.normal(key1, (5, 3))
y = jr.normal(key2, (5,))
x = x.at[cutoff:, :].set(0)

def run(x, y):
    sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.QR(), throw=False)
    return sol.value, sol.result

def run2(x, y):
    sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.SVD(), throw=False)
    return sol.value, sol.result

out_qr, result_qr = run(x, y)
out_svd, result_svd, = run2(x, y)
assert result_svd == lx.RESULTS.successful

rank = jnp.linalg.matrix_rank(x)
is_full_rank = rank == min(x.shape)
assert is_full_rank == want_full_rank

print(f"Got matrix {x} with rank {rank} which is{'' if is_full_rank else ' not'} full rank.")
if result_qr == lx.RESULTS.successful:
    assert jnp.allclose(out_qr, out_svd)
    print("QR solved succeeded.")
else:
    print("QR solve failed.")

Copy link
Contributor Author

@adconner adconner Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its enough for instance to use this structurally non-full-rank example (I've modified the lineax tests in this change to use this example as well). For a fully generic nonstructural rank 2, uncomment the alternative lines.

import lineax as lx
import jax.numpy as jnp
import jax.random as jr

want_full_rank = False

key1, key2 = jr.split(jr.key(567), 2)
x = jr.normal(key1, (5, 3))
y = jr.normal(key2, (5,))

if not want_full_rank:
    x = x.at[1:,1:].set(0)
    
    # # alternative nonstructural rank 2 x
    # key3, key4 = jr.split(key1)
    # x = jr.normal(key3, (5,2)) @ jr.normal(key4, (2,3))

def run(x, y):
    sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.QR(), throw=False)
    return sol.value, sol.result

def run2(x, y):
    sol = lx.linear_solve(lx.MatrixLinearOperator(x), y, lx.SVD(), throw=False)
    return sol.value, sol.result

out_qr, result_qr = run(x, y)
out_svd, result_svd, = run2(x, y)
assert result_svd == lx.RESULTS.successful

rank = jnp.linalg.matrix_rank(x)
is_full_rank = rank == min(x.shape)
assert is_full_rank == want_full_rank

print(f"Got matrix {x} with rank {rank} which is{'' if is_full_rank else ' not'} full rank.")
if result_qr == lx.RESULTS.successful:
    assert jnp.allclose(out_qr, out_svd)
    print("QR solved succeeded.")
else:
    print("QR solve failed.")

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah bother, that is unfortunate. My belief was the contract offered by JAX's QR decomposition was to error out if given a non-full-rank matrix.

This is thorny enough that I think I'm coming around to your idea of having a single simple method for tackling this: whether the solver may return pseudoinverse solutions or not. It won't always be as efficient but it is at least guaranteed to be both simple and correct.

WDYT? Sorry for the long back-and-forth across these points!

Copy link
Contributor Author

@adconner adconner Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contract by jax/numpy I imagine mirrors the same contract as in the underlying LAPACK routines, eg here

!> It is assumed that A has full rank, and only a rudimentary protection
!> against rank-deficient matrices is provided. This subroutine only detects
!> exact rank-deficiency, where a diagonal element of the triangular factor
!> of A is exactly zero.

Of course, even in the exact rank deficient case like the structural rank 2 above, floating point roundoff error during the calculation can make these entries small nonzero. So the impression one gets is there are two kinds of solvers, ones assuming full rank input which are allowed to divide by values very close to zero assuming they really are meant to be nonzero, and ones robust to non-full rank inputs (rank revealing algorithms).

My suggestions are documented in our discussion in #158: I would be in favor of a assume_full_rank tag for operators which the first kind of solver would filter for at compile time and the second kind of solver would verify at runtime and otherwise error out. As I mention in our discussion the cases where generality are lost are both necessarily quite trivial and also easily recovered by factoring the matrix appropriately. If you wish, you can also introduce operator tags assume_constant_row_space and assume_constant_column_space which control the corresponding jvp terms (and are appropriately implied by assume_full_rank and shape). You could even make a method to guess these tags for a given parameterized operator. The general logical contract is then that every solver does a pseudoinverse solve, but is allowed to exhibit solver specific behavior (wrong result, NaN, guaranteed error) if given an operator not satisfying its tagged assumptions.

I'm also open to assume_full_rank being a property of the solvers like in #158, but this structure seems slightly logically less pleasing.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 5, 2025

Closing now that #158 is in, we'll merge #159 instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants