Skip to content

error with ensemble solves when using Zygote #1160

Open
@SebastianM-C

Description

@SebastianM-C

Describe the bug 🐞

I'm trying to take the gradient of a function that has an ensemble solve inside it. The gradient errors due to getting a + between ensemble solutions somehow.

Expected behavior

Gradients should work.

Minimal Reproducible Example 👇

Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.

using OrdinaryDiffEqTsit5
using SciMLSensitivity
using Zygote

function mae2(sol, data)
    l = zero(eltype(data))
    for i in axes(data, 2)
        for j in axes(data, 1)
            l += abs2(sol.u[i][j] - data[j, i])
        end
    end

    l / length(data)
end

function ensemble_setup(x)
    function prob_func(prob, i, repeat)
        remake(prob, u0=rand(2))
    end

    function f(du, u, p, t)
        du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
        du[2] = -3 * u[2] + u[1] * u[2]
    end

    prob = ODEProblem(f, [0.5, 0.5], (0.0, 1.0), x)

    prob, prob_func
end

function ensemble_loss(x, data)
    prob, prob_func = ensemble_setup(x)

    ensembleprob = EnsembleProblem(prob; prob_func, safetycopy=false)

    sim = solve(ensembleprob, Tsit5(), EnsembleSerial();
        trajectories=3, saveat=[0., 0.4, 0.9],
        save_end=true)

    loss = zero(eltype(data))
    for i in Base.OneTo(3)
        sol = sim.u[i]
        loss += mae2(sol, data)
    end

    loss
end
_data = [1.1 2 4
    0 5. 6]
ensemble_loss(rand(4), _data)

Zygote.gradient(x -> ensemble_loss(x, _data), rand(4))

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching size(::Nothing)
The function `size` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  size(::IdentityOperator)
   @ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:21
  size(::NullOperator)
   @ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:115
  size(::LLVM.FunctionParameterSet)
   @ LLVM ~/.julia/packages/LLVM/b3kFs/src/core/function.jl:200
  ...

Stacktrace:
  [1] size
    @ ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:481 [inlined]
  [2] axes(VA::EnsembleSolution{Any, 1, Vector{Union{Nothing, RecursiveArrayTools.VectorOfArray{Float64, 2, Vector{}}}}})
    @ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:485
  [3] combine_axes
    @ ./broadcast.jl:497 [inlined]
  [4] instantiate
    @ ./broadcast.jl:307 [inlined]
  [5] materialize
    @ ./broadcast.jl:872 [inlined]
  [6] +(A::EnsembleSolution{Any, 1, Vector{Union{…}}}, B::EnsembleSolution{Any, 1, Vector{Union{…}}})
    @ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:661
  [7] accum(x::EnsembleSolution{Any, 1, Vector{Union{…}}}, y::EnsembleSolution{Any, 1, Vector{Union{…}}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:17
  [8] ensemble_loss
    @ ~/dev/ensemble_zygote.jl:41 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(ensemble_loss), Vector{Float64}, Matrix{Float64}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [10] #22
    @ ~/dev/ensemble_zygote.jl:52 [inlined]
 [11] (::Zygote.Pullback{Tuple{var"#22#23", Vector{…}}, Tuple{Zygote.Pullback{…}, Zygote.var"#1986#back#198"{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:148
 [14] top-level scope
    @ ~/dev/ensemble_zygote.jl:52
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
Status `~/dev/Project.toml`
  [1ed8b502] SciMLSensitivity v7.72.0
⌅ [e88e6eb3] Zygote v0.6.75
  • Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
  • Output of versioninfo()
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × Intel(R) Core(TM) i9-14900K
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions