-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes foundational components for tensor parallelism within the JAX backend, crucial for Autosharding. It provides core collective communication primitives like Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces foundational components for tensor parallelism in Keras, specifically for the JAX backend. It adds all_reduce and all_gather collective operations, which are essential for distributed computations. Additionally, it provides a split_tensor_for_parallelism utility for sharding tensors across devices. The changes are well-tested, covering both even and uneven tensor splitting. My review includes a few suggestions to improve documentation accuracy and code simplicity, and to align with the repository's style guide regarding docstring examples.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…keras into tensor_parallel
|
Can you rebase to make the tests pass? |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21792 +/- ##
===========================================
- Coverage 82.66% 61.29% -21.37%
===========================================
Files 577 583 +6
Lines 59453 60373 +920
Branches 9320 9522 +202
===========================================
- Hits 49148 37008 -12140
- Misses 7902 21012 +13110
+ Partials 2403 2353 -50
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| children_to_add = [] | ||
|
|
||
| if hasattr(current_layer, "layers") and current_layer.layers: | ||
| for sub_layer in current_layer.layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again model.layers will return the full flat list of all the layers in the model, even if they're deeply nested.
So:
processed_layersis not neededstackis not neededprefixis not needed, it will always be"", you will never get a deeply nested path. If you wanted that, you'll need to code this very differentlychildren_to_addis not needed- line 234-260 are not needed
- lines 224-232 are probably not needed
The code could be a lot shorter. This whole function get_default_config can be:
def get_default_config(module, device_ids):
device_count = len(device_ids)
state_rules = {}
output_rules = {}
for layer in module.layers:
_apply_layer_sharding_rules(
layer, layer.name, device_count, state_rules, output_rules
)
return LayoutMap(state_rules=state_rules, output_rules=output_rules)
This change introduces core building blocks for tensor parallelism by adding two key components.
First, it adds crucial collective operations, all_reduce and all_gather, to the JAX backend. These allow multiple devices to synchronize data by summing tensors (like gradients) or gathering individual slices back into a full tensor. Second, it adds the high-level tensor sharding logic (split_tensor_for_parallelism), which uses ops.array_split to intelligently slice large tensors, even unevenly, for distribution across devices. New tests confirm this new parallel logic, including the uneven splitting, works as expected.
The tests on this PR will pass after the PR #21697 gets merged