Skip to content

Commit e7df9ea

Browse files
committed
add keywords
1 parent ab751b9 commit e7df9ea

File tree

3 files changed

+127
-50
lines changed

3 files changed

+127
-50
lines changed

lib/IntegralsCuba/src/IntegralsCuba.jl

+65-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@ Importance sampling is used to reduce variance.
2121
publisher={Elsevier}
2222
}
2323
"""
24-
struct CubaVegas <: AbstractCubaAlgorithm end
24+
struct CubaVegas <: AbstractCubaAlgorithm
25+
flags::Int
26+
seed::Int
27+
minevals::Int
28+
nstart::Int
29+
nincrease::Int
30+
gridno::Int
31+
end
2532
"""
2633
CubaSUAVE()
2734
@@ -40,7 +47,14 @@ Importance sampling and subdivision are thus used to reduce variance.
4047
publisher={Elsevier}
4148
}
4249
"""
43-
struct CubaSUAVE <: AbstractCubaAlgorithm end
50+
struct CubaSUAVE{R} <: AbstractCubaAlgorithm where {R <: Real}
51+
flags::Int
52+
seed::Int
53+
minevals::Int
54+
nnew::Int
55+
nmin::Int
56+
flatness::R
57+
end
4458
"""
4559
CubaDivonne()
4660
@@ -58,7 +72,19 @@ Stratified sampling is used to reduce variance.
5872
publisher={ACM New York, NY, USA}
5973
}
6074
"""
61-
struct CubaDivonne <: AbstractCubaAlgorithm end
75+
struct CubaDivonne{R1, R2, R3} <:
76+
AbstractCubaAlgorithm where {R1 <: Real, R2 <: Real, R3 <: Real}
77+
flags::Int
78+
seed::Int
79+
minevals::Int
80+
key1::Int
81+
key2::Int
82+
key3::Int
83+
maxpass::Int
84+
border::R1
85+
maxchisq::R2
86+
mindeviation::R3
87+
end
6288
"""
6389
CubaCuhre()
6490
@@ -75,14 +101,33 @@ Multidimensional h-adaptive integration from Cuba.jl.
75101
publisher={ACM New York, NY, USA}
76102
}
77103
"""
78-
struct CubaCuhre <: AbstractCubaAlgorithm end
104+
struct CubaCuhre <: AbstractCubaAlgorithm
105+
flags::Int
106+
minevals::Int
107+
key::Int
108+
end
109+
110+
function CubaVegas(; flags = 0, seed = 0, minevals = 0, nstart = 1000, nincrease = 500,
111+
gridno = 0)
112+
CubaVegas(flags, seed, minevals, nstart, nincrease, gridno)
113+
end
114+
function CubaSUAVE(; flags = 0, seed = 0, minevals = 0, nnew = 1000, nmin = 2,
115+
flatness = 25.0)
116+
CubaSUAVE(flags, seed, minevals, nnew, nmin, flatness)
117+
end
118+
function CubaDivonne(; flags = 0, seed = 0, minevals = 0,
119+
key1 = 47, key2 = 1, key3 = 1, maxpass = 5, border = 0.0,
120+
maxchisq = 10.0, mindeviation = 0.25)
121+
CubaDivonne(flags, seed, minevals, key1, key2, key3, maxpass, border, maxchisq,
122+
mindeviation)
123+
end
124+
CubaCuhre(; flags = 0, minevals = 0, key = 0) = CubaCuhre(flags, minevals, key)
79125

80126
function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgorithm,
81127
sensealg,
82128
lb, ub, p;
83129
reltol = 1e-8, abstol = 1e-8,
84-
maxiters = alg isa CubaSUAVE ? 1000000 : typemax(Int),
85-
kwargs...)
130+
maxiters = alg isa CubaSUAVE ? 1000000 : typemax(Int))
86131
@assert maxiters>=1000 "maxiters for $alg should be larger than 1000"
87132
prob = transformation_if_inf(prob) #intercept for infinite transformation
88133
p = p
@@ -160,19 +205,29 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori
160205
if alg isa CubaVegas
161206
out = Cuba.vegas(f, ndim, prob.nout; rtol = reltol,
162207
atol = abstol, nvec = nvec,
163-
maxevals = maxiters, kwargs...)
208+
maxevals = maxiters,
209+
flags = alg.flags, seed = alg.seed, minevals = alg.minevals,
210+
nstart = alg.nstart, nincrease = alg.nincrease,
211+
gridno = alg.gridno)
164212
elseif alg isa CubaSUAVE
165213
out = Cuba.suave(f, ndim, prob.nout; rtol = reltol,
166214
atol = abstol, nvec = nvec,
167-
maxevals = maxiters, kwargs...)
215+
maxevals = maxiters,
216+
flags = alg.flags, seed = alg.seed, minevals = alg.minevals,
217+
nnew = alg.nnew, nmin = alg.nmin, flatness = alg.flatness)
168218
elseif alg isa CubaDivonne
169219
out = Cuba.divonne(f, ndim, prob.nout; rtol = reltol,
170220
atol = abstol, nvec = nvec,
171-
maxevals = maxiters, kwargs...)
221+
maxevals = maxiters,
222+
flags = alg.flags, seed = alg.seed, minevals = alg.minevals,
223+
key1 = alg.key1, key2 = alg.key2, key3 = alg.key3,
224+
maxpass = alg.maxpass, border = alg.border,
225+
maxchisq = alg.maxchisq, mindeviation = alg.mindeviation)
172226
elseif alg isa CubaCuhre
173227
out = Cuba.cuhre(f, ndim, prob.nout; rtol = reltol,
174228
atol = abstol, nvec = nvec,
175-
maxevals = maxiters, kwargs...)
229+
maxevals = maxiters,
230+
flags = alg.flags, minevals = alg.minevals, key = alg.key)
176231
end
177232

178233
if isinplace(prob) || prob.batch != 0

lib/IntegralsCubature/src/IntegralsCubature.jl

+32-12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ abstract type AbstractCubatureJLAlgorithm <: SciMLBase.AbstractIntegralAlgorithm
99
CubatureJLh()
1010
1111
Multidimensional h-adaptive integration from Cubature.jl.
12+
`error_norm` specifies the convergence criterion for vector valued integrands.
13+
Defaults to `Cubature.INDIVIDUAL`, other options are
14+
`Cubature.PAIRED`, `Cubature.L1`, `Cubature.L2`, or `Cubature.LINF`.
1215
## References
1316
@article{genz1980remarks,
1417
title={Remarks on algorithm 006: An adaptive algorithm for numerical integration over an N-dimensional rectangular region},
@@ -21,23 +24,32 @@ Multidimensional h-adaptive integration from Cubature.jl.
2124
publisher={Elsevier}
2225
}
2326
"""
24-
struct CubatureJLh <: AbstractCubatureJLAlgorithm end
27+
struct CubatureJLh <: AbstractCubatureJLAlgorithm
28+
error_norm::Int32
29+
end
30+
CubatureJLh() = CubatureJLh(Cubature.INDIVIDUAL)
31+
2532
"""
2633
CubatureJLp()
2734
2835
Multidimensional p-adaptive integration from Cubature.jl.
2936
This method is based on repeatedly doubling the degree of the cubature rules,
3037
until convergence is achieved.
3138
The used cubature rule is a tensor product of Clenshaw–Curtis quadrature rules.
39+
`error_norm` specifies the convergence criterion for vector valued integrands.
40+
Defaults to `Cubature.INDIVIDUAL`, other options are
41+
`Cubature.PAIRED`, `Cubature.L1`, `Cubature.L2`, or `Cubature.LINF`.
3242
"""
33-
struct CubatureJLp <: AbstractCubatureJLAlgorithm end
43+
struct CubatureJLp <: AbstractCubatureJLAlgorithm
44+
error_norm::Int32
45+
end
46+
CubatureJLp() = CubatureJLp(Cubature.INDIVIDUAL)
3447

3548
function Integrals.__solvebp_call(prob::IntegralProblem,
3649
alg::AbstractCubatureJLAlgorithm,
3750
sensealg, lb, ub, p;
3851
reltol = 1e-8, abstol = 1e-8,
39-
maxiters = typemax(Int),
40-
kwargs...)
52+
maxiters = typemax(Int))
4153
prob = transformation_if_inf(prob) #intercept for infinite transformation
4254
nout = prob.nout
4355
if nout == 1
@@ -130,21 +142,25 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
130142
if alg isa CubatureJLh
131143
val, err = Cubature.hquadrature(nout, f, lb, ub;
132144
reltol = reltol, abstol = abstol,
133-
maxevals = maxiters)
145+
maxevals = maxiters,
146+
error_norm = alg.error_norm)
134147
else
135148
val, err = Cubature.pquadrature(nout, f, lb, ub;
136149
reltol = reltol, abstol = abstol,
137-
maxevals = maxiters)
150+
maxevals = maxiters,
151+
error_norm = alg.error_norm)
138152
end
139153
else
140154
if alg isa CubatureJLh
141155
val, err = Cubature.hcubature(nout, f, lb, ub;
142156
reltol = reltol, abstol = abstol,
143-
maxevals = maxiters)
157+
maxevals = maxiters,
158+
error_norm = alg.error_norm)
144159
else
145160
val, err = Cubature.pcubature(nout, f, lb, ub;
146161
reltol = reltol, abstol = abstol,
147-
maxevals = maxiters)
162+
maxevals = maxiters,
163+
error_norm = alg.error_norm)
148164
end
149165
end
150166
else
@@ -162,21 +178,25 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
162178
if alg isa CubatureJLh
163179
val, err = Cubature.hquadrature_v(nout, f, lb, ub;
164180
reltol = reltol, abstol = abstol,
165-
maxevals = maxiters)
181+
maxevals = maxiters,
182+
error_norm = alg.error_norm)
166183
else
167184
val, err = Cubature.pquadrature_v(nout, f, lb, ub;
168185
reltol = reltol, abstol = abstol,
169-
maxevals = maxiters)
186+
maxevals = maxiters,
187+
error_norm = alg.error_norm)
170188
end
171189
else
172190
if alg isa CubatureJLh
173191
val, err = Cubature.hcubature_v(nout, f, lb, ub;
174192
reltol = reltol, abstol = abstol,
175-
maxevals = maxiters)
193+
maxevals = maxiters,
194+
error_norm = alg.error_norm)
176195
else
177196
val, err = Cubature.pcubature_v(nout, f, lb, ub;
178197
reltol = reltol, abstol = abstol,
179-
maxevals = maxiters)
198+
maxevals = maxiters,
199+
error_norm = alg.error_norm)
180200
end
181201
end
182202
end

