Skip to content

Commit 4fbdc47

Browse files
committed
Reorder methods for easier reading
1 parent bd1cca0 commit 4fbdc47

File tree

1 file changed

+143
-143
lines changed

1 file changed

+143
-143
lines changed

src/make_zero.jl

Lines changed: 143 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,3 @@
1-
const _RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
2-
3-
@inline function EnzymeCore.make_zero(prev::FT) where {FT<:_RealOrComplexFloat}
4-
return Base.zero(prev)::FT
5-
end
6-
7-
@inline function EnzymeCore.make_zero(
8-
::Type{FT},
9-
@nospecialize(seen::IdDict),
10-
prev::FT,
11-
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
12-
) where {FT<:_RealOrComplexFloat,copy_if_inactive}
13-
return EnzymeCore.make_zero(prev)::FT
14-
end
15-
16-
@inline function EnzymeCore.make_zero(prev::Array{FT,N}) where {FT<:_RealOrComplexFloat,N}
17-
# convert because Base.zero may return different eltype when FT is not concrete
18-
return convert(Array{FT,N}, Base.zero(prev))::Array{FT,N}
19-
end
20-
21-
@inline function EnzymeCore.make_zero(
22-
::Type{Array{FT,N}},
23-
seen::IdDict,
24-
prev::Array{FT,N},
25-
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
26-
) where {FT<:_RealOrComplexFloat,N,copy_if_inactive}
27-
if haskey(seen, prev)
28-
return seen[prev]::Array{FT,N}
29-
end
30-
newa = EnzymeCore.make_zero(prev)
31-
seen[prev] = newa
32-
return newa::Array{FT,N}
33-
end
34-
35-
@inline function EnzymeCore.make_zero(
36-
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
37-
) where {RT,copy_if_inactive}
38-
isleaftype(_) = false
39-
isleaftype(::Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}}) = true
40-
f(p) = EnzymeCore.make_zero(Core.Typeof(p), seen, p, Val(copy_if_inactive))
41-
return recursive_map(RT, f, seen, (prev,), Val(copy_if_inactive), isleaftype)::RT
42-
end
43-
441
recursive_map(f::F, xs::T...) where {F,T} = recursive_map(T, f, IdDict(), xs)::T
452

