Skip to content

Commit 51c745f

Browse files
committed
Unthunk each element in ∇eachslice
1 parent e055009 commit 51c745f

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

src/rulesets/Base/indexing.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,7 @@ end
262262
# Using Val(dim) here is worth a factor of 2 in this, on Julia 1.8-
263263
# @btime rrule(eachcol, $([1 2; 3 4]))[2]($([[10, 20], [30, 40]]))
264264
function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
265-
dys = unthunk(dys_raw)
266-
i1 = findfirst(dy -> dy isa AbstractArray, dys)
267-
if i1 === nothing # all slices are Zero!
268-
return _zero_fill!(similar(x, float(eltype(x)), axes(x)))
269-
end
265+
dys = unthunk.(unthunk(dys_raw))
270266
T = Base.promote_eltype(dys...)
271267
# The whole point of this gradient is that we can allocate one `dx` array:
272268
dx = similar(x, T, axes(x))

0 commit comments

Comments
 (0)