Skip to content

Error when differentiating solutions of ModelingToolkit models #292

Closed
@Antomek

Description

@Antomek

As I wrote in this issue, recently differentiating MTK models stopped working for me.

A MWE:

using DifferentialEquations, ModelingToolkit

function lotka_volterra(;name=name)
    states = @variables x(t)=1.0 y(t)=1.0
    params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0

    eqs = [
    D(x) ~ p1 * x - p2 * x * y,
    D(y) ~ -p3 * y + p4 * x * y
    ]

    return ODESystem(eqs, t, states, params; name = name)
end

@named lotka_volterra_sys = lotka_volterra()

prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
sol = solve(prob,Tsit5(),reltol=1e-6,abstol=1e-6)

using Zygote, SciMLSensitivity

function sum_of_solution(u0,p)
    _prob = remake(prob,u0=u0,p=p)
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1, sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())))
end

u0 = [1.0 1.0]
p = [1.5 1. 1. 1.]
du01,dp1 = Zygote.gradient(sum_of_solution,u0,p)

gives the error:

ERROR: Compiling Tuple{Type{Dict}, Base.Iterators.Zip{Tuple{Vector{Sym{Real, Base.ImmutableDict{DataType, Any}}}, Vector{Float64}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:121
  [3] #Primal#23
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:330
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/emit.jl:101
  [6] #s2924#1068
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s2924#1068"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [9] _pullback
    @ ~/.julia/packages/SciMLBase/m11uN/src/utils.jl:477 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::Vector{Sym{Real, Base.ImmutableDict{DataType, Any}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/packages/SciMLBase/m11uN/src/remake.jl:57 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#527", ::Missing, ::Matrix{Float64}, ::Missing, ::Matrix{Float64}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#465"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fe24678, 0xefbf7ae3, 0x14077d65, 0xd38b0358, 0xca1226cf)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x2b3c6fd1, 0xac2f72a0, 0xdcafd855, 0x30fb2acf, 0x61128de0)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#488#generated_observed#472"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/packages/SciMLBase/m11uN/src/remake.jl:45 [inlined]
 [14] _pullback(::Zygote.Context{false}, ::SciMLBase.var"#remake##kw", ::NamedTuple{(:u0, :p), Tuple{Matrix{Float64}, Matrix{Float64}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#465"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fe24678, 0xefbf7ae3, 0x14077d65, 0xd38b0358, 0xca1226cf)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x2b3c6fd1, 0xac2f72a0, 0xdcafd855, 0x30fb2acf, 0x61128de0)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#488#generated_observed#472"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./REPL[15]:2 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::typeof(sum_of_solution), ::Matrix{Float64}, ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [17] pullback(::Function, ::Zygote.Context{false}, ::Matrix{Float64}, ::Vararg{Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44
 [18] pullback(::Function, ::Matrix{Float64}, ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42
 [19] gradient(::Function, ::Matrix{Float64}, ::Vararg{Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
 [20] top-level scope
    @ REPL[18]:1

As the comment on the linked issue mentiones, this is likely caused by 6249468,
and a possible fix is given in FluxML/Zygote.jl#1293 (comment).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions