Skip to content

Commit bfb9a4b

Browse files
fix: gather and scatter on empty arrays (#668)
* fix: Zero case in gather * test: Test zero-sized array * fix: scatter! skip empty --------- Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
1 parent f696793 commit bfb9a4b

File tree

3 files changed

+8
-0
lines changed

3 files changed

+8
-0
lines changed

ext/NNlibCUDAExt/scatter.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ end
5858
function NNlib.scatter!(op::OP, dst::AnyCuArray,
5959
src::AnyCuArray,
6060
idx::AnyCuArray) where OP
61+
isempty(idx) && return dst
6162
dims = NNlib.scatter_dims(dst, src, idx)
6263
args = if dims == 0
6364
max_idx = length(idx)

src/gather.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
110110
end
111111

112112
function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray)
113+
isempty(dst) && return dst
113114
n_dims = scatter_dims(src, dst, idx)
114115
dims = size(src)[1:n_dims]
115116
max_dims_idx = prod(dims)

test/ext_cuda/gather.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,10 @@
103103
outv2 = NNlib.gather(v2, i)
104104
@test collect(outv2) == NNlib.gather(collect(v2), collect(i))
105105
end
106+
107+
# Zero-sized
108+
x = CT([1,2,3])
109+
i = CT(Int[])
110+
y = NNlib.gather(x, i)
111+
@test isempty(y)
106112
end

0 commit comments

Comments
 (0)