Merged
Conversation
…full_rank (#158) The two functions allow_dependent_{rows,columns} together did the job of answering if the solver accepts full rank matrices for the purposes of the jvp. Allowing them to be implemented separately created some issues: 1) Invalid states were representable. Eg. What does it mean that dependent columns are allowed for square matrices if dependent rows are not? What does it mean that dependent rows are not allowed for matrices with more rows than columns? 2) As the functions accept operator as input, a custom solver could in principle decide its answer based on operator's dynamic value rather than only jax compilation static information regarding it, as in all the lineax defined solvers. This would prevent jax compilation and jit. Both issues are addressed by asking the solver to report only if it assumes the input is numerically full rank. If this assumption is exactly violated, its behavior is allowed to be undefined, and is allowed to error, produce NaN values, and produce invalid values.
Decrease default value to prevent overflow in 32-bit.
There seem to be some spurious downstream failures in Diffrax with JAX 0.8.2 otherwise. Probably JAX has started promoting these to tracers on some unusual codepath.
…190) * Deprecate NormalCG helper function - Added **kwargs support to NormalCG function signature - Added DeprecationWarning directing users to use lx.Normal(lx.CG(...)) instead - Added docstring with deprecation notice - Imported warnings module * Align deprecation warning with Diffrax patterns - Use "in favour of" phrasing (consistent with Diffrax) - Add backticks around code examples - Specify "in some future version of Lineax" - Keep DeprecationWarning category (more semantically correct) - Update both warning message and docstring * fix precommit
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Phew, it's been a while since a release! It's about time we got something out.
Breaking changes
AbstractLinearSolver.allow_dependent_{rows, columns}have been removed in favour of the simplerAbstractLinearSolver.assume_full_rank. (Thanks @adconner! In particular for being so patient with me and my bugs 😅 Replace allow_dependent_rows and allow_dependent_columns with assume_full_rank #158)Features
lineax.LSMRsolver. This is a solver that will return the pseudoinverse (likelineax.SVD) solution, and handles nonsquare/singular matrices. In addition, it is an iterative solver. (Thanks @f0uriest @PTNobel @healeyq3 @johannahaffner! Implementation of LSMR for iterative least squares. #86)lineax.Normalsolver. This wraps another existing solver so that it operates via the normal equations. (Thanks @adconner! Implement Normal, a solver applying an inner solver to the normal equations #159)Compatibility
LU.compute#187, fix TracerBoolConversionError by using lax.cond instead of if/else, for jax>=0.8.2 bug #188)Bugfixes
Performance