Skip to content

Error when using view of Const to calculate view of Duplicated #1956

@hexaeder

Description

@hexaeder

I am trying to make the RHS of an ODEProblem Enzyme compatible. My function has the signature (du, u, p, t) and I try to differentiate du for u for constant p and t. I hit the error

ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.

for some operations which use p in a calculation for du. I am quite new to Enzyme and don't fully understand this error, but on very simple examples it isn't a problem to use Const(p) to calculate Duplicated(du).

I boiled it down to 2 MWEs. The first MWE is closer to my actual code, including loop unrolling. The second MWE seems to error because of the broadcasting but does not need the loop unrolling to fail. I am not sure whether both demonstrate the same or different problems.

Both Examples have been created on Julia 1.10.5 and Enzyme 0.13.8. I am aware of set_runtime_activity, which works for forward mode in my actual example but segfaults for reverse mode...

MWE 1

using Pkg
pkg"activate --temp"
pkg"add Enzyme"
using Enzyme: Enzyme

@inline function unrolled_foreach(f::F, t::Tuple) where {F}
    f(first(t))
    @inline unrolled_foreach(f, Base.tail(t))
end
@inline unrolled_foreach(f::F, t::Tuple{}) where {F} = nothing

struct Functor{T}
    batches::T
end
function (f::Functor)(du, u, p, t)
    unrolled_foreach(f.batches) do batch
        for i in 1:2
            start = 1 + (i-1) * 2
            stop = start + 1
            range = start:stop

            _du  = view(du, range)
            _p   = view(p, range)
            _du[1] = _p[1]
        end
    end
    nothing
end

batches = (1,)
f = Functor(batches)

# test normal call
dx, x, p, t = zeros(4), zeros(4), collect(1.0:4.0), NaN
f(dx, x, p, t)
dx

# f_and_df = Enzyme.Duplicated(f, Enzyme.make_zero(f))
dxD = Enzyme.Duplicated(zeros(4), zeros(4))
xD = Enzyme.Duplicated(x, [1.0, 0.0, 0.0, 0.0])
pC = Enzyme.Const(p)
tC = Enzyme.Const(NaN)
Enzyme.autodiff(Enzyme.Forward, f, dxD, xD, pC, tC)

MWE 2

using Pkg
pkg"activate --temp"
pkg"add Enzyme"
using Enzyme: Enzyme

struct Functor{RT}
    range::RT
end
function (f::Functor)(du, u, p, t)
    r = f.range
    # r = 1:4 # this literal would work
    _du  = view(du, r)
    _p   = view(p, r)
    _du .= _p
    nothing
end

f = Functor(1:4)

# test normal function call
dx, x, p, t = zeros(4), zeros(4), collect(1.0:4.0), NaN
f(dx, x, p, t)
@assert dx == 1:4

dxD = Enzyme.Duplicated(zeros(4), zeros(4))
xD = Enzyme.Duplicated(x, [1.0, 0.0, 0.0, 0.0])
pC = Enzyme.Const(p)
tC = Enzyme.Const(NaN)
Enzyme.autodiff(Enzyme.Forward, f, dxD, xD, pC, tC)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions