1
- const _RealOrComplexFloat = Union{AbstractFloat,Complex{<: AbstractFloat }}
2
-
3
- @inline function EnzymeCore. make_zero (prev:: FT ) where {FT<: _RealOrComplexFloat }
4
- return Base. zero (prev):: FT
5
- end
6
-
7
- @inline function EnzymeCore. make_zero (
8
- :: Type{FT} ,
9
- @nospecialize (seen:: IdDict ),
10
- prev:: FT ,
11
- @nospecialize (_:: Val{copy_if_inactive} = Val (false )),
12
- ) where {FT<: _RealOrComplexFloat ,copy_if_inactive}
13
- return EnzymeCore. make_zero (prev):: FT
14
- end
15
-
16
- @inline function EnzymeCore. make_zero (prev:: Array{FT,N} ) where {FT<: _RealOrComplexFloat ,N}
17
- # convert because Base.zero may return different eltype when FT is not concrete
18
- return convert (Array{FT,N}, Base. zero (prev)):: Array{FT,N}
19
- end
20
-
21
- @inline function EnzymeCore. make_zero (
22
- :: Type{Array{FT,N}} ,
23
- seen:: IdDict ,
24
- prev:: Array{FT,N} ,
25
- @nospecialize (_:: Val{copy_if_inactive} = Val (false )),
26
- ) where {FT<: _RealOrComplexFloat ,N,copy_if_inactive}
27
- if haskey (seen, prev)
28
- return seen[prev]:: Array{FT,N}
29
- end
30
- newa = EnzymeCore. make_zero (prev)
31
- seen[prev] = newa
32
- return newa:: Array{FT,N}
33
- end
34
-
35
- @inline function EnzymeCore. make_zero (
36
- :: Type{RT} , seen:: IdDict , prev:: RT , :: Val{copy_if_inactive} = Val (false )
37
- ) where {RT,copy_if_inactive}
38
- isleaftype (_) = false
39
- isleaftype (:: Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}} ) = true
40
- f (p) = EnzymeCore. make_zero (Core. Typeof (p), seen, p, Val (copy_if_inactive))
41
- return recursive_map (RT, f, seen, (prev,), Val (copy_if_inactive), isleaftype):: RT
42
- end
43
-
44
1
recursive_map (f:: F , xs:: T... ) where {F,T} = recursive_map (T, f, IdDict (), xs):: T
45
2
46
3
@inline function recursive_map (
@@ -59,24 +16,6 @@ recursive_map(f::F, xs::T...) where {F,T} = recursive_map(T, f, IdDict(), xs)::T
59
16
return _recursive_map (RT, f, seen, xs, Val (copy_if_inactive), isleaftype):: RT
60
17
end
61
18
62
- @inline function _recursive_map (
63
- :: Type{RT} , f:: F , seen:: IdDict , xs:: NTuple{N,RT} , args...
64
- ) where {RT<: Array ,F,N}
65
- if haskey (seen, xs)
66
- return seen[xs]:: RT
67
- end
68
- y = RT (undef, size (first (xs)))
69
- seen[xs] = y
70
- for I in eachindex (xs... )
71
- if all (x -> isassigned (x, I), xs)
72
- xIs = ntuple (j -> xs[j][I], N)
73
- ST = Core. Typeof (first (xIs))
74
- @inbounds y[I] = recursive_map (ST, f, seen, xIs, args... )
75
- end
76
- end
77
- return y
78
- end
79
-
80
19
@inline function _recursive_map (
81
20
:: Type{RT} , f:: F , seen:: IdDict , xs:: NTuple{N,RT} , args...
82
21
) where {RT,F,N}
127
66
return y
128
67
end
129
68
69
+ @inline function _recursive_map (
70
+ :: Type{RT} , f:: F , seen:: IdDict , xs:: NTuple{N,RT} , args...
71
+ ) where {RT<: Array ,F,N}
72
+ if haskey (seen, xs)
73
+ return seen[xs]:: RT
74
+ end
75
+ y = RT (undef, size (first (xs)))
76
+ seen[xs] = y
77
+ for I in eachindex (xs... )
78
+ if all (x -> isassigned (x, I), xs)
79
+ xIs = ntuple (j -> xs[j][I], N)
80
+ ST = Core. Typeof (first (xIs))
81
+ @inbounds y[I] = recursive_map (ST, f, seen, xIs, args... )
82
+ end
83
+ end
84
+ return y
85
+ end
86
+
87
+ @inline function recursive_map! (f:: F , y:: T , xs:: T... ) where {F,T}
88
+ return recursive_map! (f, y, Base. IdSet (), xs):: Nothing
89
+ end
90
+
91
+ @inline function recursive_map! (
92
+ f:: F , y:: T , seen:: Base.IdSet , xs:: NTuple{N,T} , isleaftype:: L = Returns (false )
93
+ ) where {F,T,N,L}
94
+ if guaranteed_const_nongen (T, nothing )
95
+ return nothing
96
+ elseif isleaftype (T)
97
+ # If there exist T such that isleaftype(T) and T has mutable content that is not
98
+ # guaranteed const, including mutables nested inside immutables like Tuple{Vector},
99
+ # then f must have a corresponding mutating method:
100
+ f (y, xs... )
101
+ return nothing
102
+ end
103
+ return _recursive_map! (f, y, seen, xs, isleaftype):: Nothing
104
+ end
105
+
106
+ @inline function _recursive_map! (
107
+ f:: F , y:: T , seen, xs:: NTuple{N,T} , isleaftype
108
+ ) where {F,T,N}
109
+ if y in seen
110
+ return nothing
111
+ end
112
+ @assert ! Base. isabstracttype (T)
113
+ @assert Base. isconcretetype (T)
114
+ nf = fieldcount (T)
115
+ if nf == 0
116
+ return nothing
117
+ end
118
+ push! (seen, y)
119
+ for i = 1 : nf
120
+ if isdefined (y, i) && all (x -> isdefined (x, i), xs)
121
+ yi = getfield (y, i)
122
+ xis = ntuple (j -> getfield (xs[j], i), N)
123
+ SBT = Core. Typeof (yi)
124
+ activitystate = active_reg_inner (SBT, (), nothing , Val (false ))
125
+ if activitystate == AnyState
126
+ continue
127
+ elseif activitystate == DupState
128
+ recursive_map! (f, yi, seen, xis, isleaftype)
129
+ else
130
+ yi = recursive_map_immutable! (f, yi, seen, xis, isleaftype)
131
+ if Base. isconst (T, i)
132
+ ccall (:jl_set_nth_field , Cvoid, (Any, Csize_t, Any), y, i - 1 , yi)
133
+ else
134
+ setfield! (y, i, yi)
135
+ end
136
+ end
137
+ end
138
+ end
139
+ return nothing
140
+ end
141
+
142
+ @inline function _recursive_map! (
143
+ f:: F , y:: Array{T,M} , seen, xs:: NTuple{N,Array{T,M}} , isleaftype
144
+ ) where {F,T,M,N}
145
+ if y in seen
146
+ return nothing
147
+ end
148
+ push! (seen, y)
149
+ for I in eachindex (y, xs... )
150
+ if isassigned (y, I) && all (x -> isassigned (x, I), xs)
151
+ yvalue = y[I]
152
+ xvalues = ntuple (j -> xs[j][I], N)
153
+ SBT = Core. Typeof (yvalue)
154
+ if active_reg_inner (SBT, (), nothing , Val (true )) == ActiveState #= justActive=#
155
+ @inbounds y[I] = recursive_map_immutable! (
156
+ f, yvalue, seen, xvalues, isleaftype
157
+ )
158
+ else
159
+ recursive_map! (f, yvalue, seen, xvalues, isleaftype)
160
+ end
161
+ end
162
+ end
163
+ return nothing
164
+ end
165
+
130
166
@inline function recursive_map_immutable! (f:: F , y:: T , xs:: T... ) where {F,T}
131
167
return recursive_map_immutable! (f, y, Base. IdSet (), xs):: T
132
168
end
@@ -185,20 +221,47 @@ end
185
221
return newy
186
222
end
187
223
188
- @inline function EnzymeCore. make_zero! (prev:: Array{T,N} ) where {T<: _RealOrComplexFloat ,N}
189
- fill! (prev, zero (T))
190
- return nothing
224
+ const _RealOrComplexFloat = Union{AbstractFloat,Complex{<: AbstractFloat }}
225
+
226
+ @inline function EnzymeCore. make_zero (
227
+ :: Type{RT} , seen:: IdDict , prev:: RT , :: Val{copy_if_inactive} = Val (false )
228
+ ) where {RT,copy_if_inactive}
229
+ isleaftype (_) = false
230
+ isleaftype (:: Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}} ) = true
231
+ f (p) = EnzymeCore. make_zero (Core. Typeof (p), seen, p, Val (copy_if_inactive))
232
+ return recursive_map (RT, f, seen, (prev,), Val (copy_if_inactive), isleaftype):: RT
191
233
end
192
234
193
- @inline function EnzymeCore. make_zero! (
194
- prev:: Array{T,N} , seen:: Base.IdSet ,
195
- ) where {T<: _RealOrComplexFloat ,N}
196
- if prev in seen
197
- return nothing
235
+ @inline function EnzymeCore. make_zero (
236
+ :: Type{FT} ,
237
+ @nospecialize (seen:: IdDict ),
238
+ prev:: FT ,
239
+ @nospecialize (_:: Val{copy_if_inactive} = Val (false )),
240
+ ) where {FT<: _RealOrComplexFloat ,copy_if_inactive}
241
+ return EnzymeCore. make_zero (prev):: FT
242
+ end
243
+
244
+ @inline function EnzymeCore. make_zero (prev:: FT ) where {FT<: _RealOrComplexFloat }
245
+ return Base. zero (prev):: FT
246
+ end
247
+
248
+ @inline function EnzymeCore. make_zero (
249
+ :: Type{Array{FT,N}} ,
250
+ seen:: IdDict ,
251
+ prev:: Array{FT,N} ,
252
+ @nospecialize (_:: Val{copy_if_inactive} = Val (false )),
253
+ ) where {FT<: _RealOrComplexFloat ,N,copy_if_inactive}
254
+ if haskey (seen, prev)
255
+ return seen[prev]:: Array{FT,N}
198
256
end
199
- push! (seen, prev)
200
- EnzymeCore. make_zero! (prev)
201
- return nothing
257
+ newa = EnzymeCore. make_zero (prev)
258
+ seen[prev] = newa
259
+ return newa:: Array{FT,N}
260
+ end
261
+
262
+ @inline function EnzymeCore. make_zero (prev:: Array{FT,N} ) where {FT<: _RealOrComplexFloat ,N}
263
+ # convert because Base.zero may return different eltype when FT is not concrete
264
+ return convert (Array{FT,N}, Base. zero (prev)):: Array{FT,N}
202
265
end
203
266
204
267
@inline function EnzymeCore. make_zero! (prev, seen:: Base.IdSet = Base. IdSet ())
@@ -213,81 +276,18 @@ end
213
276
return recursive_map! (f, prev, seen, (prev,), isleaftype):: Nothing
214
277
end
215
278
216
- @inline function recursive_map! (f:: F , y:: T , xs:: T... ) where {F,T}
217
- return recursive_map! (f, y, Base. IdSet (), xs):: Nothing
218
- end
219
-
220
- @inline function recursive_map! (
221
- f:: F , y:: T , seen:: Base.IdSet , xs:: NTuple{N,T} , isleaftype:: L = Returns (false )
222
- ) where {F,T,N,L}
223
- if guaranteed_const_nongen (T, nothing )
224
- return nothing
225
- elseif isleaftype (T)
226
- # If there exist T such that isleaftype(T) and T has mutable content that is not
227
- # guaranteed const, including mutables nested inside immutables like Tuple{Vector},
228
- # then f must have a corresponding mutating method:
229
- f (y, xs... )
230
- return nothing
231
- end
232
- return _recursive_map! (f, y, seen, xs, isleaftype):: Nothing
233
- end
234
-
235
- @inline function _recursive_map! (
236
- f:: F , y:: Array{T,M} , seen, xs:: NTuple{N,Array{T,M}} , isleaftype
237
- ) where {F,T,M,N}
238
- if y in seen
279
+ @inline function EnzymeCore. make_zero! (
280
+ prev:: Array{T,N} , seen:: Base.IdSet ,
281
+ ) where {T<: _RealOrComplexFloat ,N}
282
+ if prev in seen
239
283
return nothing
240
284
end
241
- push! (seen, y)
242
- for I in eachindex (y, xs... )
243
- if isassigned (y, I) && all (x -> isassigned (x, I), xs)
244
- yvalue = y[I]
245
- xvalues = ntuple (j -> xs[j][I], N)
246
- SBT = Core. Typeof (yvalue)
247
- if active_reg_inner (SBT, (), nothing , Val (true )) == ActiveState #= justActive=#
248
- @inbounds y[I] = recursive_map_immutable! (
249
- f, yvalue, seen, xvalues, isleaftype
250
- )
251
- else
252
- recursive_map! (f, yvalue, seen, xvalues, isleaftype)
253
- end
254
- end
255
- end
285
+ push! (seen, prev)
286
+ EnzymeCore. make_zero! (prev)
256
287
return nothing
257
288
end
258
289
259
- @inline function _recursive_map! (
260
- f:: F , y:: T , seen, xs:: NTuple{N,T} , isleaftype
261
- ) where {F,T,N}
262
- if y in seen
263
- return nothing
264
- end
265
- @assert ! Base. isabstracttype (T)
266
- @assert Base. isconcretetype (T)
267
- nf = fieldcount (T)
268
- if nf == 0
269
- return nothing
270
- end
271
- push! (seen, y)
272
- for i = 1 : nf
273
- if isdefined (y, i) && all (x -> isdefined (x, i), xs)
274
- yi = getfield (y, i)
275
- xis = ntuple (j -> getfield (xs[j], i), N)
276
- SBT = Core. Typeof (yi)
277
- activitystate = active_reg_inner (SBT, (), nothing , Val (false ))
278
- if activitystate == AnyState
279
- continue
280
- elseif activitystate == DupState
281
- recursive_map! (f, yi, seen, xis, isleaftype)
282
- else
283
- yi = recursive_map_immutable! (f, yi, seen, xis, isleaftype)
284
- if Base. isconst (T, i)
285
- ccall (:jl_set_nth_field , Cvoid, (Any, Csize_t, Any), y, i - 1 , yi)
286
- else
287
- setfield! (y, i, yi)
288
- end
289
- end
290
- end
291
- end
290
+ @inline function EnzymeCore. make_zero! (prev:: Array{T,N} ) where {T<: _RealOrComplexFloat ,N}
291
+ fill! (prev, zero (T))
292
292
return nothing
293
293
end
0 commit comments