Skip to content

Backend switching for Mooncake #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) or a [Mooncake.jl](https://github.com/compintell/Mooncake.jl)-compatible backend.

## Implementations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ There are, however, translation utilities:
### Backend switch

Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object.
Right now it only targets ForwardDiff.jl, ChainRulesCore.jl and Mooncake.jl but PRs are welcome to define Enzyme.jl rules for this object.
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ module DifferentiationInterfaceMooncakeExt
using ADTypes: ADTypes, AutoMooncake
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
prepare_gradient_cache,
prepare_pullback_cache,
tangent_type,
value_and_gradient!!,
value_and_pullback!!,
zero_tangent
zero_tangent,
@is_primitive,
zero_fcodual,
MinimalCtx,
NoRData,
fdata,
primal

DI.check_available(::AutoMooncake) = true

Expand All @@ -26,5 +33,6 @@ mycopy(x) = deepcopy(x)

include("onearg.jl")
include("twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}}

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
primal_func = primal(dw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are derivatives inside dw?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure i understand. dw would be a function and if derivatives are used inside i think that would be handled by the substitute backend?

primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 7 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L3-L7

Added lines #L3 - L7 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this, to me it seems like we're forgetting to carry around the FData of y.

Copy link
Author

@AstitvaAggarwal AstitvaAggarwal Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure i understand what you mean by this. what ive tried here is instead of creating the output Codual later, we create it early so that in cases of memory address function return types, the pullback functions can have access to the FData to create the correct adjoints.


# output is a vector, so we need to use the vector pullback
function pullback!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
return NoRData(), only(tx)

Check warning on line 12 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L10-L12

Added lines #L10 - L12 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not sure that f has no RData

Copy link
Author

@AstitvaAggarwal AstitvaAggarwal Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
I think all functions would have no RData

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all callable objects are functions though.

julia> struct Multiplier
           a::Float64
       end

julia> (m::Multiplier)(x) = m.a * x

julia> m = Multiplier(2)
Multiplier(2.0)

julia> m(3)
6.0

julia> m isa Function
false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I'm not even convinced the current Mooncake behavior is correct, see chalk-lab/Mooncake.jl#557

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, Mooncake treats all callable objects as a function, giving NoRData(). The issue you have opened in mooncake, i think would need to be figured out independent of DI (atleast this PR) as of now (comes in when Mooncake is a substitute backend as well).

end

# output is a scalar, so we can use the scalar pullback
function pullback!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
return NoRData(), only(tx)

Check warning on line 18 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
end

return y, pullback!!

Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L21

Added line #L21 was not covered by tests
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = fdata(x.dx)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 29 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L24-L29

Added lines #L24 - L29 were not covered by tests
# in case x is mutated in f calls
cp_primal_x = copy(primal_x)

Check warning on line 31 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L31

Added line #L31 was not covered by tests

# output is a vector, so we need to use the vector pullback
function pullback!!(dy::NoRData)
tx = DI.pullback(f, backend, cp_primal_x, (fdata(y.dx),))
fdata_arg .+= only(tx)
return NoRData(), dy

Check warning on line 37 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L34-L37

Added lines #L34 - L37 were not covered by tests
end

# output is a scalar, so we can use the scalar pullback
function pullback!!(dy::Number)
tx = DI.pullback(f, backend, cp_primal_x, (dy,))
fdata_arg .+= only(tx)
return NoRData(), NoRData()

Check warning on line 44 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L41-L44

Added lines #L41 - L44 were not covered by tests
end

return y, pullback!!

Check warning on line 47 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L47

Added line #L47 was not covered by tests
end
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using Pkg
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])

using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
using Test

LOGGING = get(ENV, "CI", "false") == "false"
Expand All @@ -24,7 +25,7 @@ function differentiatewith_scenarios()
end

test_differentiation(
[AutoForwardDiff(), AutoZygote()],
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,
Expand Down