1
- module OptimizationDIExt
2
-
3
- import OptimizationBase, OptimizationBase. ArrayInterface
1
+ using OptimizationBase
2
+ import OptimizationBase. ArrayInterface
4
3
import OptimizationBase. SciMLBase: OptimizationFunction
5
4
import OptimizationBase. LinearAlgebra: I
6
5
import DifferentiationInterface
@@ -9,21 +8,48 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
9
8
using ADTypes
10
9
using SparseConnectivityTracer, SparseMatrixColorings
11
10
12
- function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x, adtype:: ADTypes.AutoSparse , p = SciMLBase. NullParameters (), num_cons = 0 )
13
- _f = (θ, args... ) -> first (f. f (θ, p, args... ))
14
-
11
+ function generate_sparse_adtype (adtype)
15
12
if adtype. sparsity_detector isa ADTypes. NoSparsityDetector && adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
16
13
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerLocalSparsityDetector (), coloring_algorithm = GreedyColoringAlgorithm ())
17
- elseif adtype. sparsity_detector isa ADTypes. NoSparsityDetector && ! (adtype. coloring_algorithm isa AbstractADTypes. NoColoringAlgorithm)
14
+ if ! (adtype. dense_ad isa SciMLBase. NoAD) && ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
15
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()), sparsity_detector = TracerLocalSparsityDetector (), coloring_algorithm = GreedyColoringAlgorithm ()) # make zygote?
16
+ elseif ! (adtype isa SciMLBase. NoAD) && ADTypes. mode (adtype) isa ADTypes. ReverseMode
17
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype), sparsity_detector = TracerLocalSparsityDetector (), coloring_algorithm = GreedyColoringAlgorithm ())
18
+ end
19
+ elseif adtype. sparsity_detector isa ADTypes. NoSparsityDetector && ! (adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm)
18
20
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerLocalSparsityDetector (), coloring_algorithm = adtype. coloring_algorithm)
21
+ if ! (adtype. dense_ad isa SciMLBase. NoAD) && ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
22
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()), sparsity_detector = TracerLocalSparsityDetector (), coloring_algorithm = adtype. coloring_algorithm)
23
+ elseif ! (adtype isa SciMLBase. NoAD) && ADTypes. mode (adtype) isa ADTypes. ReverseMode
24
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype), sparsity_detector = TracerLocalSparsityDetector (), coloring_algorithm = adtype. coloring_algorithm)
25
+ end
19
26
elseif ! (adtype. sparsity_detector isa ADTypes. NoSparsityDetector) && adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
20
27
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = adtype. sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm ())
28
+ if ! (adtype. dense_ad isa SciMLBase. NoAD) && ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
29
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()), sparsity_detector = adtype. sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm ())
30
+ elseif ! (adtype isa SciMLBase. NoAD) && ADTypes. mode (adtype) isa ADTypes. ReverseMode
31
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype), sparsity_detector = adtype. sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm ())
32
+ end
33
+ else
34
+ if ! (adtype. dense_ad isa SciMLBase. NoAD) && ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
35
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()), sparsity_detector = adtype. sparsity_detector, coloring_algorithm = adtype. coloring_algorithm)
36
+ elseif ! (adtype isa SciMLBase. NoAD) && ADTypes. mode (adtype) isa ADTypes. ReverseMode
37
+ soadtype = AutoSparse (DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype), sparsity_detector = adtype. sparsity_detector, coloring_algorithm = adtype. coloring_algorithm)
38
+ end
21
39
end
40
+ return adtype,soadtype
41
+ end
42
+
43
+
44
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x, adtype:: ADTypes.AutoSparse{<:AbstractADType} , p = SciMLBase. NullParameters (), num_cons = 0 )
45
+ _f = (θ, args... ) -> first (f. f (θ, p, args... ))
46
+
47
+ adtype, soadtype = generate_sparse_adtype (adtype)
22
48
23
49
if f. grad === nothing
24
- extras_grad = prepare_gradient (_f, adtype, x)
50
+ extras_grad = prepare_gradient (_f, adtype. dense_ad , x)
25
51
function grad (res, θ)
26
- gradient! (_f, res, adtype, θ, extras_grad)
52
+ gradient! (_f, res, adtype. dense_ad , θ, extras_grad)
27
53
end
28
54
else
29
55
grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
@@ -34,7 +60,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
34
60
if f. hess === nothing
35
61
extras_hess = prepare_hessian (_f, soadtype, x) # placeholder logic, can be made much better
36
62
function hess (res, θ, args... )
37
- hessian! (_f, res, adtype , θ, extras_hess)
63
+ hessian! (_f, res, soadtype , θ, extras_hess)
38
64
end
39
65
else
40
66
hess = (H, θ, args... ) -> f. hess (H, θ, p, args... )
@@ -81,7 +107,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
81
107
82
108
function cons_h (H, θ)
83
109
for i in 1 : num_cons
84
- hessian! (fncs[i], H[i], adtype , θ, extras_cons_hess[i])
110
+ hessian! (fncs[i], H[i], soadtype , θ, extras_cons_hess[i])
85
111
end
86
112
end
87
113
else
@@ -104,11 +130,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
104
130
lag_h, f. lag_hess_prototype)
105
131
end
106
132
107
- function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , cache:: OptimizationBase.ReInitCache , adtype:: ADTypes.AbstractADType , num_cons = 0 )
133
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , cache:: OptimizationBase.ReInitCache , adtype:: ADTypes.AutoSparse{<: AbstractADType} , num_cons = 0 )
108
134
x = cache. u0
109
135
p = cache. p
110
136
_f = (θ, args... ) -> first (f. f (θ, p, args... ))
111
- soadtype = DifferentiationInterface. SecondOrder (adtype, adtype)
137
+
138
+ adtype, soadtype = generate_sparse_adtype (adtype)
112
139
113
140
if f. grad === nothing
114
141
extras_grad = prepare_gradient (_f, adtype, x)
@@ -171,7 +198,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca
171
198
172
199
function cons_h (H, θ)
173
200
for i in 1 : num_cons
174
- hessian! (fncs[i], H[i], adtype , θ, extras_cons_hess[i])
201
+ hessian! (fncs[i], H[i], soadtype , θ, extras_cons_hess[i])
175
202
end
176
203
end
177
204
else
@@ -195,9 +222,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca
195
222
end
196
223
197
224
198
- function OptimizationBase. instantiate_function (f:: OptimizationFunction{false} , x, adtype:: ADTypes.AbstractADType , p = SciMLBase. NullParameters (), num_cons = 0 )
225
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{false} , x, adtype:: ADTypes.AutoSparse{<: AbstractADType} , p = SciMLBase. NullParameters (), num_cons = 0 )
199
226
_f = (θ, args... ) -> first (f. f (θ, p, args... ))
200
- soadtype = DifferentiationInterface. SecondOrder (adtype, adtype)
227
+
228
+ adtype, soadtype = generate_sparse_adtype (adtype)
201
229
202
230
if f. grad === nothing
203
231
extras_grad = prepare_gradient (_f, adtype, x)
@@ -213,7 +241,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
213
241
if f. hess === nothing
214
242
extras_hess = prepare_hessian (_f, soadtype, x) # placeholder logic, can be made much better
215
243
function hess (θ, args... )
216
- hessian (_f, adtype , θ, extras_hess)
244
+ hessian (_f, soadtype , θ, extras_hess)
217
245
end
218
246
else
219
247
hess = (θ, args... ) -> f. hess (θ, p, args... )
@@ -261,7 +289,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
261
289
262
290
function cons_h (θ)
263
291
H = map (1 : num_cons) do i
264
- hessian (fncs[i], adtype , θ, extras_cons_hess[i])
292
+ hessian (fncs[i], soadtype , θ, extras_cons_hess[i])
265
293
end
266
294
return H
267
295
end
@@ -285,11 +313,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
285
313
lag_h, f. lag_hess_prototype)
286
314
end
287
315
288
- function OptimizationBase. instantiate_function (f:: OptimizationFunction{false} , cache:: OptimizationBase.ReInitCache , adtype:: ADTypes.AbstractADType , num_cons = 0 )
316
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{false} , cache:: OptimizationBase.ReInitCache , adtype:: ADTypes.AutoSparse{<: AbstractADType} , num_cons = 0 )
289
317
x = cache. u0
290
318
p = cache. p
291
319
_f = (θ, args... ) -> first (f. f (θ, p, args... ))
292
- soadtype = DifferentiationInterface. SecondOrder (adtype, adtype)
320
+
321
+ adtype, soadtype = generate_sparse_adtype (adtype)
293
322
294
323
if f. grad === nothing
295
324
extras_grad = prepare_gradient (_f, adtype, x)
@@ -353,7 +382,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
353
382
354
383
function cons_h (θ)
355
384
H = map (1 : num_cons) do i
356
- hessian (fncs[i], adtype , θ, extras_cons_hess[i])
385
+ hessian (fncs[i], soadtype , θ, extras_cons_hess[i])
357
386
end
358
387
return H
359
388
end
@@ -376,5 +405,3 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
376
405
cons_hess_colorvec = conshess_colors,
377
406
lag_h, f. lag_hess_prototype)
378
407
end
379
-
380
- end
0 commit comments