@@ -210,17 +210,17 @@ These common arguments are:
210
210
- `reltol` (relative tolerance in changes of the objective value)
211
211
"""
212
212
function SciMLBase. solve (prob:: IntegralProblem ,
213
- alg:: SciMLBase.AbstractIntegralAlgorithm ,
214
- args ... ; sensealg = ReCallVJP (ZygoteVJP ()),
213
+ alg:: SciMLBase.AbstractIntegralAlgorithm ;
214
+ sensealg = ReCallVJP (ZygoteVJP ()),
215
215
do_inf_transformation = nothing , kwargs... )
216
216
prob = transformation_if_inf (prob, do_inf_transformation)
217
- __solvebp (prob, alg, sensealg, prob. lb, prob. ub, prob. p, args ... ; kwargs... )
217
+ __solvebp (prob, alg, sensealg, prob. lb, prob. ub, prob. p; kwargs... )
218
218
end
219
219
220
220
# Give a layer to intercept with AD
221
221
__solvebp (args... ; kwargs... ) = __solvebp_call (args... ; kwargs... )
222
222
223
- function __solvebp_call (prob:: IntegralProblem , alg:: QuadGKJL , sensealg, lb, ub, p, args ... ;
223
+ function __solvebp_call (prob:: IntegralProblem , alg:: QuadGKJL , sensealg, lb, ub, p;
224
224
reltol = 1e-8 , abstol = 1e-8 ,
225
225
maxiters = typemax (Int),
226
226
kwargs... )
@@ -237,8 +237,7 @@ function __solvebp_call(prob::IntegralProblem, alg::QuadGKJL, sensealg, lb, ub,
237
237
SciMLBase. build_solution (prob, QuadGKJL (), val, err, retcode = ReturnCode. Success)
238
238
end
239
239
240
- function __solvebp_call (prob:: IntegralProblem , alg:: HCubatureJL , sensealg, lb, ub, p,
241
- args... ;
240
+ function __solvebp_call (prob:: IntegralProblem , alg:: HCubatureJL , sensealg, lb, ub, p;
242
241
reltol = 1e-8 , abstol = 1e-8 ,
243
242
maxiters = typemax (Int),
244
243
kwargs... )
@@ -264,7 +263,7 @@ function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, u
264
263
SciMLBase. build_solution (prob, HCubatureJL (), val, err, retcode = ReturnCode. Success)
265
264
end
266
265
267
- function __solvebp_call (prob:: IntegralProblem , alg:: VEGAS , sensealg, lb, ub, p, args ... ;
266
+ function __solvebp_call (prob:: IntegralProblem , alg:: VEGAS , sensealg, lb, ub, p;
268
267
reltol = 1e-8 , abstol = 1e-8 ,
269
268
maxiters = typemax (Int),
270
269
kwargs... )
@@ -292,9 +291,9 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p,
292
291
SciMLBase. build_solution (prob, alg, val, err, chi = chi, retcode = ReturnCode. Success)
293
292
end
294
293
295
- function ChainRulesCore. rrule (:: typeof (__solvebp), prob, alg, sensealg, lb, ub, p, args ... ;
294
+ function ChainRulesCore. rrule (:: typeof (__solvebp), prob, alg, sensealg, lb, ub, p;
296
295
kwargs... )
297
- out = __solvebp_call (prob, alg, sensealg, lb, ub, p, args ... ; kwargs... )
296
+ out = __solvebp_call (prob, alg, sensealg, lb, ub, p; kwargs... )
298
297
function quadrature_adjoint (Δ)
299
298
y = typeof (Δ) <: Array{<:Number, 0} ? Δ[1 ] : Δ
300
299
if isinplace (prob)
@@ -349,19 +348,18 @@ function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub,
349
348
dp_prob = remake (prob, f = dfdp, lb = lb, ub = ub, p = p, nout = length (p))
350
349
351
350
if p isa Number
352
- dp = __solvebp_call (dp_prob, alg, sensealg, lb, ub, p, args ... ; kwargs... )[1 ]
351
+ dp = __solvebp_call (dp_prob, alg, sensealg, lb, ub, p; kwargs... )[1 ]
353
352
else
354
- dp = __solvebp_call (dp_prob, alg, sensealg, lb, ub, p, args ... ; kwargs... ). u
353
+ dp = __solvebp_call (dp_prob, alg, sensealg, lb, ub, p; kwargs... ). u
355
354
end
356
355
357
356
if lb isa Number
358
357
dlb = - _f (lb)
359
358
dub = _f (ub)
360
- return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), dlb, dub, dp,
361
- ntuple (x -> NoTangent (), length (args))... )
359
+ return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), dlb, dub, dp)
362
360
else
363
361
return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (),
364
- NoTangent (), dp, ntuple (x -> NoTangent (), length (args)) ... )
362
+ NoTangent (), dp)
365
363
end
366
364
end
367
365
out, quadrature_adjoint
@@ -376,22 +374,22 @@ end
376
374
377
375
# Direct AD on solvers with QuadGK and HCubature
378
376
function __solvebp (prob, alg:: QuadGKJL , sensealg, lb, ub,
379
- p:: AbstractArray{<:ForwardDiff.Dual{T, V, P}, N} , args ... ;
377
+ p:: AbstractArray{<:ForwardDiff.Dual{T, V, P}, N} ;
380
378
kwargs... ) where {T, V, P, N}
381
- __solvebp_call (prob, alg, sensealg, lb, ub, p, args ... ; kwargs... )
379
+ __solvebp_call (prob, alg, sensealg, lb, ub, p; kwargs... )
382
380
end
383
381
384
382
function __solvebp (prob, alg:: HCubatureJL , sensealg, lb, ub,
385
- p:: AbstractArray{<:ForwardDiff.Dual{T, V, P}, N} , args ... ;
383
+ p:: AbstractArray{<:ForwardDiff.Dual{T, V, P}, N} ;
386
384
kwargs... ) where {T, V, P, N}
387
- __solvebp_call (prob, alg, sensealg, lb, ub, p, args ... ; kwargs... )
385
+ __solvebp_call (prob, alg, sensealg, lb, ub, p; kwargs... )
388
386
end
389
387
390
388
# Manually split for the pushforward
391
389
function __solvebp (prob, alg, sensealg, lb, ub,
392
- p:: AbstractArray{<:ForwardDiff.Dual{T, V, P}, N} , args ... ;
390
+ p:: AbstractArray{<:ForwardDiff.Dual{T, V, P}, N} ;
393
391
kwargs... ) where {T, V, P, N}
394
- primal = __solvebp_call (prob, alg, sensealg, lb, ub, ForwardDiff. value .(p), args ... ;
392
+ primal = __solvebp_call (prob, alg, sensealg, lb, ub, ForwardDiff. value .(p);
395
393
kwargs... )
396
394
397
395
nout = prob. nout * P
@@ -439,7 +437,7 @@ function __solvebp(prob, alg, sensealg, lb, ub,
439
437
440
438
dp_prob = IntegralProblem (dfdp, lb, ub, rawp; nout = nout, batch = prob. batch,
441
439
kwargs... )
442
- dual = __solvebp_call (dp_prob, alg, sensealg, lb, ub, rawp, args ... ; kwargs... )
440
+ dual = __solvebp_call (dp_prob, alg, sensealg, lb, ub, rawp; kwargs... )
443
441
res = similar (p, prob. nout)
444
442
partials = reinterpret (typeof (first (res). partials), dual. u)
445
443
for idx in eachindex (res)
0 commit comments