Skip to content

Commit 2bad384

Browse files
more of the switches, and in tests too
1 parent 8fb865d commit 2bad384

7 files changed

+202
-123
lines changed

ext/OptimizationEnzymeExt.jl

+4-6
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
9393
g = false, h = false, hv = false, fg = false, fgh = false,
9494
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
9595
lag_h = false)
96-
9796
if g == true && f.grad === nothing
9897
function grad(res, θ)
9998
Enzyme.make_zero!(res)
@@ -351,7 +350,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
351350
k += i
352351
end
353352
end
354-
elseif lag_h == true && cons !== nothing
353+
elseif lag_h == true && cons !== nothing
355354
lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
356355
else
357356
lag_h! = nothing
@@ -384,11 +383,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
384383
end
385384

386385
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x,
387-
adtype::AutoEnzyme, p, num_cons = 0;
386+
adtype::AutoEnzyme, p, num_cons = 0;
388387
g = false, h = false, hv = false, fg = false, fgh = false,
389388
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
390389
lag_h = false)
391-
392390
if g == true && f.grad === nothing
393391
res = zeros(eltype(x), size(x))
394392
function grad(θ)
@@ -637,10 +635,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
637635
lag_h! = nothing
638636
end
639637

640-
return OptimizationFunction{false}(f.f, adtype; grad = grad,
638+
return OptimizationFunction{false}(f.f, adtype; grad = grad,
641639
fg = fg!, fgh = fgh!,
642640
hess = hess, hv = hv!,
643-
cons = cons, cons_j = cons_j!,
641+
cons = cons, cons_j = cons_j!,
644642
cons_jvp = cons_jvp!, cons_vjp = cons_vjp!,
645643
cons_h = cons_h!,
646644
hess_prototype = f.hess_prototype,

ext/OptimizationMTKExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ function OptimizationBase.instantiate_function(
169169
num_cons))))
170170
#sys = ModelingToolkit.structural_simplify(sys)
171171
f = OptimizationProblem(sys, cache.u0, cache.p, grad = g, hess = h,
172-
sparse = false, cons_j = cons_j, cons_h = cons_h,
172+
sparse = false, cons_j = cons_j, cons_h = cons_h,
173173
cons_sparse = false).f
174174

175175
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)

0 commit comments

Comments
 (0)