Skip to content

Conversation

@buildwithsuhana
Copy link
Contributor

@buildwithsuhana buildwithsuhana commented Oct 28, 2025

This change introduces core building blocks for tensor parallelism by adding two key components.

First, it adds crucial collective operations, all_reduce and all_gather, to the JAX backend. These allow multiple devices to synchronize data by summing tensors (like gradients) or gathering individual slices back into a full tensor. Second, it adds the high-level tensor sharding logic (split_tensor_for_parallelism), which uses ops.array_split to intelligently slice large tensors, even unevenly, for distribution across devices. New tests confirm this new parallel logic, including the uneven splitting, works as expected.

The tests on this PR will pass after the PR #21697 gets merged

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes foundational components for tensor parallelism within the JAX backend, crucial for Autosharding. It provides core collective communication primitives like all_reduce and all_gather and introduces a flexible tensor splitting utility, split_tensor_for_parallelism, designed to efficiently distribute tensors across multiple devices, even when uneven splitting is required.

Highlights

  • JAX Collective Operations: Introduced all_reduce (sum, mean) and all_gather functions to the JAX backend for inter-device communication, essential for distributed computing.
  • Tensor Sharding Logic: Added split_tensor_for_parallelism to intelligently slice tensors, including support for uneven distributions, for efficient device parallelism.
  • Comprehensive Testing: Included new test cases to validate the correct functionality of both the collective operations and the tensor splitting logic, ensuring robustness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces foundational components for tensor parallelism in Keras, specifically for the JAX backend. It adds all_reduce and all_gather collective operations, which are essential for distributed computations. Additionally, it provides a split_tensor_for_parallelism utility for sharding tensors across devices. The changes are well-tested, covering both even and uneven tensor splitting. My review includes a few suggestions to improve documentation accuracy and code simplicity, and to align with the repository's style guide regarding docstring examples.

buildwithsuhana and others added 3 commits October 28, 2025 10:41
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@hertschuh
Copy link
Collaborator

Can you rebase to make the tests pass?

@codecov-commenter
Copy link

codecov-commenter commented Nov 6, 2025

Codecov Report

❌ Patch coverage is 38.37209% with 212 lines in your changes missing coverage. Please review.
✅ Project coverage is 61.29%. Comparing base (d8e0b4a) to head (d9eabc8).
⚠️ Report is 38 commits behind head on master.

Files with missing lines Patch % Lines
...tribution/tensor_parallel/coordinated_optimizer.py 16.34% 174 Missing ⚠️
...ras/src/distribution/tensor_parallel/autoconfig.py 75.86% 16 Missing and 12 partials ⚠️
keras/src/backend/jax/distribution_lib.py 27.27% 8 Missing ⚠️
.../src/distribution/tensor_parallel/tensor_layout.py 77.77% 1 Missing and 1 partial ⚠️

❗ There is a different number of reports uploaded between BASE (d8e0b4a) and HEAD (d9eabc8). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (d8e0b4a) HEAD (d9eabc8)
keras 5 2
keras-torch 1 0
keras-tensorflow 1 0
keras-jax 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #21792       +/-   ##
===========================================
- Coverage   82.66%   61.29%   -21.37%     
===========================================
  Files         577      583        +6     
  Lines       59453    60373      +920     
  Branches     9320     9522      +202     
===========================================
- Hits        49148    37008    -12140     
- Misses       7902    21012    +13110     
+ Partials     2403     2353       -50     
Flag Coverage Δ
keras 61.29% <38.37%> (-21.20%) ⬇️
keras-jax ?
keras-numpy 57.21% <38.37%> (-0.33%) ⬇️
keras-openvino 34.58% <38.37%> (+0.23%) ⬆️
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

buildwithsuhana added a commit to buildwithsuhana/keras that referenced this pull request Nov 18, 2025
children_to_add = []

if hasattr(current_layer, "layers") and current_layer.layers:
for sub_layer in current_layer.layers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again model.layers will return the full flat list of all the layers in the model, even if they're deeply nested.

So:

  • processed_layers is not needed
  • stack is not needed
  • prefix is not needed, it will always be "", you will never get a deeply nested path. If you wanted that, you'll need to code this very differently
  • children_to_add is not needed
  • line 234-260 are not needed
  • lines 224-232 are probably not needed

The code could be a lot shorter. This whole function get_default_config can be:

def get_default_config(module, device_ids):
    device_count = len(device_ids)
    state_rules = {}
    output_rules = {}
    for layer in module.layers:
        _apply_layer_sharding_rules(
            layer, layer.name, device_count, state_rules, output_rules
        )
    return LayoutMap(state_rules=state_rules, output_rules=output_rules)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants