-
Notifications
You must be signed in to change notification settings - Fork 90
Description
I am trying to make the RHS of an ODEProblem Enzyme compatible. My function has the signature (du, u, p, t) and I try to differentiate du for u for constant p and t. I hit the error
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
for some operations which use p in a calculation for du. I am quite new to Enzyme and don't fully understand this error, but on very simple examples it isn't a problem to use Const(p) to calculate Duplicated(du).
I boiled it down to 2 MWEs. The first MWE is closer to my actual code, including loop unrolling. The second MWE seems to error because of the broadcasting but does not need the loop unrolling to fail. I am not sure whether both demonstrate the same or different problems.
Both Examples have been created on Julia 1.10.5 and Enzyme 0.13.8. I am aware of set_runtime_activity, which works for forward mode in my actual example but segfaults for reverse mode...
MWE 1
using Pkg
pkg"activate --temp"
pkg"add Enzyme"
using Enzyme: Enzyme
@inline function unrolled_foreach(f::F, t::Tuple) where {F}
f(first(t))
@inline unrolled_foreach(f, Base.tail(t))
end
@inline unrolled_foreach(f::F, t::Tuple{}) where {F} = nothing
struct Functor{T}
batches::T
end
function (f::Functor)(du, u, p, t)
unrolled_foreach(f.batches) do batch
for i in 1:2
start = 1 + (i-1) * 2
stop = start + 1
range = start:stop
_du = view(du, range)
_p = view(p, range)
_du[1] = _p[1]
end
end
nothing
end
batches = (1,)
f = Functor(batches)
# test normal call
dx, x, p, t = zeros(4), zeros(4), collect(1.0:4.0), NaN
f(dx, x, p, t)
dx
# f_and_df = Enzyme.Duplicated(f, Enzyme.make_zero(f))
dxD = Enzyme.Duplicated(zeros(4), zeros(4))
xD = Enzyme.Duplicated(x, [1.0, 0.0, 0.0, 0.0])
pC = Enzyme.Const(p)
tC = Enzyme.Const(NaN)
Enzyme.autodiff(Enzyme.Forward, f, dxD, xD, pC, tC)MWE 2
using Pkg
pkg"activate --temp"
pkg"add Enzyme"
using Enzyme: Enzyme
struct Functor{RT}
range::RT
end
function (f::Functor)(du, u, p, t)
r = f.range
# r = 1:4 # this literal would work
_du = view(du, r)
_p = view(p, r)
_du .= _p
nothing
end
f = Functor(1:4)
# test normal function call
dx, x, p, t = zeros(4), zeros(4), collect(1.0:4.0), NaN
f(dx, x, p, t)
@assert dx == 1:4
dxD = Enzyme.Duplicated(zeros(4), zeros(4))
xD = Enzyme.Duplicated(x, [1.0, 0.0, 0.0, 0.0])
pC = Enzyme.Const(p)
tC = Enzyme.Const(NaN)
Enzyme.autodiff(Enzyme.Forward, f, dxD, xD, pC, tC)