Open
Description
Describe the bug 🐞
I'm trying to take the gradient of a function that has an ensemble solve inside it. The gradient errors due to getting a +
between ensemble solutions somehow.
Expected behavior
Gradients should work.
Minimal Reproducible Example 👇
Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.
using OrdinaryDiffEqTsit5
using SciMLSensitivity
using Zygote
function mae2(sol, data)
l = zero(eltype(data))
for i in axes(data, 2)
for j in axes(data, 1)
l += abs2(sol.u[i][j] - data[j, i])
end
end
l / length(data)
end
function ensemble_setup(x)
function prob_func(prob, i, repeat)
remake(prob, u0=rand(2))
end
function f(du, u, p, t)
du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = -3 * u[2] + u[1] * u[2]
end
prob = ODEProblem(f, [0.5, 0.5], (0.0, 1.0), x)
prob, prob_func
end
function ensemble_loss(x, data)
prob, prob_func = ensemble_setup(x)
ensembleprob = EnsembleProblem(prob; prob_func, safetycopy=false)
sim = solve(ensembleprob, Tsit5(), EnsembleSerial();
trajectories=3, saveat=[0., 0.4, 0.9],
save_end=true)
loss = zero(eltype(data))
for i in Base.OneTo(3)
sol = sim.u[i]
loss += mae2(sol, data)
end
loss
end
_data = [1.1 2 4
0 5. 6]
ensemble_loss(rand(4), _data)
Zygote.gradient(x -> ensemble_loss(x, _data), rand(4))
Error & Stacktrace
ERROR: MethodError: no method matching size(::Nothing)
The function `size` exists, but no method is defined for this combination of argument types.
Closest candidates are:
size(::IdentityOperator)
@ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:21
size(::NullOperator)
@ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:115
size(::LLVM.FunctionParameterSet)
@ LLVM ~/.julia/packages/LLVM/b3kFs/src/core/function.jl:200
...
Stacktrace:
[1] size
@ ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:481 [inlined]
[2] axes(VA::EnsembleSolution{Any, 1, Vector{Union{Nothing, RecursiveArrayTools.VectorOfArray{Float64, 2, Vector{…}}}}})
@ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:485
[3] combine_axes
@ ./broadcast.jl:497 [inlined]
[4] instantiate
@ ./broadcast.jl:307 [inlined]
[5] materialize
@ ./broadcast.jl:872 [inlined]
[6] +(A::EnsembleSolution{Any, 1, Vector{Union{…}}}, B::EnsembleSolution{Any, 1, Vector{Union{…}}})
@ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:661
[7] accum(x::EnsembleSolution{Any, 1, Vector{Union{…}}}, y::EnsembleSolution{Any, 1, Vector{Union{…}}})
@ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:17
[8] ensemble_loss
@ ~/dev/ensemble_zygote.jl:41 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(ensemble_loss), Vector{Float64}, Matrix{Float64}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
[10] #22
@ ~/dev/ensemble_zygote.jl:52 [inlined]
[11] (::Zygote.Pullback{Tuple{var"#22#23", Vector{…}}, Tuple{Zygote.Pullback{…}, Zygote.var"#1986#back#198"{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
[12] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:148
[14] top-level scope
@ ~/dev/ensemble_zygote.jl:52
Some type information was truncated. Use `show(err)` to see complete types.
Environment (please complete the following information):
- Output of
using Pkg; Pkg.status()
Status `~/dev/Project.toml`
[1ed8b502] SciMLSensitivity v7.72.0
⌅ [e88e6eb3] Zygote v0.6.75
- Output of
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
- Output of
versioninfo()
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 32 × Intel(R) Core(TM) i9-14900K
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
Environment:
JULIA_EDITOR = code