src/Integrals.jl

+30-28
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import ChainRulesCore: NoTangent
99
import ZygoteRules
1010

1111
"""
12-
QuadGKJL(; order = 7)
12+
QuadGKJL(; order = 7, norm=norm)
1313
1414
One-dimensional Gauss-Kronrod integration from QuadGK.jl.
15-
This method also takes the optional argument `order`,
16-
which is the order of the integration rule.
15+
This method also takes the optional arguments `order` and `norm`.
16+
Which are the order of the integration rule
17+
and the norm for calculating the error, respectively
1718
## References
1819
@article{laurie1997calculation,
1920
title={Calculation of Gauss-Kronrod quadrature rules},
@@ -25,16 +26,18 @@ which is the order of the integration rule.
2526
year={1997}
2627
}
2728
"""
28-
struct QuadGKJL <: SciMLBase.AbstractIntegralAlgorithm
29+
struct QuadGKJL{F} <: SciMLBase.AbstractIntegralAlgorithm where {F}
2930
order::Int
31+
norm::F
3032
end
3133
"""
32-
HCubatureJL(; initdiv=1)
34+
HCubatureJL(; norm=norm, initdiv=1)
3335
3436
Multidimensional "h-adaptive" integration from HCubature.jl.
35-
This method also takes the optional argument `initdiv`,
36-
which is the intial number of segments
37-
each dimension of the integration domain is divided into.
37+
This method also takes the optional arguments `initdiv` and `norm`.
38+
Which are the intial number of segments
39+
each dimension of the integration domain is divided into,
40+
and the norm for calculating the error, respectively.
3841
## References
3942
@article{genz1980remarks,
4043
title={Remarks on algorithm 006: An adaptive algorithm for numerical integration over an N-dimensional rectangular region},
@@ -47,18 +50,20 @@ each dimension of the integration domain is divided into.
4750
publisher={Elsevier}
4851
}
4952
"""
50-
struct HCubatureJL <: SciMLBase.AbstractIntegralAlgorithm
53+
struct HCubatureJL{F} <: SciMLBase.AbstractIntegralAlgorithm where {F}
5154
initdiv::Int
55+
norm::F
5256
end
5357
"""
54-
VEGAS(; nbins = 100, ncalls = 1000)
58+
VEGAS(; nbins = 100, ncalls = 1000, debug=false)
5559
5660
Multidimensional adaptive Monte Carlo integration from MonteCarloIntegration.jl.
5761
Importance sampling is used to reduce variance.
58-
This method also takes two optional arguments `nbins` and `ncalls`,
62+
This method also takes three optional arguments `nbins`, `ncalls` and `debug`
5963
which are the intial number of bins
60-
each dimension of the integration domain is divided into
61-
and the number of function calls per iteration of the algorithm.
64+
each dimension of the integration domain is divided into,
65+
the number of function calls per iteration of the algorithm,
66+
and whether debug info should be printed, respectively.
6267
## References
6368
@article{lepage1978new,
6469
title={A new algorithm for adaptive multidimensional integration},
@@ -74,10 +79,11 @@ and the number of function calls per iteration of the algorithm.
7479
struct VEGAS <: SciMLBase.AbstractIntegralAlgorithm
7580
nbins::Int
7681
ncalls::Int
82+
debug::Bool
7783
end
78-
QuadGKJL(; order = 7) = QuadGKJL(order)
79-
HCubatureJL(; initdiv = 1) = HCubatureJL(initdiv)
80-
VEGAS(; nbins = 100, ncalls = 1000) = VEGAS(nbins, ncalls)
84+
QuadGKJL(; order = 7, norm = norm) = QuadGKJL(order, norm)
85+
HCubatureJL(; initdiv = 1, norm = norm) = HCubatureJL(initdiv, norm)
86+
VEGAS(; nbins = 100, ncalls = 1000, debug = false) = VEGAS(nbins, ncalls, debug)
8187

8288
abstract type QuadSensitivityAlg end
8389
struct ReCallVJP{V}
@@ -276,8 +282,7 @@ __solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...)
276282

277283
function __solvebp_call(prob::IntegralProblem, alg::QuadGKJL, sensealg, lb, ub, p;
278284
reltol = 1e-8, abstol = 1e-8,
279-
maxiters = typemax(Int),
280-
kwargs...)
285+
maxiters = typemax(Int))
281286
if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray
282287
error("QuadGKJL only accepts one-dimensional quadrature problems.")
283288
end
@@ -286,15 +291,13 @@ function __solvebp_call(prob::IntegralProblem, alg::QuadGKJL, sensealg, lb, ub,
286291
p = p
287292
f = x -> prob.f(x, p)
288293
val, err = quadgk(f, lb, ub,
289-
rtol = reltol, atol = abstol, order = alg.order,
290-
kwargs...)
294+
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
291295
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
292296
end
293297

294298
function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, ub, p;
295299
reltol = 1e-8, abstol = 1e-8,
296-
maxiters = typemax(Int),
297-
kwargs...)
300+
maxiters = typemax(Int))
298301
p = p
299302

300303
if isinplace(prob)
@@ -308,19 +311,18 @@ function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, u
308311
if lb isa Number
309312
val, err = hquadrature(f, lb, ub;
310313
rtol = reltol, atol = abstol,
311-
maxevals = maxiters, initdiv = alg.initdiv, kwargs...)
314+
maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv)
312315
else
313316
val, err = hcubature(f, lb, ub;
314317
rtol = reltol, atol = abstol,
315-
maxevals = maxiters, initdiv = alg.initdiv, kwargs...)
318+
maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv)
316319
end
317320
SciMLBase.build_solution(prob, HCubatureJL(), val, err, retcode = ReturnCode.Success)
318321
end
319322

320323
function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
321324
reltol = 1e-8, abstol = 1e-8,
322-
maxiters = typemax(Int),
323-
kwargs...)
325+
maxiters = typemax(Int))
324326
p = p
325327
@assert prob.nout == 1
326328
if prob.batch == 0
@@ -340,8 +342,8 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
340342
end
341343
ncalls = prob.batch == 0 ? alg.ncalls : prob.batch
342344
val, err, chi = vegas(f, lb, ub, rtol = reltol, atol = abstol,
343-
maxiter = maxiters, nbins = alg.nbins,
344-
ncalls = ncalls, batch = prob.batch != 0, kwargs...)
345+
maxiter = maxiters, nbins = alg.nbins, debug = alg.debug,
346+
ncalls = ncalls, batch = prob.batch != 0)
345347
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
346348
end
347349

0 commit comments

Comments
 (0)