Skip to content

Extending scatter! to work with CUDA sparse arrays#648

Merged
CarloLucibello merged 5 commits intoFluxML:masterfrom
alonsoC1s:master
Dec 21, 2025
Merged

Extending scatter! to work with CUDA sparse arrays#648
CarloLucibello merged 5 commits intoFluxML:masterfrom
alonsoC1s:master

Conversation

@alonsoC1s
Copy link
Contributor

Aims to fix #647 by extending the signature of scatter! to work with AbstractCuSparseArray, a CUDA array type notably excluded by the original method. With the proposed patch, calling scatter! with sparse arrays from CUDA.CUSPARSE will correctly call the CUDA-specialized method instead of calling the generic CPU method, which triggered a scalar indexing error. In my testing the existing CUDA kernels work perfectly fine with CuSparseArrayCSC.

The proposed implementation, perhaps inelegantly, just expands the types in the signature with Union{...}. I am open to discussing more beautiful ways of implementing this. Ideally, AbstractCuSparseArray would be a subtype of AnyCuArray.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@alonsoC1s
Copy link
Contributor Author

The integration test with Lux fail because always_inliner! is not defined, and what look like Enzyme internal errors. Not sure if this is unrelated

@mcabbott
Copy link
Member

Seems fine. Is it possible to add a test on CI somehow, perhaps in https://github.com/FluxML/NNlib.jl/blob/master/test/ext_cuda/scatter.jl ?

@alonsoC1s
Copy link
Contributor Author

alonsoC1s commented Nov 25, 2025

I added the sparse matrix varieties to the list of array types that are automatically tested

@mcabbott Any thoughts on making the implementation less ugly? Should I open an issue on CUDA.jl suggesting making CUSPARSE arrays subtypes of AnyCuArray?

@CarloLucibello CarloLucibello merged commit c15dd3b into FluxML:master Dec 21, 2025
9 of 11 checks passed
@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 21, 2025

@alonsoC1s
Copy link
Contributor Author

alonsoC1s commented Jan 27, 2026

@CarloLucibello I'll take a look at the failing tests. I hope I can get easier access to a GPU to test it on this week

Turns out I don't need a GPU to test. I'm working on the fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Specialized scatter! dispatch for CUSPARSE.CuSparseMatrix

3 participants