@@ -43,7 +43,7 @@ function hv_f2_alloc(x, f, p)
43
43
Enzyme. autodiff_deferred (Enzyme. Reverse,
44
44
firstapply,
45
45
Active,
46
- f ,
46
+ Const (f) ,
47
47
Enzyme. Duplicated (x, dx),
48
48
Const (p)
49
49
)
@@ -105,7 +105,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
105
105
)
106
106
end
107
107
elseif g == true
108
- grad = (G, θ) -> f. grad (G, θ, p)
108
+ grad = (G, θ, p = p ) -> f. grad (G, θ, p)
109
109
else
110
110
grad = nothing
111
111
end
@@ -123,7 +123,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
123
123
return y
124
124
end
125
125
elseif fg == true
126
- fg! = (res, θ) -> f. fg (res, θ, p)
126
+ fg! = (res, θ, p = p ) -> f. fg (res, θ, p)
127
127
else
128
128
fg! = nothing
129
129
end
@@ -139,7 +139,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
139
139
vdbθ = Tuple ((copy (r) for r in eachrow (f. hess_prototype)))
140
140
end
141
141
142
- function hess (res, θ)
142
+ function hess (res, θ, p = p )
143
143
Enzyme. make_zero! (bθ)
144
144
Enzyme. make_zero! .(vdbθ)
145
145
@@ -156,13 +156,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
156
156
end
157
157
end
158
158
elseif h == true
159
- hess = (H, θ) -> f. hess (H, θ, p)
159
+ hess = (H, θ, p = p ) -> f. hess (H, θ, p)
160
160
else
161
161
hess = nothing
162
162
end
163
163
164
164
if fgh == true && f. fgh === nothing
165
- function fgh! (G, H, θ)
165
+ function fgh! (G, H, θ, p = p )
166
166
vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
167
167
vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
168
168
@@ -179,20 +179,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
179
179
end
180
180
end
181
181
elseif fgh == true
182
- fgh! = (G, H, θ) -> f. fgh (G, H, θ, p)
182
+ fgh! = (G, H, θ, p = p ) -> f. fgh (G, H, θ, p)
183
183
else
184
184
fgh! = nothing
185
185
end
186
186
187
187
if hv == true && f. hv === nothing
188
- function hv! (H, θ, v)
188
+ function hv! (H, θ, v, p = p )
189
189
H .= Enzyme. autodiff (
190
190
Enzyme. Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
191
- Const (_f), Const ( f. f), Const (p)
191
+ Const (f. f), Const (p)
192
192
)[1 ]
193
193
end
194
194
elseif hv == true
195
- hv! = (H, θ, v) -> f. hv (H, θ, v, p)
195
+ hv! = (H, θ, v, p = p ) -> f. hv (H, θ, v, p)
196
196
else
197
197
hv! = nothing
198
198
end
@@ -247,7 +247,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
247
247
cons_j! = nothing
248
248
end
249
249
250
- if cons != = nothing && cons_vjp == true && f. cons_vjp == true
250
+ if cons != = nothing && cons_vjp == true && f. cons_vjp === nothing
251
251
cons_res = zeros (eltype (x), num_cons)
252
252
function cons_vjp! (res, θ, v)
253
253
Enzyme. make_zero! (res)
@@ -267,7 +267,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
267
267
cons_vjp! = nothing
268
268
end
269
269
270
- if cons != = nothing && cons_jvp == true && f. cons_jvp == true
270
+ if cons != = nothing && cons_jvp == true && f. cons_jvp === nothing
271
271
cons_res = zeros (eltype (x), num_cons)
272
272
273
273
function cons_jvp! (res, θ, v)
@@ -327,7 +327,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
327
327
lag_vdbθ = Tuple ((copy (r) for r in eachrow (f. hess_prototype)))
328
328
end
329
329
330
- function lag_h! (h, θ, σ, μ)
330
+ function lag_h! (h, θ, σ, μ, p = p )
331
331
Enzyme. make_zero! (lag_bθ)
332
332
Enzyme. make_zero! .(lag_vdbθ)
333
333
@@ -350,8 +350,30 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
350
350
k += i
351
351
end
352
352
end
353
+
354
+ function lag_h! (H:: AbstractMatrix , θ, σ, μ, p = p)
355
+ Enzyme. make_zero! (H)
356
+ Enzyme. make_zero! (lag_bθ)
357
+ Enzyme. make_zero! .(lag_vdbθ)
358
+
359
+ Enzyme. autodiff (Enzyme. Forward,
360
+ lag_grad,
361
+ Enzyme. BatchDuplicated (θ, lag_vdθ),
362
+ Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
363
+ Const (lagrangian),
364
+ Const (f. f),
365
+ Const (f. cons),
366
+ Const (p),
367
+ Const (σ),
368
+ Const (μ)
369
+ )
370
+
371
+ for i in eachindex (θ)
372
+ H[i, :] .= lag_vdbθ[i]
373
+ end
374
+ end
353
375
elseif lag_h == true && cons != = nothing
354
- lag_h! = (θ, σ, μ) -> f. lag_h (θ, σ, μ, p)
376
+ lag_h! = (θ, σ, μ, p = p ) -> f. lag_h (θ, σ, μ, p)
355
377
else
356
378
lag_h! = nothing
357
379
end
@@ -389,7 +411,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
389
411
lag_h = false )
390
412
if g == true && f. grad === nothing
391
413
res = zeros (eltype (x), size (x))
392
- function grad (θ)
414
+ function grad (θ, p = p )
393
415
Enzyme. make_zero! (res)
394
416
Enzyme. autodiff (Enzyme. Reverse,
395
417
Const (firstapply),
@@ -401,14 +423,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
401
423
return res
402
424
end
403
425
elseif fg == true
404
- grad = (θ) -> f. grad (θ, p)
426
+ grad = (θ, p = p ) -> f. grad (θ, p)
405
427
else
406
428
grad = nothing
407
429
end
408
430
409
431
if fg == true && f. fg === nothing
410
432
res_fg = zeros (eltype (x), size (x))
411
- function fg! (θ)
433
+ function fg! (θ, p = p )
412
434
Enzyme. make_zero! (res_fg)
413
435
y = Enzyme. autodiff (Enzyme. ReverseWithPrimal,
414
436
Const (firstapply),
@@ -420,7 +442,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
420
442
return y, res
421
443
end
422
444
elseif fg == true
423
- fg! = (θ) -> f. fg (θ, p)
445
+ fg! = (θ, p = p ) -> f. fg (θ, p)
424
446
else
425
447
fg! = nothing
426
448
end
@@ -430,7 +452,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
430
452
bθ = zeros (eltype (x), length (x))
431
453
vdbθ = Tuple (zeros (eltype (x), length (x)) for i in eachindex (x))
432
454
433
- function hess (θ)
455
+ function hess (θ, p = p )
434
456
Enzyme. make_zero! (bθ)
435
457
Enzyme. make_zero! .(vdbθ)
436
458
@@ -446,7 +468,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
446
468
vcat, [reshape (vdbθ[i], (1 , length (vdbθ[i]))) for i in eachindex (θ)])
447
469
end
448
470
elseif h == true
449
- hess = (θ) -> f. hess (θ, p)
471
+ hess = (θ, p = p ) -> f. hess (θ, p)
450
472
else
451
473
hess = nothing
452
474
end
@@ -457,7 +479,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457
479
G_fgh = zeros (eltype (x), length (x))
458
480
H_fgh = zeros (eltype (x), length (x), length (x))
459
481
460
- function fgh! (θ)
482
+ function fgh! (θ, p = p )
461
483
Enzyme. make_zero! (G_fgh)
462
484
Enzyme. make_zero! (H_fgh)
463
485
Enzyme. make_zero! .(vdbθ_fgh)
@@ -476,20 +498,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
476
498
return G_fgh, H_fgh
477
499
end
478
500
elseif fgh == true
479
- fgh! = (θ) -> f. fgh (θ, p)
501
+ fgh! = (θ, p = p ) -> f. fgh (θ, p)
480
502
else
481
503
fgh! = nothing
482
504
end
483
505
484
506
if hv == true && f. hv === nothing
485
- function hv! (θ, v)
507
+ function hv! (θ, v, p = p )
486
508
return Enzyme. autodiff (
487
509
Enzyme. Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
488
510
Const (_f), Const (f. f), Const (p)
489
511
)[1 ]
490
512
end
491
513
elseif hv == true
492
- hv! = (θ, v) -> f. hv (θ, v, p)
514
+ hv! = (θ, v, p = p ) -> f. hv (θ, v, p)
493
515
else
494
516
hv! = f. hv
495
517
end
@@ -604,7 +626,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
604
626
lag_vdbθ = Tuple ((copy (r) for r in eachrow (f. hess_prototype)))
605
627
end
606
628
607
- function lag_h! (θ, σ, μ)
629
+ function lag_h! (θ, σ, μ, p = p )
608
630
Enzyme. make_zero! (lag_bθ)
609
631
Enzyme. make_zero! .(lag_vdbθ)
610
632
@@ -630,7 +652,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
630
652
return res
631
653
end
632
654
elseif lag_h == true && cons != = nothing
633
- lag_h! = (θ, σ, μ) -> f. lag_h (θ, σ, μ, p)
655
+ lag_h! = (θ, σ, μ, p = p ) -> f. lag_h (θ, σ, μ, p)
634
656
else
635
657
lag_h! = nothing
636
658
end
0 commit comments