Open
Description
MWE:
using Flux
ann = Chain(Dense(5, 50, tanh), Dense(50, 4))
p, st = Flux.destructure(ann)
u0 = rand(5)
using Zygote
function dudt_(u, p, t)
st(p)([u..., 1f1])
end
out, back = Zygote.pullback(dudt_, u0, p, 0f0)
d_u, d_p, d_t = back(rand(4))
typeof(d_u) # NTuple
d_u
should be a vector but instead it becomes a tuple. This goes away if the function is instead:
function dudt_(u, p, t)
st(p)([u;1f1])
end
This was isolated from SciML/SciMLSensitivity.jl#1082