Skip to content
17 changes: 13 additions & 4 deletions ext/IntegralsMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Mooncake
using LinearAlgebra: dot
using Integrals, SciMLBase, QuadGK
using Mooncake: @from_chainrules, @is_primitive, increment!!, MinimalCtx, rrule!!, NoFData,
NoRData, CoDual, primal, NoRData, zero_fcodual
CoDual, primal, NoRData, zero_fcodual
import Mooncake: increment_and_get_rdata!, @zero_derivative
using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem
import ChainRulesCore
Expand Down Expand Up @@ -116,7 +116,7 @@ end
# Allows clear translation from ChainRules -> Mooncake's tangent.
function Mooncake.increment_and_get_rdata!(
f::NoFData, r::Tuple{T, T},
t::Union{Tangent{Tuple{T, T}, Tuple{T, T}}, Tangent{Any, Tuple{Float64, Float64}}}
t::Union{Tangent{Tuple{T, T}, Tuple{T, T}}, Tangent{Any, Tuple{T, T}}}
) where {T <: Base.IEEEFloat}
return r .+ t.backing
end
Expand Down Expand Up @@ -234,13 +234,22 @@ function Mooncake.increment_and_get_rdata!(
return t.prob.domain
end

# cannot mutate NoRData() in place, therefore return as is.
function Mooncake.increment!!(
::Mooncake.NoRData,
y::Tangent{Any, Y}
) where {
T <: Base.IEEEFloat, Y <: Union{Tuple{T, T}, Tuple{Vector{T}, Vector{T}}},
}
return Mooncake.NoRData()
return NoRData()
end

function Mooncake.increment!!(
x::Tangent{Any, Y},
::Mooncake.NoRData
) where {
T <: Base.IEEEFloat, Y <: Union{Tuple{T, T}, Tuple{Vector{T}, Vector{T}}},
}
return x
end

end
1 change: 0 additions & 1 deletion test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ do_tests_mooncake = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
end
sol_fp = testf(lb, ub, p)

# sensealg when non zygoet?
cache = Mooncake.prepare_gradient_cache(
testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p
)
Expand Down