Skip to content

Commit 5955fb8

Browse files
Merge pull request #52 from scheidan/master
pass kwargs to `hcubature` and `hquadrature`
2 parents 5acb010 + 81afc20 commit 5955fb8

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/Quadrature.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ function __solvebp_call(prob::QuadratureProblem,::HCubatureJL,sensealg,lb,ub,p,a
185185

186186
if lb isa Number
187187
val,err = hquadrature(f, lb, ub;
188-
rtol=reltol, atol=abstol,
189-
maxevals=maxiters, initdiv=1)
188+
rtol=reltol, atol=abstol,
189+
maxevals=maxiters, kwargs...)
190190
else
191191
val,err = hcubature(f, lb, ub;
192192
rtol=reltol, atol=abstol,
193-
maxevals=maxiters, initdiv=1)
193+
maxevals=maxiters, kwargs...)
194194
end
195195
DiffEqBase.build_solution(prob,HCubatureJL(),val,err,retcode = :Success)
196196
end
@@ -490,7 +490,7 @@ ZygoteRules.@adjoint function __solvebp(prob,alg,sensealg,lb,ub,p,args...;kwargs
490490
function quadrature_adjoint(Δ)
491491
y = typeof(Δ) <: Array{<:Number,0} ? Δ[1] : Δ
492492
if isinplace(prob)
493-
dx = zeros(prob.nout)
493+
dx = zeros(prob.nout)
494494
_f = (x) -> prob.f(dx,x,p)
495495
if sensealg.vjp isa ZygoteVJP
496496
dfdp = function (dx,x,p)
@@ -500,8 +500,8 @@ ZygoteRules.@adjoint function __solvebp(prob,alg,sensealg,lb,ub,p,args...;kwargs
500500
copy(_dx)
501501
end
502502

503-
z = zeros(size(x,2))
504-
for idx in 1:size(x,2)
503+
z = zeros(size(x,2))
504+
for idx in 1:size(x,2)
505505
z[1] = 1
506506
dx[:,idx] = back(z)[1]
507507
z[idx]=0
@@ -516,13 +516,13 @@ ZygoteRules.@adjoint function __solvebp(prob,alg,sensealg,lb,ub,p,args...;kwargs
516516
if prob.batch > 0
517517
dfdp = function (x,p)
518518
_,back = Zygote.pullback(p->prob.f(x,p),p)
519-
519+
520520
out = zeros(length(p),size(x,2))
521521
z = zeros(size(x,2))
522522
for idx in 1:size(x,2)
523523
z[idx] = 1
524524
out[:,idx] = back(z)[1]
525-
z[idx]=0
525+
z[idx]=0
526526
end
527527
out
528528
end
@@ -583,7 +583,7 @@ function __solvebp(prob,alg,sensealg,lb,ub,p::AbstractArray{<:ForwardDiff.Dual{T
583583
if isinplace(prob)
584584
dfdp = function (out,x,p)
585585
dualp = reinterpret(ForwardDiff.Dual{T,V,P}, p)
586-
if prob.batch > 0
586+
if prob.batch > 0
587587
dx = similar(dualp, prob.nout, size(x,2))
588588
else
589589
dx = similar(dualp, prob.nout)

0 commit comments

Comments
 (0)