Skip to content

Commit 91491c9

Browse files
committed
feat: add reverse mode batched enzyme jacobian
1 parent 000b691 commit 91491c9

File tree

4 files changed

+57
-10
lines changed

4 files changed

+57
-10
lines changed

ext/LuxEnzymeExt/LuxEnzymeExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module LuxEnzymeExt
22

33
using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode
44
using ArgCheck: @argcheck
5+
using ConcreteStructs: @concrete
56
using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated
67
using EnzymeCore: EnzymeCore
78
using Functors: fmap
@@ -15,8 +16,8 @@ using MLDataDevices: isleaf
1516
Lux.is_extension_loaded(::Val{:Enzyme}) = true
1617

1718
normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
18-
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Forward)
19-
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Reverse)
19+
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Forward)
20+
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Reverse)
2021

2122
annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f
2223
annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f)

ext/LuxEnzymeExt/batched_autodiff.jl

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
22
f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
33
backend = normalize_backend(True(), ad)
4-
return batched_enzyme_jacobian_impl(
5-
annotate_function(ad, f), backend, ADTypes.mode(backend), x)
4+
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x)
65
end
76

87
function batched_enzyme_jacobian_impl(
9-
f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F}
8+
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G}
109
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
11-
y = f(x)
10+
y = f_orig(x)
11+
f = annotate_function(ad, f_orig)
1212

1313
@argcheck y isa AbstractArray MethodError
1414
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
@@ -36,8 +36,38 @@ function batched_enzyme_jacobian_impl(
3636
end
3737

3838
function batched_enzyme_jacobian_impl(
39-
f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F}
40-
error("reverse mode is not supported yet")
39+
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G}
40+
# We need to run the function once to get the output type. Can we use ReverseWithPrimal?
41+
y = f_orig(x)
42+
43+
@argcheck y isa AbstractArray MethodError
44+
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
45+
throw(AssertionError("`batched_jacobian` only supports batched outputs \
46+
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
47+
end
48+
B = size(y, ndims(y))
49+
50+
J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),
51+
prod(size(x)[1:(end - 1)]), B)
52+
53+
chunk_size = min(8, length(x) ÷ B)
54+
partials = ntuple(_ -> zero(y), chunk_size)
55+
J_partials = ntuple(_ -> zero(x), chunk_size)
56+
57+
fn = annotate_function(ad, OOPFunctionWrapper(f_orig))
58+
for i in 1:chunk_size:(length(y) ÷ B)
59+
idxs = i:min(i + chunk_size - 1, length(y) ÷ B)
60+
partials′ = make_onehot!(partials, idxs)
61+
J_partials′ = make_zero!(J_partials, idxs)
62+
Enzyme.autodiff(
63+
ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′)
64+
)
65+
for (idx, J_partial) in zip(idxs, J_partials)
66+
copyto!(view(J, idx, :, :), reshape(J_partial, :, B))
67+
end
68+
end
69+
70+
return J
4171
end
4272

4373
function make_onehot!(partials, idxs)
@@ -48,3 +78,19 @@ function make_onehot!(partials, idxs)
4878
end
4979
return partials[1:length(idxs)]
5080
end
81+
82+
function make_zero!(partials, idxs)
83+
for partial in partials
84+
fill!(partial, false)
85+
end
86+
return partials[1:length(idxs)]
87+
end
88+
89+
@concrete struct OOPFunctionWrapper
90+
f
91+
end
92+
93+
function (f::OOPFunctionWrapper)(y, x)
94+
copyto!(y, f.f(x))
95+
return
96+
end

src/autodiff/api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ the following properties for `y = f(x)`:
8686
8787
## Backends & AD Packages
8888
89-
| Supported Backends | Packages Needed | Note |
89+
| Supported Backends | Packages Needed | Notes |
9090
|:------------------ |:--------------- |:---------------------------------------------- |
9191
| `AutoForwardDiff` | | |
9292
| `AutoZygote` | `Zygote.jl` | |

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ using Lux
7575
@test_throws ErrorException vector_jacobian_product(
7676
x -> x, AutoZygote(), rand(2), rand(2))
7777

78-
@test_throws ArgumentError batched_jacobian(x -> x, AutoEnzyme(), rand(2, 2))
78+
@test_throws ArgumentError batched_jacobian(x -> x, AutoTracker(), rand(2, 2))
7979
@test_throws ErrorException batched_jacobian(x -> x, AutoZygote(), rand(2, 2))
8080
end
8181

0 commit comments

Comments
 (0)