Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- SDE3
version:
- '1'
- '1.11'
- 'lts'
steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 3 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit,
# AD Backends
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent,
AbstractThunk, AbstractTangent
using Enzyme: Enzyme
@static if VERSION < v"1.12"
using Enzyme: Enzyme
end
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Tracker: Tracker, TrackedArray
Expand Down
178 changes: 90 additions & 88 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1285,105 +1285,107 @@ function DiffEqBase._concrete_solve_adjoint(
p)
end

function DiffEqBase._concrete_solve_adjoint(
prob::Union{SciMLBase.AbstractDiscreteProblem,
SciMLBase.AbstractODEProblem,
SciMLBase.AbstractDAEProblem,
SciMLBase.AbstractDDEProblem,
SciMLBase.AbstractSDEProblem,
SciMLBase.AbstractSDDEProblem,
SciMLBase.AbstractRODEProblem
},
alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
du0 = Enzyme.make_zero(u0)
dp = Enzyme.make_zero(p)
mode = sensealg.mode

# Force no FunctionWrappers for Enzyme
_prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) )

diff_func = (u0,
p) -> solve(_prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)

splitmode = if mode isa Enzyme.ForwardMode
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode isa Enzyme.ReverseMode
Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal)
end
@static if VERSION < v"1.12"
function DiffEqBase._concrete_solve_adjoint(
prob::Union{SciMLBase.AbstractDiscreteProblem,
SciMLBase.AbstractODEProblem,
SciMLBase.AbstractDAEProblem,
SciMLBase.AbstractDDEProblem,
SciMLBase.AbstractSDEProblem,
SciMLBase.AbstractSDDEProblem,
SciMLBase.AbstractRODEProblem
},
alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
du0 = Enzyme.make_zero(u0)
dp = Enzyme.make_zero(p)
mode = sensealg.mode

forward,
reverse = Enzyme.autodiff_thunk(
splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated,
Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)})
tape, result,
shadow_result = forward(
Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp))

function enzyme_sensitivity_backpass(Δ)
if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray)
for (x, y) in zip(shadow_result.u, Δ.u)
x .= y
end
else
error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support")
# Force no FunctionWrappers for Enzyme
_prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) )

diff_func = (u0,
p) -> solve(_prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)

splitmode = if mode isa Enzyme.ForwardMode
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode isa Enzyme.ReverseMode
Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal)
end
reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)

forward,
reverse = Enzyme.autodiff_thunk(
splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated,
Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)})
tape, result,
shadow_result = forward(
Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp))

function enzyme_sensitivity_backpass(Δ)
if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray)
for (x, y) in zip(shadow_result.u, Δ.u)
x .= y
end
else
error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support")
end
reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
result, enzyme_sensitivity_backpass
end
result, enzyme_sensitivity_backpass
end

# NOTE: This is needed to prevent a method ambiguity error
function DiffEqBase._concrete_solve_adjoint(
prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))

du0 = make_zero(u0)
dp = make_zero(p)
mode = sensealg.mode
# NOTE: This is needed to prevent a method ambiguity error
function DiffEqBase._concrete_solve_adjoint(
prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))

f = (u0,
p) -> solve(prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)
du0 = make_zero(u0)
dp = make_zero(p)
mode = sensealg.mode

splitmode = if mode isa Forward
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode === Reverse
ReverseSplitWithPrimal
end
f = (u0,
p) -> solve(prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)

forward,
reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated,
Duplicated{typeof(u0)}, Duplicated{typeof(p)})
tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp))
splitmode = if mode isa Forward
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode === Reverse
ReverseSplitWithPrimal
end

function enzyme_sensitivity_backpass(Δ)
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
forward,
reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated,
Duplicated{typeof(u0)}, Duplicated{typeof(p)})
tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp))

function enzyme_sensitivity_backpass(Δ)
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
sol, enzyme_sensitivity_backpass
end
sol, enzyme_sensitivity_backpass
end

