Description
Problem
Compilation can't be used with run-time control flow. This stops some code from taking advantage of tape compilation.
Possible solution
Enable ReverseDiff's tape caching functionality to be used in cases with run-time control flow by introducing guarded/sub-tapes which are recompiled automatically if the instructions they contain are invalidated by a user-specified guard statement.
My implementation idea is that these guarded/sub-tapes live directly on normal compiled tapes as another type of AbstractInstruction
(if I'm correct in assuming that it doesn't fit inside SpecialInstruction
).
Here's a quick-and-dirty non-implementation showcasing the idea in action:
using ReverseDiff
import ReverseDiff: CompiledTape, GradientTape, compile, gradient!
mutable struct GuardedTape{F,T,G,V,C} # mutable because of `guard_value`
f::F
tape::CompiledTape{T}
guard_f::G
guard_value::V
cache::C
end
function guarded_tape(func, guard_func, input)
tape = GradientTape(func, input)
ctape = compile(tape)
guard_value = guard_func(input)
cache = Dict(guard_value => ctape)
return GuardedTape(func, ctape, guard_func, guard_value, cache)
end
function gradient!(gt::GuardedTape, input)
new_guard_value = gt.guard_f(input)
if new_guard_value != gt.guard_value
new_ctape = get!(gt.cache, new_guard_value) do
println("Recompiling")
tape = GradientTape(gt.f, input)
compile(tape)
end
gt.guard_value = new_guard_value
gt.tape = new_ctape
end
gradient!(gt.tape, input)
end
f(x) = x[1] > 1 ? x[1]^ 3 : x[1]^2
input = [0.0]
gt = guarded_tape(f, x -> x[1]>1, input)
gradient!(gt, [0.1]) # No recompilation
gradient!(gt, [1.1]) # Recompilation triggered
gradient!(gt, [0.5]) # No recompilation
The soul of this is borrowed from JAX's static_argnums
/static_argnames
in jit
, where users can specify an argument(s) that, if changed, triggers the lookup/recompilation step. This is essentially value dispatch. I'm not sure on its performance implications.
Impact
The original context this project is the Turing package. Gradient-based methods like HMC and NUTS are the state-of-the-art for MCMC sampling and, as stated on Turing's GSoC projects page, their performance is greatly improved by the caching features of ReverseDiff. However, this is not universally applicable and more complicated models using other packages will normally contain unavoidable control flow.
More generally, the ability to efficiently differentiate through control flow will allow ReverseDiff to be more universally recommended in packages that rely on ForwardDiff. AD backend selection is a great feature in the SciML ecosystem, and many of its packages, such as Optimization, could benefit from this contribution by making AD backend selection a potential performance footgun as opposed to a (admittedly blatant but not trivial) correctness one.
While next generation AD backends such as Diffractor and Enzyme are a hot topic in the ecosystem at the moment, ReverseDiff is a package which has stood the test of time for its reliability and performance. For workloads such as those found in Turing, "out of the box" it is almost always faster than Zygote, especially in compiled mode. Zygote may sometimes be faster, but requires far more hand-tuning to reach the necessary speeds, most of which is inaccessible to end-users.
ReverseDiff has a clear niche in the AD backend ecosystem: its target users are moderately performance sensitive with medium-to-high dimensional problems and it covers these very well with little to no hand-tuning. While Enzyme has incredible performance, which is a feature for the most performance-critical applications, it is neither trivial to use and tune, nor can it be applied in every situation due to some compatibility issues. In a similar vein, Zygote is a high performance solution that works great for applications heavy in linear algebra, but often requires significant hand-tuning.