Skip to content

Decrease default value to prevent overflow in 32-bit.#176

Merged
johannahaffner merged 10 commits intopatrick-kidger:devfrom
johannahaffner:lsmr-init-overflow
Dec 13, 2025
Merged

Decrease default value to prevent overflow in 32-bit.#176
johannahaffner merged 10 commits intopatrick-kidger:devfrom
johannahaffner:lsmr-init-overflow

Conversation

@johannahaffner
Copy link
Collaborator

@johannahaffner johannahaffner changed the base branch from main to dev November 14, 2025 08:09
@johannahaffner
Copy link
Collaborator Author

Created a dev branch.

@johannahaffner
Copy link
Collaborator Author

We're failing pre-commit checks due to a new version of pyright, I'll fix this when I have time.

@patrick-kidger
Copy link
Owner

Should this be something like jnp.finfo(jnp.float32).max? If nothing else to document the choice here.

@johannahaffner
Copy link
Collaborator Author

Sure, I can add a comment! A jnp anything in a default value of a public function causes JAX to initialise upon import of the library, I'm not exactly sure if this would be triggered here as well and will check if that would be the case.

@patrick-kidger
Copy link
Owner

This one should be safe from JAX initialisation, I believe. :)

@johannahaffner
Copy link
Collaborator Author

Indeed it is. To address the pyright complaints, I've added ignore statements wherever I could not find an ergonomic way to use astype or cast.

Regarding our two_norm, it looks like pyright now disregards the Scalar return type and looks at what the function may return on all code paths, which leads to a union that includes a None, due to an assert False branch. Is that what pyright should be doing, seeing as assert False, if ever hit, cannot lead to a type-related bug? It probably has no concept of what happens in a branch other than checking what it may return (?), but I'm not sure why it only started picking up on this starting from the newest version.

@johannahaffner
Copy link
Collaborator Author

We're getting the test failure reported here, which seems to be due to numerical bad luck: #172

@johannahaffner
Copy link
Collaborator Author

...and another one in GMRES, although I don't see how that one could be due to the changes I made.

@johannahaffner
Copy link
Collaborator Author

@PTNobel could you take a look at this? We're getting test failures for LSMR from tests that involve the computation of a JVP. The generated matrices should have permissible condition numbers: we've previously set a cutoff for the condition number at 1e3, which is the default we use for solvers that do not normalise the equations, and conlim is 1e8 for LSMR.

Right now I'm decreasing condition numbers to see at which values these tests will pass, but that is a little empirical and unsatisfactory. Besides, if there is a numerical problem underneath this that we should catch, then this is maybe a good canary in the coal mine.

@PTNobel
Copy link
Contributor

PTNobel commented Nov 26, 2025

Happy to look at it. It is Thanksgiving here in the US, so I'm a bit busy the next few days. Feel free to ping me if it haven't finished in a week.

tests/helpers.py Outdated
elif isinstance(solver, lx.GMRES):
cond_cutoff = 900
elif isinstance(solver, lx.LSMR):
cond_cutoff = 800

This comment was marked as outdated.

This comment was marked as outdated.

@johannahaffner
Copy link
Collaborator Author

@PTNobel here is the requested ping ☺️

Hope you had a wonderful thanksgiving with your family!

@johannahaffner
Copy link
Collaborator Author

johannahaffner commented Dec 13, 2025

The LSMR test failures in vmap_jvp are really weird. I cannot reproduce them run-by-run locally - tests will sometimes fail, and sometimes won't. By default the seed is set using the built-in random, so it will be different each time (we do not specify EQX_GETKEY_SEED).

So everything is a bit heuristic here, and I only have some observations:

  • Anecdotally, failures seem to happen more often in CI (on Linux) than they do on my Mac
  • Printing + checking the result does seem to "improve" the success rate - this changes something along the chain of custody and I've seen before that this can affect numerical weirdness
  • It does not seem to come from the backward pass after all - at least things pass more often if throw=False is passed to the wrapped linear_solve in the body of the test definition

What I think is going on here is that some unlucky random streams cause things to fail here. Given that we do guard against high condition numbers, this is potentially a serious issue affecting the reliability of lx.LSMR, which seems to fail for condition numbers below the stated/promised conlim.


As a practical next step, I've reset tests/helpers to what we currently have in dev. This reduces the changes made in this PR down to prevention of overflow due to hyperparameter initialisation when using 32-bit. Since LSMR is already in main, this is a change worth making - and I think we should merge this and fix the JVP issues for LSMR in a different PR.
This then means that ahead of the next release, LSMR will require a second look. (Not sure when I can get around to this, someone taking this on would be greatly appreciated.)

@johannahaffner
Copy link
Collaborator Author

johannahaffner commented Dec 13, 2025

I also note that the CI passed in #182. While the initialisation of the max_steps value in LSMR should not contribute to these test failures, the altered value for minrbar could do so.

Comment on lines 124 to 127
if jnp.issubdtype(dtype, jnp.complexfloating):
real_dtype = jnp.finfo(dtype).dtype
else:
real_dtype = dtype
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a complex_to_real_dtype function that can be used in place of this. (And I think the if statement here is unnecessary, the jnp.finfo trick should work for real dtypes too.)

@johannahaffner
Copy link
Collaborator Author

Just crossed my mind - do you think conlim should potentially depend on the dtype used?

@johannahaffner johannahaffner merged commit dddc513 into patrick-kidger:dev Dec 13, 2025
1 check passed
@johannahaffner
Copy link
Collaborator Author

Just noting that the last CI run actually passed :D

@johannahaffner johannahaffner deleted the lsmr-init-overflow branch December 13, 2025 16:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LSMR init overflow

3 participants