diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index a3c55dd5c..7d67e3981 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. -At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend. +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/compintell/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)). ## Implementations diff --git a/DifferentiationInterface/docs/src/faq/differentiability.md b/DifferentiationInterface/docs/src/faq/differentiability.md index 2a89ded21..d3e51dd35 100644 --- a/DifferentiationInterface/docs/src/faq/differentiability.md +++ b/DifferentiationInterface/docs/src/faq/differentiability.md @@ -111,4 +111,4 @@ There are, however, translation utilities: ### Backend switch Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. -Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object. \ No newline at end of file +Right now, it only targets ForwardDiff.jl, Mooncake.jl, ChainRules.jl-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object. \ No newline at end of file diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 52e742b05..c61969019 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -3,6 +3,7 @@ module DifferentiationInterfaceMooncakeExt using ADTypes: ADTypes, AutoMooncake import DifferentiationInterface as DI using Mooncake: + Mooncake, CoDual, Config, prepare_gradient_cache, @@ -10,7 +11,13 @@ using Mooncake: tangent_type, value_and_gradient!!, value_and_pullback!!, - zero_tangent + zero_tangent, + @is_primitive, + zero_fcodual, + MinimalCtx, + NoRData, + fdata, + primal DI.check_available(::AutoMooncake) = true @@ -26,5 +33,6 @@ mycopy(x) = deepcopy(x) include("onearg.jl") include("twoarg.jl") +include("differentiate_with.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl new file mode 100644 index 000000000..18114437e --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -0,0 +1,46 @@ +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}} + +function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) + primal_func = primal(dw) + primal_x = primal(x) + (; f, backend) = primal_func + y = zero_fcodual(f(primal_x)) + + # output is a vector, so we need to use the vector pullback + function pullback_array!!(dy::NoRData) + tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) + return NoRData(), only(tx) + end + + # output is a scalar, so we can use the scalar pullback + function pullback_scalar!!(dy::Number) + tx = DI.pullback(f, backend, primal_x, (dy,)) + return NoRData(), only(tx) + end + + return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! +end + +function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) + primal_func = primal(dw) + primal_x = primal(x) + fdata_arg = fdata(x.dx) + (; f, backend) = primal_func + y = zero_fcodual(f(primal_x)) + + # output is a vector, so we need to use the vector pullback + function pullback_array!!(dy::NoRData) + tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) + fdata_arg .+= only(tx) + return NoRData(), dy + end + + # output is a scalar, so we can use the scalar pullback + function pullback_scalar!!(dy::Number) + tx = DI.pullback(f, backend, primal_x, (dy,)) + fdata_arg .+= only(tx) + return NoRData(), NoRData() + end + + return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! +end diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index dbc41f548..d474dd517 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -1,11 +1,12 @@ using Pkg -Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"]) +Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Zygote: Zygote +using Mooncake: Mooncake using Test LOGGING = get(ENV, "CI", "false") == "false" @@ -24,7 +25,7 @@ function differentiatewith_scenarios() end test_differentiation( - [AutoForwardDiff(), AutoZygote()], + [AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], differentiatewith_scenarios(); excluded=SECOND_ORDER, logging=LOGGING,