Skip to content

Commit d48ac5e

Browse files
committed
remove args
1 parent ce92e87 commit d48ac5e

File tree

3 files changed

+21
-23
lines changed

3 files changed

+21
-23
lines changed

lib/IntegralsCuba/src/IntegralsCuba.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ struct CubaCuhre <: AbstractCubaAlgorithm end
7979

8080
function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgorithm,
8181
sensealg,
82-
lb, ub, p, args...;
82+
lb, ub, p;
8383
reltol = 1e-8, abstol = 1e-8,
8484
maxiters = alg isa CubaSUAVE ? 1000000 : typemax(Int),
8585
kwargs...)

lib/IntegralsCubature/src/IntegralsCubature.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct CubatureJLp <: AbstractCubatureJLAlgorithm end
3434

3535
function Integrals.__solvebp_call(prob::IntegralProblem,
3636
alg::AbstractCubatureJLAlgorithm,
37-
sensealg, lb, ub, p, args...;
37+
sensealg, lb, ub, p;
3838
reltol = 1e-8, abstol = 1e-8,
3939
maxiters = typemax(Int),
4040
kwargs...)

src/Integrals.jl

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -210,17 +210,17 @@ These common arguments are:
210210
- `reltol` (relative tolerance in changes of the objective value)
211211
"""
212212
function SciMLBase.solve(prob::IntegralProblem,
213-
alg::SciMLBase.AbstractIntegralAlgorithm,
214-
args...; sensealg = ReCallVJP(ZygoteVJP()),
213+
alg::SciMLBase.AbstractIntegralAlgorithm;
214+
sensealg = ReCallVJP(ZygoteVJP()),
215215
do_inf_transformation = nothing, kwargs...)
216216
prob = transformation_if_inf(prob, do_inf_transformation)
217-
__solvebp(prob, alg, sensealg, prob.lb, prob.ub, prob.p, args...; kwargs...)
217+
__solvebp(prob, alg, sensealg, prob.lb, prob.ub, prob.p; kwargs...)
218218
end
219219

220220
# Give a layer to intercept with AD
221221
__solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...)
222222

223-
function __solvebp_call(prob::IntegralProblem, alg::QuadGKJL, sensealg, lb, ub, p, args...;
223+
function __solvebp_call(prob::IntegralProblem, alg::QuadGKJL, sensealg, lb, ub, p;
224224
reltol = 1e-8, abstol = 1e-8,
225225
maxiters = typemax(Int),
226226
kwargs...)
@@ -237,8 +237,7 @@ function __solvebp_call(prob::IntegralProblem, alg::QuadGKJL, sensealg, lb, ub,
237237
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
238238
end
239239

240-
function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, ub, p,
241-
args...;
240+
function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, ub, p;
242241
reltol = 1e-8, abstol = 1e-8,
243242
maxiters = typemax(Int),
244243
kwargs...)
@@ -264,7 +263,7 @@ function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, u
264263
SciMLBase.build_solution(prob, HCubatureJL(), val, err, retcode = ReturnCode.Success)
265264
end
266265

267-
function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p, args...;
266+
function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
268267
reltol = 1e-8, abstol = 1e-8,
269268
maxiters = typemax(Int),
270269
kwargs...)
@@ -292,9 +291,9 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p,
292291
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
293292
end
294293

295-
function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub, p, args...;
294+
function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub, p;
296295
kwargs...)
297-
out = __solvebp_call(prob, alg, sensealg, lb, ub, p, args...; kwargs...)
296+
out = __solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
298297
function quadrature_adjoint(Δ)
299298
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ
300299
if isinplace(prob)
@@ -349,19 +348,18 @@ function ChainRulesCore.rrule(::typeof(__solvebp), prob, alg, sensealg, lb, ub,
349348
dp_prob = remake(prob, f = dfdp, lb = lb, ub = ub, p = p, nout = length(p))
350349

351350
if p isa Number
352-
dp = __solvebp_call(dp_prob, alg, sensealg, lb, ub, p, args...; kwargs...)[1]
351+
dp = __solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...)[1]
353352
else
354-
dp = __solvebp_call(dp_prob, alg, sensealg, lb, ub, p, args...; kwargs...).u
353+
dp = __solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...).u
355354
end
356355

357356
if lb isa Number
358357
dlb = -_f(lb)
359358
dub = _f(ub)
360-
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp,
361-
ntuple(x -> NoTangent(), length(args))...)
359+
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp)
362360
else
363361
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(),
364-
NoTangent(), dp, ntuple(x -> NoTangent(), length(args))...)
362+
NoTangent(), dp)
365363
end
366364
end
367365
out, quadrature_adjoint
@@ -376,22 +374,22 @@ end
376374

377375
# Direct AD on solvers with QuadGK and HCubature
378376
function __solvebp(prob, alg::QuadGKJL, sensealg, lb, ub,
379-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}, args...;
377+
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
380378
kwargs...) where {T, V, P, N}
381-
__solvebp_call(prob, alg, sensealg, lb, ub, p, args...; kwargs...)
379+
__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
382380
end
383381

384382
function __solvebp(prob, alg::HCubatureJL, sensealg, lb, ub,
385-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}, args...;
383+
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
386384
kwargs...) where {T, V, P, N}
387-
__solvebp_call(prob, alg, sensealg, lb, ub, p, args...; kwargs...)
385+
__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
388386
end
389387

390388
# Manually split for the pushforward
391389
function __solvebp(prob, alg, sensealg, lb, ub,
392-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}, args...;
390+
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
393391
kwargs...) where {T, V, P, N}
394-
primal = __solvebp_call(prob, alg, sensealg, lb, ub, ForwardDiff.value.(p), args...;
392+
primal = __solvebp_call(prob, alg, sensealg, lb, ub, ForwardDiff.value.(p);
395393
kwargs...)
396394

397395
nout = prob.nout * P
@@ -439,7 +437,7 @@ function __solvebp(prob, alg, sensealg, lb, ub,
439437

440438
dp_prob = IntegralProblem(dfdp, lb, ub, rawp; nout = nout, batch = prob.batch,
441439
kwargs...)
442-
dual = __solvebp_call(dp_prob, alg, sensealg, lb, ub, rawp, args...; kwargs...)
440+
dual = __solvebp_call(dp_prob, alg, sensealg, lb, ub, rawp; kwargs...)
443441
res = similar(p, prob.nout)
444442
partials = reinterpret(typeof(first(res).partials), dual.u)
445443
for idx in eachindex(res)

0 commit comments

Comments
 (0)