From 87ec9407f9282f8dec04d2e1191c1d779c5c98a2 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 12 Jan 2026 18:13:51 +0100 Subject: [PATCH] revert 648 --- ext/NNlibCUDAExt/scatter.jl | 34 ++++++++++++++++++++++------------ test/ext_cuda/scatter.jl | 4 ++-- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 9b323d504..cec16e34d 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,5 +1,15 @@ # supported op: +, -, *, /, max, min, &, |, mean -import CUDA.CUSPARSE: AbstractCuSparseArray + +## TODO support sparse dst/src/idx +## See issue https://github.com/FluxML/NNlib.jl/issues/647 +# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector +# const AnyCuSparseMatrix{Tv,Ti} = Union{ +# AbstractCuSparseMatrix{Tv,Ti}, +# CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix +# CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX +# CUDA.CuSparseMatrixCOO{Tv,Ti}, +# } +# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}} function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -45,9 +55,9 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn end -function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{AnyCuArray,AbstractCuSparseArray}, - idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP +function NNlib.scatter!(op::OP, dst::AnyCuArray, + src::AnyCuArray, + idx::AnyCuArray) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -67,9 +77,9 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, return dst end -function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{AnyCuArray,AbstractCuSparseArray}, - idx::Union{AnyCuArray,AbstractCuSparseArray}) +function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, + src::AnyCuArray, + idx::AnyCuArray) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -166,21 +176,21 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::Union{AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray}, - idx::Union{AnyCuArray{Tidx,Nidx},AbstractCuSparseArray}) where {Tsrc,Tidx,Nsrc,Nidx} - dims = Nsrc - Nidx + src::AnyCuArray, + idx::AnyCuArray) + dims = ndims(src) - ndims(idx) Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx) rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) if dims == 0 max_idx = length(idx) - args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc + args = op, Δsrc, src, idx, rev_idx, max_idx, eltype(src) else pre_cart_idx = CartesianIndices(axes(src)[1:dims]) max_dims_idx = length(pre_cart_idx) max_idx = max_dims_idx * length(idx) - args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc + args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, eltype(src) end kernel = @cuda launch=false ∇scatter_src_kernel!(args...) diff --git a/test/ext_cuda/scatter.jl b/test/ext_cuda/scatter.jl index 10b660e40..a4977f285 100644 --- a/test/ext_cuda/scatter.jl +++ b/test/ext_cuda/scatter.jl @@ -21,7 +21,7 @@ idxs = [ (3,) (5,) (5,) (3,)])), # CartesianIndex index ] -types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}] +types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] @testset "scatter" begin @@ -70,7 +70,7 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuS end - for T = [CuArray{Float32}, CuArray{Float64}, Sparse, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}] + for T = [CuArray{Float32}, CuArray{Float64}] @testset "$(T)" begin @testset "*" begin for idx = idxs, dims = [0, 1]