463
@inline function recursive_map(
@@ -59,24 +16,6 @@ recursive_map(f::F, xs::T...) where {F,T} = recursive_map(T, f, IdDict(), xs)::T
5916
return _recursive_map(RT, f, seen, xs, Val(copy_if_inactive), isleaftype)::RT
6017
end
6118

62-
@inline function _recursive_map(
63-
::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args...
64-
) where {RT<:Array,F,N}
65-
if haskey(seen, xs)
66-
return seen[xs]::RT
67-
end
68-
y = RT(undef, size(first(xs)))
69-
seen[xs] = y
70-
for I in eachindex(xs...)
71-
if all(x -> isassigned(x, I), xs)
72-
xIs = ntuple(j -> xs[j][I], N)
73-
ST = Core.Typeof(first(xIs))
74-
@inbounds y[I] = recursive_map(ST, f, seen, xIs, args...)
75-
end
76-
end
77-
return y
78-
end
79-
8019
@inline function _recursive_map(
8120
::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args...
8221
) where {RT,F,N}
@@ -127,6 +66,103 @@ end
12766
return y
12867
end
12968

69+
@inline function _recursive_map(
70+
::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args...
71+
) where {RT<:Array,F,N}
72+
if haskey(seen, xs)
73+
return seen[xs]::RT
74+
end
75+
y = RT(undef, size(first(xs)))
76+
seen[xs] = y
77+
for I in eachindex(xs...)
78+
if all(x -> isassigned(x, I), xs)
79+
xIs = ntuple(j -> xs[j][I], N)
80+
ST = Core.Typeof(first(xIs))
81+
@inbounds y[I] = recursive_map(ST, f, seen, xIs, args...)
82+
end
83+
end
84+
return y
85+
end
86+
87+
@inline function recursive_map!(f::F, y::T, xs::T...) where {F,T}
88+
return recursive_map!(f, y, Base.IdSet(), xs)::Nothing
89+
end
90+
91+
@inline function recursive_map!(
92+
f::F, y::T, seen::Base.IdSet, xs::NTuple{N,T}, isleaftype::L=Returns(false)
93+
) where {F,T,N,L}
94+
if guaranteed_const_nongen(T, nothing)
95+
return nothing
96+
elseif isleaftype(T)
97+
# If there exist T such that isleaftype(T) and T has mutable content that is not
98+
# guaranteed const, including mutables nested inside immutables like Tuple{Vector},
99+
# then f must have a corresponding mutating method:
100+
f(y, xs...)
101+
return nothing
102+
end
103+
return _recursive_map!(f, y, seen, xs, isleaftype)::Nothing
104+
end
105+
106+
@inline function _recursive_map!(
107+
f::F, y::T, seen, xs::NTuple{N,T}, isleaftype
108+
) where {F,T,N}
109+
if y in seen
110+
return nothing
111+
end
112+
@assert !Base.isabstracttype(T)
113+
@assert Base.isconcretetype(T)
114+
nf = fieldcount(T)
115+
if nf == 0
116+
return nothing
117+
end
118+
push!(seen, y)
119+
for i = 1:nf
120+
if isdefined(y, i) && all(x -> isdefined(x, i), xs)
121+
yi = getfield(y, i)
122+
xis = ntuple(j -> getfield(xs[j], i), N)
123+
SBT = Core.Typeof(yi)
124+
activitystate = active_reg_inner(SBT, (), nothing, Val(false))
125+
if activitystate == AnyState
126+
continue
127+
elseif activitystate == DupState
128+
recursive_map!(f, yi, seen, xis, isleaftype)
129+
else
130+
yi = recursive_map_immutable!(f, yi, seen, xis, isleaftype)
131+
if Base.isconst(T, i)
132+
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi)
133+
else
134+
setfield!(y, i, yi)
135+
end
136+
end
137+
end
138+
end
139+
return nothing
140+
end
141+
142+
@inline function _recursive_map!(
143+
f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype
144+
) where {F,T,M,N}
145+
if y in seen
146+
return nothing
147+
end
148+
push!(seen, y)
149+
for I in eachindex(y, xs...)
150+
if isassigned(y, I) && all(x -> isassigned(x, I), xs)
151+
yvalue = y[I]
152+
xvalues = ntuple(j -> xs[j][I], N)
153+
SBT = Core.Typeof(yvalue)
154+
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
155+
@inbounds y[I] = recursive_map_immutable!(
156+
f, yvalue, seen, xvalues, isleaftype
157+
)
158+
else
159+
recursive_map!(f, yvalue, seen, xvalues, isleaftype)
160+
end
161+
end
162+
end
163+
return nothing
164+
end
165+
130166
@inline function recursive_map_immutable!(f::F, y::T, xs::T...) where {F,T}
131167
return recursive_map_immutable!(f, y, Base.IdSet(), xs)::T
132168
end
@@ -185,20 +221,47 @@ end
185221
return newy
186222
end
187223

188-
@inline function EnzymeCore.make_zero!(prev::Array{T,N}) where {T<:_RealOrComplexFloat,N}
189-
fill!(prev, zero(T))
190-
return nothing
224+
const _RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
225+
226+
@inline function EnzymeCore.make_zero(
227+
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
228+
) where {RT,copy_if_inactive}
229+
isleaftype(_) = false
230+
isleaftype(::Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}}) = true
231+
f(p) = EnzymeCore.make_zero(Core.Typeof(p), seen, p, Val(copy_if_inactive))
232+
return recursive_map(RT, f, seen, (prev,), Val(copy_if_inactive), isleaftype)::RT
191233
end
192234

193-
@inline function EnzymeCore.make_zero!(
194-
prev::Array{T,N}, seen::Base.IdSet,
195-
) where {T<:_RealOrComplexFloat,N}
196-
if prev in seen
197-
return nothing
235+
@inline function EnzymeCore.make_zero(
236+
::Type{FT},
237+
@nospecialize(seen::IdDict),
238+
prev::FT,
239+
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
240+
) where {FT<:_RealOrComplexFloat,copy_if_inactive}
241+
return EnzymeCore.make_zero(prev)::FT
242+
end
243+
244+
@inline function EnzymeCore.make_zero(prev::FT) where {FT<:_RealOrComplexFloat}
245+
return Base.zero(prev)::FT
246+
end
247+
248+
@inline function EnzymeCore.make_zero(
249+
::Type{Array{FT,N}},
250+
seen::IdDict,
251+
prev::Array{FT,N},
252+
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
253+
) where {FT<:_RealOrComplexFloat,N,copy_if_inactive}
254+
if haskey(seen, prev)
255+
return seen[prev]::Array{FT,N}
198256
end
199-
push!(seen, prev)
200-
EnzymeCore.make_zero!(prev)
201-
return nothing
257+
newa = EnzymeCore.make_zero(prev)
258+
seen[prev] = newa
259+
return newa::Array{FT,N}
260+
end
261+
262+
@inline function EnzymeCore.make_zero(prev::Array{FT,N}) where {FT<:_RealOrComplexFloat,N}
263+
# convert because Base.zero may return different eltype when FT is not concrete
264+
return convert(Array{FT,N}, Base.zero(prev))::Array{FT,N}
202265
end
203266

