Skip to content

Commit 9f041dc

Browse files
Merge pull request #66 from SciML/chainrules
Move from ZygoteRules to ChainRules
2 parents b7e75c5 + cc8b221 commit 9f041dc

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>"]
44
version = "1.8.1"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
89
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
910
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -19,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1920
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2021

2122
[compat]
23+
ChainRulesCore = "0.10.7"
2224
CommonSolve = "0.2"
2325
DiffEqBase = "6.1"
2426
Distributions = "0.23, 0.24, 0.25"

src/Quadrature.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ module Quadrature
22

33
using Requires, Reexport, MonteCarloIntegration, QuadGK, HCubature
44
@reexport using DiffEqBase
5-
using ZygoteRules, Zygote, ReverseDiff, ForwardDiff , LinearAlgebra
5+
using Zygote, ReverseDiff, ForwardDiff , LinearAlgebra
6+
7+
import ChainRulesCore
8+
import ChainRulesCore: NoTangent
9+
import ZygoteRules
610

711
struct QuadGKJL <: DiffEqBase.AbstractQuadratureAlgorithm end
812
struct HCubatureJL <: DiffEqBase.AbstractQuadratureAlgorithm end
@@ -485,7 +489,7 @@ function __init__()
485489
end
486490
end
487491

488-
ZygoteRules.@adjoint function __solvebp(prob,alg,sensealg,lb,ub,p,args...;kwargs...)
492+
function ChainRulesCore.rrule(::typeof(__solvebp),prob,alg,sensealg,lb,ub,p,args...;kwargs...)
489493
out = __solvebp_call(prob,alg,sensealg,lb,ub,p,args...;kwargs...)
490494
function quadrature_adjoint(Δ)
491495
y = typeof(Δ) <: Array{<:Number,0} ? Δ[1] : Δ
@@ -549,9 +553,9 @@ ZygoteRules.@adjoint function __solvebp(prob,alg,sensealg,lb,ub,p,args...;kwargs
549553
if lb isa Number
550554
dlb = -_f(lb)
551555
dub = _f(ub)
552-
return (nothing,nothing,nothing,dlb,dub,dp,ntuple(x->nothing,length(args))...)
556+
return (NoTangent(),NoTangent(),NoTangent(),NoTangent(),dlb,dub,dp,ntuple(x->NoTangent(),length(args))...)
553557
else
554-
return (nothing,nothing,nothing,nothing,nothing,dp,ntuple(x->nothing,length(args))...)
558+
return (NoTangent(),NoTangent(),NoTangent(),NoTangent(),NoTangent(),NoTangent(),dp,ntuple(x->NoTangent(),length(args))...)
555559
end
556560
end
557561
out,quadrature_adjoint

0 commit comments

Comments
 (0)