Skip to content

Commit 043a01c

Browse files
Enzyme updates correctly
1 parent 94061d8 commit 043a01c

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

ext/OptimizationEnzymeExt.jl

+14-12
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ using Core: Vararg
1818
end
1919

2020
function inner_grad(θ, bθ, f, p)
21-
Enzyme.autodiff(Enzyme.Reverse,
22-
firstapply,
21+
Enzyme.autodiff_deferred(Enzyme.Reverse,
22+
Const(firstapply),
2323
Active,
2424
Const(f),
2525
Enzyme.Duplicated(θ, bθ),
@@ -29,8 +29,9 @@ function inner_grad(θ, bθ, f, p)
2929
end
3030

3131
function inner_grad_primal(θ, bθ, f, p)
32-
Enzyme.autodiff(Enzyme.ReverseWithPrimal,
33-
firstapply,
32+
Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal,
33+
Const(firstapply),
34+
Active,
3435
Const(f),
3536
Enzyme.Duplicated(θ, bθ),
3637
Const(p)
@@ -39,8 +40,9 @@ end
3940

4041
function hv_f2_alloc(x, f, p)
4142
dx = Enzyme.make_zero(x)
42-
Enzyme.autodiff(Enzyme.Reverse,
43-
firstapply,
43+
Enzyme.autodiff_deferred(Enzyme.Reverse,
44+
Const(firstapply),
45+
Active,
4446
Const(f),
4547
Enzyme.Duplicated(x, dx),
4648
Const(p)
@@ -56,7 +58,7 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5658
end
5759

5860
function cons_f2(x, dx, fcons, p, num_cons, i)
59-
Enzyme.autodiff(Enzyme.Reverse, inner_cons, Enzyme.Duplicated(x, dx),
61+
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
6062
Const(fcons), Const(p), Const(num_cons), Const(i))
6163
return nothing
6264
end
@@ -68,8 +70,8 @@ function inner_cons_oop(
6870
end
6971

7072
function cons_f2_oop(x, dx, fcons, p, i)
71-
Enzyme.autodiff(
72-
Enzyme.Reverse, inner_cons_oop, Enzyme.Duplicated(x, dx),
73+
Enzyme.autodiff_deferred(
74+
Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
7375
Const(fcons), Const(p), Const(i))
7476
return nothing
7577
end
@@ -81,7 +83,7 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8183
end
8284

8385
function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
84-
Enzyme.autodiff(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx),
86+
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
8587
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
8688
return nothing
8789
end
@@ -185,7 +187,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
185187
if hv == true && f.hv === nothing
186188
function hv!(H, θ, v, p = p)
187189
H .= Enzyme.autodiff(
188-
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
190+
Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v),
189191
Const(f.f), Const(p)
190192
)[1]
191193
end
@@ -529,7 +531,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
529531
for i in eachindex(Jaccache)
530532
Enzyme.make_zero!(Jaccache[i])
531533
end
532-
y, Jaccache = Enzyme.autodiff(Enzyme.Forward, f.cons, Duplicated,
534+
Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated,
533535
BatchDuplicated(θ, seeds), Const(p))
534536
if size(y, 1) == 1
535537
return reduce(vcat, Jaccache)

src/OptimizationDISparseExt.jl

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ function instantiate_function(
212212
function cons_oop(x, i)
213213
_res = zeros(eltype(x))
214214
f.cons(_res, x, p)
215+
@show _res
215216
return _res[i]
216217
end
217218

test/adtests.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ optprob.cons_h(H3, x0)
164164
optprob.cons(res, x0)
165165
@test res == [0.0]
166166
J = Array{Float64}(undef, 2)
167-
@test_broken optprob.cons_j(J, [5.0, 3.0])
168-
# @test J == [10.0, 6.0]
167+
optprob.cons_j(J, [5.0, 3.0])
168+
@test J == [10.0, 6.0]
169169
vJ = Array{Float64}(undef, 2)
170170
optprob.cons_vjp(vJ, [5.0, 3.0], [1.0])
171171
@test vJ == [10.0, 6.0]
@@ -202,8 +202,8 @@ optprob.cons_h(H3, x0)
202202
optprob.cons(res, x0)
203203
@test res == [0.0]
204204
J = Array{Float64}(undef, 2)
205-
@test_broken optprob.cons_j(J, [5.0, 3.0])
206-
# @test J == [10.0, 6.0]
205+
@test optprob.cons_j(J, [5.0, 3.0])
206+
@test J == [10.0, 6.0]
207207
vJ = Array{Float64}(undef, 2)
208208
optprob.cons_vjp(vJ, [5.0, 3.0], [1.0])
209209
@test vJ == [10.0, 6.0]

0 commit comments

Comments
 (0)