Skip to content

Sparse operators in Lineax: CSR/CSC and COO format#178

Closed
johannahaffner wants to merge 2 commits intopatrick-kidger:mainfrom
johannahaffner:sparse-operators
Closed

Sparse operators in Lineax: CSR/CSC and COO format#178
johannahaffner wants to merge 2 commits intopatrick-kidger:mainfrom
johannahaffner:sparse-operators

Conversation

@johannahaffner
Copy link
Collaborator

These operators wrap the respective JAX operators. They are currently limited to 2D matrices, which is what JAX supports as well. As discussed in #24 (comment), this introduces a dependency on jax.experimental.sparse, which while part of the experimental module does seem to be actively maintained. This is mostly necessary so that we can use their implementation of __matmul__.

It turns out I had the COO version already, I had forgotten about that 😃

As a drive-by change, this adds extra fixtures for cache clearing, which were necessary to get tests to pass on the Linux machine I was using while coding this up.

Some loose strings to tie up:

  • When not type-checking, then the batched sparse format (BCOO, BCSR) may be substituted for their bare-bones, non-batched counterparts. I have not checked whether these pass all the tests, and whether their special batching rules would interfere with what we're already doing. Vmapping works without using these formats.
  • I'll expand the documentation of both, to also include the limitation to 2D matrices explicitly and I'll add an example for their usage.

@johannahaffner
Copy link
Collaborator Author

Test failures are the same as in #176, so perhaps we'll get that in first and then I'll rebase? We only get the condition number error in LSMR in 0.7.1, documented in #172.

@johannahaffner
Copy link
Collaborator Author

I've updated the documentation. I think some sparsity-specific tests would be good, I'm not yet convinced that we're catching everything that we should be catching.

@johannahaffner
Copy link
Collaborator Author

Cross-reference: discussion about sparse data types in jax-ml/jax#33514

@patrick-kidger
Copy link
Owner

Whilst I love the idea of sparse operators in JAX, I'm a bit antsy about depending on the unmaintained jax.experimental.sparse.

At least for the representation of the operator I'm guessing that we could just store coordinates/etc as JAX arrays ourselves, without using e.g. jsp.COO objects. The only bit I'm not sure about is whether we could easily call into sparse kernels, though. (Which might only be available through the unmaintained namespace.)

@johannahaffner
Copy link
Collaborator Author

Yes, that is pretty much where I land on this as well. I don't think this should be part of the next release, at the very least a lot of sparsity-specific testing needs to be done first - it could be that we circumvent a lot of edge cases by virtue of testing these with solvers that never actually need to interact with the sparse structures themselves, and either work with the densely materialised matrix or matrix-vector products.

Not relying on jsp.COO etc. at all requires implementing our own __matmul__ for these structures, and that is a bit more involved. The BCSR formats also have other special features to sum operators with the same shape but different sparsity structures - these are useful when you need to add a diagonal for example, for regularisation. These are all quite good, and rewriting this would be a big technical task.

FWIW I would not call jax.experimental.sparse unmaintained, it is in maintenance mode but not actively developed any further. It does seem to be quite stable and has been around for quite a while now.

TLDR - let's not merge this now, leave it here and / or revisit it for BCSR later. Right now the only use case is CuDSS anyway, and the operator is basically a vessel to communicate to CuDSS what the values, offsets and columns are. We can do all of this in Spineax.

@johannahaffner
Copy link
Collaborator Author

Closing this for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants