diff --git a/Project.toml b/Project.toml index 714e701..4973e88 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ OptimizationZygoteExt = "Zygote" ADTypes = "1.3" ArrayInterface = "7.6" DocStringExtensions = "0.9" -Enzyme = "0.11.11, =0.12.6" +Enzyme = "0.12.12" FiniteDiff = "2.12" ForwardDiff = "0.10.26" LinearAlgebra = "1.9, 1.10" diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 223d6d9..08f362d 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -15,13 +15,47 @@ isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme) end end +function inner_grad(θ, bθ, f, p, args::Vararg{Any, N}) where N + Enzyme.autodiff_deferred(Enzyme.Reverse, + Const(firstapply), + Active, + Const(f), + Enzyme.Duplicated(θ, bθ), + Const(p), + Const.(args)...), + return nothing +end + +function hv_f2_alloc(x, f, p, args...) + dx = Enzyme.make_zero(x) + Enzyme.autodiff_deferred(Enzyme.Reverse, + firstapply, + Active, + f, + Enzyme.Duplicated(x, dx), + Const(p), + Const.(args)...) + return dx +end + +function inner_cons(x, p, num_cons, i) + res = zeros(eltype(x), num_cons) + f.cons(res, x, p) + return res[i] +end + +function cons_f2(x, dx, fcons, p, num_cons, i) + Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx), Const(p), Const(num_cons), Const(i)) + return nothing +end + function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0) if f.grad === nothing grad = let function (res, θ, args...) - res .= zero(eltype(res)) + Enzyme.make_zero!(res) Enzyme.autodiff(Enzyme.Reverse, Const(firstapply), Active, @@ -36,16 +70,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if f.hess === nothing - function g(θ, bθ, f, p, args...) - Enzyme.autodiff_deferred(Enzyme.Reverse, - Const(firstapply), - Active, - Const(f), - Enzyme.Duplicated(θ, bθ), - Const(p), - Const.(args)...), - return nothing - end function hess(res, θ, args...) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) @@ -53,7 +77,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) Enzyme.autodiff(Enzyme.Forward, - g, + inner_grad, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), @@ -69,19 +93,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if f.hv === nothing - function f2(x, f, p, args...) - dx = zeros(length(x)) - Enzyme.autodiff_deferred(Enzyme.Reverse, - firstapply, - Active, - f, - Enzyme.Duplicated(x, dx), - Const(p), - Const.(args)...) - return dx - end hv = function (H, θ, v, args...) - H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v), + H .= Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), Const(_f), Const(f.f), Const(p), Const.(args)...)[1] end @@ -109,19 +122,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if cons !== nothing && f.cons_h === nothing - fncs = map(1:num_cons) do i - function (x) - res = zeros(eltype(x), num_cons) - f.cons(res, x, p) - return res[i] - end - end - - function f2(x, dx, fnc) - Enzyme.autodiff_deferred(Enzyme.Reverse, fnc, Enzyme.Duplicated(x, dx)) - return nothing - end - cons_h = function (res, θ) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) bθ = zeros(length(θ)) @@ -132,10 +132,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, el .= zeros(length(θ)) end Enzyme.autodiff(Enzyme.Forward, - f2, + cons_f2, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), - Const(fncs[i])) + Const(f.cons), + Const(p), + Const(num_cons), + Const(i) + ) for j in eachindex(θ) res[i][j, :] .= vdbθ[j] @@ -161,7 +165,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, if f.grad === nothing function grad(res, θ, args...) - res .= zero(eltype(res)) + Enzyme.make_zero!(res) Enzyme.autodiff(Enzyme.Reverse, Const(firstapply), Active, @@ -175,21 +179,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, end if f.hess === nothing - function g(θ, bθ, f, p, args...) - Enzyme.autodiff_deferred(Enzyme.Reverse, Const(firstapply), Active, Const(f), - Enzyme.Duplicated(θ, bθ), - Const(p), - Const.(args)...) - return nothing - end function hess(res, θ, args...) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) - bθ = zeros(length(θ)) vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) Enzyme.autodiff(Enzyme.Forward, - g, + inner_grad, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), @@ -205,17 +201,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, end if f.hv === nothing - function f2(x, f, p, args...) - dx = zeros(length(x)) - Enzyme.autodiff_deferred(Enzyme.Reverse, firstapply, Active, - f, - Enzyme.Duplicated(x, dx), - Const(p), - Const.(args)...) - return dx - end hv = function (H, θ, v, args...) - H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v), + H .= Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), Const(f.f), Const(p), Const.(args)...)[1] end @@ -294,16 +281,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x end if f.hess === nothing - function g(θ, bθ, f, p, args...) - Enzyme.autodiff_deferred(Enzyme.Reverse, - Const(firstapply), - Active, - Const(f), - Enzyme.Duplicated(θ, bθ), - Const(p), - Const.(args)...), - return nothing - end function hess(θ, args...) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) @@ -311,7 +288,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) Enzyme.autodiff(Enzyme.Forward, - g, + inner_grad, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), @@ -418,7 +395,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, res = zeros(eltype(x), size(x)) grad = let res = res function (θ, args...) - res .= zero(eltype(res)) + Enzyme.make_zero!(res) Enzyme.autodiff(Enzyme.Reverse, Const(firstapply), Active, @@ -434,16 +411,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, end if f.hess === nothing - function g(θ, bθ, f, p, args...) - Enzyme.autodiff_deferred(Enzyme.Reverse, - Const(firstapply), - Active, - Const(f), - Enzyme.Duplicated(θ, bθ), - Const(p), - Const.(args)...), - return nothing - end function hess(θ, args...) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) @@ -451,7 +418,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) Enzyme.autodiff(Enzyme.Forward, - g, + inner_grad, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), @@ -465,20 +432,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, end if f.hv === nothing - dx = zeros(length(x)) - function f2(x, f, p, args...) - dx .= zero(eltype(dx)) - Enzyme.autodiff_deferred(Enzyme.Reverse, - firstapply, - Active, - f, - Enzyme.Duplicated(x, dx), - Const(p), - Const.(args)...) - return dx - end hv = function (θ, v, args...) - Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v), + Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), Const(_f), Const(f.f), Const(p), Const.(args)...)[1] end