Skip to content

Commit 96543c5

Browse files
committed
Update tangents
1 parent 63a3d42 commit 96543c5

File tree

3 files changed

+29
-29
lines changed

3 files changed

+29
-29
lines changed

src/OptimizationDIExt.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ function instantiate_function(
116116
end
117117

118118
if hv == true && f.hv === nothing
119-
prep_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x)))
119+
prep_hvp = prepare_hvp(_f, soadtype, x, (zeros(eltype(x), size(x)),))
120120
function hv!(H, θ, v)
121-
hvp!(_f, H, prep_hvp, soadtype, θ, v)
121+
only(hvp!(_f, (H,), prep_hvp, soadtype, θ, (v,)))
122122
end
123123
if p !== SciMLBase.NullParameters() && p !== nothing
124124
function hv!(H, θ, v, p)
125125
global _p = p
126-
hvp!(_f, H, soadtype, θ, v)
126+
only(hvp!(_f, (H,), soadtype, θ, (v,)))
127127
end
128128
end
129129
elseif hv == true
@@ -170,9 +170,9 @@ function instantiate_function(
170170
end
171171

172172
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
173-
prep_pullback = prepare_pullback(cons_oop, adtype, x, ones(eltype(x), num_cons))
173+
prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),))
174174
function cons_vjp!(J, θ, v)
175-
pullback!(cons_oop, J, prep_pullback, adtype, θ, v)
175+
only(pullback!(cons_oop, (J,), prep_pullback, adtype, θ, (v,)))
176176
end
177177
elseif cons_vjp == true && cons !== nothing
178178
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
@@ -182,9 +182,9 @@ function instantiate_function(
182182

183183
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
184184
prep_pushforward = prepare_pushforward(
185-
cons_oop, adtype, x, ones(eltype(x), length(x)))
185+
cons_oop, adtype, x, (ones(eltype(x), length(x)),))
186186
function cons_jvp!(J, θ, v)
187-
pushforward!(cons_oop, J, prep_pushforward, adtype, θ, v)
187+
only(pushforward!(cons_oop, (J,), prep_pushforward, adtype, θ, (v,)))
188188
end
189189
elseif cons_jvp == true && cons !== nothing
190190
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
@@ -383,14 +383,14 @@ function instantiate_function(
383383
end
384384

385385
if hv == true && f.hv === nothing
386-
prep_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x)))
386+
prep_hvp = prepare_hvp(_f, soadtype, x, (zeros(eltype(x), size(x)),))
387387
function hv!(θ, v)
388-
hvp(_f, prep_hvp, soadtype, θ, v)
388+
only(hvp(_f, prep_hvp, soadtype, θ, (v)))
389389
end
390390
if p !== SciMLBase.NullParameters() && p !== nothing
391391
function hv!(θ, v, p)
392392
global _p = p
393-
hvp(_f, prep_hvp, soadtype, θ, v)
393+
only(vp(_f, prep_hvp, soadtype, θ, (v,)))
394394
end
395395
end
396396
elseif hv == true
@@ -432,9 +432,9 @@ function instantiate_function(
432432
end
433433

434434
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
435-
prep_pullback = prepare_pullback(cons, adtype, x, ones(eltype(x), num_cons))
435+
prep_pullback = prepare_pullback(cons, adtype, x, (ones(eltype(x), num_cons),))
436436
function cons_vjp!(θ, v)
437-
return pullback(cons, prep_pullback, adtype, θ, v)
437+
return only(pullback(cons, prep_pullback, adtype, θ, (v,)))
438438
end
439439
elseif cons_vjp == true && cons !== nothing
440440
cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p)
@@ -444,9 +444,9 @@ function instantiate_function(
444444

445445
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
446446
prep_pushforward = prepare_pushforward(
447-
cons, adtype, x, ones(eltype(x), length(x)))
447+
cons, adtype, x, (ones(eltype(x), length(x)),))
448448
function cons_jvp!(θ, v)
449-
return pushforward(cons, prep_pushforward, adtype, θ, v)
449+
return only(pushforward(cons, prep_pushforward, adtype, θ, (v,)))
450450
end
451451
elseif cons_jvp == true && cons !== nothing
452452
cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p)

