-
Notifications
You must be signed in to change notification settings - Fork 22
Backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #768 +/- ##
==========================================
- Coverage 97.93% 97.05% -0.89%
==========================================
Files 129 127 -2
Lines 7458 7466 +8
==========================================
- Hits 7304 7246 -58
- Misses 154 220 +66
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
removed the code that piggybacks off the Chainrules wrapper. This is specifically now a Mooncake generic rule which handles backend switching. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this first draft!
I think there are some changes necessary, and most importantly you need to test it, first locally and then during CI (try not to run CI before having tested your changes locally, the process is very expensive since it tests a dozen different backends for like half an hour each).
For the testing, start with manual tests, and then once your code works you can add AutoMooncake()
to this line
...tionInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
...tionInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
sorry i got preoccupied with some other work, hence the incomplete PR. This would be on route now. |
Please keep in mind that every commit costs around 6 hours of CI budget. I suggest you make as many modifications as possible locally and add tests first before pushing |
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we're getting closer!
Unfortunately I think my existing tests are not enough to capture everything that can go wrong in a Mooncake rule. Perhaps the Mooncake test utilities should be brought in, or more sophisticated tests should be written.
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this, to me it seems like we're forgetting to carry around the FData
of y
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im not sure i understand what you mean by this. what ive tried here is instead of creating the output Codual later, we create it early so that in cases of memory address function return types, the pullback functions can have access to the FData to create the correct adjoints.
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}} | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) | ||
primal_func = primal(dw) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if there are derivatives inside dw
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im not sure i understand. dw would be a function and if derivatives are used inside i think that would be handled by the substitute backend?
# output is a vector, so we need to use the vector pullback | ||
function pullback!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
return NoRData(), only(tx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not sure that f
has no RData
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all callable objects are functions though.
julia> struct Multiplier
a::Float64
end
julia> (m::Multiplier)(x) = m.a * x
julia> m = Multiplier(2)
Multiplier(2.0)
julia> m(3)
6.0
julia> m isa Function
false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I'm not even convinced the current Mooncake behavior is correct, see compintell/Mooncake.jl#557
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, Mooncake treats all callable objects as a function, giving NoRData()
. The issue you have opened in mooncake, i think would need to be figured out independent of DI (atleast this PR) as of now (comes in when Mooncake is a substitute backend as well).
Define
Mooncake.rrule!!
forDI.DifferentiateWith
.