diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 04a63805..f9e3b6f1 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,4 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean +import CUDA.CUSPARSE: AbstractCuSparseArray ## TODO support sparse dst/src/idx ## See issue https://github.com/FluxML/NNlib.jl/issues/647 @@ -54,10 +55,9 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn return nothing end - -function NNlib.scatter!(op::OP, dst::AnyCuArray, - src::AnyCuArray, - idx::AnyCuArray) where OP +function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP isempty(idx) && return dst dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 @@ -78,9 +78,9 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, return dst end -function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, - src::AnyCuArray, - idx::AnyCuArray) +function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -177,8 +177,8 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AnyCuArray, - idx::AnyCuArray) + src::Union{AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray}, + idx::Union{AnyCuArray{Tidx,Nidx},AbstractCuSparseArray}) where {Tsrc,Tidx,Nsrc,Nidx} dims = ndims(src) - ndims(idx) Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx) diff --git a/test/ext_cuda/scatter.jl b/test/ext_cuda/scatter.jl index a4977f28..421099aa 100644 --- a/test/ext_cuda/scatter.jl +++ b/test/ext_cuda/scatter.jl @@ -1,13 +1,13 @@ dsts = Dict( 0 => cu([3, 4, 5, 6, 7]), 1 => cu([3 3 4 4 5; - 5 5 6 6 7]), + 5 5 6 6 7]), ) srcs = Dict( (0, true) => cu(ones(Int, 3, 4)), (0, false) => cu(ones(Int, 3) * collect(1:4)'), (1, true) => cu(ones(Int, 2, 3, 4)), - (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1, 3, 4)), ) idxs = [ cu([1 2 3 4; @@ -21,7 +21,7 @@ idxs = [ (3,) (5,) (5,) (3,)])), # CartesianIndex index ] -types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] +types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}] @testset "scatter" begin @@ -70,6 +70,108 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] end + # Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays. + dsts_sp = Dict( + 0 => cu(sparse([3, 4, 5, 6, 7])), + 1 => cu(sparse([3 3 4 4 5; + 5 5 6 6 7])), + ) + srcs_sp = Dict( + (0, true) => cu(sparse(ones(Int, 3, 4))), + (0, false) => cu(sparse(ones(Int, 3) * collect(1:4)')), + # No sparse equivalent for 3D arrays + ) + types_sp = [ + CuSparseMatrixCSC{Int32}, CuSparseMatrixCSC{Int64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSC{Float64}, + CuSparseMatrixCSR{Int32}, CuSparseMatrixCSR{Int64}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCSR{Float64}, + CuSparseMatrixCOO{Int32}, CuSparseMatrixCOO{Int64}, CuSparseMatrixCOO{Float32}, CuSparseMatrixCOO{Float64} + ] + + @testset "scatter sparse-specialized" begin + for T = types_sp + @testset "$(T)" begin + @testset "+" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src/dst + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "-" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "max" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "min" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + end + end + + # Sparse-specialized for operations not tested on eltype <: Integer + for T = [CuSparseMatrixCSC{Float32}, CuSparseMatrixCSC{Float64}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCSR{Float64}, CuSparseMatrixCOO{Float32}, CuSparseMatrixCOO{Float64}] + @testset "$(T)" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src/dst + @testset "*" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "/" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "mean" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + end + end + end + + + for T = [CuArray{Float32}, CuArray{Float64}] @testset "$(T)" begin @testset "*" begin