Closed
Description
As I wrote in this issue, recently differentiating MTK models stopped working for me.
A MWE:
using DifferentialEquations, ModelingToolkit
function lotka_volterra(;name=name)
states = @variables x(t)=1.0 y(t)=1.0
params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0
eqs = [
D(x) ~ p1 * x - p2 * x * y,
D(y) ~ -p3 * y + p4 * x * y
]
return ODESystem(eqs, t, states, params; name = name)
end
@named lotka_volterra_sys = lotka_volterra()
prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
sol = solve(prob,Tsit5(),reltol=1e-6,abstol=1e-6)
using Zygote, SciMLSensitivity
function sum_of_solution(u0,p)
_prob = remake(prob,u0=u0,p=p)
sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1, sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())))
end
u0 = [1.0 1.0]
p = [1.5 1. 1. 1.]
du01,dp1 = Zygote.gradient(sum_of_solution,u0,p)
gives the error:
ERROR: Compiling Tuple{Type{Dict}, Base.Iterators.Zip{Tuple{Vector{Sym{Real, Base.ImmutableDict{DataType, Any}}}, Vector{Float64}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:121
[3] #Primal#23
@ ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:205 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:330
[5] _generate_pullback_via_decomposition(T::Type)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/emit.jl:101
[6] #s2924#1068
@ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:28 [inlined]
[7] var"#s2924#1068"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:582
[9] _pullback
@ ~/.julia/packages/SciMLBase/m11uN/src/utils.jl:477 [inlined]
[10] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::Vector{Sym{Real, Base.ImmutableDict{DataType, Any}}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[11] _pullback
@ ~/.julia/packages/SciMLBase/m11uN/src/remake.jl:57 [inlined]
[12] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#527", ::Missing, ::Matrix{Float64}, ::Missing, ::Matrix{Float64}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#465"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fe24678, 0xefbf7ae3, 0x14077d65, 0xd38b0358, 0xca1226cf)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x2b3c6fd1, 0xac2f72a0, 0xdcafd855, 0x30fb2acf, 0x61128de0)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#488#generated_observed#472"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[13] _pullback
@ ~/.julia/packages/SciMLBase/m11uN/src/remake.jl:45 [inlined]
[14] _pullback(::Zygote.Context{false}, ::SciMLBase.var"#remake##kw", ::NamedTuple{(:u0, :p), Tuple{Matrix{Float64}, Matrix{Float64}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#465"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fe24678, 0xefbf7ae3, 0x14077d65, 0xd38b0358, 0xca1226cf)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x2b3c6fd1, 0xac2f72a0, 0xdcafd855, 0x30fb2acf, 0x61128de0)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#488#generated_observed#472"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[15] _pullback
@ ./REPL[15]:2 [inlined]
[16] _pullback(::Zygote.Context{false}, ::typeof(sum_of_solution), ::Matrix{Float64}, ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[17] pullback(::Function, ::Zygote.Context{false}, ::Matrix{Float64}, ::Vararg{Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44
[18] pullback(::Function, ::Matrix{Float64}, ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42
[19] gradient(::Function, ::Matrix{Float64}, ::Vararg{Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
[20] top-level scope
@ REPL[18]:1
As the comment on the linked issue mentiones, this is likely caused by 6249468,
and a possible fix is given in FluxML/Zygote.jl#1293 (comment).
Metadata
Metadata
Assignees
Labels
No labels