Skip to content

Commit f7a60d9

Browse files
Merge pull request #147 from DanielVandH/gauss_quad
Add Gauss-Legendre quadrature
2 parents 70444a9 + de0a6ee commit f7a60d9

8 files changed

+211
-12
lines changed

Project.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Integrals"
22
uuid = "de52edbc-65ea-441a-8357-d3a637375a31"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "3.6.0"
4+
version = "3.7.0"
55

66
[deps]
77
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
@@ -26,10 +26,12 @@ Requires = "1"
2626
SciMLBase = "1.70"
2727
Zygote = "0.4.22, 0.5, 0.6"
2828
julia = "1.6"
29+
FastGaussQuadrature = "0.5"
2930

3031
[extensions]
3132
IntegralsForwardDiffExt = "ForwardDiff"
3233
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]
34+
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
3335

3436
[extras]
3537
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -42,11 +44,13 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
4244
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4345
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4446
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
47+
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
4548

4649
[targets]
47-
test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore"]
50+
test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature"]
4851

4952
[weakdeps]
5053
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5154
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5255
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
56+
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"

docs/src/solvers/IntegralSolvers.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ The following algorithms are available:
1111
- `CubaSUAVE`: SUAVE from Cuba.jl. Requires `using IntegralsCuba`.
1212
- `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`.
1313
- `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`.
14+
- `GaussLegendre`: Uses Gauss-Legendre quadrature with nodes and weights from FastGaussQuadrature.jl.
1415

1516
```@docs
1617
QuadGKJL
1718
HCubatureJL
1819
VEGAS
20+
GaussLegendre
1921
```
+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
module IntegralsFastGaussQuadratureExt
2+
using Integrals
3+
if isdefined(Base, :get_extension)
4+
import FastGaussQuadrature
5+
import FastGaussQuadrature: gausslegendre
6+
# and eventually gausschebyshev, etc.
7+
else
8+
import ..FastGaussQuadrature
9+
import ..FastGaussQuadrature: gausslegendre
10+
end
11+
using LinearAlgebra
12+
13+
Integrals.gausslegendre(n) = FastGaussQuadrature.gausslegendre(n)
14+
15+
function gauss_legendre(f, p, lb, ub, nodes, weights)
16+
scale = (ub - lb) / 2
17+
shift = (lb + ub) / 2
18+
I = dot(weights, @. f(scale * nodes + shift, $Ref(p)))
19+
return scale * I
20+
end
21+
function composite_gauss_legendre(f, p, lb, ub, nodes, weights, subintervals)
22+
h = (ub - lb) / subintervals
23+
I = zero(h)
24+
for i in 1:subintervals
25+
_lb = lb + (i - 1) * h
26+
_ub = _lb + h
27+
I += gauss_legendre(f, p, _lb, _ub, nodes, weights)
28+
end
29+
return I
30+
end
31+
32+
function Integrals.__solvebp_call(prob::IntegralProblem, alg::Integrals.GaussLegendre{C},
33+
sensealg, lb, ub, p;
34+
reltol = nothing, abstol = nothing,
35+
maxiters = nothing) where {C}
36+
if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray
37+
error("GaussLegendre only accepts one-dimensional quadrature problems.")
38+
end
39+
@assert prob.batch == 0
40+
@assert prob.nout == 1
41+
if C
42+
val = composite_gauss_legendre(prob.f, prob.p, lb, ub,
43+
alg.nodes, alg.weights, alg.subintervals)
44+
else
45+
val = gauss_legendre(prob.f, prob.p, lb, ub,
46+
alg.nodes, alg.weights)
47+
end
48+
err = nothing
49+
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
50+
end
51+
end

src/Integrals.jl

+11-10
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,16 @@ function SciMLBase.solve(prob::IntegralProblem,
108108
__solvebp(prob, alg, sensealg, prob.lb, prob.ub, prob.p; kwargs...)
109109
end
110110
# Throw error if alg is not provided, as defaults are not implemented.
111-
SciMLBase.solve(::IntegralProblem) = throw(ArgumentError("""
112-
No integration algorithm `alg` was supplied as the second positional argument.
113-
Reccomended integration algorithms are:
114-
For scalar functions: QuadGKJL()
115-
For ≤ 8 dimensional vector functions: HCubatureJL()
116-
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
117-
See the docstrings of the different algorithms for more detail.
118-
"""
119-
))
111+
function SciMLBase.solve(::IntegralProblem)
112+
throw(ArgumentError("""
113+
No integration algorithm `alg` was supplied as the second positional argument.
114+
Reccomended integration algorithms are:
115+
For scalar functions: QuadGKJL()
116+
For ≤ 8 dimensional vector functions: HCubatureJL()
117+
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
118+
See the docstrings of the different algorithms for more detail.
119+
"""))
120+
end
120121

121122
# Give a layer to intercept with AD
122123
__solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...)
@@ -188,5 +189,5 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
188189
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
189190
end
190191

191-
export QuadGKJL, HCubatureJL, VEGAS
192+
export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre
192193
end # module

src/algorithms.jl

+40
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,43 @@ struct VEGAS <: SciMLBase.AbstractIntegralAlgorithm
8282
debug::Bool
8383
end
8484
VEGAS(; nbins = 100, ncalls = 1000, debug = false) = VEGAS(nbins, ncalls, debug)
85+
86+
"""
87+
GaussLegendre{C, N, W}
88+
89+
Struct for evaluating an integral via (composite) Gauss-Legendre quadrature.
90+
The field `C` will be `true` if `subintervals > 1`, and `false` otherwise.
91+
92+
The fields `nodes::N` and `weights::W` are defined by
93+
`nodes, weights = gausslegendre(n)` for a given number of nodes `n`.
94+
95+
The field `subintervals::Int64 = 1` (with default value `1`) defines the
96+
number of intervals to partition the original interval of integration
97+
`[a, b]` into, splitting it into `[xⱼ, xⱼ₊₁]` for `j = 1,…,subintervals`,
98+
where `xⱼ = a + (j-1)h` and `h = (b-a)/subintervals`. Gauss-Legendre
99+
quadrature is then applied on each subinterval. For example, if
100+
`[a, b] = [-1, 1]` and `subintervals = 2`, then Gauss-Legendre
101+
quadrature will be applied separately on `[-1, 0]` and `[0, 1]`,
102+
summing the two results.
103+
"""
104+
struct GaussLegendre{C, N, W} <: SciMLBase.AbstractIntegralAlgorithm
105+
nodes::N
106+
weights::W
107+
subintervals::Int64
108+
function GaussLegendre(nodes::N, weights::W, subintervals = 1) where {N, W}
109+
if subintervals > 1
110+
return new{true, N, W}(nodes, weights, subintervals)
111+
elseif subintervals == 1
112+
return new{false, N, W}(nodes, weights, subintervals)
113+
else
114+
throw(ArgumentError("Cannot use a nonpositive number of subintervals."))
115+
end
116+
end
117+
end
118+
function gausslegendre end
119+
function GaussLegendre(; n = 250, subintervals = 1, nodes = nothing, weights = nothing)
120+
if isnothing(nodes) || isnothing(weights)
121+
nodes, weights = gausslegendre(n)
122+
end
123+
return GaussLegendre(nodes, weights, subintervals)
124+
end

src/init.jl

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
function __init__()
33
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/IntegralsForwardDiffExt.jl") end
44
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/IntegralsZygoteExt.jl") end
5+
@require FastGaussQuadrature="442a2c76-b920-505d-bb47-c5924d526838" begin include("../ext/IntegralsFastGaussQuadratureExt.jl") end
56
end
67
end

test/gaussian_quadrature_tests.jl

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
using Integrals, Test, FastGaussQuadrature
2+
3+
#=
4+
f = (x, p) -> x^3 * sin(5x)
5+
n = 250
6+
nodes, weights = gausslegendre(n)
7+
I = gauss_legendre(f, nothing, -1, 1, nodes, weights)
8+
@test I ≈ 2 / (625) * (69sin(5) - 95cos(5))
9+
I = Integrals.composite_gauss_legendre(f, nothing, -1, 1, nodes, weights, 2)
10+
@test I ≈ 2 / (625) * (69sin(5) - 95cos(5))
11+
12+
f = (x, p) -> (x + p) * abs(x)
13+
n = 100
14+
nodes, weights = gausslegendre(n)
15+
I = Integrals.gauss_legendre(f, 0.0, -2, 2, nodes, weights)
16+
Ic = Integrals.composite_gauss_legendre(f, 6, -2, 2, nodes, weights, 5)
17+
@inferred Integrals.gauss_legendre(f, 0.0, -2, 2, nodes, weights)
18+
@inferred Integrals.composite_gauss_legendre(f, 6, -2, 2, nodes, weights, 5)
19+
@test I≈0.0 atol=1e-6
20+
@test Ic≈24 rtol=1e-4
21+
=#
22+
23+
alg = GaussLegendre()
24+
n = 250
25+
nd, wt = gausslegendre(n)
26+
@test alg.nodes == nd
27+
@test alg.weights == wt
28+
@test alg.subintervals == 1
29+
alg = GaussLegendre(n = 125, subintervals = 3)
30+
n = 125
31+
nd, wt = gausslegendre(n)
32+
@test alg.nodes == nd
33+
@test alg.weights == wt
34+
@test alg.subintervals == 3
35+
@test typeof(alg).parameters[1]
36+
nd, wt = gausslegendre(275)
37+
alg = GaussLegendre(nodes = nd, weights = wt)
38+
@test !typeof(alg).parameters[1]
39+
@test alg.nodes == nd
40+
@test alg.weights == wt
41+
@test alg.subintervals == 1
42+
alg = GaussLegendre(nodes = nd, weights = wt, subintervals = 20)
43+
@test typeof(alg).parameters[1]
44+
@test alg.nodes == nd
45+
@test alg.weights == wt
46+
@test alg.subintervals == 20
47+
48+
f = (x, p) -> 5x + sin(x) - p * exp(x)
49+
prob = IntegralProblem(f, -5, 3, 3.3)
50+
alg = GaussLegendre()
51+
sol = solve(prob, alg)
52+
@test isnothing(sol.chi)
53+
@test sol.alg === alg
54+
@test sol.prob === prob
55+
@test isnothing(sol.resid)
56+
@test SciMLBase.successful_retcode(sol)
57+
@test sol.u -exp(3) * 3.3 + 3.3 / exp(5) - 40 + cos(5) - cos(3)
58+
alg = GaussLegendre(subintervals = 7)
59+
sol = solve(prob, alg)
60+
@test sol.u -exp(3) * 3.3 + 3.3 / exp(5) - 40 + cos(5) - cos(3)
61+
62+
f = (x, p) -> exp(-x^2)
63+
prob = IntegralProblem(f, 0.0, Inf)
64+
alg = GaussLegendre()
65+
sol = solve(prob, alg)
66+
@test sol.u sqrt(π)/2
67+
alg = GaussLegendre(subintervals=1)
68+
@test sol.u sqrt(π)/2
69+
alg = GaussLegendre(subintervals=17)
70+
@test sol.u sqrt(π)/2
71+
72+
prob = IntegralProblem(f, -Inf, Inf)
73+
alg = GaussLegendre()
74+
sol = solve(prob, alg)
75+
@test sol.u sqrt(π)
76+
alg = GaussLegendre(subintervals=1)
77+
@test sol.u sqrt(π)
78+
alg = GaussLegendre(subintervals=17)
79+
@test sol.u sqrt(π)
80+
81+
prob = IntegralProblem(f, -Inf, 0.0)
82+
alg = GaussLegendre()
83+
sol = solve(prob, alg)
84+
@test sol.u sqrt(π)/2
85+
alg = GaussLegendre(subintervals=1)
86+
@test sol.u sqrt(π)/2
87+
alg = GaussLegendre(subintervals=17)
88+
@test sol.u sqrt(π)/2
89+
90+
# Make sure broadcasting correctly handles the argument p
91+
f = (x, p) -> 1 + x + x^p[1] - cos(x*p[2]) + exp(x)*p[3]
92+
p = [0.3, 1.3, -0.5]
93+
prob = IntegralProblem(f, 2, 6.3, p)
94+
alg = GaussLegendre()
95+
sol = solve(prob, alg)
96+
@test sol.u -240.25235266303063249920743158729
97+
alg = GaussLegendre(n = 500, subintervals = 17)
98+
sol = solve(prob, alg)
99+
@test sol.u -240.25235266303063249920743158729

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ dev_subpkg("IntegralsCubature")
1313
@time @safetestset "Interface Tests" begin include("interface_tests.jl") end
1414
@time @safetestset "Derivative Tests" begin include("derivative_tests.jl") end
1515
@time @safetestset "Infinite Integral Tests" begin include("inf_integral_tests.jl") end
16+
@time @safetestset "Gaussian Quadrature Tests" begin include("gaussian_quadrature_tests.jl") end

0 commit comments

Comments
 (0)