@@ -4,7 +4,8 @@ import OptimizationBase.SciMLBase: OptimizationFunction
4
4
import OptimizationBase. LinearAlgebra: I
5
5
import DifferentiationInterface
6
6
import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
7
- prepare_jacobian,
7
+ prepare_jacobian, value_and_gradient!, value_and_gradient,
8
+ value_derivative_and_second_derivative!, value_derivative_and_second_derivative,
8
9
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
9
10
hvp, jacobian
10
11
using ADTypes, SciMLBase
@@ -26,8 +27,9 @@ function instantiate_function(
26
27
g = false , h = false , hv = false , fg = false , fgh = false ,
27
28
cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
28
29
lag_h = false )
30
+ global _p = p
29
31
function _f (θ)
30
- return f (θ, p )[1 ]
32
+ return f (θ, _p )[1 ]
31
33
end
32
34
33
35
adtype, soadtype = generate_adtype (adtype)
@@ -37,19 +39,41 @@ function instantiate_function(
37
39
function grad (res, θ)
38
40
gradient! (_f, res, adtype, θ, extras_grad)
39
41
end
42
+ if p != = SciMLBase. NullParameters () && p != = nothing
43
+ function grad (res, θ, p)
44
+ global _p = p
45
+ gradient! (_f, res, adtype, θ)
46
+ end
47
+ end
40
48
elseif g == true
41
49
grad = (G, θ) -> f. grad (G, θ, p)
50
+ if p != = SciMLBase. NullParameters () && p != = nothing
51
+ grad = (G, θ, p) -> f. grad (G, θ, p)
52
+ end
42
53
else
43
54
grad = nothing
44
55
end
45
56
46
57
if fg == true && f. fg === nothing
58
+ if g == false
59
+ extras_grad = prepare_gradient (_f, adtype, x)
60
+ end
47
61
function fg! (res, θ)
48
62
(y, _) = value_and_gradient! (_f, res, adtype, θ, extras_grad)
49
63
return y
50
64
end
65
+ if p != = SciMLBase. NullParameters () && p != = nothing
66
+ function fg! (res, θ, p)
67
+ global _p = p
68
+ (y, _) = value_and_gradient! (_f, res, adtype, θ)
69
+ return y
70
+ end
71
+ end
51
72
elseif fg == true
52
73
fg! = (G, θ) -> f. fg (G, θ, p)
74
+ if p != = SciMLBase. NullParameters ()
75
+ fg! = (G, θ, p) -> f. fg (G, θ, p)
76
+ end
53
77
else
54
78
fg! = nothing
55
79
end
@@ -196,7 +220,8 @@ function instantiate_function(
196
220
lag_h! = nothing
197
221
end
198
222
199
- return OptimizationFunction {true} (f. f, adtype; grad = grad, hess = hess, hv = hv!,
223
+ return OptimizationFunction {true} (f. f, adtype;
224
+ grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
200
225
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
201
226
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
202
227
hess_prototype = hess_sparsity,
@@ -232,8 +257,9 @@ function instantiate_function(
232
257
g = false , h = false , hv = false , fg = false , fgh = false ,
233
258
cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
234
259
lag_h = false )
260
+ global _p = p
235
261
function _f (θ)
236
- return f (θ, p )[1 ]
262
+ return f (θ, _p )[1 ]
237
263
end
238
264
239
265
adtype, soadtype = generate_adtype (adtype)
@@ -243,19 +269,41 @@ function instantiate_function(
243
269
function grad (θ)
244
270
gradient (_f, adtype, θ, extras_grad)
245
271
end
272
+ if p != = SciMLBase. NullParameters () && p != = nothing
273
+ function grad (θ, p)
274
+ global _p = p
275
+ gradient (_f, adtype, θ)
276
+ end
277
+ end
246
278
elseif g == true
247
279
grad = (θ) -> f. grad (θ, p)
280
+ if p != = SciMLBase. NullParameters () && p != = nothing
281
+ grad = (θ, p) -> f. grad (θ, p)
282
+ end
248
283
else
249
284
grad = nothing
250
285
end
251
286
252
287
if fg == true && f. fg === nothing
288
+ if g == false
289
+ extras_grad = prepare_gradient (_f, adtype, x)
290
+ end
253
291
function fg! (θ)
254
292
(y, res) = value_and_gradient (_f, adtype, θ, extras_grad)
255
293
return y, res
256
294
end
295
+ if p != = SciMLBase. NullParameters () && p != = nothing
296
+ function fg! (θ, p)
297
+ global _p = p
298
+ (y, res) = value_and_gradient (_f, adtype, θ)
299
+ return y, res
300
+ end
301
+ end
257
302
elseif fg == true
258
303
fg! = (θ) -> f. fg (θ, p)
304
+ if p != = SciMLBase. NullParameters () && p != = nothing
305
+ fg! = (θ, p) -> f. fg (θ, p)
306
+ end
259
307
else
260
308
fg! = nothing
261
309
end
@@ -387,7 +435,8 @@ function instantiate_function(
387
435
lag_h! = nothing
388
436
end
389
437
390
- return OptimizationFunction {false} (f. f, adtype; grad = grad, hess = hess, hv = hv!,
438
+ return OptimizationFunction {false} (f. f, adtype;
439
+ grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
391
440
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
392
441
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
393
442
hess_prototype = hess_sparsity,
0 commit comments