Skip to content

Adding torch support for Data Parallel#22397

Draft
buildwithsuhana wants to merge 18 commits intokeras-team:masterfrom
buildwithsuhana:dp_ddp
Draft

Adding torch support for Data Parallel#22397
buildwithsuhana wants to merge 18 commits intokeras-team:masterfrom
buildwithsuhana:dp_ddp

Conversation

@buildwithsuhana
Copy link
Collaborator

This PR introduces torch backend support for Data Parallelism (DP) in Keras. It aligns the internal distribution_lib implementations to ensure that high-level Keras Distribution APIs (like DeviceMesh, LayoutMap, and ModelParallel) behave consistently regardless of the underlying framework. Leveraged PyTorch Distributed Data Parallel for the implementation

Design document: go/distributionLib

Kaggle link testing data parallel for torch and jax backend (using keras_hub opt model):
https://www.kaggle.com/code/buildwithsuhana/dataparallel-torch-ddp?scriptVersionId=302683263

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Keras Torch backend by integrating robust support for data parallelism. It ensures that Keras's high-level distribution APIs function seamlessly with PyTorch's distributed capabilities, allowing models to be efficiently scaled across multiple devices or processes. The changes involve making core tensor and variable operations distribution-aware, adding XLA device compatibility, and providing a structured way to parallelize Keras layers.

Highlights

  • Torch Backend Data Parallelism: Introduced comprehensive support for Data Parallelism (DP) within the Keras Torch backend, leveraging PyTorch Distributed Data Parallel (DDP) for distributed training.
  • Distribution-Aware Tensor Operations: Implemented sharding-aware mechanisms for various Torch tensor operations and Keras Variable handling, ensuring proper distribution and redistribution of tensors across devices.
  • XLA Device Support: Expanded device detection in the Torch backend to include XLA devices, allowing Keras to utilize TPUs via PyTorch/XLA.
  • Keras Layer Parallelization: Integrated logic to automatically parallelize Keras layers and models based on the active distribution strategy (DataParallel or ModelParallel) within the Torch backend.
  • New Distribution Library: Added a new distribution_lib.py module specifically for the Torch backend, centralizing utilities for device management, process initialization, and tensor/variable distribution.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/backend/init.py
    • Removed the explicit distribution_lib = None assignment for the Torch backend, allowing the actual distribution library to be imported.
  • keras/src/backend/torch/init.py
    • Imported the new distribution_lib module to expose its functionalities within the Torch backend.
  • keras/src/backend/torch/core.py
    • Imported threading and torch.func for new functionalities.
    • Added XLA as a potential DEFAULT_DEVICE if torch_xla is available.
    • Modified the Variable class to accept a layout argument during initialization and to distribute its underlying value using distribution_lib.
    • Updated __torch_function__ to use tree.map_structure for unwrapping Variable instances, improving flexibility.
    • Introduced _is_sharded and _ensure_replicated_local helper functions for managing sharded tensors.
    • Implemented _sharding_aware_op decorator and specific sharding-aware wrappers for torch.Tensor methods like reshape, view, __getitem__, detach, and torch functions like unbind, broadcast_to, einsum.
    • Added _distribution_aware_creation_op decorator to make tensor creation functions (e.g., arange, ones) distribution-aware.
    • Defined maybe_distribute_tensor to conditionally distribute tensors based on the active distribution strategy.
  • keras/src/backend/torch/distribution_lib.py
    • Added a new file to house Torch-specific distribution utilities.
    • Implemented list_devices and get_device_count for various device types (cpu, cuda, xla).
    • Provided initialize function to set up the Torch distributed process group.
    • Added num_processes and process_id to query distributed group information.
    • Created _to_backend_mesh and _to_backend_layout to convert Keras distribution objects to their Torch equivalents.
    • Defined DDPModelWrapper to correctly wrap Keras models for PyTorch's DistributedDataParallel (DDP).
    • Implemented distribute_variable and distribute_tensor to apply sharding based on a given layout.
    • Added distribute_data_input for distributing input batches.
    • Provided parallelize_layer to apply DDP or ModelParallel to Keras layers/models.
    • Included _infer_parallel_style to deduce PyTorch parallel styles (ColwiseParallel, RowwiseParallel) from Keras layout maps.
  • keras/src/backend/torch/distribution_lib_test.py
    • Added a new file containing unit tests for the distribution_lib module, covering device listing, process info, mesh/layout conversion, variable distribution, and parallel style inference.
  • keras/src/backend/torch/layer.py
    • Integrated a call to distribution_lib.parallelize_layer within the _post_build method of Keras layers, enabling automatic parallelization.
    • Modified the forward method to conditionally use a _ddp_wrapper for DistributedDataParallel, preventing infinite recursion and ensuring proper DDP execution.
  • keras/src/ops/nn.py
    • Added a specialized implementation for dot_product_attention that handles distributed tensors, ensuring correct behavior when inputs are sharded across devices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for data parallelism in the Torch backend, which is a significant feature. The implementation leverages PyTorch's Distributed Data Parallel (DDP) and DTensor APIs. The changes include a new distribution_lib for Torch, modifications to core.py to make tensor operations and variables distribution-aware, and updates to layers to hook into the distribution logic.

My review has identified a critical bug related to integer-to-tensor conversion that could lead to incorrect behavior. I've also pointed out a piece of code that needs clarification and a potential design issue in the dot_product_attention op where distribution-specific logic is mixed with the generic implementation. Overall, the approach is solid, but these points should be addressed to ensure correctness and maintainability.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@buildwithsuhana buildwithsuhana changed the title Adding torch support for Data Prallel Adding torch support for Data Parallel Mar 11, 2026
@codecov-commenter
Copy link

codecov-commenter commented Mar 11, 2026

Codecov Report

❌ Patch coverage is 35.71429% with 261 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.71%. Comparing base (e4834f6) to head (839ae0e).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/distribution_lib.py 28.09% 134 Missing and 17 partials ⚠️
keras/src/backend/torch/core.py 54.60% 50 Missing and 19 partials ⚠️
keras/src/ops/nn.py 0.00% 32 Missing and 1 partial ⚠️
keras/src/backend/torch/layer.py 20.00% 6 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22397      +/-   ##
==========================================
- Coverage   82.99%   82.71%   -0.29%     
==========================================
  Files         596      597       +1     
  Lines       66423    66982     +559     
  Branches    10353    10470     +117     
==========================================
+ Hits        55130    55406     +276     
- Misses       8665     8899     +234     
- Partials     2628     2677      +49     
Flag Coverage Δ
keras 82.54% <35.71%> (-0.28%) ⬇️
keras-jax 60.38% <19.45%> (-0.36%) ⬇️
keras-numpy 54.64% <19.95%> (-0.31%) ⬇️
keras-openvino 49.24% <19.95%> (-0.04%) ⬇️
keras-tensorflow 61.60% <19.95%> (-0.36%) ⬇️
keras-torch 60.54% <35.71%> (-0.28%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@buildwithsuhana buildwithsuhana marked this pull request as draft March 12, 2026 02:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants