Skip to content

Commit ca32340

Browse files
Merge pull request #319 from AstitvaAggarwal/dev
Fix CI AD tests
2 parents 7c0130d + e5a1ad3 commit ca32340

File tree

3 files changed

+4
-170
lines changed

3 files changed

+4
-170
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ HCubature = "1.7"
6161
LinearAlgebra = "1.10"
6262
MCIntegration = "0.4.2"
6363
MonteCarloIntegration = "0.2"
64-
Mooncake = "0.4.184, 0.5"
64+
Mooncake = "0.5.6"
6565
Pkg = "1.10"
6666
QuadGK = "2.11"
6767
Random = "1.10"

ext/IntegralsMooncakeExt.jl

Lines changed: 3 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ module IntegralsMooncakeExt
22
using Mooncake
33
using LinearAlgebra: dot
44
using Integrals, SciMLBase, QuadGK
5-
using Mooncake: @from_chainrules, @is_primitive, increment!!, MinimalCtx, rrule!!, NoFData,
6-
NoRData, CoDual, primal, NoRData, zero_fcodual
7-
import Mooncake: increment_and_get_rdata!, @zero_derivative
5+
using Mooncake: @from_chainrules, @is_primitive, increment!!, MinimalCtx, NoFData,
6+
CoDual, primal, NoRData, zero_fcodual, increment_and_get_rdata!, @zero_derivative
87
using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem
98
import ChainRulesCore
109
import ChainRulesCore: Tangent, NoTangent, ProjectTo
@@ -34,42 +33,9 @@ batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x))
3433
Union{<:AbstractVector, <:Number}, Union{<:AbstractVector, <:Number},
3534
}
3635

37-
# @from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} true
38-
@is_primitive MinimalCtx Tuple{Type{IntegralProblem{iip}}, Any, Any, Any} where {iip}
39-
function Mooncake.rrule!!(
40-
::CoDual{Type{IntegralProblem{iip}}}, f::CoDual,
41-
domain::CoDual, p::CoDual; kwargs...
42-
) where {iip}
43-
f_prim, domain_prim, p_prim = map(primal, (f, domain, p))
44-
prob = IntegralProblem{iip}(f_prim, domain_prim, p_prim; kwargs...)
45-
46-
function IntegralProblem_iip_pullback(Δ)
47-
data = Δ isa NoRData ? Δ : Δ.data
48-
ddomain = hasproperty(data, :domain) ? data.domain : NoRData()
49-
dp = hasproperty(data, :p) ? data.p : NoRData()
50-
dkwargs = hasproperty(Δ, :kwargs) ? data.kwargs : NoRData()
51-
52-
# domain is always a Tuple, so it always has NoFData
53-
# below conditional is in case p is an Array or similar
54-
if Mooncake.rdata_type(typeof(p_prim)) == NoRData()
55-
Mooncake.increment!!(p.dx, dp)
56-
grad_p = NoRData()
57-
else
58-
grad_p = dp
59-
end
60-
61-
return NoRData(), NoRData(), ddomain, grad_p, dkwargs
62-
end
63-
return zero_fcodual(prob), IntegralProblem_iip_pullback
64-
end
65-
66-
# Mooncake does not need chainrule for evaluate! as it supports mutation.
36+
@from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}}, Any, Any, Any} where {iip} true
6737
@from_chainrules MinimalCtx Tuple{Type{IntegralProblem}, Any, Any, Any} true
6838
@from_chainrules MinimalCtx Tuple{typeof(Integrals.u2t), Any, Any} true
69-
@from_chainrules MinimalCtx Tuple{
70-
typeof(SciMLBase.build_solution), IntegralProblem, Any, Any, Any,
71-
} true
72-
7339
@from_chainrules MinimalCtx Tuple{typeof(Integrals.__solvebp), Any, Any, Any, Any, Any} true
7440

7541
# Add MooncakeVJP support to the dispatch function defined in ZygoteExt
@@ -112,135 +78,4 @@ function Integrals._compute_dfdp_and_f(::Integrals.MooncakeVJP, cache, p, Δ)
11278
return dfdp, _f
11379
end
11480

