-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtest.jl
32 lines (28 loc) · 984 Bytes
/
test.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
using Pkg
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
using Test
LOGGING = get(ENV, "CI", "false") == "false"
function differentiatewith_scenarios()
bad_scens = # these closurified scenarios have mutation and type constraints
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen
DIT.function_place(scen) == :out
end
good_scens = map(bad_scens) do scen
DIT.change_function(
scen, DifferentiateWith(scen.f, AutoFiniteDiff()); keep_smaller=false
)
end
return good_scens
end
test_differentiation(
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,
)