Skip to content

Commit 4ba2a4c

Browse files
minor updates to get tests passing
1 parent f14086d commit 4ba2a4c

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

ext/OptimizationZygoteExt.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ function OptimizationBase.instantiate_function(
220220
if f.lag_h === nothing && cons !== nothing && lag_h == true
221221
lag_extras = prepare_hessian(
222222
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
223-
lag_hess_prototype = zeros(Bool, length(x), length(x))
223+
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
224224

225225
function lag_h!(H::AbstractMatrix, θ, σ, λ)
226226
if σ == zero(eltype(θ))
@@ -232,13 +232,11 @@ function OptimizationBase.instantiate_function(
232232
end
233233
end
234234

235-
function lag_h!(h, θ, σ, λ)
236-
H = eltype(θ).(lag_hess_prototype)
237-
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
235+
function lag_h!(h::AbstractVector, θ, σ, λ)
236+
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
238237
k = 0
239-
rows, cols, _ = findnz(H)
240-
for (i, j) in zip(rows, cols)
241-
if i <= j
238+
for i in 1:length(θ)
239+
for j in 1:i
242240
k += 1
243241
h[k] = H[i, j]
244242
end
@@ -442,7 +440,7 @@ function OptimizationBase.instantiate_function(
442440
θ = augvars[1:length(x)]
443441
σ = augvars[length(x) + 1]
444442
λ = augvars[(length(x) + 2):end]
445-
return σ * _f(θ) + dot(λ, cons(θ))
443+
return σ * _f(θ) + dot(λ, cons_oop(θ))
446444
end
447445
end
448446

@@ -465,7 +463,8 @@ function OptimizationBase.instantiate_function(
465463
end
466464

467465
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
468-
extras_pullback = prepare_pullback(cons_oop, adtype, x)
466+
extras_pullback = prepare_pullback(
467+
cons_oop, adtype.dense_ad, x, ones(eltype(x), num_cons))
469468
function cons_vjp!(J, θ, v)
470469
pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback)
471470
end
@@ -476,7 +475,8 @@ function OptimizationBase.instantiate_function(
476475
end
477476

478477
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
479-
extras_pushforward = prepare_pushforward(cons_oop, adtype, x)
478+
extras_pushforward = prepare_pushforward(
479+
cons_oop, adtype.dense_ad, x, ones(eltype(x), length(x)))
480480
function cons_jvp!(J, θ, v)
481481
pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward)
482482
end
@@ -513,7 +513,7 @@ function OptimizationBase.instantiate_function(
513513
if cons !== nothing && f.lag_h === nothing && lag_h == true
514514
lag_extras = prepare_hessian(
515515
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
516-
lag_hess_prototype = lag_extras.coloring_result.S[1:length(θ), 1:length(θ)]
516+
lag_hess_prototype = lag_extras.coloring_result.S[1:length(x), 1:length(x)]
517517
lag_hess_colors = lag_extras.coloring_result.color
518518

519519
function lag_h!(H::AbstractMatrix, θ, σ, λ)

src/OptimizationDIExt.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ function instantiate_function(
229229
if cons !== nothing && lag_h == true && f.lag_h === nothing
230230
lag_extras = prepare_hessian(
231231
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
232-
lag_hess_prototype = zeros(Bool, length(x), length(x))
232+
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
233233

234234
function lag_h!(H::AbstractMatrix, θ, σ, λ)
235235
if σ == zero(eltype(θ))
@@ -507,7 +507,7 @@ function instantiate_function(
507507
if cons !== nothing && lag_h == true && f.lag_h === nothing
508508
lag_extras = prepare_hessian(
509509
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
510-
lag_hess_prototype = zeros(Bool, length(x), length(x))
510+
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
511511

512512
function lag_h!(θ, σ, λ)
513513
if σ == zero(eltype(θ))

src/OptimizationDISparseExt.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ function instantiate_function(
266266
end
267267

268268
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
269-
extras_pullback = prepare_pullback(cons_oop, adtype, x, ones(eltype(x), num_cons))
269+
extras_pullback = prepare_pullback(
270+
cons_oop, adtype.dense_ad, x, ones(eltype(x), num_cons))
270271
function cons_vjp!(J, θ, v)
271272
pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback)
272273
end
@@ -278,7 +279,7 @@ function instantiate_function(
278279

279280
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
280281
extras_pushforward = prepare_pushforward(
281-
cons_oop, adtype, x, ones(eltype(x), length(x)))
282+
cons_oop, adtype.dense_ad, x, ones(eltype(x), length(x)))
282283
function cons_jvp!(J, θ, v)
283284
pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward)
284285
end
@@ -557,9 +558,10 @@ function instantiate_function(
557558
end
558559

559560
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
560-
extras_pullback = prepare_pullback(cons, adtype, x, ones(eltype(x), num_cons))
561+
extras_pullback = prepare_pullback(
562+
cons, adtype.dense_ad, x, ones(eltype(x), num_cons))
561563
function cons_vjp!(θ, v)
562-
pullback(cons, adtype, θ, v, extras_pullback)
564+
pullback(cons, adtype.dense_ad, θ, v, extras_pullback)
563565
end
564566
elseif cons_vjp === true && cons !== nothing
565567
cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p)
@@ -569,9 +571,9 @@ function instantiate_function(
569571

570572
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
571573
extras_pushforward = prepare_pushforward(
572-
cons, adtype, x, ones(eltype(x), length(x)))
574+
cons, adtype.dense_ad, x, ones(eltype(x), length(x)))
573575
function cons_jvp!(θ, v)
574-
pushforward(cons, adtype, θ, v, extras_pushforward)
576+
pushforward(cons, adtype.dense_ad, θ, v, extras_pushforward)
575577
end
576578
elseif cons_jvp === true && cons !== nothing
577579
cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p)

0 commit comments

Comments
 (0)