Skip to content

Commit 4593763

Browse files
Merge pull request #239 from lxvm/forward
ForwardDiff directly on all non-C solvers
2 parents 12c635f + e7192fb commit 4593763

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

docs/src/basics/FAQ.md

+12-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The in-place interface allows evaluating vector-valued integrands without
66
allocating an output array. This can be beneficial for reducing allocations when
77
integrating many functions simultaneously or to make use of existing in-place
88
code. However, note that not all algorithms use in-place operations under the
9-
hood, i.e. `HCubatureJL()`, and may still allocate.
9+
hood, i.e. [`HCubatureJL`](@ref), and may still allocate.
1010

1111
You can construct an `IntegralFunction(f, prototype)`, where `f` is of the form
1212
`f(y, u, p)` where `prototype` is of the desired type and shape of `y`.
@@ -22,16 +22,17 @@ different points, which maximizes the parallelism for a given algorithm.
2222
You can construct an out-of-place `BatchIntegralFunction(bf)` where `bf` is of
2323
the form `bf(u, p) = stack(x -> f(x, p), eachslice(u; dims=ndims(u)))`, where
2424
`f` is the (unbatched) integrand.
25+
For interoperability with as many algorithms as possible, it is important that your out-of-place batch integrand accept an **empty** array of quadrature points and still return an output with a size and type consistent with the non-empty case.
2526

2627
You can construct an in-place `BatchIntegralFunction(bf, prototype)`, where `bf`
2728
is of the form `bf(y, u, p) = foreach((y,x) -> f(y,x,p), eachslice(y, dims=ndims(y)), eachslice(x, dims=ndims(x)))`.
2829

2930
Note that not all algorithms use in-place batched operations under the hood,
30-
i.e. `QuadGKJL()`.
31+
i.e. [`QuadGKJL`](@ref).
3132

3233
## What should I do if my solution is not converged?
3334

34-
Certain algorithms, such as `QuadratureRule` used a fixed number of points to
35+
Certain algorithms, such as [`QuadratureRule`](@ref) used a fixed number of points to
3536
calculate an integral and cannot provide an error estimate. In this case, you
3637
have to increase the number of points and check the convergence yourself, which
3738
will depend on the accuracy of the rule you choose.
@@ -47,7 +48,7 @@ precision arithmetic may help.
4748

4849
## How can I integrate arbitrarily-spaced data?
4950

50-
See `SampledIntegralProblem`.
51+
See [`SampledIntegralProblem`](@ref).
5152

5253
## How can I integrate on arbitrary geometries?
5354

@@ -59,6 +60,13 @@ because that is what lower-level packages implement.
5960
Fixed quadrature rules from other packages can be used with `QuadratureRule`.
6061
Otherwise, feel free to open an issue or pull request.
6162

63+
## My integrand works with algorithm X but fails on algorithm Y
64+
65+
While bugs are not out of the question, certain algorithms, especially those implemented in C, are not compatible with arbitrary Julia types and have to return specific numeric types or arrays thereof.
66+
In some cases, such as [`ArblibJL`](@ref), it is also expected that the integrand work with a custom quadrature point type.
67+
Moreover, some algorithms, such as [`VEGAS`](@ref), only support scalar integrands.
68+
For more details see the [solver page](@ref solvers).
69+
6270
## Can I take derivatives with respect to the limits of integration?
6371

6472
Currently this is not implemented.

ext/IntegralsForwardDiffExt.jl

+26-15
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,37 @@ using Integrals
33
isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
44
### Forward-Mode AD Intercepts
55

6-
#= Direct AD on solvers with QuadGK and HCubature
7-
# incompatible with iip since types must change
8-
function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, domain,
9-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
10-
kwargs...) where {T, V, P, N}
11-
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
12-
end
6+
# Default to direct AD on solvers
7+
function Integrals.__solvebp(cache, alg, sensealg, domain,
8+
p::Union{D,AbstractArray{<:D}};
9+
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}
1310

14-
function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, domain,
15-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
16-
kwargs...) where {T, V, P, N}
17-
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
11+
if isinplace(cache.f)
12+
prototype = cache.f.integrand_prototype
13+
elt = eltype(prototype)
14+
ForwardDiff.can_dual(elt) || throw(ArgumentError("ForwardDiff of in-place integrands only supports prototypes with real elements"))
15+
dprototype = similar(prototype, replace_dualvaltype(D, elt))
16+
df = if cache.f isa BatchIntegralFunction
17+
BatchIntegralFunction{true}(cache.f.f, dprototype)
18+
else
19+
IntegralFunction{true}(cache.f.f, dprototype)
20+
end
21+
prob = Integrals.build_problem(cache)
22+
dprob = remake(prob, f = df)
23+
dcache = init(dprob, alg; sensealg = sensealg, do_inf_transformation=Val(false), kwargs...)
24+
Integrals.__solvebp_call(dcache, alg, sensealg, domain, p; kwargs...)
25+
else
26+
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
27+
end
1828
end
19-
=#
29+
2030

2131
# TODO: add the pushforward for derivative w.r.t lb, and ub (and then combinations?)
2232

2333
# Manually split for the pushforward
24-
function Integrals.__solvebp(cache, alg, sensealg, domain,
25-
p::Union{D, AbstractArray{<:D}};
26-
kwargs...) where {T, V, P, D <: ForwardDiff.Dual{T, V, P}}
34+
function Integrals.__solvebp(cache, alg::Integrals.AbstractIntegralCExtensionAlgorithm, sensealg, domain,
35+
p::Union{D,AbstractArray{<:D}};
36+
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}
2737

2838
# we need the output type to avoid perturbation confusion while unwrapping nested duals
2939
# We compute a vector-valued integral of the primal and dual simultaneously
@@ -73,6 +83,7 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
7383
end
7484
end
7585

86+
DT <: Real || throw(ArgumentError("differentiating algorithms in C"))
7687
ForwardDiff.can_dual(elt) || ForwardDiff.throw_cannot_dual(elt)
7788
rawp = p isa D ? reinterpret(V, [p]) : copy(reinterpret(V, vec(p)))
7889

src/algorithms_extension.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
## Extension Algorithms
22

33
abstract type AbstractIntegralExtensionAlgorithm <: SciMLBase.AbstractIntegralAlgorithm end
4+
abstract type AbstractIntegralCExtensionAlgorithm <: AbstractIntegralExtensionAlgorithm end
45

5-
abstract type AbstractCubaAlgorithm <: AbstractIntegralExtensionAlgorithm end
6+
abstract type AbstractCubaAlgorithm <: AbstractIntegralCExtensionAlgorithm end
67

78
"""
89
CubaVegas()
@@ -152,7 +153,7 @@ function CubaCuhre(; flags = 0, minevals = 0, key = 0)
152153
return CubaCuhre(flags, minevals, key)
153154
end
154155

155-
abstract type AbstractCubatureJLAlgorithm <: AbstractIntegralExtensionAlgorithm end
156+
abstract type AbstractCubatureJLAlgorithm <: AbstractIntegralCExtensionAlgorithm end
156157

157158
"""
158159
CubatureJLh(; error_norm=Cubature.INDIVIDUAL)
@@ -219,7 +220,7 @@ documentation for additional details the algorithm arguments and on implementing
219220
high-precision integrands. Additionally, the error estimate is included in the return value
220221
of the integral, representing a ball.
221222
"""
222-
struct ArblibJL{O} <: AbstractIntegralExtensionAlgorithm
223+
struct ArblibJL{O} <: AbstractIntegralCExtensionAlgorithm
223224
check_analytic::Bool
224225
take_prec::Bool
225226
warn_on_no_convergence::Bool

0 commit comments

Comments
 (0)