@@ -196,14 +196,14 @@ function instantiate_function(
196
196
end
197
197
198
198
if hv == true && f. hv === nothing
199
- prep_hvp = prepare_hvp (_f, soadtype. dense_ad, x, zeros (eltype (x), size (x)))
199
+ prep_hvp = prepare_hvp (_f, soadtype. dense_ad, x, ( zeros (eltype (x), size (x)), ))
200
200
function hv! (H, θ, v)
201
- hvp! (_f, H, prep_hvp, soadtype. dense_ad, θ, v )
201
+ only ( hvp! (_f, (H,), prep_hvp, soadtype. dense_ad, θ, (v,)) )
202
202
end
203
203
if p != = SciMLBase. NullParameters () && p != = nothing
204
204
function hv! (H, θ, v, p)
205
205
global _p = p
206
- hvp! (_f, H, prep_hvp, soadtype. dense_ad, θ, v )
206
+ only ( hvp! (_f, (H,), prep_hvp, soadtype. dense_ad, θ, (v,)) )
207
207
end
208
208
end
209
209
elseif hv == true
@@ -253,9 +253,9 @@ function instantiate_function(
253
253
254
254
if f. cons_vjp === nothing && cons_vjp == true && cons != = nothing
255
255
prep_pullback = prepare_pullback (
256
- cons_oop, adtype. dense_ad, x, ones (eltype (x), num_cons))
256
+ cons_oop, adtype. dense_ad, x, ( ones (eltype (x), num_cons), ))
257
257
function cons_vjp! (J, θ, v)
258
- pullback! (cons_oop, J, prep_pullback, adtype. dense_ad, θ, v )
258
+ only ( pullback! (cons_oop, (J,), prep_pullback, adtype. dense_ad, θ, (v,)) )
259
259
end
260
260
elseif cons_vjp === true && cons != = nothing
261
261
cons_vjp! = (J, θ, v) -> f. cons_vjp (J, θ, v, p)
@@ -265,9 +265,9 @@ function instantiate_function(
265
265
266
266
if f. cons_jvp === nothing && cons_jvp == true && cons != = nothing
267
267
prep_pushforward = prepare_pushforward (
268
- cons_oop, adtype. dense_ad, x, ones (eltype (x), length (x)))
268
+ cons_oop, adtype. dense_ad, x, ( ones (eltype (x), length (x)), ))
269
269
function cons_jvp! (J, θ, v)
270
- pushforward! (cons_oop, J, prep_pushforward, adtype. dense_ad, θ, v )
270
+ only ( pushforward! (cons_oop, (J,), prep_pushforward, adtype. dense_ad, θ, (v,)) )
271
271
end
272
272
elseif cons_jvp === true && cons != = nothing
273
273
cons_jvp! = (J, θ, v) -> f. cons_jvp (J, θ, v, p)
@@ -480,15 +480,15 @@ function instantiate_function(
480
480
end
481
481
482
482
if hv == true && f. hv === nothing
483
- prep_hvp = prepare_hvp (_f, soadtype. dense_ad, x, zeros (eltype (x), size (x)))
483
+ prep_hvp = prepare_hvp (_f, soadtype. dense_ad, x, ( zeros (eltype (x), size (x)), ))
484
484
function hv! (θ, v)
485
- hvp (_f, prep_hvp, soadtype. dense_ad, θ, v )
485
+ only ( hvp (_f, prep_hvp, soadtype. dense_ad, θ, (v,)) )
486
486
end
487
487
488
488
if p != = SciMLBase. NullParameters () && p != = nothing
489
489
function hv! (θ, v, p)
490
490
global _p = p
491
- hvp (_f, prep_hvp, soadtype. dense_ad, θ, v )
491
+ only ( hvp (_f, prep_hvp, soadtype. dense_ad, θ, (v,)) )
492
492
end
493
493
end
494
494
elseif hv == true
@@ -533,9 +533,9 @@ function instantiate_function(
533
533
534
534
if f. cons_vjp === nothing && cons_vjp == true && cons != = nothing
535
535
prep_pullback = prepare_pullback (
536
- cons, adtype. dense_ad, x, ones (eltype (x), num_cons))
536
+ cons, adtype. dense_ad, x, ( ones (eltype (x), num_cons), ))
537
537
function cons_vjp! (θ, v)
538
- pullback (cons, prep_pullback, adtype. dense_ad, θ, v )
538
+ only ( pullback (cons, prep_pullback, adtype. dense_ad, θ, (v,)) )
539
539
end
540
540
elseif cons_vjp === true && cons != = nothing
541
541
cons_vjp! = (θ, v) -> f. cons_vjp (θ, v, p)
@@ -545,9 +545,9 @@ function instantiate_function(
545
545
546
546
if f. cons_jvp === nothing && cons_jvp == true && cons != = nothing
547
547
prep_pushforward = prepare_pushforward (
548
- cons, adtype. dense_ad, x, ones (eltype (x), length (x)))
548
+ cons, adtype. dense_ad, x, ( ones (eltype (x), length (x)), ))
549
549
function cons_jvp! (θ, v)
550
- pushforward (cons, prep_pushforward, adtype. dense_ad, θ, v )
550
+ only ( pushforward (cons, prep_pushforward, adtype. dense_ad, θ, (v,)) )
551
551
end
552
552
elseif cons_jvp === true && cons != = nothing
553
553
cons_jvp! = (θ, v) -> f. cons_jvp (θ, v, p)
0 commit comments