Skip to content

Commit 9bc217c

Browse files
committed
Make recursive_add/accumulate more recursive
1 parent 6a19be2 commit 9bc217c

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/compiler.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6562,7 +6562,13 @@ end
65626562
Base.@_inline_meta
65636563
prev = getfield(x, i)
65646564
next = getfield(y, i)
6565-
recursive_add(prev, next, f, forcelhs)
6565+
ST = Core.Typeof(prev)
6566+
if !mutable_register(ST)
6567+
recursive_add(prev, next, f, forcelhs)
6568+
elseif !(ST <: Integer)
6569+
recursive_accumulate(prev, next, f)
6570+
prev
6571+
end
65666572
end)
65676573
end
65686574

@@ -6591,18 +6597,19 @@ end
65916597

65926598
# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y)
65936599
@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F}
6594-
if !mutable_register(T)
6595-
for I in eachindex(x)
6596-
prev = x[I]
6600+
for I in eachindex(x, y)
6601+
if !mutable_register(T)
65976602
@inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register)
6603+
elseif !(T <: Integer)
6604+
recursive_accumulate((@inbounds x[I]), (@inbounds y[I]), f)
65986605
end
65996606
end
66006607
end
66016608

66026609

66036610
# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y)
66046611
@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F=identity) where {F}
6605-
recursive_accumulate(x.contents, y.contents, seen, f)
6612+
recursive_accumulate(x.contents, y.contents, f)
66066613
end
66076614

66086615
@inline function recursive_accumulate(x::T, y::T, f::F=identity) where {T, F}
@@ -6613,12 +6620,14 @@ end
66136620
for i in 1:nf
66146621
if isdefined(x, i)
66156622
xi = getfield(x, i)
6623+
yi = getfield(y, i)
66166624
ST = Core.Typeof(xi)
66176625
if !mutable_register(ST)
66186626
@assert ismutable(x)
6619-
yi = getfield(y, i)
66206627
nexti = recursive_add(xi, yi, f, mutable_register)
66216628
setfield!(x, i, nexti)
6629+
elseif !(ST <: Integer)
6630+
recursive_accumulate(xi, yi, f)
66226631
end
66236632
end
66246633
end

0 commit comments

Comments
 (0)