Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ OptimizationZygoteExt = "Zygote"
ADTypes = "1.3"
ArrayInterface = "7.6"
DocStringExtensions = "0.9"
Enzyme = "0.11.11, =0.12.6"
Enzyme = "0.12.12"
FiniteDiff = "2.12"
ForwardDiff = "0.10.26"
LinearAlgebra = "1.9, 1.10"
Expand Down
145 changes: 50 additions & 95 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,47 @@ isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme)
end
end

function inner_grad(θ, bθ, f, p, args::Vararg{Any, N}) where N
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end

function hv_f2_alloc(x, f, p, args...)
dx = Enzyme.make_zero(x)
Enzyme.autodiff_deferred(Enzyme.Reverse,
firstapply,
Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end

function inner_cons(x, p, num_cons, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function inner_cons(x, p, num_cons, i)
function inner_cons(f, x, p, num_cons, i)

res = zeros(eltype(x), num_cons)
f.cons(res, x, p)
return res[i]
end

function cons_f2(x, dx, fcons, p, num_cons, i)
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx), Const(p), Const(num_cons), Const(i))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx), Const(p), Const(num_cons), Const(i))
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, fcons, Enzyme.Duplicated(x, dx), Const(p), Const(num_cons), Const(i))

Won't we need to zero and duplicate the function if it's a closure?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow, can you add more detail what you mean by zeroing the function? It doesn't need to be duplicated it should be Const iiuc (done that in #60)

return nothing
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
adtype::AutoEnzyme, p,
num_cons = 0)
if f.grad === nothing
grad = let
function (res, θ, args...)
res .= zero(eltype(res))
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Expand All @@ -36,24 +70,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end
function hess(res, θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -69,19 +93,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if f.hv === nothing
function f2(x, f, p, args...)
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse,
firstapply,
Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end
hv = function (H, θ, v, args...)
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
H .= Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
Const(_f), Const(f.f), Const(p),
Const.(args)...)[1]
end
Expand Down Expand Up @@ -109,19 +122,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if cons !== nothing && f.cons_h === nothing
fncs = map(1:num_cons) do i
function (x)
res = zeros(eltype(x), num_cons)
f.cons(res, x, p)
return res[i]
end
end

function f2(x, dx, fnc)
Enzyme.autodiff_deferred(Enzyme.Reverse, fnc, Enzyme.Duplicated(x, dx))
return nothing
end

cons_h = function (res, θ)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))
bθ = zeros(length(θ))
Expand All @@ -132,10 +132,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
el .= zeros(length(θ))
end
Enzyme.autodiff(Enzyme.Forward,
f2,
cons_f2,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(fncs[i]))
Const(f.cons),
Const(p),
Const(num_const),
Const(i)
)

for j in eachindex(θ)
res[i][j, :] .= vdbθ[j]
Expand All @@ -161,7 +165,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},

if f.grad === nothing
function grad(res, θ, args...)
res .= zero(eltype(res))
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Expand All @@ -175,21 +179,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(firstapply), Active, Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...)
return nothing
end
function hess(res, θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -205,17 +201,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
end

if f.hv === nothing
function f2(x, f, p, args...)
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse, firstapply, Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end
hv = function (H, θ, v, args...)
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
H .= Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
Const(f.f), Const(p),
Const.(args)...)[1]
end
Expand Down Expand Up @@ -294,24 +281,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end
function hess(θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand Down Expand Up @@ -418,7 +395,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
res = zeros(eltype(x), size(x))
grad = let res = res
function (θ, args...)
res .= zero(eltype(res))
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Expand All @@ -434,24 +411,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end
function hess(θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -465,20 +432,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
end

if f.hv === nothing
dx = zeros(length(x))
function f2(x, f, p, args...)
dx .= zero(eltype(dx))
Enzyme.autodiff_deferred(Enzyme.Reverse,
firstapply,
Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end
hv = function (θ, v, args...)
Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
Const(_f), Const(f.f), Const(p),
Const.(args)...)[1]
end
Expand Down