src/OptimizationDISparseExt.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@ function instantiate_function(
196196
end
197197

198198
if hv == true && f.hv === nothing
199-
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
199+
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, (zeros(eltype(x), size(x)),))
200200
function hv!(H, θ, v)
201-
hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, v)
201+
only(hvp!(_f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,)))
202202
end
203203
if p !== SciMLBase.NullParameters() && p !== nothing
204204
function hv!(H, θ, v, p)
205205
global _p = p
206-
hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, v)
206+
only(hvp!(_f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,)))
207207
end
208208
end
209209
elseif hv == true
@@ -253,9 +253,9 @@ function instantiate_function(
253253

254254
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
255255
prep_pullback = prepare_pullback(
256-
cons_oop, adtype.dense_ad, x, ones(eltype(x), num_cons))
256+
cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),))
257257
function cons_vjp!(J, θ, v)
258-
pullback!(cons_oop, J, prep_pullback, adtype.dense_ad, θ, v)
258+
only(pullback!(cons_oop, (J,), prep_pullback, adtype.dense_ad, θ, (v,)))
259259
end
260260
elseif cons_vjp === true && cons !== nothing
261261
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
@@ -265,9 +265,9 @@ function instantiate_function(
265265

266266
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
267267
prep_pushforward = prepare_pushforward(
268-
cons_oop, adtype.dense_ad, x, ones(eltype(x), length(x)))
268+
cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),))
269269
function cons_jvp!(J, θ, v)
270-
pushforward!(cons_oop, J, prep_pushforward, adtype.dense_ad, θ, v)
270+
only(pushforward!(cons_oop, (J,), prep_pushforward, adtype.dense_ad, θ, (v,)))
271271
end
272272
elseif cons_jvp === true && cons !== nothing
273273
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
@@ -480,15 +480,15 @@ function instantiate_function(
480480
end
481481

482482
if hv == true && f.hv === nothing
483-
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
483+
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, (zeros(eltype(x), size(x)),))
484484
function hv!(θ, v)
485-
hvp(_f, prep_hvp, soadtype.dense_ad, θ, v)
485+
only(hvp(_f, prep_hvp, soadtype.dense_ad, θ, (v,)))
486486
end
487487

488488
if p !== SciMLBase.NullParameters() && p !== nothing
489489
function hv!(θ, v, p)
490490
global _p = p
491-
hvp(_f, prep_hvp, soadtype.dense_ad, θ, v)
491+
only(hvp(_f, prep_hvp, soadtype.dense_ad, θ, (v,)))
492492
end
493493
end
494494
elseif hv == true
@@ -533,9 +533,9 @@ function instantiate_function(
533533

534534
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
535535
prep_pullback = prepare_pullback(
536-
cons, adtype.dense_ad, x, ones(eltype(x), num_cons))
536+
cons, adtype.dense_ad, x, (ones(eltype(x), num_cons),))
537537
function cons_vjp!(θ, v)
538-
pullback(cons, prep_pullback, adtype.dense_ad, θ, v)
538+
only(pullback(cons, prep_pullback, adtype.dense_ad, θ, (v,)))
539539
end
540540
elseif cons_vjp === true && cons !== nothing
541541
cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p)
@@ -545,9 +545,9 @@ function instantiate_function(
545545

546546
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
547547
prep_pushforward = prepare_pushforward(
548-
cons, adtype.dense_ad, x, ones(eltype(x), length(x)))
548+
cons, adtype.dense_ad, x, (ones(eltype(x), length(x)),))
549549
function cons_jvp!(θ, v)
550-
pushforward(cons, prep_pushforward, adtype.dense_ad, θ, v)
550+
only(pushforward(cons, prep_pushforward, adtype.dense_ad, θ, (v,)))
551551
end
552552
elseif cons_jvp === true && cons !== nothing
553553
cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p)

test/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3333
[compat]
3434
Aqua = "0.8"
3535
ComponentArrays = ">= 0.13.9"
36-
DifferentiationInterface = "0.5.2"
36+
DifferentiationInterface = "0.6.1"
3737
DiffEqFlux = ">= 2"
3838
Flux = "0.13, 0.14"
3939
IterTools = ">= 1.3.0"

0 commit comments

Comments
 (0)