115-
# Internal Mooncake overloads to accommodate IntegralSolution etc. Struct's Tangent Types.
116-
# Allows clear translation from ChainRules -> Mooncake's tangent.
117-
function Mooncake.increment_and_get_rdata!(
118-
f::NoFData, r::Tuple{T, T},
119-
t::Union{Tangent{Tuple{T, T}, Tuple{T, T}}, Tangent{Any, Tuple{Float64, Float64}}}
120-
) where {T <: Base.IEEEFloat}
121-
return r .+ t.backing
122-
end
123-
124-
function Mooncake.increment_and_get_rdata!(
125-
f::Tuple{Vector{T}, Vector{T}},
126-
r::NoRData,
127-
t::Tangent{Any, Tuple{Vector{T}, Vector{T}}}
128-
) where {T <: Base.IEEEFloat}
129-
Mooncake.increment!!(f, t.backing)
130-
return NoRData()
131-
end
132-
133-
# sol.u & p are single scalar values, domain (lb,ub) is single/multi - variate.
134-
function Mooncake.increment_and_get_rdata!(
135-
f::NoFData,
136-
r::T,
137-
t::Tangent{
138-
Any,
139-
@NamedTuple{
140-
u::T,
141-
resid::R,
142-
prob::Tangent{
143-
Any,
144-
@NamedTuple{
145-
f::NoTangent,
146-
domain::Tangent{Any, Tuple{M, M}},
147-
p::P,
148-
kwargs::NoTangent,
149-
}
150-
},
151-
alg::A,
152-
retcode::NoTangent,
153-
chi::NoTangent,
154-
stats::NoTangent,
155-
}
156-
}
157-
) where {
158-
T <: Base.IEEEFloat,
159-
R <: Union{NoTangent, T},
160-
P <: Union{T, Vector{T}},
161-
M <: Union{T, Vector{T}},
162-
A <: Union{
163-
NoTangent,
164-
Tangent{
165-
Any,
166-
@NamedTuple{
167-
nodes::Vector{T},
168-
weights::Vector{T},
169-
subintervals::NoTangent,
170-
}
171-
}
172-
}
173-
}
174-
# rdata component of t + r (u field)
175-
return Mooncake.increment_and_get_rdata!(f, r, t.u)
176-
end
177-
178-
# sol.u is vector valued, p is scalar/vector valued, domain can be single/multi - variate
179-
# resid can be single/vector valued. For inplace integrals (iip true) : included integrand_prototype field in typeof{prob.f}
180-
function Mooncake.increment_and_get_rdata!(
181-
f::Vector{T},
182-
r::NoRData,
183-
t::Union{
184-
Tangent{
185-
Any,
186-
@NamedTuple{
187-
u::Vector{T},
188-
resid::R,
189-
prob::Tangent{
190-
Any,
191-
@NamedTuple{
192-
f::F,
193-
domain::Tangent{Any, M},
194-
p::P,
195-
kwargs::NoTangent,
196-
}
197-
},
198-
alg::A,
199-
retcode::NoTangent,
200-
chi::NoTangent,
201-
stats::NoTangent,
202-
}
203-
}
204-
}
205-
) where {
206-
T <: Base.IEEEFloat,
207-
R <: Union{NoTangent, T, Vector{T}},
208-
P <: Union{T, Vector{T}},
209-
M <: Union{Tuple{T, T}, Tuple{Vector{T}, Vector{T}}},
210-
F <: Union{
211-
NoTangent,
212-
Tangent{
213-
Any,
214-
@NamedTuple{
215-
f::NoTangent,
216-
integrand_prototype::Vector{T},
217-
}
218-
}
219-
},
220-
A <: Union{
221-
NoTangent,
222-
Tangent{
223-
Any,
224-
@NamedTuple{
225-
nodes::Vector{T},
226-
weights::Vector{T},
227-
subintervals::NoTangent,
228-
}
229-
}
230-
}
231-
}
232-
Mooncake.increment!!(f, t.u)
233-
# rdata component(t) + r
234-
return t.prob.domain
235-
end
236-
237-
# cannot mutate NoRData() in place, therefore return as is.
238-
function Mooncake.increment!!(
239-
::Mooncake.NoRData,
240-
y::Tangent{Any, Y}
241-
) where {
242-
T <: Base.IEEEFloat, Y <: Union{Tuple{T, T}, Tuple{Vector{T}, Vector{T}}},
243-
}
244-
return Mooncake.NoRData()
245-
end
24681
end

test/derivative_tests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ do_tests_mooncake = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
182182
end
183183
sol_fp = testf(lb, ub, p)
184184

185-
# sensealg when non zygoet?
186185
cache = Mooncake.prepare_gradient_cache(
187186
testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p
188187
)

0 commit comments

Comments
 (0)