Skip to content

Commit c161954

Browse files
Hack p to serve as data arg and implement stochastic gradient oracle
1 parent bbfdf96 commit c161954

6 files changed

+230
-16
lines changed

ext/OptimizationEnzymeExt.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
9494
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
9595
lag_h = false)
9696
if g == true && f.grad === nothing
97-
function grad(res, θ)
97+
function grad(res, θ, p = p)
9898
Enzyme.make_zero!(res)
9999
Enzyme.autodiff(Enzyme.Reverse,
100100
Const(firstapply),
@@ -111,7 +111,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
111111
end
112112

113113
if fg == true && f.fg === nothing
114-
function fg!(res, θ)
114+
function fg!(res, θ, p = p)
115115
Enzyme.make_zero!(res)
116116
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
117117
Const(firstapply),

ext/OptimizationZygoteExt.jl

+51-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import OptimizationBase.SciMLBase: OptimizationFunction
77
import OptimizationBase.LinearAlgebra: I, dot
88
import DifferentiationInterface
99
import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
10-
prepare_jacobian,
10+
prepare_jacobian, value_and_gradient!, value_derivative_and_second_derivative!,
1111
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
1212
hvp, jacobian
1313
using ADTypes, SciMLBase
@@ -19,8 +19,9 @@ function OptimizationBase.instantiate_function(
1919
g = false, h = false, hv = false, fg = false, fgh = false,
2020
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
2121
lag_h = false)
22+
global _p = p
2223
function _f(θ)
23-
return f(θ, p)[1]
24+
return f(θ, _p)[1]
2425
end
2526

2627
adtype, soadtype = OptimizationBase.generate_adtype(adtype)
@@ -30,19 +31,41 @@ function OptimizationBase.instantiate_function(
3031
function grad(res, θ)
3132
gradient!(_f, res, adtype, θ, extras_grad)
3233
end
34+
if p !== SciMLBase.NullParameters() && p !== nothing
35+
function grad(res, θ, p)
36+
global _p = p
37+
gradient!(_f, res, adtype, θ)
38+
end
39+
end
3340
elseif g == true
3441
grad = (G, θ) -> f.grad(G, θ, p)
42+
if p !== SciMLBase.NullParameters() && p !== nothing
43+
grad = (G, θ, p) -> f.grad(G, θ, p)
44+
end
3545
else
3646
grad = nothing
3747
end
3848

3949
if fg == true && f.fg === nothing
50+
if g == false
51+
extras_grad = prepare_gradient(_f, adtype, x)
52+
end
4053
function fg!(res, θ)
4154
(y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad)
4255
return y
4356
end
57+
if p !== SciMLBase.NullParameters() && p !== nothing
58+
function fg!(res, θ, p)
59+
global _p = p
60+
(y, _) = value_and_gradient!(_f, res, adtype, θ)
61+
return y
62+
end
63+
end
4464
elseif fg == true
4565
fg! = (G, θ) -> f.fg(G, θ, p)
66+
if p !== SciMLBase.NullParameters() && p !== nothing
67+
fg! = (G, θ, p) -> f.fg(G, θ, p)
68+
end
4669
else
4770
fg! = nothing
4871
end
@@ -188,7 +211,8 @@ function OptimizationBase.instantiate_function(
188211
lag_h! = nothing
189212
end
190213

191-
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!,
214+
return OptimizationFunction{true}(f.f, adtype;
215+
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
192216
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
193217
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
194218
hess_prototype = hess_sparsity,
@@ -232,19 +256,41 @@ function OptimizationBase.instantiate_function(
232256
function grad(res, θ)
233257
gradient!(_f, res, adtype.dense_ad, θ, extras_grad)
234258
end
259+
if p !== SciMLBase.NullParameters() && p !== nothing
260+
function grad(res, θ, p)
261+
global p = p
262+
gradient!(_f, res, adtype.dense_ad, θ)
263+
end
264+
end
235265
elseif g == true
236266
grad = (G, θ) -> f.grad(G, θ, p)
267+
if p !== SciMLBase.NullParameters() && p !== nothing
268+
grad = (G, θ, p) -> f.grad(G, θ, p)
269+
end
237270
else
238271
grad = nothing
239272
end
240273

241274
if fg == true && f.fg !== nothing
275+
if g == false
276+
extras_grad = prepare_gradient(_f, adtype.dense_ad, x)
277+
end
242278
function fg!(res, θ)
243279
(y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad)
244280
return y
245281
end
282+
if p !== SciMLBase.NullParameters() && p !== nothing
283+
function fg!(res, θ, p)
284+
global p = p
285+
(y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ)
286+
return y
287+
end
288+
end
246289
elseif fg == true
247290
fg! = (G, θ) -> f.fg(G, θ, p)
291+
if p !== SciMLBase.NullParameters() && p !== nothing
292+
fg! = (G, θ, p) -> f.fg(G, θ, p)
293+
end
248294
else
249295
fg! = nothing
250296
end
@@ -398,7 +444,8 @@ function OptimizationBase.instantiate_function(
398444
else
399445
lag_h! = nothing
400446
end
401-
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!,
447+
return OptimizationFunction{true}(f.f, adtype;
448+
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
402449
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
403450
hess_prototype = hess_sparsity,
404451
hess_colorvec = hess_colors,

src/OptimizationDIExt.jl

+54-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import OptimizationBase.SciMLBase: OptimizationFunction
44
import OptimizationBase.LinearAlgebra: I
55
import DifferentiationInterface
66
import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
7-
prepare_jacobian,
7+
prepare_jacobian, value_and_gradient!, value_and_gradient,
8+
value_derivative_and_second_derivative!, value_derivative_and_second_derivative,
89
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
910
hvp, jacobian
1011
using ADTypes, SciMLBase
@@ -26,8 +27,9 @@ function instantiate_function(
2627
g = false, h = false, hv = false, fg = false, fgh = false,
2728
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
2829
lag_h = false)
30+
global _p = p
2931
function _f(θ)
30-
return f(θ, p)[1]
32+
return f(θ, _p)[1]
3133
end
3234

3335
adtype, soadtype = generate_adtype(adtype)
@@ -37,19 +39,41 @@ function instantiate_function(
3739
function grad(res, θ)
3840
gradient!(_f, res, adtype, θ, extras_grad)
3941
end
42+
if p !== SciMLBase.NullParameters() && p !== nothing
43+
function grad(res, θ, p)
44+
global _p = p
45+
gradient!(_f, res, adtype, θ)
46+
end
47+
end
4048
elseif g == true
4149
grad = (G, θ) -> f.grad(G, θ, p)
50+
if p !== SciMLBase.NullParameters() && p !== nothing
51+
grad = (G, θ, p) -> f.grad(G, θ, p)
52+
end
4253
else
4354
grad = nothing
4455
end
4556

4657
if fg == true && f.fg === nothing
58+
if g == false
59+
extras_grad = prepare_gradient(_f, adtype, x)
60+
end
4761
function fg!(res, θ)
4862
(y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad)
4963
return y
5064
end
65+
if p !== SciMLBase.NullParameters() && p !== nothing
66+
function fg!(res, θ, p)
67+
global _p = p
68+
(y, _) = value_and_gradient!(_f, res, adtype, θ)
69+
return y
70+
end
71+
end
5172
elseif fg == true
5273
fg! = (G, θ) -> f.fg(G, θ, p)
74+
if p !== SciMLBase.NullParameters()
75+
fg! = (G, θ, p) -> f.fg(G, θ, p)
76+
end
5377
else
5478
fg! = nothing
5579
end
@@ -196,7 +220,8 @@ function instantiate_function(
196220
lag_h! = nothing
197221
end
198222

199-
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!,
223+
return OptimizationFunction{true}(f.f, adtype;
224+
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
200225
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
201226
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
202227
hess_prototype = hess_sparsity,
@@ -232,8 +257,9 @@ function instantiate_function(
232257
g = false, h = false, hv = false, fg = false, fgh = false,
233258
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
234259
lag_h = false)
260+
global _p = p
235261
function _f(θ)
236-
return f(θ, p)[1]
262+
return f(θ, _p)[1]
237263
end
238264

239265
adtype, soadtype = generate_adtype(adtype)
@@ -243,19 +269,41 @@ function instantiate_function(
243269
function grad(θ)
244270
gradient(_f, adtype, θ, extras_grad)
245271
end
272+
if p !== SciMLBase.NullParameters() && p !== nothing
273+
function grad(θ, p)
274+
global _p = p
275+
gradient(_f, adtype, θ)
276+
end
277+
end
246278
elseif g == true
247279
grad = (θ) -> f.grad(θ, p)
280+
if p !== SciMLBase.NullParameters() && p !== nothing
281+
grad = (θ, p) -> f.grad(θ, p)
282+
end
248283
else
249284
grad = nothing
250285
end
251286

252287
if fg == true && f.fg === nothing
288+
if g == false
289+
extras_grad = prepare_gradient(_f, adtype, x)
290+
end
253291
function fg!(θ)
254292
(y, res) = value_and_gradient(_f, adtype, θ, extras_grad)
255293
return y, res
256294
end
295+
if p !== SciMLBase.NullParameters() && p !== nothing
296+
function fg!(θ, p)
297+
global _p = p
298+
(y, res) = value_and_gradient(_f, adtype, θ)
299+
return y, res
300+
end
301+
end
257302
elseif fg == true
258303
fg! = (θ) -> f.fg(θ, p)
304+
if p !== SciMLBase.NullParameters() && p !== nothing
305+
fg! = (θ, p) -> f.fg(θ, p)
306+
end
259307
else
260308
fg! = nothing
261309
end
@@ -387,7 +435,8 @@ function instantiate_function(
387435
lag_h! = nothing
388436
end
389437

390-
return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv!,
438+
return OptimizationFunction{false}(f.f, adtype;
439+
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
391440
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
392441
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
393442
hess_prototype = hess_sparsity,

0 commit comments

Comments
 (0)