Skip to content

Commit 5b61724

Browse files
authored
Merge pull request #1491 from lxvm/zerod
Preserve 0d arrays in `Zygote.accum`
2 parents 392a1f9 + 65a32f5 commit 5b61724

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

src/lib/lib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ accum(x, y) =
2222
accum(x, y, zs...) = accum(accum(x, y), zs...)
2323

2424
accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...)
25-
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)
25+
accum(x::AbstractArray, ys::AbstractArray...) = Base.broadcast_preserving_zero_d(accum, x, ys...)
2626

2727
@generated function accum(x::NamedTuple, y::NamedTuple)
2828
# assumes that y has no keys apart from those also in x

test/lib/lib.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
t2 = (a=1, b=2)
55
@test Zygote.accum(t1, t2) == (a = 2, b = 4, c = 3)
66
@test_throws ArgumentError Zygote.accum(t2, t1)
7+
@test Zygote.accum(fill(0.0), fill(0.0)) == fill(0.0)
78
end
89
end

0 commit comments

Comments
 (0)