Skip to content

Backend switching for Mooncake #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

AstitvaAggarwal
Copy link

@AstitvaAggarwal AstitvaAggarwal commented Apr 1, 2025

Define Mooncake.rrule!! for DI.DifferentiateWith.

@AstitvaAggarwal AstitvaAggarwal requested a review from gdalle as a code owner April 1, 2025 12:30
Copy link

codecov bot commented Apr 1, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.05%. Comparing base (4b85c4f) to head (233c312).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #768      +/-   ##
==========================================
- Coverage   97.93%   97.05%   -0.89%     
==========================================
  Files         129      127       -2     
  Lines        7458     7466       +8     
==========================================
- Hits         7304     7246      -58     
- Misses        154      220      +66     
Flag Coverage Δ
DI 97.86% <100.00%> (-1.16%) ⬇️
DIT 95.01% <ø> (-0.20%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gdalle gdalle marked this pull request as draft April 1, 2025 12:43
@AstitvaAggarwal
Copy link
Author

removed the code that piggybacks off the Chainrules wrapper. This is specifically now a Mooncake generic rule which handles backend switching.

Copy link
Member

@gdalle gdalle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this first draft!
I think there are some changes necessary, and most importantly you need to test it, first locally and then during CI (try not to run CI before having tested your changes locally, the process is very expensive since it tests a dozen different backends for like half an hour each).
For the testing, start with manual tests, and then once your code works you can add AutoMooncake() to this line

@AstitvaAggarwal
Copy link
Author

AstitvaAggarwal commented Apr 9, 2025

sorry i got preoccupied with some other work, hence the incomplete PR. This would be on route now.

@gdalle
Copy link
Member

gdalle commented Apr 10, 2025

Please keep in mind that every commit costs around 6 hours of CI budget. I suggest you make as many modifications as possible locally and add tests first before pushing

Copy link
Member

@gdalle gdalle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we're getting closer!
Unfortunately I think my existing tests are not enough to capture everything that can go wrong in a Mooncake rule. Perhaps the Mooncake test utilities should be brought in, or more sophisticated tests should be written.

primal_func = primal(dw)
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this, to me it seems like we're forgetting to carry around the FData of y.

Copy link
Author

@AstitvaAggarwal AstitvaAggarwal Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure i understand what you mean by this. what ive tried here is instead of creating the output Codual later, we create it early so that in cases of memory address function return types, the pullback functions can have access to the FData to create the correct adjoints.

@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}}

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
primal_func = primal(dw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are derivatives inside dw?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure i understand. dw would be a function and if derivatives are used inside i think that would be handled by the substitute backend?

# output is a vector, so we need to use the vector pullback
function pullback!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
return NoRData(), only(tx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not sure that f has no RData

Copy link
Author

@AstitvaAggarwal AstitvaAggarwal Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
I think all functions would have no RData

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all callable objects are functions though.

julia> struct Multiplier
           a::Float64
       end

julia> (m::Multiplier)(x) = m.a * x

julia> m = Multiplier(2)
Multiplier(2.0)

julia> m(3)
6.0

julia> m isa Function
false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I'm not even convinced the current Mooncake behavior is correct, see compintell/Mooncake.jl#557

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, Mooncake treats all callable objects as a function, giving NoRData(). The issue you have opened in mooncake, i think would need to be figured out independent of DI (atleast this PR) as of now (comes in when Mooncake is a substitute backend as well).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants