Skip to content

Replace allow_dependent_rows and allow_dependent_columns with assume_full_rank#158

Merged
patrick-kidger merged 1 commit intopatrick-kidger:devfrom
adconner:push-qpkqvxnzrsow
Dec 5, 2025
Merged

Replace allow_dependent_rows and allow_dependent_columns with assume_full_rank#158
patrick-kidger merged 1 commit intopatrick-kidger:devfrom
adconner:push-qpkqvxnzrsow

Conversation

@adconner
Copy link
Contributor

@adconner adconner commented Jun 9, 2025

The two functions allow_dependent_{rows,columns} together did the job of answering if the solver accepts full rank matrices for the purposes of the jvp. Allowing them to be implemented separately created some issues:

  1. Invalid states were representable. Eg. What does it mean that
    dependent columns are allowed for square matrices if dependent rows
    are not? What does it mean that dependent rows are not allowed for
    matrices with more rows than columns?
  2. As the functions accept operator as input, a custom solver could in
    principle decide its answer based on operator's dynamic value rather
    than only jax compilation static information regarding it, as in all
    the lineax defined solvers. This would prevent jax compilation and
    jit.

Both issues are addressed by asking the solver to report only if it assumes the input is numerically full rank. If this assumption is exactly violated, its behavior is allowed to be undefined, and is allowed to error, produce NaN values, and produce invalid values.

@adconner
Copy link
Contributor Author

adconner commented Jun 10, 2025

If you wanted to go further than this PR and did not mind a breaking change, you can make full_rank an operator tag. Then solvers requiring full_rank operators check this just like they check any tag, and the jvp checks the operator tag when choosing the jvp terms to emit. This has benefits

  1. More general solvers also benefit from a faster jvp when called with operators statically known (or assumed) to be full rank
  2. AutoLinearSolver can now be fully automatic, not requiring well_posed for an argument
  3. It makes all static assumptions about operators needed by solvers explicit in operator tags

Users would have to update their code to tag most operators as full rank

@patrick-kidger
Copy link
Owner

Hey there! So you're actually touching on a design point from way back in the early days of Lineax. We in fact originally had a maybe_singular_tag on the operator rather than a description of what could be handled by the solver.

We eventually nixed this because it was too much of a footgun.

  • If Lineax defaulted to assuming well-posedness and you left off the tag maybe_singular_tag then you'd silently incorrect JVPs.
  • If Lineax defaulted to assuming singularity (as you seem to be suggesting here) and you left off the full_rank tag then you'd get silently expensive JVPs. Or conversely if adding that tag because the accepted default, then forgetting to remove it puts you back in the land of silently incorrect JVPs.

Either way this was clearly going to go wrong for a lot of users a substantial fraction of the time!

Making this a property of the solver instead reflects real-world usage. You're probably only using SVD if you have a singular problem, for example, and you're anyway happy to accept the extra computational cost in the JVP rule.


As for the adjustment you're originally making in this PR: being able to distinguish these two cases is useful for the sake of faster JVPs, in particular for QR solves.

On (1) it's true that illegal states are representable. We could probably add some assert statements to _linear_solve_jvp to catch the cases you highlight, which are all determinable given compiletime information. On (2) indeed it has to be static; this is already captured in its type annotation -> bool which is not a JAX array; it is an error for a solver to return anything else -- as meaningless as returning a string or object() etc.

WDYT?

@adconner
Copy link
Contributor Author

adconner commented Jun 11, 2025

If Lineax defaulted to assuming singularity (as you seem to be suggesting here) and you left off the full_rank tag then you'd get silently expensive JVPs. Or conversely if adding that tag because the accepted default, then forgetting to remove it puts you back in the land of silently incorrect JVPs.

Making this a property of the solver instead reflects real-world usage. You're probably only using SVD if you have a singular problem, for example, and you're anyway happy to accept the extra computational cost in the JVP rule.

The user at some point needs to indicate whether they assume the operators are full rank. Currently they do this in their solver selection if manual, or in the well_posed argument to AutoLinearSolver. The problem (?) of the users having to correctly identify their assumptions is already present and exists irrespective of if we express the assumptions in the solver or the operator.

Keep in mind that if the user forgets the full_rank tag and uses a full rank only solver, like QR, they will get an error, just like if they forgot the positive semidefinite tag for the Cholesky solver. It is true that AutoLinearSolver would be SVD for undecorated operators, but this is the only source of silent inefficiencies (and this is also analagous to the psd tag: the user will already get silent inefficiencies from the auto solver selection if they do not tag a known psd operator)

Whats more, there is an additional benefit to the operator (rather than the solver) knowing whether it is full rank where there isnt for another static assertion like psd, as more efficient jvps can be emitted for that operator even using solvers supporting more general operators.

As for the adjustment you're originally making in this PR: being able to distinguish these two cases is useful for the sake of faster JVPs, in particular for QR solves.

The JVPs of this approach are identical to those of the previous approach including for QR. Notice that in the calculation of the jvp, we now recreate the old info as needed. Any values of allow_dependent_{rows,columns} which are not functions of the operator being {tall,square,wide} and whether it is assumed to be full rank are invalid.

@patrick-kidger
Copy link
Owner

It is true that AutoLinearSolver would be SVD for undecorated operators, but this is the only source of silent inefficiencies (and this is also analagous to the psd tag: the user will already get silent inefficiencies from the auto solver selection if they do not tag a known psd operator)

Actually, there is one further problem here (and one that is a motivating reason for us not to do this ): if a user has what they believe to be a definite operate, but mistakenly provides an indefinite operator, then SVD will silently compute a pseudoinverse (least-square/least-norm) solve, rather than erroring out as would be desirable.

Any values of allow_dependent_{rows,columns} which are not functions of the operator being {tall,square,wide} and whether it is assumed to be full rank are invalid.

On this part, I think I like these changes. This was always a subtle point I have to think a lot about.

I think my main comment is that in general I think this could still be a function of any structure/sparsity in the operator? I'd need to noodle on this to be sure I'm getting it right, but I could believe that certain solvers might exhibit or not exhibit this behaviour depending on the operator. (The trivial example is of course a user-defined assume_full_rank tag that the solver checks and respects! But perhaps there are other examples?)

@adconner
Copy link
Contributor Author

Actually, there is one further problem here (and one that is a motivating reason for us not to do this ): if a user has what they believe to be a definite operate, but mistakenly provides an indefinite operator, then SVD will silently compute a pseudoinverse (least-square/least-norm) solve, rather than erroring out as would be desirable.

I see. So you're concerned about the situation that the user is using the SVD solver, intending that their operator is assumed full rank, but SVD ignores this and always does rcond filtering of singular values. If this is concerning it seems like a problem already in the existing implementation, one with a simple fix: Do the same thing that you do with the Diagonal solver and make svd take a well_posed argument. If true, svd assumes all singular values nonzero (rcond = 0, basically). Of course, if operators know whether they are full rank, both this argument and the one in Diagonal (and future pivoted QR) are unneeded, as they would just respect the operator tag.

Or is your concern that for the question of operator rank, you want the user to be forced to be explicit, and any default behavior is potentially confusing? If this is the case, you could just as well require all operators to be decorated with exactly one of a full_rank, possibly_not_full_rank tag (or perhaps some other equivalent design). My suggestion is only about which object the assumption should be encoded into, not what the default value should be. On the choice (default assume full, default no assumption, require explicit), I don't have a strong opinion and there are reasonable arguments for all 3 options.

On this part, I think I like these changes. This was always a subtle point I have to think a lot about.

I think my main comment is that in general I think this could still be a function of any structure/sparsity in the operator? I'd need to noodle on this to be sure I'm getting it right, but I could believe that certain solvers might exhibit or not exhibit this behaviour depending on the operator. (The trivial example is of course a user-defined assume_full_rank tag that the solver checks and respects! But perhaps there are other examples?)

Structure and sparsity in the operator could potentially mean it is structurally rank deficient, and in this situation you would want the operator to know this, but still the only question that the jvp cares about is whether the rows of the operator are independent and whether the columns of the operator are indepndent. The answer to this question, mathematically, for any operator, is given only by the table

full rank? operator size cols independent rows independent
yes tall yes no
yes square yes yes
yes wide no yes
no any no no

The jvp knows the size of the operator, so the the only information it is missing is whether the operator is full rank.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 27, 2025

so the the only information it is missing is whether the operator is full rank

So let's consider a solver that consumes a 2nxn matrix and a 2n-length vector and returns an n-length solution, and for which assume_full_rank = False; that is to say it will always return a pseudoinverse solution.

I've not yet told you anything about the structure of the matrix or the behaviour of the solver.

According to the description you've given here, we would always compute a more-expensive JVP, corresponding to the allow-dependent-rows case.

However! It just so happens that I have a case in mind for which we may only need the inexpensive JVP. (The part corresponding to a true linsolve solution, not the more general pseudoinverse solution).

Namely, we introduce all of the following:

# Consumes two n x n operators and stacks them to create a 2n x n operator.
class TwoOperatorsInATrenchcoast(AbstractLinearOperator):
    first: AbstractLinearOperator
    second: AbstractLinearOperator

    ...

