|
| 1 | +using ADTypes: AutoEnzyme |
| 2 | +using DifferentiationInterface: DifferentiationInterface |
| 3 | +using DifferentiationInterfaceTest: |
| 4 | + default_scenarios, |
| 5 | + sparse_scenarios, |
| 6 | + static_scenarios, |
| 7 | + test_differentiation, |
| 8 | + function_place, |
| 9 | + operator_place, |
| 10 | + FIRST_ORDER, |
| 11 | + SECOND_ORDER, |
| 12 | + Scenario |
| 13 | +using Enzyme: Enzyme |
| 14 | +using EnzymeCore: Forward, Reverse, Const, Duplicated |
| 15 | +using StaticArrays: StaticArrays |
| 16 | +using Test |
| 17 | + |
| 18 | +logging = get(ENV, "CI", "false") == "false" |
| 19 | + |
| 20 | +function remove_matrices(scens::Vector{<:Scenario}) # TODO: remove |
| 21 | + return filter(s -> s.x isa Union{Number, AbstractVector} && s.y isa Union{Number, AbstractVector}, scens) |
| 22 | +end |
| 23 | + |
| 24 | +backends = [ |
| 25 | + AutoEnzyme(; function_annotation = Const), |
| 26 | + AutoEnzyme(; mode = Forward), |
| 27 | + AutoEnzyme(; mode = Reverse), |
| 28 | +] |
| 29 | + |
| 30 | +duplicated_backends = [ |
| 31 | + AutoEnzyme(; mode = Forward, function_annotation = Duplicated), |
| 32 | + AutoEnzyme(; mode = Reverse, function_annotation = Duplicated), |
| 33 | +] |
| 34 | + |
| 35 | +@testset verbose = true "DifferentiationInterface integration" begin |
| 36 | + test_differentiation( |
| 37 | + backends, |
| 38 | + default_scenarios(; include_constantified = true); |
| 39 | + excluded = SECOND_ORDER, |
| 40 | + logging, |
| 41 | + testset_name = "Generic first order", |
| 42 | + ) |
| 43 | + |
| 44 | + test_differentiation( |
| 45 | + backends[1], |
| 46 | + remove_matrices(default_scenarios(; include_constantified = true)); |
| 47 | + excluded = FIRST_ORDER, |
| 48 | + logging, |
| 49 | + testset_name = "Generic second order", |
| 50 | + ) |
| 51 | + |
| 52 | + test_differentiation( |
| 53 | + backends[2], |
| 54 | + remove_matrices( |
| 55 | + default_scenarios(; |
| 56 | + include_normal = false, |
| 57 | + include_cachified = true, |
| 58 | + include_constantorcachified = true, |
| 59 | + use_tuples = true, |
| 60 | + ) |
| 61 | + ); |
| 62 | + excluded = SECOND_ORDER, |
| 63 | + logging, |
| 64 | + testset_name = "Caches", |
| 65 | + ) |
| 66 | + |
| 67 | + test_differentiation( |
| 68 | + duplicated_backends, |
| 69 | + remove_matrices(default_scenarios(; include_normal = false, include_closurified = true)); |
| 70 | + excluded = SECOND_ORDER, |
| 71 | + logging, |
| 72 | + testset_name = "Closures", |
| 73 | + ) |
| 74 | + |
| 75 | + filtered_static_scenarios = filter(remove_matrices(static_scenarios())) do s |
| 76 | + operator_place(s) == :out && function_place(s) == :out |
| 77 | + end |
| 78 | + |
| 79 | + test_differentiation( |
| 80 | + backends[2:3], |
| 81 | + filtered_static_scenarios; |
| 82 | + excluded = SECOND_ORDER, |
| 83 | + logging, |
| 84 | + testset_name = "Static arrays", |
| 85 | + ) |
| 86 | +end |
0 commit comments