Skip to content

Gradient w.r.t parameters not working with MTKParameters #1130

Open
@AayushSabharwal

Description

@AayushSabharwal

Describe the bug 🐞

Taking the gradient with respect to a vector of parameter values (which are replaced into the parameter object) is not working
with MTKParameters

Expected behavior

The gradient works

Minimal Reproducible Example 👇

using ModelingToolkit, OrdinaryDiffEq, Zygote, SciMLSensitivity
using SymbolicIndexingInterface: setp_oop
using ModelingToolkit: t_nounits as t, D_nounits as D

@variables x(t) o(t)
function lotka_volterra(; name = name)
    unknowns = @variables x(t)=1.0 y(t)=1.0 o(t)
    params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0
    eqs = [
        D(x) ~ p1 * x - p2 * x * y,
        D(y) ~ -p3 * y + p4 * x * y,
        o ~ x * y
    ]
    return ODESystem(eqs, t, unknowns, params; name = name)
end

@mtkbuild lotka_volterra_sys = lotka_volterra()
prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
u0 = [1.0, 1.0]
p = [1.5, 1.0, 1.0, 1.0]

oop_setter = setp_oop(prob, [lotka_volterra_sys.p1, lotka_volterra_sys.p2, lotka_volterra_sys.p3, lotka_volterra_sys.p4])

function symbolic_indexing(u0, p)
           _p = oop_setter(prob, p)
           _prob = remake(prob, u0 = u0, p = _p)
           soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
               sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))
           sum(soln[x])
       end

du01, dp1 = Zygote.gradient(symbolic_indexing, u0, p)

Error & Stacktrace ⚠️

Adjoint sensitivity analysis functionality requires being able to solve
a differential equation defined by the parameter struct `p`. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with adjoint sensitivity analysis requires that `p` could be a valid
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during adjoint differentiation.
To work around this issue for complicated cases like nested structs, look
into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type.

Stacktrace:
  [1] _concrete_solve_adjoint(::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Float64, save_idxs::Nothing, kwargs::@Kwargs{reltol::Float64, abstol::Float64})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/concrete_solve.jl:378
  [2] _solve_adjoint(prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, u0::Vector{Float64}, p::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, originator::SciMLBase.ChainRulesOriginator, args::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}; merge_callbacks::Bool, kwargs::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1537
  [3] rrule(::typeof(DiffEqBase.solve_up), prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, u0::Vector{Float64}, p::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, args::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}; kwargs::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64})
    @ DiffEqBaseChainRulesCoreExt ~/.julia/packages/DiffEqBase/DdIeW/ext/DiffEqBaseChainRulesCoreExt.jl:26
  [4] kwcall(::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(ChainRulesCore.rrule), ::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6Pucz/src/rules.jl:144
  [5] chain_rrule_kw
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:236 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [8] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [9] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [10] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [11] #solve#51
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::DiffEqBase.var"##solve#51", ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Nothing, ::Nothing, ::Val{true}, ::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [13] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [14] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [15] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [16] solve
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:993 [inlined]
 [17] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [18] symbolic_indexing
    @ ./REPL[125]:4 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(symbolic_indexing), ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [20] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
 [21] pullback(::Function, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88
 [22] gradient(::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:147
 [23] top-level scope
    @ REPL[126]:1

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()

(SciMLBase test/downstream environment)

Status `~/Julia/SciML/SciMLBase.jl/test/downstream/Project.toml`
  [764a87c0] BoundaryValueDiffEq v5.10.0
  [bcd4f6db] DelayDiffEq v5.48.1
⌅ [459566f4] DiffEqCallbacks v3.9.1
  [f6369f11] ForwardDiff v0.10.36
  [ccbc3e58] JumpProcesses v9.13.7
  [961ee093] ModelingToolkit v9.42.0
  [16a59e39] ModelingToolkitStandardLibrary v2.15.0
⌃ [8913a72c] NonlinearSolve v3.14.0
⌅ [7f7a1694] Optimization v3.28.0
⌅ [fd9f6733] OptimizationMOI v0.4.3
⌅ [36348300] OptimizationOptimJL v0.3.2
  [1dea7af3] OrdinaryDiffEq v6.89.0
  [91a5bcdd] Plots v1.40.8
  [731186ca] RecursiveArrayTools v3.27.0
  [0bca4576] SciMLBase v2.55.0 `../..`
  [1ed8b502] SciMLSensitivity v7.68.0
  [53ae85a6] SciMLStructures v1.5.0
  [860ef19b] StableRNGs v1.0.2
  [9672c7b4] SteadyStateDiffEq v2.4.1
  [789caeaf] StochasticDiffEq v6.69.1
  [c3572dad] Sundials v4.25.0
  [2efcf032] SymbolicIndexingInterface v0.3.31
  [d1185830] SymbolicUtils v3.7.1
  [1986cc42] Unitful v1.21.0
  [e88e6eb3] Zygote v0.6.71

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions