diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index cec16e34..04a63805 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -58,6 +58,7 @@ end function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP + isempty(idx) && return dst dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) diff --git a/src/gather.jl b/src/gather.jl index d75f89a2..7997f878 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -110,6 +110,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) end function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) + isempty(dst) && return dst n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims] max_dims_idx = prod(dims) diff --git a/test/ext_cuda/gather.jl b/test/ext_cuda/gather.jl index 9fa30efa..1c2b1384 100644 --- a/test/ext_cuda/gather.jl +++ b/test/ext_cuda/gather.jl @@ -103,4 +103,10 @@ outv2 = NNlib.gather(v2, i) @test collect(outv2) == NNlib.gather(collect(v2), collect(i)) end + + # Zero-sized + x = CT([1,2,3]) + i = CT(Int[]) + y = NNlib.gather(x, i) + @test isempty(y) end