Skip to content

How to get the gradient (zygote) calculated by DynamicalSystems #241

Open
@chooron

Description

@chooron

Describe the bug
Hello, I want to use DynamicalSystems.jl to calculate the gradient, but the following problem occurs, which may be related to mutate array

Minimal Working Example

using DynamicalSystems
using Zygote
using StaticArrays

function f2(pp)
    function henon_rule(u, p, n) # here `n` is "time", but we don't use it.
        x, y = u # system state
        a, b = p # system parameters
        xn = 1.0 - a * x^2 + y
        yn = b * x
        return SVector(xn, yn)
        # return SA[xn, yn]
        # return SVector(1, 2, 3)
        # return @SVector [1, 2, 3]
        # return SVector{2}([xn,yn])
    end

    u0 = [0.2, 0.3]
    p0 = [1.4, 0.3]

    henon = DeterministicIteratedMap(henon_rule, u0, pp)

    total_time = 10_000
    X, t = trajectory(henon, total_time)
    sum(sum(X))
end

gradient(f2, [0.7, 0.6])

output:

ERROR: Mutating arrays is not supported -- called setindex!(Vector{SVector{2, Float64}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{SVector{2, Float64}})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\array.jl:70
  [3] (::Zygote.var"#539#540"{Vector{SVector{2, Float64}}})(::Nothing)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{SVector{2, Float64}}}})(Δ::Nothing)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
  [5] #trajectory_discrete#6
    @ D:\Julia\Julia-1.10.4\packages\packages\DynamicalSystemsBase\zFDur\src\core\trajectory.jl:58 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{@NamedTuple{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [7] trajectory_discrete
    @ D:\Julia\Julia-1.10.4\packages\packages\DynamicalSystemsBase\zFDur\src\core\trajectory.jl:45 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{@NamedTuple{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [9] #trajectory#5
    @ D:\Julia\Julia-1.10.4\packages\packages\DynamicalSystemsBase\zFDur\src\core\trajectory.jl:39 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{@NamedTuple{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [11] trajectory
    @ D:\Julia\Julia-1.10.4\packages\packages\DynamicalSystemsBase\zFDur\src\core\trajectory.jl:33 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{@NamedTuple{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [13] trajectory
    @ D:\Julia\Julia-1.10.4\packages\packages\DynamicalSystemsBase\zFDur\src\core\trajectory.jl:33 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{@NamedTuple{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [15] f2
    @ e:\JlCode\LumpedHydro\temp\test_ds.jl:83 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:91
 [18] gradient(f::Function, args::Vector{Float64})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:148
 [19] top-level scope
    @ e:\JlCode\LumpedHydro\temp\test_ds.jl:87
Some type information was truncated. Use `show(err)` to see complete types.

Package versions

Please provide the versions you use. To do this, run the code:

using Pkg
Pkg.status([
    "DynamicalSystems",
    "StateSpaceSets", "DynamicalSystemsBase", "RecurrenceAnalysis", "FractalDimensions", "DelayEmbeddings", "ComplexityMeasures", "TimeseriesSurrogates", "PredefinedDynamicalSystems", "Attractors", "ChaosTools"
    ];
    mode = PKGMODE_MANIFEST
)
  [61744808] DynamicalSystems v3.3.17
  [90137ffa] StaticArrays v1.9.7
  [e88e6eb3] Zygote v0.6.70

Metadata

Metadata

Assignees

No one assigned

    Labels

    externalRelated with an external libraryquestion

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions