@@ -2,9 +2,8 @@ module IntegralsMooncakeExt
22using Mooncake
33using LinearAlgebra: dot
44using Integrals, SciMLBase, QuadGK
5- using Mooncake: @from_chainrules , @is_primitive , increment!!, MinimalCtx, rrule!!, NoFData,
6- NoRData, CoDual, primal, NoRData, zero_fcodual
7- import Mooncake: increment_and_get_rdata!, @zero_derivative
5+ using Mooncake: @from_chainrules , @is_primitive , increment!!, MinimalCtx, NoFData,
6+ CoDual, primal, NoRData, zero_fcodual, increment_and_get_rdata!, @zero_derivative
87using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem
98import ChainRulesCore
109import ChainRulesCore: Tangent, NoTangent, ProjectTo
@@ -34,42 +33,9 @@ batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x))
3433 Union{<: AbstractVector , <: Number }, Union{<: AbstractVector , <: Number },
3534}
3635
37- # @from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} true
38- @is_primitive MinimalCtx Tuple{Type{IntegralProblem{iip}}, Any, Any, Any} where {iip}
39- function Mooncake. rrule!! (
40- :: CoDual{Type{IntegralProblem{iip}}} , f:: CoDual ,
41- domain:: CoDual , p:: CoDual ; kwargs...
42- ) where {iip}
43- f_prim, domain_prim, p_prim = map (primal, (f, domain, p))
44- prob = IntegralProblem {iip} (f_prim, domain_prim, p_prim; kwargs... )
45-
46- function IntegralProblem_iip_pullback (Δ)
47- data = Δ isa NoRData ? Δ : Δ. data
48- ddomain = hasproperty (data, :domain ) ? data. domain : NoRData ()
49- dp = hasproperty (data, :p ) ? data. p : NoRData ()
50- dkwargs = hasproperty (Δ, :kwargs ) ? data. kwargs : NoRData ()
51-
52- # domain is always a Tuple, so it always has NoFData
53- # below conditional is in case p is an Array or similar
54- if Mooncake. rdata_type (typeof (p_prim)) == NoRData ()
55- Mooncake. increment!! (p. dx, dp)
56- grad_p = NoRData ()
57- else
58- grad_p = dp
59- end
60-
61- return NoRData (), NoRData (), ddomain, grad_p, dkwargs
62- end
63- return zero_fcodual (prob), IntegralProblem_iip_pullback
64- end
65-
66- # Mooncake does not need chainrule for evaluate! as it supports mutation.
36+ @from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}}, Any, Any, Any} where {iip} true
6737@from_chainrules MinimalCtx Tuple{Type{IntegralProblem}, Any, Any, Any} true
6838@from_chainrules MinimalCtx Tuple{typeof (Integrals. u2t), Any, Any} true
69- @from_chainrules MinimalCtx Tuple{
70- typeof (SciMLBase. build_solution), IntegralProblem, Any, Any, Any,
71- } true
72-
7339@from_chainrules MinimalCtx Tuple{typeof (Integrals. __solvebp), Any, Any, Any, Any, Any} true
7440
7541# Add MooncakeVJP support to the dispatch function defined in ZygoteExt
@@ -112,135 +78,4 @@ function Integrals._compute_dfdp_and_f(::Integrals.MooncakeVJP, cache, p, Δ)
11278 return dfdp, _f
11379end
11480
115- # Internal Mooncake overloads to accommodate IntegralSolution etc. Struct's Tangent Types.
116- # Allows clear translation from ChainRules -> Mooncake's tangent.
117- function Mooncake. increment_and_get_rdata! (
118- f:: NoFData , r:: Tuple{T, T} ,
119- t:: Union{Tangent{Tuple{T, T}, Tuple{T, T}}, Tangent{Any, Tuple{Float64, Float64}}}
120- ) where {T <: Base.IEEEFloat }
121- return r .+ t. backing
122- end
123-
124- function Mooncake. increment_and_get_rdata! (
125- f:: Tuple{Vector{T}, Vector{T}} ,
126- r:: NoRData ,
127- t:: Tangent{Any, Tuple{Vector{T}, Vector{T}}}
128- ) where {T <: Base.IEEEFloat }
129- Mooncake. increment!! (f, t. backing)
130- return NoRData ()
131- end
132-
133- # sol.u & p are single scalar values, domain (lb,ub) is single/multi - variate.
134- function Mooncake. increment_and_get_rdata! (
135- f:: NoFData ,
136- r:: T ,
137- t:: Tangent {
138- Any,
139- @NamedTuple {
140- u:: T ,
141- resid:: R ,
142- prob:: Tangent {
143- Any,
144- @NamedTuple {
145- f:: NoTangent ,
146- domain:: Tangent{Any, Tuple{M, M}} ,
147- p:: P ,
148- kwargs:: NoTangent ,
149- }
150- },
151- alg:: A ,
152- retcode:: NoTangent ,
153- chi:: NoTangent ,
154- stats:: NoTangent ,
155- }
156- }
157- ) where {
158- T <: Base.IEEEFloat ,
159- R <: Union{NoTangent, T} ,
160- P <: Union{T, Vector{T}} ,
161- M <: Union{T, Vector{T}} ,
162- A <: Union {
163- NoTangent,
164- Tangent{
165- Any,
166- @NamedTuple {
167- nodes:: Vector{T} ,
168- weights:: Vector{T} ,
169- subintervals:: NoTangent ,
170- }
171- }
172- }
173- }
174- # rdata component of t + r (u field)
175- return Mooncake. increment_and_get_rdata! (f, r, t. u)
176- end
177-
178- # sol.u is vector valued, p is scalar/vector valued, domain can be single/multi - variate
179- # resid can be single/vector valued. For inplace integrals (iip true) : included integrand_prototype field in typeof{prob.f}
180- function Mooncake. increment_and_get_rdata! (
181- f:: Vector{T} ,
182- r:: NoRData ,
183- t:: Union {
184- Tangent{
185- Any,
186- @NamedTuple {
187- u:: Vector{T} ,
188- resid:: R ,
189- prob:: Tangent {
190- Any,
191- @NamedTuple {
192- f:: F ,
193- domain:: Tangent{Any, M} ,
194- p:: P ,
195- kwargs:: NoTangent ,
196- }
197- },
198- alg:: A ,
199- retcode:: NoTangent ,
200- chi:: NoTangent ,
201- stats:: NoTangent ,
202- }
203- }
204- }
205- ) where {
206- T <: Base.IEEEFloat ,
207- R <: Union{NoTangent, T, Vector{T}} ,
208- P <: Union{T, Vector{T}} ,
209- M <: Union{Tuple{T, T}, Tuple{Vector{T}, Vector{T}}} ,
210- F <: Union {
211- NoTangent,
212- Tangent{
213- Any,
214- @NamedTuple {
215- f:: NoTangent ,
216- integrand_prototype:: Vector{T} ,
217- }
218- }
219- },
220- A <: Union {
221- NoTangent,
222- Tangent{
223- Any,
224- @NamedTuple {
225- nodes:: Vector{T} ,
226- weights:: Vector{T} ,
227- subintervals:: NoTangent ,
228- }
229- }
230- }
231- }
232- Mooncake. increment!! (f, t. u)
233- # rdata component(t) + r
234- return t. prob. domain
235- end
236-
237- # cannot mutate NoRData() in place, therefore return as is.
238- function Mooncake. increment!! (
239- :: Mooncake.NoRData ,
240- y:: Tangent{Any, Y}
241- ) where {
242- T <: Base.IEEEFloat , Y <: Union{Tuple{T, T}, Tuple{Vector{T}, Vector{T}}} ,
243- }
244- return Mooncake. NoRData ()
245- end
24681end
0 commit comments