Skip to content

Commit f696793

Browse files
revert 648 (#670)
1 parent c15dd3b commit f696793

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

ext/NNlibCUDAExt/scatter.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# supported op: +, -, *, /, max, min, &, |, mean
2-
import CUDA.CUSPARSE: AbstractCuSparseArray
2+
3+
## TODO support sparse dst/src/idx
4+
## See issue https://github.com/FluxML/NNlib.jl/issues/647
5+
# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector
6+
# const AnyCuSparseMatrix{Tv,Ti} = Union{
7+
# AbstractCuSparseMatrix{Tv,Ti},
8+
# CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix
9+
# CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX
10+
# CUDA.CuSparseMatrixCOO{Tv,Ti},
11+
# }
12+
# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}}
313

414
function scatter_kernel!(op::OP, dst, src, idx) where OP
515
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@@ -45,9 +55,9 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn
4555
end
4656

4757

48-
function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray},
49-
src::Union{AnyCuArray,AbstractCuSparseArray},
50-
idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP
58+
function NNlib.scatter!(op::OP, dst::AnyCuArray,
59+
src::AnyCuArray,
60+
idx::AnyCuArray) where OP
5161
dims = NNlib.scatter_dims(dst, src, idx)
5262
args = if dims == 0
5363
max_idx = length(idx)
@@ -67,9 +77,9 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray},
6777
return dst
6878
end
6979

70-
function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray},
71-
src::Union{AnyCuArray,AbstractCuSparseArray},
72-
idx::Union{AnyCuArray,AbstractCuSparseArray})
80+
function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray,
81+
src::AnyCuArray,
82+
idx::AnyCuArray)
7383
Ns = NNlib.scatter!(+, zero(dst), one.(src), idx)
7484
dst_ = NNlib.scatter!(+, zero(dst), src, idx)
7585
dst .+= NNlib.safe_div.(dst_, Ns)
@@ -166,21 +176,21 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
166176
end
167177

168178
function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
169-
src::Union{AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray},
170-
idx::Union{AnyCuArray{Tidx,Nidx},AbstractCuSparseArray}) where {Tsrc,Tidx,Nsrc,Nidx}
171-
dims = Nsrc - Nidx
179+
src::AnyCuArray,
180+
idx::AnyCuArray)
181+
dims = ndims(src) - ndims(idx)
172182
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
173183
rev_idx = NNlib.reverse_indices(idx)
174184
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))
175185

176186
if dims == 0
177187
max_idx = length(idx)
178-
args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc
188+
args = op, Δsrc, src, idx, rev_idx, max_idx, eltype(src)
179189
else
180190
pre_cart_idx = CartesianIndices(axes(src)[1:dims])
181191
max_dims_idx = length(pre_cart_idx)
182192
max_idx = max_dims_idx * length(idx)
183-
args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc
193+
args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, eltype(src)
184194
end
185195

186196
kernel = @cuda launch=false ∇scatter_src_kernel!(args...)

test/ext_cuda/scatter.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ idxs = [
2121
(3,) (5,) (5,) (3,)])), # CartesianIndex index
2222
]
2323

24-
types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}]
24+
types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]
2525

2626

2727
@testset "scatter" begin
@@ -70,7 +70,7 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuS
7070
end
7171

7272

73-
for T = [CuArray{Float32}, CuArray{Float64}, Sparse, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}]
73+
for T = [CuArray{Float32}, CuArray{Float64}]
7474
@testset "$(T)" begin
7575
@testset "*" begin
7676
for idx = idxs, dims = [0, 1]

0 commit comments

Comments
 (0)