Open
Description
Describe the bug 🐞
Taking the gradient with respect to a vector of parameter values (which are replaced into the parameter object) is not working
with MTKParameters
Expected behavior
The gradient works
Minimal Reproducible Example 👇
using ModelingToolkit, OrdinaryDiffEq, Zygote, SciMLSensitivity
using SymbolicIndexingInterface: setp_oop
using ModelingToolkit: t_nounits as t, D_nounits as D
@variables x(t) o(t)
function lotka_volterra(; name = name)
unknowns = @variables x(t)=1.0 y(t)=1.0 o(t)
params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0
eqs = [
D(x) ~ p1 * x - p2 * x * y,
D(y) ~ -p3 * y + p4 * x * y,
o ~ x * y
]
return ODESystem(eqs, t, unknowns, params; name = name)
end
@mtkbuild lotka_volterra_sys = lotka_volterra()
prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
u0 = [1.0, 1.0]
p = [1.5, 1.0, 1.0, 1.0]
oop_setter = setp_oop(prob, [lotka_volterra_sys.p1, lotka_volterra_sys.p2, lotka_volterra_sys.p3, lotka_volterra_sys.p4])
function symbolic_indexing(u0, p)
_p = oop_setter(prob, p)
_prob = remake(prob, u0 = u0, p = _p)
soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))
sum(soln[x])
end
du01, dp1 = Zygote.gradient(symbolic_indexing, u0, p)
Error & Stacktrace
Adjoint sensitivity analysis functionality requires being able to solve
a differential equation defined by the parameter struct `p`. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with adjoint sensitivity analysis requires that `p` could be a valid
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during adjoint differentiation.
To work around this issue for complicated cases like nested structs, look
into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type.
Stacktrace:
[1] _concrete_solve_adjoint(::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Float64, save_idxs::Nothing, kwargs::@Kwargs{reltol::Float64, abstol::Float64})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/concrete_solve.jl:378
[2] _solve_adjoint(prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, u0::Vector{Float64}, p::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, originator::SciMLBase.ChainRulesOriginator, args::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}; merge_callbacks::Bool, kwargs::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1537
[3] rrule(::typeof(DiffEqBase.solve_up), prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, u0::Vector{Float64}, p::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, args::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}; kwargs::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64})
@ DiffEqBaseChainRulesCoreExt ~/.julia/packages/DiffEqBase/DdIeW/ext/DiffEqBaseChainRulesCoreExt.jl:26
[4] kwcall(::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(ChainRulesCore.rrule), ::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/6Pucz/src/rules.jl:144
[5] chain_rrule_kw
@ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:236 [inlined]
[6] macro expansion
@ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
[7] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
[8] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[9] adjoint
@ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
[10] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[11] #solve#51
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
[12] _pullback(::Zygote.Context{false}, ::DiffEqBase.var"##solve#51", ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Nothing, ::Nothing, ::Val{true}, ::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[13] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[14] adjoint
@ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
[15] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[16] solve
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:993 [inlined]
[17] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[18] symbolic_indexing
@ ./REPL[125]:4 [inlined]
[19] _pullback(::Zygote.Context{false}, ::typeof(symbolic_indexing), ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[20] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
[21] pullback(::Function, ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88
[22] gradient(::Function, ::Vector{Float64}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:147
[23] top-level scope
@ REPL[126]:1
Environment (please complete the following information):
- Output of
using Pkg; Pkg.status()
(SciMLBase test/downstream
environment)
Status `~/Julia/SciML/SciMLBase.jl/test/downstream/Project.toml`
[764a87c0] BoundaryValueDiffEq v5.10.0
[bcd4f6db] DelayDiffEq v5.48.1
⌅ [459566f4] DiffEqCallbacks v3.9.1
[f6369f11] ForwardDiff v0.10.36
[ccbc3e58] JumpProcesses v9.13.7
[961ee093] ModelingToolkit v9.42.0
[16a59e39] ModelingToolkitStandardLibrary v2.15.0
⌃ [8913a72c] NonlinearSolve v3.14.0
⌅ [7f7a1694] Optimization v3.28.0
⌅ [fd9f6733] OptimizationMOI v0.4.3
⌅ [36348300] OptimizationOptimJL v0.3.2
[1dea7af3] OrdinaryDiffEq v6.89.0
[91a5bcdd] Plots v1.40.8
[731186ca] RecursiveArrayTools v3.27.0
[0bca4576] SciMLBase v2.55.0 `../..`
[1ed8b502] SciMLSensitivity v7.68.0
[53ae85a6] SciMLStructures v1.5.0
[860ef19b] StableRNGs v1.0.2
[9672c7b4] SteadyStateDiffEq v2.4.1
[789caeaf] StochasticDiffEq v6.69.1
[c3572dad] Sundials v4.25.0
[2efcf032] SymbolicIndexingInterface v0.3.31
[d1185830] SymbolicUtils v3.7.1
[1986cc42] Unitful v1.21.0
[e88e6eb3] Zygote v0.6.71