Skip to content

Commit 3463886

Browse files
authored
EnzymeTestUtils: faster array/complex support (#2688)
* EnzymeTestUtils: faster array/complex support * fix * fix * More speed * fix * fix * fix
1 parent 3110c39 commit 3463886

File tree

3 files changed

+184
-26
lines changed

3 files changed

+184
-26
lines changed

lib/EnzymeTestUtils/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeTestUtils"
22
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
33
authors = ["Seth Axen <seth@sethaxen.com>", "William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
4-
version = "0.2.4"
4+
version = "0.2.5"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"

lib/EnzymeTestUtils/src/finite_difference_calls.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,18 @@ function multi_tovec(active_return, vals)
6060
end
6161

6262
function j′vp(fdm, f_vec, ȳ, x)
63-
mat = transpose(first(FiniteDifferences.jacobian(fdm, f_vec, x)))
63+
ẏs = map(eachindex(x)) do n
64+
return fdm(zero(eltype(x))) do ε
65+
xn = x[n]
66+
try
67+
x[n] = xn + ε
68+
return copy(f_vec(x)) # copy required incase `f(x)` returns something that aliases `x`
69+
finally
70+
x[n] = xn # Can't do `x[n] -= ϵ` as floating-point math is not associative
71+
end
72+
end
73+
end
74+
mat = transpose(reduce(hcat, ẏs))
6475
result = zero(x)
6576
for i in 1:length(ȳ)
6677
tp = @inbounds ȳ[i]

lib/EnzymeTestUtils/src/to_vec.jl

Lines changed: 171 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ function Base.setindex!(d::AliasDict, val, key)
2828
return d
2929
end
3030

31+
const ElementType = Base.IEEEFloat # , Complex{<:Base.IEEEFloat}}
32+
3133
# alternative to FiniteDifferences.to_vec to use Enzyme's semantics for arrays instead of
3234
# ChainRules': Enzyme treats tangents of AbstractArrays the same as tangents of any other
3335
# struct (i.e. with a container of the same type as the original), while ChainRules
@@ -36,41 +38,130 @@ end
3638
# We take special care that floats that occupy the same memory in the argument only appear
3739
# once in the vector, and that the reconstructed object shares the same memory pattern
3840

41+
function from_vec(from_vec_inner, x_vec::AbstractVector{<:ElementType})
42+
from_vec_inner(x_vec, AliasDict())
43+
end
44+
3945
function to_vec(x)
4046
x_vec, from_vec_inner = to_vec(x, AliasDict())
41-
from_vec(x_vec::Vector{<:AbstractFloat}) = from_vec_inner(x_vec, AliasDict())
42-
return x_vec, from_vec
47+
return x_vec, Base.Fix1(from_vec, from_vec_inner)
4348
end
4449

4550
# base case: we've unwrapped to a number, so we break the recursion
46-
function to_vec(x::AbstractFloat, seen_vecs::AliasDict)
47-
AbstractFloat_from_vec(v::Vector{<:AbstractFloat}, _) = oftype(x, only(v))
51+
function to_vec(x::ElementType, seen_vecs::AliasDict)
52+
AbstractFloat_from_vec(v::AbstractVector{<:ElementType}, _) = oftype(x, only(v))
4853
return [x], AbstractFloat_from_vec
4954
end
5055

56+
# base case: we've unwrapped to a number, so we break the recursion
57+
function to_vec(x::Complex{<:ElementType}, seen_vecs::AliasDict)
58+
AbstractComplex_from_vec(v::AbstractVector{<:ElementType}, _) = Core.Typeof(x)(v[1], v[2])
59+
return [real(x), imag(x)], AbstractComplex_from_vec
60+
end
61+
5162
# basic containers: loop over defined elements, recursively converting them to vectors
52-
function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:Array}
63+
function to_vec(x::Array{<:ElementType}, seen_vecs::AliasDict)
5364
has_seen = haskey(seen_vecs, x)
54-
is_const = Enzyme.Compiler.guaranteed_const(RT)
65+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
66+
if has_seen || is_const
67+
x_vec = Float32[]
68+
else
69+
x_vec = reshape(x, length(x))
70+
seen_vecs[x] = x_vec
71+
end
72+
sz = size(x)
73+
function FastArray_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
74+
if xor(has_seen, haskey(seen_xs, x))
75+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
76+
end
77+
has_seen && return reshape(seen_xs[x], size(x))
78+
is_const && return x
79+
x_new = reshape(x_vec_new, sz)
80+
if Core.Typeof(x_new) != Core.Typeof(x)
81+
x_new = Core.Typeof(x)(x_new)
82+
end
83+
seen_xs[x] = x_new
84+
return x_new
85+
end
86+
return x_vec, FastArray_from_vec
87+
end
88+
89+
# Returns (vector, bool if new allocation)
90+
function append_or_merge(prev::Union{Nothing, Tuple{Vector, Bool}}, newv::Vector)::Tuple{Vector, Bool}
91+
if prev === nothing
92+
return (newv, false)
93+
elseif prev[2] && eltype(newv) <: eltype(prev[1])
94+
append!(prev[1], newv)
95+
return prev
96+
else
97+
ET2 = Base.promote_type(eltype(prev[1]), eltype(newv))
98+
if prev[2] && ET2 == eltype(prev[1])
99+
append!(prev[1], newv)
100+
return prev
101+
else
102+
res = Vector{ET2}(undef, length(prev[1]) + length(newv))
103+
copyto!(@view(res[1:length(prev[1])]), prev[1])
104+
copyto!(@view(res[length(prev[1])+1:end]), newv)
105+
return (res, true)
106+
end
107+
end
108+
end
109+
110+
# basic containers: loop over defined elements, recursively converting them to vectors
111+
function to_vec(x::Array{<:Complex{<:ElementType}}, seen_vecs::AliasDict)
112+
has_seen = haskey(seen_vecs, x)
113+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
114+
if has_seen || is_const
115+
x_vec = Float32[]
116+
else
117+
y = reshape(x, length(x))
118+
x_vec = vcat(real.(y), imag.(y))
119+
seen_vecs[x] = x_vec
120+
end
121+
sz = size(x)
122+
function ComplexArray_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
123+
if xor(has_seen, haskey(seen_xs, x))
124+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
125+
end
126+
has_seen && return reshape(seen_xs[x], size(x))
127+
is_const && return x
128+
x_new = Core.Typeof(x)(undef, sz)
129+
@inbounds @simd for i in 1:length(x)
130+
x_new[i] = eltype(x)(x_vec_new[i], x_vec_new[i + length(x)])
131+
end
132+
seen_xs[x] = x_new
133+
return x_new
134+
end
135+
return x_vec, ComplexArray_from_vec
136+
end
137+
138+
# basic containers: loop over defined elements, recursively converting them to vectors
139+
function to_vec(x::Array, seen_vecs::AliasDict)
140+
has_seen = haskey(seen_vecs, x)
141+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
55142
if has_seen || is_const
56143
x_vec = Float32[]
57144
else
58-
x_vecs = Vector{<:AbstractFloat}[]
145+
x_vecs = nothing
59146
from_vecs = []
60147
subvec_inds = UnitRange{Int}[]
61148
l = 0
62149
for i in eachindex(x)
63150
isassigned(x, i) || continue
64151
xi_vec, xi_from_vec = to_vec(x[i], seen_vecs)
65-
push!(x_vecs, xi_vec)
66-
push!(from_vecs, xi_from_vec)
67152
push!(subvec_inds, (l + 1):(l + length(xi_vec)))
153+
push!(from_vecs, xi_from_vec)
154+
x_vecs = append_or_merge(x_vecs, xi_vec)
68155
l += length(xi_vec)
69156
end
70-
x_vec = reduce(vcat, x_vecs; init=Float32[])
157+
158+
if x_vecs === nothing
159+
x_vecs = (Float32[], true)
160+
end
161+
x_vec = x_vecs[1]
71162
seen_vecs[x] = x_vec
72163
end
73-
function Array_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict)
164+
function Array_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
74165
if xor(has_seen, haskey(seen_xs, x))
75166
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
76167
end
@@ -80,7 +171,7 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:Array}
80171
k = 1
81172
for i in eachindex(x)
82173
isassigned(x, i) || continue
83-
xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs)
174+
xi = from_vecs[k](@view(x_vec_new[subvec_inds[k]]), seen_xs)
84175
x_new[i] = xi
85176
k += 1
86177
end
@@ -93,28 +184,32 @@ end
93184
@static if VERSION < v"1.11-"
94185
else
95186
# basic containers: loop over defined elements, recursively converting them to vectors
96-
function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:GenericMemory}
187+
function to_vec(x::GenericMemory, seen_vecs::AliasDict)
97188
has_seen = haskey(seen_vecs, x)
98-
is_const = Enzyme.Compiler.guaranteed_const(RT)
189+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
99190
if has_seen || is_const
100191
x_vec = Float32[]
101192
else
102-
x_vecs = Vector{<:AbstractFloat}[]
103193
from_vecs = []
104194
subvec_inds = UnitRange{Int}[]
105195
l = 0
196+
x_vecs = nothing
106197
for i in eachindex(x)
107198
isassigned(x, i) || continue
108199
xi_vec, xi_from_vec = to_vec(x[i], seen_vecs)
109-
push!(x_vecs, xi_vec)
110200
push!(from_vecs, xi_from_vec)
111201
push!(subvec_inds, (l + 1):(l + length(xi_vec)))
202+
x_vecs = append_or_merge(x_vecs, xi_vec)
112203
l += length(xi_vec)
113204
end
114-
x_vec = reduce(vcat, x_vecs; init=Float32[])
205+
206+
if x_vecs === nothing
207+
x_vecs = (Float32[], true)
208+
end
209+
x_vec = x_vecs[1]
115210
seen_vecs[x] = x_vec
116211
end
117-
function Memory_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict)
212+
function Memory_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
118213
if xor(has_seen, haskey(seen_xs, x))
119214
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
120215
end
@@ -124,7 +219,7 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:GenericMemory}
124219
k = 1
125220
for i in eachindex(x)
126221
isassigned(x, i) || continue
127-
xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs)
222+
xi = from_vecs[k](@view(x_vec_new[subvec_inds[k]]), seen_xs)
128223
x_new[i] = xi
129224
k += 1
130225
end
@@ -133,18 +228,70 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:GenericMemory}
133228
end
134229
return x_vec, Memory_from_vec
135230
end
231+
232+
# basic containers: loop over defined elements, recursively converting them to vectors
233+
function to_vec(x::GenericMemory{<:ElementType}, seen_vecs::AliasDict)
234+
has_seen = haskey(seen_vecs, x)
235+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
236+
if has_seen || is_const
237+
x_vec = Float32[]
238+
else
239+
seen_vecs[x] = collect(x)
240+
end
241+
function Memory_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
242+
if xor(has_seen, haskey(seen_xs, x))
243+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
244+
end
245+
has_seen && return reshape(seen_xs[x], size(x))
246+
is_const && return x
247+
x_new = typeof(x)(undef, size(x))
248+
copyto!(x_new, x_vec_new)
249+
seen_xs[x] = x_new
250+
return x_new
251+
end
252+
return x_vec, Memory_from_vec
253+
end
136254
end
137255

