Skip to content

Enhancement proposal: Modular tape caching #234

Open
@jacobusmmsmit

Description

@jacobusmmsmit

Problem

Compilation can't be used with run-time control flow. This stops some code from taking advantage of tape compilation.

Possible solution

Enable ReverseDiff's tape caching functionality to be used in cases with run-time control flow by introducing guarded/sub-tapes which are recompiled automatically if the instructions they contain are invalidated by a user-specified guard statement.

My implementation idea is that these guarded/sub-tapes live directly on normal compiled tapes as another type of AbstractInstruction (if I'm correct in assuming that it doesn't fit inside SpecialInstruction).

Here's a quick-and-dirty non-implementation showcasing the idea in action:

using ReverseDiff
import ReverseDiff: CompiledTape, GradientTape, compile, gradient!

mutable struct GuardedTape{F,T,G,V,C} # mutable because of `guard_value`
    f::F
    tape::CompiledTape{T}
    guard_f::G
    guard_value::V
    cache::C
end

function guarded_tape(func, guard_func, input) 
    tape = GradientTape(func, input)
    ctape = compile(tape)
    guard_value = guard_func(input)
    cache = Dict(guard_value => ctape)
    return GuardedTape(func, ctape, guard_func, guard_value, cache)
end

function gradient!(gt::GuardedTape, input)
    new_guard_value = gt.guard_f(input)
    if new_guard_value != gt.guard_value
        new_ctape = get!(gt.cache, new_guard_value) do
            println("Recompiling")
            tape = GradientTape(gt.f, input)
            compile(tape)
        end
        gt.guard_value = new_guard_value
        gt.tape = new_ctape
    end
    gradient!(gt.tape, input)    
end

f(x) = x[1] > 1 ? x[1]^ 3 : x[1]^2
input = [0.0]
gt = guarded_tape(f, x -> x[1]>1, input)

gradient!(gt, [0.1]) # No recompilation
gradient!(gt, [1.1]) # Recompilation triggered
gradient!(gt, [0.5]) # No recompilation

The soul of this is borrowed from JAX's static_argnums/static_argnames in jit, where users can specify an argument(s) that, if changed, triggers the lookup/recompilation step. This is essentially value dispatch. I'm not sure on its performance implications.

Impact

The original context this project is the Turing package. Gradient-based methods like HMC and NUTS are the state-of-the-art for MCMC sampling and, as stated on Turing's GSoC projects page, their performance is greatly improved by the caching features of ReverseDiff. However, this is not universally applicable and more complicated models using other packages will normally contain unavoidable control flow.

More generally, the ability to efficiently differentiate through control flow will allow ReverseDiff to be more universally recommended in packages that rely on ForwardDiff. AD backend selection is a great feature in the SciML ecosystem, and many of its packages, such as Optimization, could benefit from this contribution by making AD backend selection a potential performance footgun as opposed to a (admittedly blatant but not trivial) correctness one.

While next generation AD backends such as Diffractor and Enzyme are a hot topic in the ecosystem at the moment, ReverseDiff is a package which has stood the test of time for its reliability and performance. For workloads such as those found in Turing, "out of the box" it is almost always faster than Zygote, especially in compiled mode. Zygote may sometimes be faster, but requires far more hand-tuning to reach the necessary speeds, most of which is inaccessible to end-users.

ReverseDiff has a clear niche in the AD backend ecosystem: its target users are moderately performance sensitive with medium-to-high dimensional problems and it covers these very well with little to no hand-tuning. While Enzyme has incredible performance, which is a feature for the most performance-critical applications, it is neither trivial to use and tune, nor can it be applied in every situation due to some compatibility issues. In a similar vein, Zygote is a high performance solution that works great for applications heavy in linear algebra, but often requires significant hand-tuning.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions