@@ -33,10 +33,7 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F}
3333end
3434
3535function vector_jacobian_product(f:: F , backend:: AutoZygote , x, u) where {F}
36- if ! is_extension_loaded(Val(:Zygote))
37- error(" `Zygote.jl` must be loaded for `vector_jacobian_product` \
38- to work with `$(backend) `." )
39- end
36+ assert_backend_loaded(:vector_jacobian_product, backend)
4037 return AutoDiffInternalImpl. vector_jacobian_product(f, backend, x, u)
4138end
4239
@@ -89,10 +86,11 @@ the following properties for `y = f(x)`:
8986
9087## Backends & AD Packages
9188
92- | Supported Backends | Packages Needed |
93- |:------------------ |:--------------- |
94- | `AutoForwardDiff` | |
95- | `AutoZygote` | `Zygote.jl` |
89+ | Supported Backends | Packages Needed | Note |
90+ |:------------------ |:--------------- |:---------------------------------------------- |
91+ | `AutoForwardDiff` | | |
92+ | `AutoZygote` | `Zygote.jl` | |
93+ | `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |
9694
9795## Arguments
9896
@@ -118,14 +116,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where
118116 throw(ArgumentError(" `batched_jacobian` is not implemented for `$(backend) `." ))
119117end
120118
121- function batched_jacobian(f:: F , backend:: AutoForwardDiff , x:: AbstractArray ) where {F}
122- return AutoDiffInternalImpl. batched_jacobian(f, backend, x)
119+ for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme)
120+ @eval function batched_jacobian(
121+ f:: F , backend:: $ (implemented_backend), x:: AbstractArray ) where {F}
122+ assert_backend_loaded(:batched_jacobian, backend)
123+ return AutoDiffInternalImpl. batched_jacobian(f, backend, x)
124+ end
123125end
124126
125- function batched_jacobian(f:: F , backend:: AutoZygote , x:: AbstractArray ) where {F}
126- if ! is_extension_loaded(Val(:Zygote))
127- error(" `Zygote.jl` must be loaded for `batched_jacobian` to work with \
128- `$(backend) `." )
127+ function assert_backend_loaded(fname:: Symbol , ad:: AbstractADType )
128+ return assert_backend_loaded(fname, ad, adtype_to_backend(ad))
129+ end
130+ function assert_backend_loaded(fname:: Symbol , ad:: AbstractADType , backend:: Val{B} ) where {B}
131+ if ! is_extension_loaded(backend)
132+ error(" $(fname) with `$(ad) ` requires $(B) .jl to be loaded." )
129133 end
130- return AutoDiffInternalImpl . batched_jacobian(f, backend, x)
134+ return
131135end
136+
137+ adtype_to_backend(:: AutoEnzyme ) = Val(:Enzyme)
138+ adtype_to_backend(:: AutoForwardDiff ) = Val(:ForwardDiff)
139+ adtype_to_backend(:: AutoZygote ) = Val(:Zygote)
0 commit comments