# Matrix of all zeros
class ZeroLinearOperator(AbstractLinearOperator):
    ...

def is_full_and_zero(operator: AbstractLinearOperator) -> bool:
    return (isinstance(operator, TwoOperatorsInATrenchcoast) and isinstance(operator.second, ZeroLinearOperator))

class WeirdSolver(AbstractLinearSolver):
    ...

    def allow_dependent_rows(self, operator):
        return not is_full_and_zero(operator)

    def compute(self, ...):
        if is_full_and_zero(operator) and is_lower_half_zero(vector):  # second function elided for brevity, can be statically evaluatable if there is suitable pytree structure (just as with the operator).
            # Perform an LU solve on the upper half, then pad with zeros
        else:
            # Perform an SVD solve

In this case we have that WeirdSolver is indeed satisfying its contract of only returning a pseudoinverse solution. In the fastpath branch, this coincides with performing a full-rank solve on a submatrix. When on that branch we can skip the extra JVP computation (the tangents in the zero regions are also structurally zero), and this is reflected in allow_dependent_rows.

Thus we have obtained a case in which the solver needs to consume the operator to determine the answer to how it handles dependent rows or not.


The above is obviously highly contrived. But I think it indicates that completely factoring out the operator height-vs-width directly into the JVP rule would indeed lose expressivity.

(Now, one could certainly make the argument that this is a worthy trade-off in the name of simplicity. And if we done this the first time around I'd agree, but as it is I'm inclined to keep things for the sake of backwards-compatibility.)

WDYT?

@adconner
Copy link
Contributor Author

