@@ -2,7 +2,11 @@ module Quadrature
2
2
3
3
using Requires, Reexport, MonteCarloIntegration, QuadGK, HCubature
4
4
@reexport using DiffEqBase
5
- using ZygoteRules, Zygote, ReverseDiff, ForwardDiff , LinearAlgebra
5
+ using Zygote, ReverseDiff, ForwardDiff , LinearAlgebra
6
+
7
+ import ChainRulesCore
8
+ import ChainRulesCore: NoTangent
9
+ import ZygoteRules
6
10
7
11
struct QuadGKJL <: DiffEqBase.AbstractQuadratureAlgorithm end
8
12
struct HCubatureJL <: DiffEqBase.AbstractQuadratureAlgorithm end
@@ -485,7 +489,7 @@ function __init__()
485
489
end
486
490
end
487
491
488
- ZygoteRules . @adjoint function __solvebp ( prob,alg,sensealg,lb,ub,p,args... ;kwargs... )
492
+ function ChainRulesCore . rrule ( :: typeof (__solvebp), prob,alg,sensealg,lb,ub,p,args... ;kwargs... )
489
493
out = __solvebp_call (prob,alg,sensealg,lb,ub,p,args... ;kwargs... )
490
494
function quadrature_adjoint (Δ)
491
495
y = typeof (Δ) <: Array{<:Number,0} ? Δ[1 ] : Δ
@@ -549,9 +553,9 @@ ZygoteRules.@adjoint function __solvebp(prob,alg,sensealg,lb,ub,p,args...;kwargs
549
553
if lb isa Number
550
554
dlb = - _f (lb)
551
555
dub = _f (ub)
552
- return (nothing , nothing , nothing , dlb,dub,dp,ntuple (x-> nothing ,length (args))... )
556
+ return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), dlb,dub,dp,ntuple (x-> NoTangent () ,length (args))... )
553
557
else
554
- return (nothing , nothing , nothing , nothing , nothing , dp,ntuple (x-> nothing ,length (args))... )
558
+ return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (), dp,ntuple (x-> NoTangent () ,length (args))... )
555
559
end
556
560
end
557
561
out,quadrature_adjoint
0 commit comments