Skip to content

Commit c1a5e1f

Browse files
fix sparse adtype passed to hvp
1 parent 35bd067 commit c1a5e1f

File tree

4 files changed

+98
-21
lines changed

4 files changed

+98
-21
lines changed

src/OptimizationBase.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ Base.length(::NullData) = 0
3131

3232
include("adtypes.jl")
3333
include("cache.jl")
34-
include("function.jl")
3534
include("OptimizationDIExt.jl")
3635
include("OptimizationDISparseExt.jl")
36+
include("function.jl")
3737

3838
export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA
3939

src/OptimizationDIExt.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
99
hvp, jacobian
1010
using ADTypes, SciMLBase
1111

12-
function OptimizationBase.instantiate_function(
12+
function instantiate_function(
1313
f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType,
1414
p = SciMLBase.NullParameters(), num_cons = 0)
1515
_f = (θ, args...) -> first(f.f(θ, p, args...))
@@ -103,7 +103,7 @@ function OptimizationBase.instantiate_function(
103103
lag_h, f.lag_hess_prototype)
104104
end
105105

106-
function OptimizationBase.instantiate_function(
106+
function instantiate_function(
107107
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
108108
adtype::ADTypes.AbstractADType, num_cons = 0)
109109
x = cache.u0
@@ -199,7 +199,7 @@ function OptimizationBase.instantiate_function(
199199
lag_h, f.lag_hess_prototype)
200200
end
201201

202-
function OptimizationBase.instantiate_function(
202+
function instantiate_function(
203203
f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType,
204204
p = SciMLBase.NullParameters(), num_cons = 0)
205205
_f = (θ, args...) -> first(f.f(θ, p, args...))
@@ -295,7 +295,7 @@ function OptimizationBase.instantiate_function(
295295
lag_h, f.lag_hess_prototype)
296296
end
297297

298-
function OptimizationBase.instantiate_function(
298+
function instantiate_function(
299299
f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache,
300300
adtype::ADTypes.AbstractADType, num_cons = 0)
301301
x = cache.u0

src/OptimizationDISparseExt.jl

+36-16
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ function generate_sparse_adtype(adtype)
1515
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
1616
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
1717
coloring_algorithm = GreedyColoringAlgorithm())
18-
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
18+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
19+
soadtype = AutoSparse(
20+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
21+
sparsity_detector = TracerSparsityDetector(),
22+
coloring_algorithm = GreedyColoringAlgorithm())
23+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
1924
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
2025
soadtype = AutoSparse(
2126
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
@@ -32,7 +37,12 @@ function generate_sparse_adtype(adtype)
3237
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
3338
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
3439
coloring_algorithm = adtype.coloring_algorithm)
35-
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
40+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
41+
soadtype = AutoSparse(
42+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
43+
sparsity_detector = TracerSparsityDetector(),
44+
coloring_algorithm = adtype.coloring_algorithm)
45+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
3646
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
3747
soadtype = AutoSparse(
3848
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
@@ -49,7 +59,12 @@ function generate_sparse_adtype(adtype)
4959
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
5060
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
5161
coloring_algorithm = GreedyColoringAlgorithm())
52-
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
62+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
63+
soadtype = AutoSparse(
64+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
65+
sparsity_detector = adtype.sparsity_detector,
66+
coloring_algorithm = GreedyColoringAlgorithm())
67+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
5368
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
5469
soadtype = AutoSparse(
5570
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
@@ -63,7 +78,12 @@ function generate_sparse_adtype(adtype)
6378
coloring_algorithm = GreedyColoringAlgorithm())
6479
end
6580
else
66-
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
81+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
82+
soadtype = AutoSparse(
83+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
84+
sparsity_detector = adtype.sparsity_detector,
85+
coloring_algorithm = adtype.coloring_algorithm)
86+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
6787
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
6888
soadtype = AutoSparse(
6989
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
@@ -80,7 +100,7 @@ function generate_sparse_adtype(adtype)
80100
return adtype, soadtype
81101
end
82102

83-
function OptimizationBase.instantiate_function(
103+
function instantiate_function(
84104
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType},
85105
p = SciMLBase.NullParameters(), num_cons = 0)
86106
_f = (θ, args...) -> first(f.f(θ, p, args...))
@@ -108,9 +128,9 @@ function OptimizationBase.instantiate_function(
108128
end
109129

110130
if f.hv === nothing
111-
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
131+
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
112132
hv = function (H, θ, v, args...)
113-
hvp!(_f, H, soadtype, θ, v, extras_hvp)
133+
hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp)
114134
end
115135
else
116136
hv = f.hv
@@ -168,7 +188,7 @@ function OptimizationBase.instantiate_function(
168188
lag_h, f.lag_hess_prototype)
169189
end
170190

171-
function OptimizationBase.instantiate_function(
191+
function instantiate_function(
172192
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
173193
adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0)
174194
x = cache.u0
@@ -198,9 +218,9 @@ function OptimizationBase.instantiate_function(
198218
end
199219

200220
if f.hv === nothing
201-
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
221+
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
202222
hv = function (H, θ, v, args...)
203-
hvp!(_f, H, soadtype, θ, v, extras_hvp)
223+
hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp)
204224
end
205225
else
206226
hv = f.hv
@@ -258,7 +278,7 @@ function OptimizationBase.instantiate_function(
258278
lag_h, f.lag_hess_prototype)
259279
end
260280

261-
function OptimizationBase.instantiate_function(
281+
function instantiate_function(
262282
f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType},
263283
p = SciMLBase.NullParameters(), num_cons = 0)
264284
_f = (θ, args...) -> first(f.f(θ, p, args...))
@@ -286,9 +306,9 @@ function OptimizationBase.instantiate_function(
286306
end
287307

288308
if f.hv === nothing
289-
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
309+
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
290310
hv = function (θ, v, args...)
291-
hvp(_f, soadtype, θ, v, extras_hvp)
311+
hvp(_f, soadtype.dense_ad, θ, v, extras_hvp)
292312
end
293313
else
294314
hv = f.hv
@@ -348,7 +368,7 @@ function OptimizationBase.instantiate_function(
348368
lag_h, f.lag_hess_prototype)
349369
end
350370

351-
function OptimizationBase.instantiate_function(
371+
function instantiate_function(
352372
f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache,
353373
adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0)
354374
x = cache.u0
@@ -378,9 +398,9 @@ function OptimizationBase.instantiate_function(
378398
end
379399

380400
if f.hv === nothing
381-
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
401+
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
382402
hv = function (θ, v, args...)
383-
hvp(_f, soadtype, θ, v, extras_hvp)
403+
hvp(_f, soadtype.dense_ad, θ, v, extras_hvp)
384404
end
385405
else
386406
hv = f.hv

src/function.jl

+57
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,63 @@ function that is not defined, an error is thrown.
4343
For more information on the use of automatic differentiation, see the
4444
documentation of the `AbstractADType` types.
4545
"""
46+
function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD,
47+
p, num_cons = 0)
48+
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
49+
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
50+
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
51+
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
52+
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
53+
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
54+
hess_prototype = f.hess_prototype === nothing ? nothing :
55+
convert.(eltype(x), f.hess_prototype)
56+
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
57+
convert.(eltype(x), f.cons_jac_prototype)
58+
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
59+
[convert.(eltype(x), f.cons_hess_prototype[i])
60+
for i in 1:num_cons]
61+
expr = symbolify(f.expr)
62+
cons_expr = symbolify.(f.cons_expr)
63+
64+
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
65+
hv = hv,
66+
cons = cons, cons_j = cons_j, cons_h = cons_h,
67+
hess_prototype = hess_prototype,
68+
cons_jac_prototype = cons_jac_prototype,
69+
cons_hess_prototype = cons_hess_prototype,
70+
expr = expr, cons_expr = cons_expr,
71+
sys = f.sys,
72+
observed = f.observed)
73+
end
74+
75+
function instantiate_function(f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD,
76+
num_cons = 0)
77+
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...)
78+
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...)
79+
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...)
80+
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p)
81+
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p)
82+
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p)
83+
hess_prototype = f.hess_prototype === nothing ? nothing :
84+
convert.(eltype(cache.u0), f.hess_prototype)
85+
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
86+
convert.(eltype(cache.u0), f.cons_jac_prototype)
87+
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
88+
[convert.(eltype(cache.u0), f.cons_hess_prototype[i])
89+
for i in 1:num_cons]
90+
expr = symbolify(f.expr)
91+
cons_expr = symbolify.(f.cons_expr)
92+
93+
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
94+
hv = hv,
95+
cons = cons, cons_j = cons_j, cons_h = cons_h,
96+
hess_prototype = hess_prototype,
97+
cons_jac_prototype = cons_jac_prototype,
98+
cons_hess_prototype = cons_hess_prototype,
99+
expr = expr, cons_expr = cons_expr,
100+
sys = f.sys,
101+
observed = f.observed)
102+
end
46103

47104
function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType,
48105
p, num_cons = 0)

0 commit comments

Comments
 (0)