Description
Now that direct adjoints are starting to work with Enzyme over OrdinaryDiffEq.jl, it would make sense to add this to the SciMLSensitivity.jl system.
using Enzyme, OrdinaryDiffEq, StaticArrays
function lorenz!(du, u, p, t)
du[1] = 10.0(u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
function f(y::Array{Float64}, u0::Array{Float64})
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
y .= sol[1,:]
return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y = zeros(13)
dy = zeros(13)
Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0));
That's a working demonstration. Now what we need is just an EnzymeAdjoint
struct which then does exactly that internally: https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/concrete_solve.jl#L1222-L1405.
Better Support for EnzymeAdjoint inside an Enzyme Diff
Now that version is great for a user which defines a loss function with Zygote, but then does sensealg=EnzymeAdjoint()
and we take care of the hard ODE part. But if the user uses Enzyme for the loss function and differentiates the ODE, we should somehow detect this case and completely remove it from being hitting the SciMLSensitivity path in the DiffEqBase. Basically if sensealg=EnzymeAdjoint()
and in an Enzyme environment, solve
should then just switch to sensealg = DiffEqBase.SensitivityADPassThrough()
. That said, I don't know how to detect the "in an Enzyme environment", so I don't know how to pull this off. @wsmoses it would be helpful to know how to do this. If this is done then I think we get some extra speed bonuses since then there's no rules used at all in this case.
Supporting EnzymeAdjoint for SDEs
It's probably the same steps as what was required for ODEs, which was:
- Refactor ODEIntegrator to not allow undef fsal states OrdinaryDiffEq.jl#2390
- Split and inactivate increment functions OrdinaryDiffEq.jl#2389
- Add Enzyme support for fastpow DiffEqBase.jl#1072
Since both use the same fastpow
, that should already be handled. The SDE integrator type does not use FSAL, https://github.com/SciML/StochasticDiffEq.jl/blob/master/src/integrators/type.jl, so that PR isn't handled. Which means only SciML/OrdinaryDiffEq.jl#2390 is the same issue.
But SciML/OrdinaryDiffEq.jl#2390 was a workaround for a bug in Enzyme, which is maybe fixed now? (@wsmoses). So it's worth just giving direct Enzyme a try. To do it, you'd put it into a mode that force it to ignore the SciMLSensitivity adjoint rules, which is what the ODE code above is doing there. We'd just need an SDE test case like:
using Enzyme, StochasticDiffEq, StaticArrays
function lorenz!(du, u, p, t)
du[1] = 10.0(u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
function lorenz_noise!(du, u, p, t)
du .= 0.1u
end
const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
function f(y::Array{Float64}, u0::Array{Float64})
tspan = (0.0, 3.0)
prob = SDEProblem{true}(lorenz!, lorenz_noise!, u0, tspan)
sol = DiffEqBase.solve(prob, EM(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
y .= sol[1,:]
return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y = zeros(13)
dy = zeros(13)
Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0));
I haven't ran that to see how it does, but it might just work now.