204267
@inline function EnzymeCore.make_zero!(prev, seen::Base.IdSet=Base.IdSet())
@@ -213,81 +276,18 @@ end
213276
return recursive_map!(f, prev, seen, (prev,), isleaftype)::Nothing
214277
end
215278

216-
@inline function recursive_map!(f::F, y::T, xs::T...) where {F,T}
217-
return recursive_map!(f, y, Base.IdSet(), xs)::Nothing
218-
end
219-
220-
@inline function recursive_map!(
221-
f::F, y::T, seen::Base.IdSet, xs::NTuple{N,T}, isleaftype::L=Returns(false)
222-
) where {F,T,N,L}
223-
if guaranteed_const_nongen(T, nothing)
224-
return nothing
225-
elseif isleaftype(T)
226-
# If there exist T such that isleaftype(T) and T has mutable content that is not
227-
# guaranteed const, including mutables nested inside immutables like Tuple{Vector},
228-
# then f must have a corresponding mutating method:
229-
f(y, xs...)
230-
return nothing
231-
end
232-
return _recursive_map!(f, y, seen, xs, isleaftype)::Nothing
233-
end
234-
235-
@inline function _recursive_map!(
236-
f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype
237-
) where {F,T,M,N}
238-
if y in seen
279+
@inline function EnzymeCore.make_zero!(
280+
prev::Array{T,N}, seen::Base.IdSet,
281+
) where {T<:_RealOrComplexFloat,N}
282+
if prev in seen
239283
return nothing
240284
end
241-
push!(seen, y)
242-
for I in eachindex(y, xs...)
243-
if isassigned(y, I) && all(x -> isassigned(x, I), xs)
244-
yvalue = y[I]
245-
xvalues = ntuple(j -> xs[j][I], N)
246-
SBT = Core.Typeof(yvalue)
247-
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
248-
@inbounds y[I] = recursive_map_immutable!(
249-
f, yvalue, seen, xvalues, isleaftype
250-
)
251-
else
252-
recursive_map!(f, yvalue, seen, xvalues, isleaftype)
253-
end
254-
end
255-
end
285+
push!(seen, prev)
286+
EnzymeCore.make_zero!(prev)
256287
return nothing
257288
end
258289

259-
@inline function _recursive_map!(
260-
f::F, y::T, seen, xs::NTuple{N,T}, isleaftype
261-
) where {F,T,N}
262-
if y in seen
263-
return nothing
264-
end
265-
@assert !Base.isabstracttype(T)
266-
@assert Base.isconcretetype(T)
267-
nf = fieldcount(T)
268-
if nf == 0
269-
return nothing
270-
end
271-
push!(seen, y)
272-
for i = 1:nf
273-
if isdefined(y, i) && all(x -> isdefined(x, i), xs)
274-
yi = getfield(y, i)
275-
xis = ntuple(j -> getfield(xs[j], i), N)
276-
SBT = Core.Typeof(yi)
277-
activitystate = active_reg_inner(SBT, (), nothing, Val(false))
278-
if activitystate == AnyState
279-
continue
280-
elseif activitystate == DupState
281-
recursive_map!(f, yi, seen, xis, isleaftype)
282-
else
283-
yi = recursive_map_immutable!(f, yi, seen, xis, isleaftype)
284-
if Base.isconst(T, i)
285-
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi)
286-
else
287-
setfield!(y, i, yi)
288-
end
289-
end
290-
end
291-
end
290+
@inline function EnzymeCore.make_zero!(prev::Array{T,N}) where {T<:_RealOrComplexFloat,N}
291+
fill!(prev, zero(T))
292292
return nothing
293293
end

0 commit comments

Comments
 (0)