1+ # taken from implementation in Lux.jl
2+ struct LuxEltypeAdaptor{T} end
3+
4+ (l:: LuxEltypeAdaptor )(x) = fmap(adapt(l), x)
5+ function (l:: LuxEltypeAdaptor )(x:: AbstractArray{T} ) where {T}
6+ return isbitstype(T) ? adapt(l, x) : map(adapt(l), x)
7+ end
8+
9+ function Adapt. adapt_storage(
10+ :: LuxEltypeAdaptor{T} , x:: AbstractArray{<:AbstractFloat}
11+ ) where {T<: AbstractFloat }
12+ return convert(AbstractArray{T}, x)
13+ end
14+
15+ function Adapt. adapt_storage(
16+ :: LuxEltypeAdaptor{T} , x:: AbstractArray{<:Complex{<:AbstractFloat}}
17+ ) where {T<: AbstractFloat }
18+ return convert(AbstractArray{Complex{T}}, x)
19+ end
20+
121struct Constant{T}
222 val:: T
323end
424
5- # Zygote .jl on CPU
25+ # FiniteDiff .jl on CPU
626function ground_truth_gradient(f, args... )
727 cdev = cpu_device()
28+ f64 = LuxEltypeAdaptor{Float64}()
829 f_cpu = try
9- cdev(f)
30+ f64( cdev(f) )
1031 catch err
1132 @error " Encountered error while moving $(f) to CPU. Skipping movement... This can \
1233 be fixed by defining overloads using ConstructionBase.jl" err
1334 f
1435 end
15- return gradient(f_cpu, AutoZygote (), map(cdev, args). .. )
36+ return gradient(f_cpu, AutoFiniteDiff (), map(f64, map( cdev, args) ). .. )
1637end
1738
1839# Zygote.jl
@@ -31,8 +52,9 @@ function gradient(f::F, ::AutoEnzyme{Nothing}, args...) where {F}
3152end
3253
3354function gradient(f:: F , ad:: AutoEnzyme{<:Enzyme.ReverseMode} , args... ) where {F}
34- ! ENZYME_TESTING_ENABLED &&
55+ if ! ENZYME_TESTING_ENABLED[]
3556 return ntuple(Returns(GradientComputationSkipped()), length(args))
57+ end
3658
3759 args_activity = map(args) do x
3860 needs_gradient(x) && return Enzyme. Duplicated(x, Enzyme. make_zero(x))
79101"""
80102 test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...)
81103
82- Test the gradients of `f` with respect to `args` using the specified backends.
104+ Test the gradients of `f` with respect to `args` using the specified backends. The ground
105+ truth gradients are computed using FiniteDiff.jl on CPU.
83106
84107| Backend | ADType | CPU | GPU | Notes |
85108|:-------------- |:------------------- |:--- |:--- |:----------------- |
86109| Zygote.jl | `AutoZygote()` | ✔ | ✔ | |
87110| ForwardDiff.jl | `AutoForwardDiff()` | ✔ | ✖ | `len ≤ 32` |
88- | FiniteDiff.jl | `AutoFiniteDiff()` | ✔ | ✖ | `len ≤ 32` |
89111| Enzyme.jl | `AutoEnzyme()` | ✔ | ✖ | Only Reverse Mode |
90112
91113## Arguments
@@ -156,9 +178,8 @@ function test_gradients(
156178 push!(backends, AutoZygote())
157179 if ! on_gpu
158180 total_length ≤ 32 && push!(backends, AutoForwardDiff())
159- total_length ≤ 32 && push!(backends, AutoFiniteDiff())
160181 # TODO : Move Enzyme out of here once it supports GPUs
161- if enable_enzyme_reverse_mode || ENZYME_TESTING_ENABLED
182+ if enable_enzyme_reverse_mode || ENZYME_TESTING_ENABLED[]
162183 mode = if enzyme_set_runtime_activity
163184 Enzyme. set_runtime_activity(Enzyme. Reverse)
164185 else
0 commit comments