const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """
Expand Down
19 changes: 15 additions & 4 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,21 @@ EnzymeAdjoint(mode = nothing)

Currently fails on almost every solver.
"""
struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <:
AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
mode::M
EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode)
@static if VERSION < v"1.12"
struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <:
AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
mode::M
EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode)
end
else
# Dummy type for Julia 1.12+ - Enzyme is not loaded on this version
struct EnzymeAdjoint{M <: Nothing} <:
AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
mode::M
function EnzymeAdjoint(mode = nothing)
error("EnzymeAdjoint is not supported on Julia 1.12+. Please use a different sensitivity algorithm.")
end
end
end

"""
Expand Down
96 changes: 52 additions & 44 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,18 @@ easy_res11 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP(true)))
_,
easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
_,
easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
@static if VERSION < v"1.12"
_,
easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
_,
easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
end
_,
easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
Expand Down Expand Up @@ -179,11 +181,13 @@ easy_res143 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true)))
_,
easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
@static if VERSION < v"1.12"
_,
easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
end
_,
easy_res145 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
Expand Down Expand Up @@ -212,11 +216,13 @@ easy_res143k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussKronrodAdjoint(autojacvec = ReverseDiffVJP(true)))
_,
easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
@static if VERSION < v"1.12"
_,
easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
end
_,
easy_res145k = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
Expand Down Expand Up @@ -1049,34 +1055,36 @@ function dynamics!(du, u, p, t)
du[2] = -u[2] + tanh(p[3] * u[1] + p[4] * u[2])
end

function backsolve_grad(sol, lqr_params, checkpointing)
bwd_sol = solve(
ODEAdjointProblem(sol,
BacksolveAdjoint(autojacvec = EnzymeVJP(),
checkpointing = checkpointing),
@static if VERSION < v"1.12"
function backsolve_grad(sol, lqr_params, checkpointing)
bwd_sol = solve(
ODEAdjointProblem(sol,
BacksolveAdjoint(autojacvec = EnzymeVJP(),
checkpointing = checkpointing),
Tsit5(),
nothing, nothing, nothing, nothing, nothing,
(x, lqr_params, t) -> cost(x, lqr_params)),
Tsit5(),
nothing, nothing, nothing, nothing, nothing,
(x, lqr_params, t) -> cost(x, lqr_params)),
Tsit5(),
dense = false,
save_everystep = false)

bwd_sol.u[end][1:(end - x_dim)]
#fwd_sol, bwd_sol
end

x0 = ones(x_dim)
fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params),
Tsit5(), abstol = 1e-9, reltol = 1e-9,
u0 = x0,
p = params,
dense = false,
save_everystep = false)
save_everystep = true)

bwd_sol.u[end][1:(end - x_dim)]
#fwd_sol, bwd_sol
end
backsolve_results = backsolve_grad(fwd_sol, params, false)
backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true)

x0 = ones(x_dim)
fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params),
Tsit5(), abstol = 1e-9, reltol = 1e-9,
u0 = x0,
p = params,
dense = false,
save_everystep = true)

backsolve_results = backsolve_grad(fwd_sol, params, false)
backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true)

@test backsolve_results != backsolve_checkpointing_results
@test backsolve_results != backsolve_checkpointing_results
end

int_u0,
int_p = adjoint_sensitivities(fwd_sol, Tsit5(),
Expand Down
6 changes: 3 additions & 3 deletions test/autodiff_events.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using SciMLSensitivity
using OrdinaryDiffEq, Calculus, Test
using OrdinaryDiffEq, OrdinaryDiffEqCore, Calculus, Test
using Zygote

function f(du, u, p, t)
Expand Down Expand Up @@ -56,11 +56,11 @@ g4 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint(), PIController(7 // 5
p)
g6 = Zygote.gradient(
θ -> test_f2(θ, ForwardDiffSensitivity(),
OrdinaryDiffEq.PredictiveController(), TRBDF2()),
OrdinaryDiffEqCore.PredictiveController(), TRBDF2()),
p)
@test_broken g7 = Zygote.gradient(
θ -> test_f2(θ, ReverseDiffAdjoint(),
OrdinaryDiffEq.PredictiveController(),
OrdinaryDiffEqCore.PredictiveController(),
TRBDF2()),
p)

Expand Down
Loading
Loading