-
Notifications
You must be signed in to change notification settings - Fork 25
feat: backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
1a389a6
08b176a
ba0c9e6
1340d92
2ce1ee2
08de6df
84f27c9
2e95299
13233e5
1e8df98
afdddd4
233c312
7a07127
f3e436d
6a0d937
e543958
2472ecc
c63c956
36da036
d2b5a8c
c389a80
b4fe0f8
ec4b75d
0f0b9fc
3c5f99e
d94f146
c982f46
749fea5
9e5ecfd
1e85f17
ff5c4e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}} | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) | ||
primal_func = primal(dw) | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
primal_x = primal(x) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
Check warning on line 7 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function pullback!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
return NoRData(), only(tx) | ||
Check warning on line 12 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
return NoRData(), only(tx) | ||
Check warning on line 18 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
return y, pullback!! | ||
Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) | ||
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
fdata_arg = fdata(x.dx) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
Check warning on line 29 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
# in case x is mutated in f calls | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cp_primal_x = copy(primal_x) | ||
Check warning on line 31 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
function pullback!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, cp_primal_x, (fdata(y.dx),)) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), dy | ||
Check warning on line 37 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback!!(dy::Number) | ||
tx = DI.pullback(f, backend, cp_primal_x, (dy,)) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), NoRData() | ||
Check warning on line 44 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
return y, pullback!! | ||
Check warning on line 47 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.