I now understand your concern. At the very least the current allow_dependent_{rows,columns} are misnamed. Its really more like allow_nonconstant_{row,column}_space, as row span A(x) or column span A(x) is independent of x if and only if the corresponding term of the jvp vanishes (for some neighborhood of x's). The rows or columns being independent are just special cases of this notion where the row/column span is constantly the whole vector space.

I could potentially get behind the idea that we might want to expose this possibility (better documented and with more descriptive names) to the user. However, at least first consider that the extra generality we are exposing (a situation where the solver detects that the operator statically has constant row/column space) can already be obtained in the case where it is needed without this feature. If the row span of A(x) is constant, then A(x) factors as A(x) = B C(x), with C(x) has independent rows and B is constant. If the user is in this case they can compute A^+ b = C^+ B^+ b and get the same jvp.

Furthermore, the only nontrivial instance of this property which can ever occur structurally (only looking at the sparsity pattern of A) is precisely the one of your contrived example, where we discard rows/columns which are zero. And the solver only has access to structural information about A, so the only extra generality we are enabling is detection of some zero rows or columns inside the solver. I think I might prefer the simpler system where the solver doesnt make an attempt to detect structurally zero rows/columns, and instead push this functionality to some solver wrapper which structurally factors A(x) = B C(x) D and computes A^+ b = B^T C(x)^+ D^T b, where B is a subset of columns of the identity matrix and D is a subset of rows of the identity matrix (if it is actually desired to implement in the library, rather than by the user).

@patrick-kidger
Copy link
Owner

Sorry for taking so long to get back to you; as this is a very technical + mostly internal change, it's been pushed down my priority list.

Anyway, I think we agree with each other then! As per my last message:

(Now, one could certainly make the argument that this is a worthy trade-off in the name of simplicity. And if we done this the first time around I'd agree, but as it is I'm inclined to keep things for the sake of backwards-compatibility.)

Indeed a simpler solution could have been arrived at, at the cost of an acceptable loss of generality. However breaking backward compatibility requires a very high bar, and I don't think that adjusting where we land on this trade-off meets that.

I am concious of your other work in #159. I'd definitely still be very happy to get that in if you'd be willing to update it without this PR as a dependency?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, on the basis of #168 - returning to this PR @adconner, I think I'm now open to merging this more-or-less as-is. The convincing factor for me was the difficulty of getting allow_dependent_{rows,columns} to be correct, as you highlighted a long-standing bug in the QR solver of exactly this sort.

I have one question on QR below, otherwise I like this PR as-is.

columns = operator.in_size()
return rows > columns
def assume_full_rank(self):
return True
Copy link
Owner

@patrick-kidger patrick-kidger Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one is a bit of a footgun. When using AutoLinearSolver(well_posed=None) then we'll select QR if the operator is nonsquare, but as your example in #168 demonstrates, then this may erroneously return a solution.

Unlike when we first wrote Lineax, I believe JAX now has a rank-revealing QR decomposition available via jax.lax.linalg.qr.

I think setting the return value here to False, and unconditionally using the rank-revealing implementation, should fix this however?

Copy link
Contributor Author

@adconner adconner Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the solvers except SVD which are assume_full_rank=True might erroneously return a solution in the rank deficient case. If you wish to test this use a 3x3 matrix with the bottom right 2x2 zero and try LU, or make it symmetric and try Cholesky. This is just a fact of life, if you need guaranteed errors you use SVD or pivoted qr.

In my mind, based on the solver selection in the documentation, AutoLinearSolver(well_posed=None) is supposed to mean "handle full rank non square operators", True means "handle arbitrary non square operators" and False means "handle full rank square operators". The language in the documentation is a bit indirect, but seems to support this ("ill posed" meaning "nonsquare", fine, like "over/underdetermined"). Maybe this isn't what you had in mind when you implemented this option originally but this is the actual interface lineax currently provides.

And as an aside, you definitely want to efficiently explicitly handle full rank nonsquare operators, both with solver and grad support, as this is almost the most common case: Remember than a random matrix is full rank, so also are most matrices encountered "in the wild". The ways you might encounter rank deficiency are (1) if your application really has it for some structural reason, ie all matrices in sight have the same rank which is not full, in which case you probably know or you find out during development by checking the singular values of a random example from the application, or (2) you are optimizing in the space of full rank matrices and your objective takes you close to rank deficient matrices. In case (1) you want SVD or pivoted qr, and in case (2) you probably dont want to start suddenly changing the structure of the objective function, computing minimum norm solutions as your singular values get small (not to mention changing your gradients, which are only valid in a locus of constant rank matrices, recall e.g. the first sentence of https://en.wikipedia.org/wiki/Moore-Penrose_inverse#Derivative ). In this case you probably don't want it hidden from you that you are close to dividing by zero. (As a truly tangential aside, in case (1) you have also the same concern as case (2): maybe you are optimizing in the space of constant rank r matrices and you get close to the rank <r-1 locus. You probably dont want your solver to suddenly start assuming an extra singular value is zero. Currently there is no clear way to work around this aside from forking lineax to add operator rank assumptions.)

Its also for exactly this case (full rank non square) that I contributed Normal, so that we can have many more solvers in this space, most simply Normal(Cholesky).

Aside from the problem of wanting to ignore the above use case, there is another subtlety implementing rank revealing qr only and making it assume_full_rank=False. Firstly, its relatively simple to implement it to solve full rank systems with guaranteed errors on rank deficiency if desired (at least it would only have the same problems as the current qr, see jax-ml/jax#29173 ). This would still be assume_full_rank=True, but with guaranteed checking of this assumption at run time. However, a version doing a pseudo inverse solve needs more lapack (and equivalent cuda/tpu) routines than currently provided by jax, to do the so-called "complete orthogonal factorization": see the documentation of sgelsy. The additional needed routine (tzrzf) is not provided by jax; I also mention this in the last paragraph of that jax issue. Even if you wanted to do the complete orthogonal factorization less efficiently, eg by using a second qr and not exploiting upper triangularity, I dont think the current jax interface allows this, as this would be a qr and manipulations on a submatrix of dynamic (depending on r) shape.

In any case, even if you want to add the required support to jax to implement full pseudo inverse pivoted qr in lineax, you would still want to support pivoted and ordinary qr for assumed full rank matrices, because there is a real saving on avoiding a jvp term and this is the common case (and in the pivoted qr case, you additionally save the step of forming the complete orthogonal decomposition).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the solvers except SVD which are assume_full_rank=True might erroneously return a solution in the rank deficient case. If you wish to test this use a 3x3 matrix with the bottom right 2x2 zero and try LU, or make it symmetric and try Cholesky. This is just a fact of life, if you need guaranteed errors you use SVD or pivoted qr.

Ah, bother. The fact that the underlying LU/etc routines can also erroneously return solutions was a detail that I hadn't appreciated.

When originally writing Lineax, one of the big hopes was that we could be robust about only ever (a) returning a valid solution or (b) erroring out. We reckoned this was important for good UX, since silently getting the wrong solution is behaviour that is very hard to debug. (The only exception we made to this were operator tags, but since these are an explicit opt-in they do at least become clearer for users that this is a potential spot for things to go wrong.)

It sounds like you're dashing those hopes.

In my mind, based on the solver selection in the documentation, AutoLinearSolver(well_posed=None) is supposed to mean "handle full rank non square operators", True means "handle arbitrary non square operators" and False means "handle full rank square operators". The language in the documentation is a bit indirect, but seems to support this ("ill posed" meaning "nonsquare", fine, like "over/underdetermined"). Maybe this isn't what you had in mind when you implemented this option originally but this is the actual interface lineax currently provides.

Agreed, our language isn't clear enough here. We can adjust this to be exactly what you describe here.

there is another subtlety implementing rank revealing qr only and making it assume_full_rank=False

Ah, complete brain fart with what I wrote originally. Indeed I just meant using rank-revealing QR to check for success or not; not to get the pseudoinverse solution.

Broadly, in agreement with all that you've written above.


In terms of next steps, my current thinking:

  1. Merge this PR completely as-is.

  2. Relax our API contract to allow for returning invalid solutions. (Yuck, but I don't see a way around it.)

  3. This won't be perfect, but we could introduce a post-solve safety check: if our solver declares that is is assuming full-rank, then compute ||Ax - b|| and check that this is small. Asymptotically this is O(n^2) so as the solve itself is O(n^3) then this may be acceptably small overhead, and we could anyway make it possible to disable this for those who would like to opt-out.

    This won't be perfect because e.g. [[1 0] [0 0]] @ [1 x] == [1 0] no matter the value of x, so our check may erroneously pass in some cases. But perhaps this catches common cases / is better than nothing.

    The main thing I don't like about this idea is that it introduces a tolerance that we'll have to pick carefully if we don't want spurious false positives. That alone may be enough of a reason not to do this. Maybe if we pick something ludicrously big (in the sense of 1e-1 or something) then that will at least catch a reasonable fraction of cases.

WDYT?

Copy link
Contributor Author

@adconner adconner Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relax our API contract to allow for returning invalid solutions. (Yuck, but I don't see a way around it.)

I don't think you should feel as bad about this as you do. I think users frequently actually want the matrix assumed full rank even when its condition number is large. How often do you really want np.where(np.abs(t) < rcond, 0.0, 1/t) rather than 1/t?

This won't be perfect, but we could introduce a post-solve safety check: if our solver declares that is is assuming full-rank, then compute ||Ax - b|| and check that this is small.

Solutions, when returned, even if "invalid" for some notion of pseudoinverse with given rcond, will be valid in the sense that ||Ax - b|| will be small (I believe the only failure of this would be catastrophic cancellation in some of the sums in the resulting dot products). Your test will basically never fail.

For these reasons I don't recommend any post solve check.

…full_rank

The two functions allow_dependent_{rows,columns} together did the job of
answering if the solver accepts full rank matrices for the purposes of
the jvp. Allowing them to be implemented separately created some issues:
1) Invalid states were representable. Eg. What does it mean that
   dependent columns are allowed for square matrices if dependent rows
   are not? What does it mean that dependent rows are not allowed for
   matrices with more rows than columns?
2) As the functions accept operator as input, a custom solver could in
   principle decide its answer based on operator's dynamic value rather
   than only jax compilation static information regarding it, as in all
   the lineax defined solvers. This would prevent jax compilation and
   jit.

Both issues are addressed by asking the solver to report only if it
assumes the input is numerically full rank. If this assumption is
exactly violated, its behavior is allowed to be undefined, and is
allowed to error, produce NaN values, and produce invalid values.
@patrick-kidger patrick-kidger changed the base branch from main to dev December 5, 2025 00:10
@patrick-kidger patrick-kidger merged commit 8e91b44 into patrick-kidger:dev Dec 5, 2025
@patrick-kidger
Copy link
Owner

Okay, I am very belatedly getting back around to this PR now. (Personal life issues caught up for a while.)

I've just merged this into our dev branch. I really appreciate your patience, I think this is an improvement really worth having.

Also CC @johannahaffner who I think is coordinating the next release – this is a breaking change so we'll bump the version appropriately. :)

@patrick-kidger patrick-kidger mentioned this pull request Jan 26, 2026
patrick-kidger pushed a commit that referenced this pull request Jan 27, 2026
…full_rank (#158)

The two functions allow_dependent_{rows,columns} together did the job of
answering if the solver accepts full rank matrices for the purposes of
the jvp. Allowing them to be implemented separately created some issues:
1) Invalid states were representable. Eg. What does it mean that
   dependent columns are allowed for square matrices if dependent rows
   are not? What does it mean that dependent rows are not allowed for
   matrices with more rows than columns?
2) As the functions accept operator as input, a custom solver could in
   principle decide its answer based on operator's dynamic value rather
   than only jax compilation static information regarding it, as in all
   the lineax defined solvers. This would prevent jax compilation and
   jit.

Both issues are addressed by asking the solver to report only if it
assumes the input is numerically full rank. If this assumption is
exactly violated, its behavior is allowed to be undefined, and is
allowed to error, produce NaN values, and produce invalid values.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants