Skip to content

Commit 6d4f6d7

Browse files
hessian of lagrangian
1 parent eb74297 commit 6d4f6d7

File tree

2 files changed

+68
-46
lines changed

2 files changed

+68
-46
lines changed

ext/OptimizationEnzymeExt.jl

+47-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module OptimizationEnzymeExt
33
import OptimizationBase, OptimizationBase.ArrayInterface
44
import OptimizationBase.SciMLBase: OptimizationFunction
55
import OptimizationBase.SciMLBase
6-
import OptimizationBase.LinearAlgebra: I
6+
import OptimizationBase.LinearAlgebra: I, dot
77
import OptimizationBase.ADTypes: AutoEnzyme
88
using Enzyme
99
using Core: Vararg
@@ -76,6 +76,18 @@ function cons_f2_oop(x, dx, fcons, p, i)
7676
return nothing
7777
end
7878

79+
function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))::Float64
80+
res = zeros(eltype(x), length(λ))
81+
cons(res, x, p)
82+
return σ * _f(x, p) + dot(λ, res)
83+
end
84+
85+
function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
86+
Enzyme.autodiff_deferred(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx),
87+
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
88+
return nothing
89+
end
90+
7991
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
8092
adtype::AutoEnzyme, p,
8193
num_cons = 0)
@@ -219,7 +231,40 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
219231
end
220232

221233
if f.lag_h === nothing
222-
lag_h = nothing # Consider implementing this
234+
lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
235+
lag_bθ = zeros(eltype(x), length(x))
236+
237+
if f.hess_prototype === nothing
238+
lag_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
239+
else
240+
#useless right now, looks like there is no way to tell Enzyme the sparsity pattern?
241+
lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype)))
242+
end
243+
244+
function lag_h(h, θ, σ, μ)
245+
Enzyme.make_zero!.(lag_vdθ)
246+
Enzyme.make_zero!(lag_bθ)
247+
Enzyme.make_zero!.(lag_vdbθ)
248+
249+
Enzyme.autodiff(Enzyme.Forward,
250+
lag_grad,
251+
Enzyme.BatchDuplicated(θ, lag_vdθ),
252+
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
253+
Const(lagrangian),
254+
Const(f.f),
255+
Const(f.cons),
256+
Const(p),
257+
Const(σ),
258+
Const(μ)
259+
)
260+
k = 0
261+
262+
for i in eachindex(θ)
263+
vec_lagv = lag_vdbθ[i]
264+
h[k+1:k+i] .= @view(vec_lagv[1:i])
265+
k += i
266+
end
267+
end
223268
else
224269
lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
225270
end

src/OptimizationDISparseExt.jl

+21-44
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,12 @@ function instantiate_function(
174174
conshess_colors = f.cons_hess_colorvec
175175
if cons !== nothing && f.cons_h === nothing
176176
fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons]
177-
extras_cons_hess = Vector{DifferentiationInterface.SparseHessianExtras}(undef, length(fncs))
178-
for ind in 1:num_cons
179-
extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x)
180-
end
181-
conshess_sparsity = [sum(sparse, cons)]
182-
conshess_colors = getfield.(extras_cons_hess, Ref(:colors))
177+
# extras_cons_hess = Vector(undef, length(fncs))
178+
# for ind in 1:num_cons
179+
# extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x)
180+
# end
181+
# conshess_sparsity = getfield.(extras_cons_hess, :sparsity)
182+
# conshess_colors = getfield.(extras_cons_hess, :colors)
183183
function cons_h(H, θ)
184184
for i in 1:num_cons
185185
hessian!(fncs[i], H[i], soadtype, θ)
@@ -189,56 +189,33 @@ function instantiate_function(
189189
cons_h = (res, θ) -> f.cons_h(res, θ, p)
190190
end
191191

192-
function lagrangian(x, σ = one(eltype(x)))
193-
θ = x[1:end-num_cons]
194-
λ = x[end-num_cons+1:end]
195-
return σ * _f(θ) + dot(λ, cons_oop(θ))
192+
function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons))
193+
return σ * _f(x) + dot(λ, cons_oop(x))
196194
end
197195

196+
lag_hess_prototype = f.lag_hess_prototype
198197
if f.lag_h === nothing
199-
lag_extras = prepare_hessian(lagrangian, soadtype, vcat(x, ones(eltype(x), num_cons)))
198+
lag_extras = prepare_hessian(lagrangian, soadtype, x)
200199
lag_hess_prototype = lag_extras.sparsity
201-
202-
function lag_h(H::Matrix, θ, σ, λ)
203-
@show size(H)
204-
@show size(θ)
205-
@show size(λ)
200+
201+
function lag_h(H::AbstractMatrix, θ, σ, λ)
206202
if σ == zero(eltype(θ))
207203
cons_h(H, θ)
208204
H *= λ
209205
else
210-
hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras)
206+
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
211207
end
212208
end
213209

214210
function lag_h(h, θ, σ, λ)
215-
# @show h
216-
sparseHproto = findnz(lag_extras.sparsity)
217-
H = sparse(sparseHproto[1], sparseHproto[2], zeros(eltype(θ), length(sparseHproto[1])))
218-
if σ == zero(eltype(θ))
219-
cons_h(H, θ)
220-
H *= λ
221-
else
222-
hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras)
223-
k = 0
224-
rows, cols, _ = findnz(H)
225-
for (i, j) in zip(rows, cols)
226-
if i <= j
227-
k += 1
228-
h[k] = σ * H[i, j]
229-
end
230-
end
231-
k = 0
232-
for λi in λ
233-
if Hi isa SparseMatrixCSC
234-
rows, cols, _ = findnz(Hi)
235-
for (i, j) in zip(rows, cols)
236-
if i <= j
237-
k += 1
238-
h[k] += λi * Hi[i, j]
239-
end
240-
end
241-
end
211+
H = eltype(θ).(lag_hess_prototype)
212+
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
213+
k = 0
214+
rows, cols, _ = findnz(H)
215+
for (i, j) in zip(rows, cols)
216+
if i <= j
217+
k += 1
218+
h[k] = H[i, j]
242219
end
243220
end
244221
end

0 commit comments

Comments
 (0)