138256
function to_vec(x::Tuple, seen_vecs::AliasDict)
139-
x_vec, from_vec = to_vec(collect(x), seen_vecs)
140-
function Tuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict)
141-
return typeof(x)(Tuple(from_vec(x_vec_new, seen_xs)))
257+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
258+
if is_const
259+
x_vec = Float32[]
260+
else
261+
x_vecs = nothing
262+
from_vecs = []
263+
subvec_inds = UnitRange{Int}[]
264+
l = 0
265+
for xi in x
266+
xi_vec, xi_from_vec = to_vec(xi, seen_vecs)
267+
push!(subvec_inds, (l + 1):(l + length(xi_vec)))
268+
push!(from_vecs, xi_from_vec)
269+
x_vecs = append_or_merge(x_vecs, xi_vec)
270+
l += length(xi_vec)
271+
end
272+
if x_vecs === nothing
273+
x_vecs = (Float32[], true)
274+
end
275+
x_vec = x_vecs[1]
276+
seen_vecs[x] = x_vec
277+
end
278+
function Tuple_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
279+
is_const && return x
280+
x_new = Vector{Any}(undef, length(x))
281+
for i in 1:length(x)
282+
xi = from_vecs[i](@view(x_vec_new[subvec_inds[i]]), seen_xs)
283+
x_new[i] = xi
284+
end
285+
x_new = (x_new...,)
286+
seen_xs[x] = x_new
287+
return x_new
142288
end
143289
return x_vec, Tuple_from_vec
144290
end
291+
145292
function to_vec(x::NamedTuple, seen_vecs::AliasDict)
146293
x_vec, from_vec = to_vec(values(x), seen_vecs)
147-
function NamedTuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict)
294+
function NamedTuple_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
148295
return NamedTuple{keys(x)}(from_vec(x_vec_new, seen_xs))
149296
end
150297
return x_vec, NamedTuple_from_vec
@@ -174,7 +321,7 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT}
174321
seen_vecs[x] = x_vec
175322
end
176323
end
177-
function Struct_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict)
324+
function Struct_from_vec(x_vec_new::AbstractVector{<:ElementType}, seen_xs::AliasDict)
178325
if xor(has_seen, haskey(seen_xs, x))
179326
throw(ErrorException("Objects must be reconstructed in the same order as they are vectorized."))
180327
end

0 commit comments

Comments
 (0)