11function Lux. AutoDiffInternalImpl. batched_jacobian_impl(
22 f:: F , ad:: AutoEnzyme , x:: AbstractArray ) where {F}
33 backend = normalize_backend(True(), ad)
4- return batched_enzyme_jacobian_impl(
5- annotate_function(ad, f), backend, ADTypes. mode(backend), x)
4+ return batched_enzyme_jacobian_impl(f, backend, ADTypes. mode(backend), x)
65end
76
87function batched_enzyme_jacobian_impl(
9- f :: F , ad:: AutoEnzyme , :: ForwardMode , x:: AbstractArray ) where {F }
8+ f_orig :: G , ad:: AutoEnzyme , :: ForwardMode , x:: AbstractArray ) where {G }
109 # We need to run the function once to get the output type. Can we use ForwardWithPrimal?
11- y = f(x)
10+ y = f_orig(x)
11+ f = annotate_function(ad, f_orig)
1212
1313 @argcheck y isa AbstractArray MethodError
1414 if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x))
@@ -36,8 +36,38 @@ function batched_enzyme_jacobian_impl(
3636end
3737
3838function batched_enzyme_jacobian_impl(
39- f:: F , ad:: AutoEnzyme , :: ReverseMode , x:: AbstractArray ) where {F}
40- error(" reverse mode is not supported yet" )
39+ f_orig:: G , ad:: AutoEnzyme , :: ReverseMode , x:: AbstractArray ) where {G}
40+ # We need to run the function once to get the output type. Can we use ReverseWithPrimal?
41+ y = f_orig(x)
42+
43+ @argcheck y isa AbstractArray MethodError
44+ if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x))
45+ throw(AssertionError(" `batched_jacobian` only supports batched outputs \
46+ (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))." ))
47+ end
48+ B = size(y, ndims(y))
49+
50+ J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1 : (end - 1 )]),
51+ prod(size(x)[1 : (end - 1 )]), B)
52+
53+ chunk_size = min(8 , length(x) ÷ B)
54+ partials = ntuple(_ -> zero(y), chunk_size)
55+ J_partials = ntuple(_ -> zero(x), chunk_size)
56+
57+ fn = annotate_function(ad, OOPFunctionWrapper(f_orig))
58+ for i in 1 : chunk_size: (length(y) ÷ B)
59+ idxs = i: min(i + chunk_size - 1 , length(y) ÷ B)
60+ partials′ = make_onehot!(partials, idxs)
61+ J_partials′ = make_zero!(J_partials, idxs)
62+ Enzyme. autodiff(
63+ ad. mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′)
64+ )
65+ for (idx, J_partial) in zip(idxs, J_partials)
66+ copyto!(view(J, idx, :, :), reshape(J_partial, :, B))
67+ end
68+ end
69+
70+ return J
4171end
4272
4373function make_onehot!(partials, idxs)
@@ -48,3 +78,19 @@ function make_onehot!(partials, idxs)
4878 end
4979 return partials[1 : length(idxs)]
5080end
81+
82+ function make_zero!(partials, idxs)
83+ for partial in partials
84+ fill!(partial, false )
85+ end
86+ return partials[1 : length(idxs)]
87+ end
88+
89+ @concrete struct OOPFunctionWrapper
90+ f
91+ end
92+
93+ function (f:: OOPFunctionWrapper )(y, x)
94+ copyto!(y, f. f(x))
95+ return
96+ end
0 commit comments