Skip to content

Backend switching for Mooncake #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfaceMooncakeExt = ["ChainRulesCore", "Mooncake"]
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ module DifferentiationInterfaceMooncakeExt
using ADTypes: ADTypes, AutoMooncake
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
prepare_gradient_cache,
prepare_pullback_cache,
tangent_type,
value_and_gradient!!,
value_and_pullback!!,
zero_tangent
zero_tangent,
@is_primitive,
zero_fcodual,
MinimalCtx,
NoRData

using ChainRulesCore: ChainRulesCore, rrule

DI.check_available(::AutoMooncake) = true

Expand All @@ -26,5 +33,6 @@ mycopy(x) = deepcopy(x)

include("onearg.jl")
include("twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:AbstractArray}}
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:Number}}

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...)
primal_func = Mooncake.primal(dw)
primal_args = map(arg -> Mooncake.primal(arg), args)

Check warning on line 6 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L4-L6

Added lines #L4 - L6 were not covered by tests

(; f, backend) = primal_func
y = f(primal_args...)

Check warning on line 9 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L8-L9

Added lines #L8 - L9 were not covered by tests

prep_same = DI.prepare_pullback_same_point_nokwarg(

Check warning on line 11 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L11

Added line #L11 was not covered by tests
Val(true), f, backend, primal_args..., (y,)
)

function pullback!!(dy)
tx = DI.pullback(f, prep_same, backend, primal_args, (dy,))
args_rdata = map((x) -> (x, Mooncake.zero_rdata(x)), only(tx))
return NoRData(), args_rdata...

Check warning on line 18 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L15-L18

Added lines #L15 - L18 were not covered by tests
end

return zero_fcodual(y), pullback!!

Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L21

Added line #L21 was not covered by tests
end