Skip to content

Commit f4bd1c6

Browse files
Merge branch 'main' into compathelper/new_version/2024-08-24-00-11-57-238-00614299169
2 parents 838bad5 + 2f6b0eb commit f4bd1c6

8 files changed

+1246
-494
lines changed

.github/workflows/SpellCheck.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ jobs:
1010
- name: Checkout Actions Repository
1111
uses: actions/checkout@v4
1212
- name: Check spelling
13-
uses: crate-ci/typos@v1.23.6
13+
uses: crate-ci/typos@v1.24.3

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ ArrayInterface = "7.6"
4343
DifferentiationInterface = "0.5"
4444
DocStringExtensions = "0.9"
4545
Enzyme = "0.12.12"
46+
FastClosures = "0.3"
4647
FiniteDiff = "2.12"
4748
ForwardDiff = "0.10.26"
4849
LinearAlgebra = "1.9, 1.10"
@@ -52,6 +53,7 @@ Reexport = "1.2"
5253
Requires = "1"
5354
ReverseDiff = "1.14"
5455
SciMLBase = "2"
56+
SparseConnectivityTracer = "0.6"
5557
SparseMatrixColorings = "0.4"
5658
SymbolicAnalysis = "0.3"
5759
SymbolicIndexingInterface = "0.3"

ext/OptimizationEnzymeExt.jl

+48-26
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function hv_f2_alloc(x, f, p)
4343
Enzyme.autodiff_deferred(Enzyme.Reverse,
4444
firstapply,
4545
Active,
46-
f,
46+
Const(f),
4747
Enzyme.Duplicated(x, dx),
4848
Const(p)
4949
)
@@ -105,7 +105,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
105105
)
106106
end
107107
elseif g == true
108-
grad = (G, θ) -> f.grad(G, θ, p)
108+
grad = (G, θ, p = p) -> f.grad(G, θ, p)
109109
else
110110
grad = nothing
111111
end
@@ -123,7 +123,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
123123
return y
124124
end
125125
elseif fg == true
126-
fg! = (res, θ) -> f.fg(res, θ, p)
126+
fg! = (res, θ, p = p) -> f.fg(res, θ, p)
127127
else
128128
fg! = nothing
129129
end
@@ -139,7 +139,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
139139
vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype)))
140140
end
141141

142-
function hess(res, θ)
142+
function hess(res, θ, p = p)
143143
Enzyme.make_zero!(bθ)
144144
Enzyme.make_zero!.(vdbθ)
145145

@@ -156,13 +156,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
156156
end
157157
end
158158
elseif h == true
159-
hess = (H, θ) -> f.hess(H, θ, p)
159+
hess = (H, θ, p = p) -> f.hess(H, θ, p)
160160
else
161161
hess = nothing
162162
end
163163

164164
if fgh == true && f.fgh === nothing
165-
function fgh!(G, H, θ)
165+
function fgh!(G, H, θ, p = p)
166166
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
167167
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))
168168

@@ -179,20 +179,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
179179
end
180180
end
181181
elseif fgh == true
182-
fgh! = (G, H, θ) -> f.fgh(G, H, θ, p)
182+
fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p)
183183
else
184184
fgh! = nothing
185185
end
186186

187187
if hv == true && f.hv === nothing
188-
function hv!(H, θ, v)
188+
function hv!(H, θ, v, p = p)
189189
H .= Enzyme.autodiff(
190190
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
191-
Const(_f), Const(f.f), Const(p)
191+
Const(f.f), Const(p)
192192
)[1]
193193
end
194194
elseif hv == true
195-
hv! = (H, θ, v) -> f.hv(H, θ, v, p)
195+
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
196196
else
197197
hv! = nothing
198198
end
@@ -247,7 +247,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
247247
cons_j! = nothing
248248
end
249249

250-
if cons !== nothing && cons_vjp == true && f.cons_vjp == true
250+
if cons !== nothing && cons_vjp == true && f.cons_vjp === nothing
251251
cons_res = zeros(eltype(x), num_cons)
252252
function cons_vjp!(res, θ, v)
253253
Enzyme.make_zero!(res)
@@ -267,7 +267,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
267267
cons_vjp! = nothing
268268
end
269269

270-
if cons !== nothing && cons_jvp == true && f.cons_jvp == true
270+
if cons !== nothing && cons_jvp == true && f.cons_jvp === nothing
271271
cons_res = zeros(eltype(x), num_cons)
272272

273273
function cons_jvp!(res, θ, v)
@@ -327,7 +327,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
327327
lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype)))
328328
end
329329

330-
function lag_h!(h, θ, σ, μ)
330+
function lag_h!(h, θ, σ, μ, p = p)
331331
Enzyme.make_zero!(lag_bθ)
332332
Enzyme.make_zero!.(lag_vdbθ)
333333

