-
Notifications
You must be signed in to change notification settings - Fork 15
Call comms / compute overlap passes when compile=False #304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
autoparallel/api.py
Outdated
| with V.set_fake_mode(fake_mode): | ||
| cuda_context = get_cuda_device_context(fx_g) | ||
| with cuda_context: | ||
| _recursive_post_grad_passes(fx_g, is_inference=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some of the post grad passes are bad for perf unless lowered e.g. view_to_reshape which materializes all views
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed it to only call into the comms / compute reordering pass, to keep graph changes to a minimum
… full post_grad passes
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems OK to me. i will say that it's not super clear to me what the best formulation is. It's a little arbitrary which compiler passes to put 'inside' vs 'outside'.
from a use-case perspective, it seems nice to always have the distributed passes run, even if codegen isn't important. otoh, other things like cudagraph might also be preferred, even without codegen. For debugging, the unmodified original graphmodule might be nice to get out? (though, you can see it in its various states of transformation using tlparse).
Previously, when we would call
AutoParallelwithcompile=False, we wouldn't have any of the comms / compute overlap passes being applied to the model.This effectively meant that we would need
compile=Trueto have a performant autoparallelized model.I've for now decided to call into all thepost_gradpasses, but it is also possible that we only call into the comms / compute overlap passes, to keep graph modifications to a minimum.I'm now calling into the comms / compute reordering pass even when
compile=False