@@ -28,6 +28,8 @@ function Base.setindex!(d::AliasDict, val, key)
2828 return d
2929end
3030
31+ const ElementType = Base. IEEEFloat # , Complex{<:Base.IEEEFloat}}
32+
3133# alternative to FiniteDifferences.to_vec to use Enzyme's semantics for arrays instead of
3234# ChainRules': Enzyme treats tangents of AbstractArrays the same as tangents of any other
3335# struct (i.e. with a container of the same type as the original), while ChainRules
3638# We take special care that floats that occupy the same memory in the argument only appear
3739# once in the vector, and that the reconstructed object shares the same memory pattern
3840
41+ function from_vec (from_vec_inner, x_vec:: AbstractVector{<:ElementType} )
42+ from_vec_inner (x_vec, AliasDict ())
43+ end
44+
3945function to_vec (x)
4046 x_vec, from_vec_inner = to_vec (x, AliasDict ())
41- from_vec (x_vec:: Vector{<:AbstractFloat} ) = from_vec_inner (x_vec, AliasDict ())
42- return x_vec, from_vec
47+ return x_vec, Base. Fix1 (from_vec, from_vec_inner)
4348end
4449
4550# base case: we've unwrapped to a number, so we break the recursion
46- function to_vec (x:: AbstractFloat , seen_vecs:: AliasDict )
47- AbstractFloat_from_vec (v:: Vector {<:AbstractFloat } , _) = oftype (x, only (v))
51+ function to_vec (x:: ElementType , seen_vecs:: AliasDict )
52+ AbstractFloat_from_vec (v:: AbstractVector {<:ElementType } , _) = oftype (x, only (v))
4853 return [x], AbstractFloat_from_vec
4954end
5055
56+ # base case: we've unwrapped to a number, so we break the recursion
57+ function to_vec (x:: Complex{<:ElementType} , seen_vecs:: AliasDict )
58+ AbstractComplex_from_vec (v:: AbstractVector{<:ElementType} , _) = Core. Typeof (x)(v[1 ], v[2 ])
59+ return [real (x), imag (x)], AbstractComplex_from_vec
60+ end
61+
5162# basic containers: loop over defined elements, recursively converting them to vectors
52- function to_vec (x:: RT , seen_vecs:: AliasDict ) where {RT <: Array }
63+ function to_vec (x:: Array{<:ElementType} , seen_vecs:: AliasDict )
5364 has_seen = haskey (seen_vecs, x)
54- is_const = Enzyme. Compiler. guaranteed_const (RT)
65+ is_const = Enzyme. Compiler. guaranteed_const (Core. Typeof (x))
66+ if has_seen || is_const
67+ x_vec = Float32[]
68+ else
69+ x_vec = reshape (x, length (x))
70+ seen_vecs[x] = x_vec
71+ end
72+ sz = size (x)
73+ function FastArray_from_vec (x_vec_new:: AbstractVector{<:ElementType} , seen_xs:: AliasDict )
74+ if xor (has_seen, haskey (seen_xs, x))
75+ throw (ErrorException (" Arrays must be reconstructed in the same order as they are vectorized." ))
76+ end
77+ has_seen && return reshape (seen_xs[x], size (x))
78+ is_const && return x
79+ x_new = reshape (x_vec_new, sz)
80+ if Core. Typeof (x_new) != Core. Typeof (x)
81+ x_new = Core. Typeof (x)(x_new)
82+ end
83+ seen_xs[x] = x_new
84+ return x_new
85+ end
86+ return x_vec, FastArray_from_vec
87+ end
88+
89+ # Returns (vector, bool if new allocation)
90+ function append_or_merge (prev:: Union{Nothing, Tuple{Vector, Bool}} , newv:: Vector ):: Tuple{Vector, Bool}
91+ if prev === nothing
92+ return (newv, false )
93+ elseif prev[2 ] && eltype (newv) <: eltype (prev[1 ])
94+ append! (prev[1 ], newv)
95+ return prev
96+ else
97+ ET2 = Base. promote_type (eltype (prev[1 ]), eltype (newv))
98+ if prev[2 ] && ET2 == eltype (prev[1 ])
99+ append! (prev[1 ], newv)
100+ return prev
101+ else
102+ res = Vector {ET2} (undef, length (prev[1 ]) + length (newv))
103+ copyto! (@view (res[1 : length (prev[1 ])]), prev[1 ])
104+ copyto! (@view (res[length (prev[1 ])+ 1 : end ]), newv)
105+ return (res, true )
106+ end
107+ end
108+ end
109+
110+ # basic containers: loop over defined elements, recursively converting them to vectors
111+ function to_vec (x:: Array{<:Complex{<:ElementType}} , seen_vecs:: AliasDict )
112+ has_seen = haskey (seen_vecs, x)
113+ is_const = Enzyme. Compiler. guaranteed_const (Core. Typeof (x))
114+ if has_seen || is_const
115+ x_vec = Float32[]
116+ else
117+ y = reshape (x, length (x))
118+ x_vec = vcat (real .(y), imag .(y))
119+ seen_vecs[x] = x_vec
120+ end
121+ sz = size (x)
122+ function ComplexArray_from_vec (x_vec_new:: AbstractVector{<:ElementType} , seen_xs:: AliasDict )
123+ if xor (has_seen, haskey (seen_xs, x))
124+ throw (ErrorException (" Arrays must be reconstructed in the same order as they are vectorized." ))
125+ end
126+ has_seen && return reshape (seen_xs[x], size (x))
127+ is_const && return x
128+ x_new = Core. Typeof (x)(undef, sz)
129+ @inbounds @simd for i in 1 : length (x)
130+ x_new[i] = eltype (x)(x_vec_new[i], x_vec_new[i + length (x)])
131+ end
132+ seen_xs[x] = x_new
133+ return x_new
134+ end
135+ return x_vec, ComplexArray_from_vec
136+ end
137+
138+ # basic containers: loop over defined elements, recursively converting them to vectors
139+ function to_vec (x:: Array , seen_vecs:: AliasDict )
140+ has_seen = haskey (seen_vecs, x)
141+ is_const = Enzyme. Compiler. guaranteed_const (Core. Typeof (x))
55142 if has_seen || is_const
56143 x_vec = Float32[]
57144 else
58- x_vecs = Vector{ <: AbstractFloat }[]
145+ x_vecs = nothing
59146 from_vecs = []
60147 subvec_inds = UnitRange{Int}[]
61148 l = 0
62149 for i in eachindex (x)
63150 isassigned (x, i) || continue
64151 xi_vec, xi_from_vec = to_vec (x[i], seen_vecs)
65- push! (x_vecs, xi_vec)
66- push! (from_vecs, xi_from_vec)
67152 push! (subvec_inds, (l + 1 ): (l + length (xi_vec)))
153+ push! (from_vecs, xi_from_vec)
154+ x_vecs = append_or_merge (x_vecs, xi_vec)
68155 l += length (xi_vec)
69156 end
70- x_vec = reduce (vcat, x_vecs; init= Float32[])
157+
158+ if x_vecs === nothing
159+ x_vecs = (Float32[], true )
160+ end
161+ x_vec = x_vecs[1 ]
71162 seen_vecs[x] = x_vec
72163 end
73- function Array_from_vec (x_vec_new:: Vector {<:AbstractFloat } , seen_xs:: AliasDict )
164+ function Array_from_vec (x_vec_new:: AbstractVector {<:ElementType } , seen_xs:: AliasDict )
74165 if xor (has_seen, haskey (seen_xs, x))
75166 throw (ErrorException (" Arrays must be reconstructed in the same order as they are vectorized." ))
76167 end
@@ -80,7 +171,7 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:Array}
80171 k = 1
81172 for i in eachindex (x)
82173 isassigned (x, i) || continue
83- xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs)
174+ xi = from_vecs[k](@view ( x_vec_new[subvec_inds[k]]) , seen_xs)
84175 x_new[i] = xi
85176 k += 1
86177 end
93184@static if VERSION < v " 1.11-"
94185else
95186# basic containers: loop over defined elements, recursively converting them to vectors
96- function to_vec (x:: RT , seen_vecs:: AliasDict ) where {RT <: GenericMemory }
187+ function to_vec (x:: GenericMemory , seen_vecs:: AliasDict )
97188 has_seen = haskey (seen_vecs, x)
98- is_const = Enzyme. Compiler. guaranteed_const (RT )
189+ is_const = Enzyme. Compiler. guaranteed_const (Core . Typeof (x) )
99190 if has_seen || is_const
100191 x_vec = Float32[]
101192 else
102- x_vecs = Vector{<: AbstractFloat }[]
103193 from_vecs = []
104194 subvec_inds = UnitRange{Int}[]
105195 l = 0
196+ x_vecs = nothing
106197 for i in eachindex (x)
107198 isassigned (x, i) || continue
108199 xi_vec, xi_from_vec = to_vec (x[i], seen_vecs)
109- push! (x_vecs, xi_vec)
110200 push! (from_vecs, xi_from_vec)
111201 push! (subvec_inds, (l + 1 ): (l + length (xi_vec)))
202+ x_vecs = append_or_merge (x_vecs, xi_vec)
112203 l += length (xi_vec)
113204 end
114- x_vec = reduce (vcat, x_vecs; init= Float32[])
205+
206+ if x_vecs === nothing
207+ x_vecs = (Float32[], true )
208+ end
209+ x_vec = x_vecs[1 ]
115210 seen_vecs[x] = x_vec
116211 end
117- function Memory_from_vec (x_vec_new:: Vector {<:AbstractFloat } , seen_xs:: AliasDict )
212+ function Memory_from_vec (x_vec_new:: AbstractVector {<:ElementType } , seen_xs:: AliasDict )
118213 if xor (has_seen, haskey (seen_xs, x))
119214 throw (ErrorException (" Arrays must be reconstructed in the same order as they are vectorized." ))
120215 end
@@ -124,7 +219,7 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:GenericMemory}
124219 k = 1
125220 for i in eachindex (x)
126221 isassigned (x, i) || continue
127- xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs)
222+ xi = from_vecs[k](@view ( x_vec_new[subvec_inds[k]]) , seen_xs)
128223 x_new[i] = xi
129224 k += 1
130225 end
@@ -133,18 +228,70 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:GenericMemory}
133228 end
134229 return x_vec, Memory_from_vec
135230end
231+
232+ # basic containers: loop over defined elements, recursively converting them to vectors
233+ function to_vec (x:: GenericMemory{<:ElementType} , seen_vecs:: AliasDict )
234+ has_seen = haskey (seen_vecs, x)
235+ is_const = Enzyme. Compiler. guaranteed_const (Core. Typeof (x))
236+ if has_seen || is_const
237+ x_vec = Float32[]
238+ else
239+ seen_vecs[x] = collect (x)
240+ end
241+ function Memory_from_vec (x_vec_new:: AbstractVector{<:ElementType} , seen_xs:: AliasDict )
242+ if xor (has_seen, haskey (seen_xs, x))
243+ throw (ErrorException (" Arrays must be reconstructed in the same order as they are vectorized." ))
244+ end
245+ has_seen && return reshape (seen_xs[x], size (x))
246+ is_const && return x
247+ x_new = typeof (x)(undef, size (x))
248+ copyto! (x_new, x_vec_new)
249+ seen_xs[x] = x_new
250+ return x_new
251+ end
252+ return x_vec, Memory_from_vec
253+ end
136254end
137255
138256function to_vec (x:: Tuple , seen_vecs:: AliasDict )
139- x_vec, from_vec = to_vec (collect (x), seen_vecs)
140- function Tuple_from_vec (x_vec_new:: Vector{<:AbstractFloat} , seen_xs:: AliasDict )
141- return typeof (x)(Tuple (from_vec (x_vec_new, seen_xs)))
257+ is_const = Enzyme. Compiler. guaranteed_const (Core. Typeof (x))
258+ if is_const
259+ x_vec = Float32[]
260+ else
261+ x_vecs = nothing
262+ from_vecs = []
263+ subvec_inds = UnitRange{Int}[]
264+ l = 0
265+ for xi in x
266+ xi_vec, xi_from_vec = to_vec (xi, seen_vecs)
267+ push! (subvec_inds, (l + 1 ): (l + length (xi_vec)))
268+ push! (from_vecs, xi_from_vec)
269+ x_vecs = append_or_merge (x_vecs, xi_vec)
270+ l += length (xi_vec)
271+ end
272+ if x_vecs === nothing
273+ x_vecs = (Float32[], true )
274+ end
275+ x_vec = x_vecs[1 ]
276+ seen_vecs[x] = x_vec
277+ end
278+ function Tuple_from_vec (x_vec_new:: AbstractVector{<:ElementType} , seen_xs:: AliasDict )
279+ is_const && return x
280+ x_new = Vector {Any} (undef, length (x))
281+ for i in 1 : length (x)
282+ xi = from_vecs[i](@view (x_vec_new[subvec_inds[i]]), seen_xs)
283+ x_new[i] = xi
284+ end
285+ x_new = (x_new... ,)
286+ seen_xs[x] = x_new
287+ return x_new
142288 end
143289 return x_vec, Tuple_from_vec
144290end
291+
145292function to_vec (x:: NamedTuple , seen_vecs:: AliasDict )
146293 x_vec, from_vec = to_vec (values (x), seen_vecs)
147- function NamedTuple_from_vec (x_vec_new:: Vector {<:AbstractFloat } , seen_xs:: AliasDict )
294+ function NamedTuple_from_vec (x_vec_new:: AbstractVector {<:ElementType } , seen_xs:: AliasDict )
148295 return NamedTuple {keys(x)} (from_vec (x_vec_new, seen_xs))
149296 end
150297 return x_vec, NamedTuple_from_vec
@@ -174,7 +321,7 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT}
174321 seen_vecs[x] = x_vec
175322 end
176323 end
177- function Struct_from_vec (x_vec_new:: Vector {<:AbstractFloat } , seen_xs:: AliasDict )
324+ function Struct_from_vec (x_vec_new:: AbstractVector {<:ElementType } , seen_xs:: AliasDict )
178325 if xor (has_seen, haskey (seen_xs, x))
179326 throw (ErrorException (" Objects must be reconstructed in the same order as they are vectorized." ))
180327 end
0 commit comments