@@ -350,8 +350,30 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
350350
k += i
351351
end
352352
end
353+
354+
function lag_h!(H::AbstractMatrix, θ, σ, μ, p = p)
355+
Enzyme.make_zero!(H)
356+
Enzyme.make_zero!(lag_bθ)
357+
Enzyme.make_zero!.(lag_vdbθ)
358+
359+
Enzyme.autodiff(Enzyme.Forward,
360+
lag_grad,
361+
Enzyme.BatchDuplicated(θ, lag_vdθ),
362+
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
363+
Const(lagrangian),
364+
Const(f.f),
365+
Const(f.cons),
366+
Const(p),
367+
Const(σ),
368+
Const(μ)
369+
)
370+
371+
for i in eachindex(θ)
372+
H[i, :] .= lag_vdbθ[i]
373+
end
374+
end
353375
elseif lag_h == true && cons !== nothing
354-
lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
376+
lag_h! = (θ, σ, μ, p = p) -> f.lag_h(θ, σ, μ, p)
355377
else
356378
lag_h! = nothing
357379
end
@@ -389,7 +411,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
389411
lag_h = false)
390412
if g == true && f.grad === nothing
391413
res = zeros(eltype(x), size(x))
392-
function grad(θ)
414+
function grad, p = p)
393415
Enzyme.make_zero!(res)
394416
Enzyme.autodiff(Enzyme.Reverse,
395417
Const(firstapply),
@@ -401,14 +423,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
401423
return res
402424
end
403425
elseif fg == true
404-
grad = (θ) -> f.grad(θ, p)
426+
grad =, p = p) -> f.grad(θ, p)
405427
else
406428
grad = nothing
407429
end
408430

409431
if fg == true && f.fg === nothing
410432
res_fg = zeros(eltype(x), size(x))
411-
function fg!(θ)
433+
function fg!, p = p)
412434
Enzyme.make_zero!(res_fg)
413435
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
414436
Const(firstapply),
@@ -420,7 +442,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
420442
return y, res
421443
end
422444
elseif fg == true
423-
fg! = (θ) -> f.fg(θ, p)
445+
fg! =, p = p) -> f.fg(θ, p)
424446
else
425447
fg! = nothing
426448
end
@@ -430,7 +452,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
430452
= zeros(eltype(x), length(x))
431453
vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
432454

433-
function hess(θ)
455+
function hess, p = p)
434456
Enzyme.make_zero!(bθ)
435457
Enzyme.make_zero!.(vdbθ)
436458

@@ -446,7 +468,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
446468
vcat, [reshape(vdbθ[i], (1, length(vdbθ[i]))) for i in eachindex(θ)])
447469
end
448470
elseif h == true
449-
hess = (θ) -> f.hess(θ, p)
471+
hess =, p = p) -> f.hess(θ, p)
450472
else
451473
hess = nothing
452474
end
@@ -457,7 +479,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457479
G_fgh = zeros(eltype(x), length(x))
458480
H_fgh = zeros(eltype(x), length(x), length(x))
459481

460-
function fgh!(θ)
482+
function fgh!, p = p)
461483
Enzyme.make_zero!(G_fgh)
462484
Enzyme.make_zero!(H_fgh)
463485
Enzyme.make_zero!.(vdbθ_fgh)
@@ -476,20 +498,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
476498
return G_fgh, H_fgh
477499
end
478500
elseif fgh == true
479-
fgh! = (θ) -> f.fgh(θ, p)
501+
fgh! =, p = p) -> f.fgh(θ, p)
480502
else
481503
fgh! = nothing
482504
end
483505

484506
if hv == true && f.hv === nothing
485-
function hv!(θ, v)
507+
function hv!(θ, v, p = p)
486508
return Enzyme.autodiff(
487509
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
488510
Const(_f), Const(f.f), Const(p)
489511
)[1]
490512
end
491513
elseif hv == true
492-
hv! = (θ, v) -> f.hv(θ, v, p)
514+
hv! = (θ, v, p = p) -> f.hv(θ, v, p)
493515
else
494516
hv! = f.hv
495517
end
@@ -604,7 +626,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
604626
lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype)))
605627
end
606628

607-
function lag_h!(θ, σ, μ)
629+
function lag_h!(θ, σ, μ, p = p)
608630
Enzyme.make_zero!(lag_bθ)
609631
Enzyme.make_zero!.(lag_vdbθ)
610632

@@ -630,7 +652,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
630652
return res
631653
end
632654
elseif lag_h == true && cons !== nothing
633-
lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
655+
lag_h! = (θ, σ, μ, p = p) -> f.lag_h(θ, σ, μ, p)
634656
else
635657
lag_h! = nothing
636658
end

0 commit comments

Comments
 (0)