Skip to content

[Proof of Concept] Make ForwardDiff work with implicit time integration #2371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions examples/tree_1d_dgsem/elixir_advection_diffusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,39 @@ boundary_conditions = boundary_condition_periodic
boundary_conditions_parabolic = boundary_condition_periodic

# A semidiscretization collects data structures and functions for the spatial discretization
semi = SemidiscretizationHyperbolicParabolic(mesh, (equations, equations_parabolic),
semi_ = SemidiscretizationHyperbolicParabolic(mesh, (equations, equations_parabolic),
initial_condition,
solver;
boundary_conditions = (boundary_conditions,
boundary_conditions_parabolic))

import DiffEqBase, ForwardDiff
DiffEqBase.anyeltypedual(p::SemidiscretizationHyperbolic) = Any
function DiffEqBase.anyeltypedual(p::SemidiscretizationHyperbolic,
::Type{Val{counter}}) where {counter}
Any
end

T = typeof(ForwardDiff.Tag(DiffEqBase.OrdinaryDiffEqTag(), Float64))
dual_type = ForwardDiff.Dual{T, Float64, 1}
semi_dual = Trixi.remake(semi_, uEltype = dual_type)

dual_type11 = ForwardDiff.Dual{T, Float64, 11}
semi_dual11 = Trixi.remake(semi_, uEltype = dual_type11)
# semi = Trixi.remake(semi_, cache = (; semi_.cache..., semi_dual))
new_cache = (; semi_.cache..., semi_dual, semi_dual11)
semi = SemidiscretizationHyperbolic{typeof(semi_.mesh), typeof(semi_.equations),
typeof(semi_.initial_condition),
typeof(semi_.boundary_conditions),
typeof(semi_.source_terms),
typeof(semi_.solver), typeof(new_cache)}(semi_.mesh,
semi_.equations,
semi_.initial_condition,
semi_.boundary_conditions,
semi_.source_terms,
semi_.solver,
new_cache)

###############################################################################
# ODE solvers, callbacks etc.

Expand Down Expand Up @@ -87,6 +114,6 @@ callbacks = CallbackSet(summary_callback, analysis_callback, alive_callback, sav
# OrdinaryDiffEq's `solve` method evolves the solution in time and executes the passed callbacks
time_int_tol = 1.0e-10
time_abs_tol = 1.0e-10
sol = solve(ode, KenCarp4(autodiff = AutoFiniteDiff());
sol = solve(ode, KenCarp4(autodiff = AutoForwardDiff());
abstol = time_abs_tol, reltol = time_int_tol,
ode_default_options()..., callback = callbacks)
7 changes: 7 additions & 0 deletions src/semidiscretization/semidiscretization_hyperbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,13 @@ end
function rhs!(du_ode, u_ode, semi::SemidiscretizationHyperbolic, t)
@unpack mesh, equations, boundary_conditions, source_terms, solver, cache = semi

if du_ode isa AbstractArray{<:ForwardDiff.Dual{<:Any, <:Any, 1}} && haskey(cache, :semi_dual)
return rhs!(du_ode, u_ode, semi.cache.semi_dual, t)
end
if du_ode isa AbstractArray{<:ForwardDiff.Dual{<:Any, <:Any, 11}} && haskey(cache, :semi_dual11)
return rhs!(du_ode, u_ode, semi.cache.semi_dual11, t)
end

u = wrap_array(u_ode, mesh, equations, solver, cache)
du = wrap_array(du_ode, mesh, equations, solver, cache)

Expand Down
Loading