Skip to content

Commit 7ef8995

Browse files
committed
feat: add forward mode batched enzyme jacobian
1 parent 1625528 commit 7ef8995

File tree

5 files changed

+94
-19
lines changed

5 files changed

+94
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.0.6"
4+
version = "1.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/LuxEnzymeExt/LuxEnzymeExt.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
11
module LuxEnzymeExt
22

3-
using ADTypes: AutoEnzyme
4-
using Enzyme: Enzyme, Active, Const, Duplicated
3+
using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode
4+
using ArgCheck: @argcheck
5+
using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated
56
using EnzymeCore: EnzymeCore
67
using Setfield: @set!
7-
using Static: False, True
8+
using Static: False, True, StaticBool
89

910
using Lux: Lux
1011
using Lux.Training: TrainingBackendCache, TrainState
1112

13+
Lux.is_extension_loaded(::Val{:Enzyme}) = true
14+
15+
normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
16+
function normalize_backend(#=prefer_forward=#::True, ad::AutoEnzyme{Nothing, A}) where {A}
17+
return AutoEnzyme(; mode=Enzyme.Forward, function_annotation=A)
18+
end
19+
function normalize_backend(#=prefer_forward=#::False, ad::AutoEnzyme{Nothing, A}) where {A}
20+
return AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=A)
21+
end
22+
23+
annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f
24+
annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f)
25+
1226
include("training.jl")
1327

28+
include("batched_autodiff.jl")
29+
1430
end
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
2+
f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
3+
backend = normalize_backend(True(), ad)
4+
return batched_enzyme_jacobian_impl(
5+
annotate_function(ad, f), backend, ADTypes.mode(backend), x)
6+
end
7+
8+
function batched_enzyme_jacobian_impl(
9+
f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F}
10+
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
11+
y = f(x)
12+
13+
@argcheck y isa AbstractArray MethodError
14+
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
15+
throw(AssertionError("`batched_jacobian` only supports batched outputs \
16+
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
17+
end
18+
B = size(y, ndims(y))
19+
20+
J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),
21+
prod(size(x)[1:(end - 1)]), B)
22+
23+
chunk_size = min(8, length(y) ÷ B)
24+
partials = ntuple(_ -> zero(x), chunk_size)
25+
26+
for i in 1:chunk_size:(length(x) ÷ B)
27+
idxs = i:min(i + chunk_size - 1, length(x) ÷ B)
28+
partials′ = make_onehot!(partials, idxs)
29+
J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′)))
30+
for (idx, J_partial) in zip(idxs, J_partials)
31+
copyto!(view(J, :, idx, :), reshape(J_partial, :, B))
32+
end
33+
end
34+
35+
return J
36+
end
37+
38+
function batched_enzyme_jacobian_impl(
39+
f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F}
40+
error("reverse mode is not supported yet")
41+
end
42+
43+
function make_onehot!(partials, idxs)
44+
for (idx, partial) in zip(idxs, partials)
45+
partial′ = reshape(partial, :, size(partial, ndims(partial)))
46+
fill!(partial′, false)
47+
partial′[idx, :] .= true
48+
end
49+
return partials[1:length(idxs)]
50+
end

src/Lux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ const NAME_TYPE = Union{Nothing, String, Symbol}
3434
const Optional{T} = Union{T, Nothing}
3535

3636
is_extension_loaded(::Val) = false
37+
is_extension_loaded(::Val{:ForwardDiff}) = true
3738

3839
# Preferences
3940
include("preferences.jl")

src/autodiff/api.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F}
3333
end
3434

3535
function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F}
36-
if !is_extension_loaded(Val(:Zygote))
37-
error("`Zygote.jl` must be loaded for `vector_jacobian_product` \
38-
to work with `$(backend)`.")
39-
end
36+
assert_backend_loaded(:vector_jacobian_product, backend)
4037
return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u)
4138
end
4239

@@ -89,10 +86,11 @@ the following properties for `y = f(x)`:
8986
9087
## Backends & AD Packages
9188
92-
| Supported Backends | Packages Needed |
93-
|:------------------ |:--------------- |
94-
| `AutoForwardDiff` | |
95-
| `AutoZygote` | `Zygote.jl` |
89+
| Supported Backends | Packages Needed | Note |
90+
|:------------------ |:--------------- |:---------------------------------------------- |
91+
| `AutoForwardDiff` | | |
92+
| `AutoZygote` | `Zygote.jl` | |
93+
| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |
9694
9795
## Arguments
9896
@@ -118,14 +116,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where
118116
throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`."))
119117
end
120118

121-
function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F}
122-
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
119+
for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme)
120+
@eval function batched_jacobian(
121+
f::F, backend::$(implemented_backend), x::AbstractArray) where {F}
122+
assert_backend_loaded(:batched_jacobian, backend)
123+
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
124+
end
123125
end
124126

125-
function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F}
126-
if !is_extension_loaded(Val(:Zygote))
127-
error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \
128-
`$(backend)`.")
127+
function assert_backend_loaded(fname::Symbol, ad::AbstractADType)
128+
return assert_backend_loaded(fname, ad, adtype_to_backend(ad))
129+
end
130+
function assert_backend_loaded(fname::Symbol, ad::AbstractADType, backend::Val{B}) where {B}
131+
if !is_extension_loaded(backend)
132+
error("$(fname) with `$(ad)` requires $(B).jl to be loaded.")
129133
end
130-
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
134+
return
131135
end
136+
137+
adtype_to_backend(::AutoEnzyme) = Val(:Enzyme)
138+
adtype_to_backend(::AutoForwardDiff) = Val(:ForwardDiff)
139+
adtype_to_backend(::AutoZygote) = Val(:Zygote)

0 commit comments

Comments
 (0)