11# supported op: +, -, *, /, max, min, &, |, mean
2- import CUDA. CUSPARSE: AbstractCuSparseArray
2+
3+ # # TODO support sparse dst/src/idx
4+ # # See issue https://github.com/FluxML/NNlib.jl/issues/647
5+ # import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector
6+ # const AnyCuSparseMatrix{Tv,Ti} = Union{
7+ # AbstractCuSparseMatrix{Tv,Ti},
8+ # CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix
9+ # CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX
10+ # CUDA.CuSparseMatrixCOO{Tv,Ti},
11+ # }
12+ # const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}}
313
414function scatter_kernel!(op:: OP , dst, src, idx) where OP
515 index = threadIdx(). x + (blockIdx(). x - 1 ) * blockDim(). x
@@ -45,9 +55,9 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn
4555end
4656
4757
48- function NNlib. scatter!(op:: OP , dst:: Union{ AnyCuArray,AbstractCuSparseArray} ,
49- src:: Union{ AnyCuArray,AbstractCuSparseArray} ,
50- idx:: Union{ AnyCuArray,AbstractCuSparseArray} ) where OP
58+ function NNlib. scatter!(op:: OP , dst:: AnyCuArray ,
59+ src:: AnyCuArray ,
60+ idx:: AnyCuArray ) where OP
5161 dims = NNlib. scatter_dims(dst, src, idx)
5262 args = if dims == 0
5363 max_idx = length(idx)
@@ -67,9 +77,9 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray},
6777 return dst
6878end
6979
70- function NNlib. scatter!(op:: typeof (mean), dst:: Union{ AnyCuArray,AbstractCuSparseArray} ,
71- src:: Union{ AnyCuArray,AbstractCuSparseArray} ,
72- idx:: Union{ AnyCuArray,AbstractCuSparseArray} )
80+ function NNlib. scatter!(op:: typeof (mean), dst:: AnyCuArray ,
81+ src:: AnyCuArray ,
82+ idx:: AnyCuArray )
7383 Ns = NNlib. scatter!(+ , zero(dst), one.(src), idx)
7484 dst_ = NNlib. scatter!(+ , zero(dst), src, idx)
7585 dst .+ = NNlib. safe_div.(dst_, Ns)
@@ -166,21 +176,21 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
166176end
167177
168178function NNlib.∇scatter_src(op:: Union{typeof(*),typeof(/)} , Δ, dst,
169- src:: Union{ AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray} ,
170- idx:: Union{ AnyCuArray{Tidx,Nidx},AbstractCuSparseArray} ) where {Tsrc,Tidx,Nsrc,Nidx}
171- dims = Nsrc - Nidx
179+ src:: AnyCuArray ,
180+ idx:: AnyCuArray )
181+ dims = ndims(src) - ndims(idx)
172182 Δsrc = NNlib. modify_src(op, NNlib. gather(Δ, idx), src)
173183 rev_idx = NNlib. reverse_indices(idx)
174184 rev_idx = CuArray(map(CUDA. cudaconvert, rev_idx))
175185
176186 if dims == 0
177187 max_idx = length(idx)
178- args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc
188+ args = op, Δsrc, src, idx, rev_idx, max_idx, eltype(src)
179189 else
180190 pre_cart_idx = CartesianIndices(axes(src)[1 : dims])
181191 max_dims_idx = length(pre_cart_idx)
182192 max_idx = max_dims_idx * length(idx)
183- args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc
193+ args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, eltype(src)
184194 end
185195
186196 kernel = @cuda launch= false ∇scatter_src_kernel!(args... )
0 commit comments