@@ -15,7 +15,12 @@ function generate_sparse_adtype(adtype)
15
15
adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
16
16
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
17
17
coloring_algorithm = GreedyColoringAlgorithm ())
18
- if ! (adtype. dense_ad isa SciMLBase. NoAD) &&
18
+ if adtype. dense_ad isa ADTypes. AutoFiniteDiff
19
+ soadtype = AutoSparse (
20
+ DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
21
+ sparsity_detector = TracerSparsityDetector (),
22
+ coloring_algorithm = GreedyColoringAlgorithm ())
23
+ elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
19
24
ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
20
25
soadtype = AutoSparse (
21
26
DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
@@ -32,7 +37,12 @@ function generate_sparse_adtype(adtype)
32
37
! (adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm)
33
38
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
34
39
coloring_algorithm = adtype. coloring_algorithm)
35
- if ! (adtype. dense_ad isa SciMLBase. NoAD) &&
40
+ if adtype. dense_ad isa ADTypes. AutoFiniteDiff
41
+ soadtype = AutoSparse (
42
+ DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
43
+ sparsity_detector = TracerSparsityDetector (),
44
+ coloring_algorithm = adtype. coloring_algorithm)
45
+ elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
36
46
ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
37
47
soadtype = AutoSparse (
38
48
DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
@@ -49,7 +59,12 @@ function generate_sparse_adtype(adtype)
49
59
adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
50
60
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = adtype. sparsity_detector,
51
61
coloring_algorithm = GreedyColoringAlgorithm ())
52
- if ! (adtype. dense_ad isa SciMLBase. NoAD) &&
62
+ if adtype. dense_ad isa ADTypes. AutoFiniteDiff
63
+ soadtype = AutoSparse (
64
+ DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
65
+ sparsity_detector = adtype. sparsity_detector,
66
+ coloring_algorithm = GreedyColoringAlgorithm ())
67
+ elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
53
68
ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
54
69
soadtype = AutoSparse (
55
70
DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
@@ -63,7 +78,12 @@ function generate_sparse_adtype(adtype)
63
78
coloring_algorithm = GreedyColoringAlgorithm ())
64
79
end
65
80
else
66
- if ! (adtype. dense_ad isa SciMLBase. NoAD) &&
81
+ if adtype. dense_ad isa ADTypes. AutoFiniteDiff
82
+ soadtype = AutoSparse (
83
+ DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
84
+ sparsity_detector = adtype. sparsity_detector,
85
+ coloring_algorithm = adtype. coloring_algorithm)
86
+ elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
67
87
ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
68
88
soadtype = AutoSparse (
69
89
DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
@@ -80,7 +100,7 @@ function generate_sparse_adtype(adtype)
80
100
return adtype, soadtype
81
101
end
82
102
83
- function OptimizationBase . instantiate_function (
103
+ function instantiate_function (
84
104
f:: OptimizationFunction{true} , x, adtype:: ADTypes.AutoSparse{<:AbstractADType} ,
85
105
p = SciMLBase. NullParameters (), num_cons = 0 )
86
106
_f = (θ, args... ) -> first (f. f (θ, p, args... ))
@@ -108,9 +128,9 @@ function OptimizationBase.instantiate_function(
108
128
end
109
129
110
130
if f. hv === nothing
111
- extras_hvp = prepare_hvp (_f, soadtype, x, rand (size (x)))
131
+ extras_hvp = prepare_hvp (_f, soadtype. dense_ad , x, rand (size (x)))
112
132
hv = function (H, θ, v, args... )
113
- hvp! (_f, H, soadtype, θ, v, extras_hvp)
133
+ hvp! (_f, H, soadtype. dense_ad , θ, v, extras_hvp)
114
134
end
115
135
else
116
136
hv = f. hv
@@ -168,7 +188,7 @@ function OptimizationBase.instantiate_function(
168
188
lag_h, f. lag_hess_prototype)
169
189
end
170
190
171
- function OptimizationBase . instantiate_function (
191
+ function instantiate_function (
172
192
f:: OptimizationFunction{true} , cache:: OptimizationBase.ReInitCache ,
173
193
adtype:: ADTypes.AutoSparse{<:AbstractADType} , num_cons = 0 )
174
194
x = cache. u0
@@ -198,9 +218,9 @@ function OptimizationBase.instantiate_function(
198
218
end
199
219
200
220
if f. hv === nothing
201
- extras_hvp = prepare_hvp (_f, soadtype, x, rand (size (x)))
221
+ extras_hvp = prepare_hvp (_f, soadtype. dense_ad , x, rand (size (x)))
202
222
hv = function (H, θ, v, args... )
203
- hvp! (_f, H, soadtype, θ, v, extras_hvp)
223
+ hvp! (_f, H, soadtype. dense_ad , θ, v, extras_hvp)
204
224
end
205
225
else
206
226
hv = f. hv
@@ -258,7 +278,7 @@ function OptimizationBase.instantiate_function(
258
278
lag_h, f. lag_hess_prototype)
259
279
end
260
280
261
- function OptimizationBase . instantiate_function (
281
+ function instantiate_function (
262
282
f:: OptimizationFunction{false} , x, adtype:: ADTypes.AutoSparse{<:AbstractADType} ,
263
283
p = SciMLBase. NullParameters (), num_cons = 0 )
264
284
_f = (θ, args... ) -> first (f. f (θ, p, args... ))
@@ -286,9 +306,9 @@ function OptimizationBase.instantiate_function(
286
306
end
287
307
288
308
if f. hv === nothing
289
- extras_hvp = prepare_hvp (_f, soadtype, x, rand (size (x)))
309
+ extras_hvp = prepare_hvp (_f, soadtype. dense_ad , x, rand (size (x)))
290
310
hv = function (θ, v, args... )
291
- hvp (_f, soadtype, θ, v, extras_hvp)
311
+ hvp (_f, soadtype. dense_ad , θ, v, extras_hvp)
292
312
end
293
313
else
294
314
hv = f. hv
@@ -348,7 +368,7 @@ function OptimizationBase.instantiate_function(
348
368
lag_h, f. lag_hess_prototype)
349
369
end
350
370
351
- function OptimizationBase . instantiate_function (
371
+ function instantiate_function (
352
372
f:: OptimizationFunction{false} , cache:: OptimizationBase.ReInitCache ,
353
373
adtype:: ADTypes.AutoSparse{<:AbstractADType} , num_cons = 0 )
354
374
x = cache. u0
@@ -378,9 +398,9 @@ function OptimizationBase.instantiate_function(
378
398
end
379
399
380
400
if f. hv === nothing
381
- extras_hvp = prepare_hvp (_f, soadtype, x, rand (size (x)))
401
+ extras_hvp = prepare_hvp (_f, soadtype. dense_ad , x, rand (size (x)))
382
402
hv = function (θ, v, args... )
383
- hvp (_f, soadtype, θ, v, extras_hvp)
403
+ hvp (_f, soadtype. dense_ad , θ, v, extras_hvp)
384
404
end
385
405
else
386
406
hv = f. hv
0 commit comments