Skip to content

JAX NCCL Analyser for multi-node collective communication bandwidth analysis#438

Open
amital-amd wants to merge 4 commits intomainfrom
feat/jax-rccl-analyser
Open

JAX NCCL Analyser for multi-node collective communication bandwidth analysis#438
amital-amd wants to merge 4 commits intomainfrom
feat/jax-rccl-analyser

Conversation

@amital-amd
Copy link

@amital-amd amital-amd commented Nov 21, 2025

Pull Request Template

Note to AMDers:
This is a public repository. Please do not upload any confidential or customer data. Make sure all such data has been anonymized or removed before making this PR. If you need to attach any private files or links, please insert a Internal OneDrive Link or a Jira Ticket Link instead.

This PR introduces the JAX collective communication analyzer for TraceLens, addressing issue #262 with support for multi-node distributed training analysis.

JAX-Specific Analysis:

  • Integrates JAX protobuf traces with XLA dump files for collective operation analysis
  • Supports complex replica group configurations and multi-node multi-rank distributed setups
  • Calculates bandwidths for all collective types

Scope:

The initial implementation prioritizes core bandwidth analysis for JAX workloads. Given the complexity of multi-source parsing and data aggregation—combining traces and XLA dumps for accurate bandwidth calculations— features such as synchronization analysis, multiprocessing support, and integrated reporting (similar to PyTorch NCCL analyzer) are intentionally deferred to future iterations.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a JAX NCCL Analyser for analyzing collective communication bandwidth in multi-node JAX distributed training workloads. The implementation provides comprehensive tooling to parse JAX traces, extract collective operations from XLA dumps, and calculate algorithmic and bus bandwidth metrics.

Key Changes:

  • JAX-specific NCCL analyzer with XLA dump parsing capabilities
  • Bandwidth calculation engine supporting multiple collective types (all-reduce, all-gather, reduce-scatter, all-to-all, collective-permute)
  • Utility functions for automatic node-to-protobuf file mapping

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
TraceLens/NcclAnalyser/jax_nccl_analyser.py Core analyzer class implementing trace loading, XLA parsing, and bandwidth calculations
TraceLens/NcclAnalyser/util/xla_parser.py XLA dump parser for extracting collective operation metadata including replica groups and tensor dimensions
TraceLens/NcclAnalyser/util/node_rank_to_protobuf_file_mapping.py Utility for automatic discovery and mapping of node ranks to protobuf trace files
TraceLens/util.py Added regex pattern for extracting replica_groups information from HLO operations
tests/test_jax_nccl_analyser.py Unit tests covering trace loading, bandwidth scaler calculations, and end-to-end analysis
examples/jax_nccl_analyser_example.ipynb Jupyter notebook demonstrating usage with detailed documentation
TraceLens/init.py Export JaxNcclAnalyser class
TraceLens/NcclAnalyser/init.py Export JaxNcclAnalyser from submodule

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@olehtika
Copy link
Contributor

@gabeweisz and @devalshahamd Any input for the PR review? This has been open quite a while.

Copy link
Collaborator

@gabeweisz gabeweisz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks OK to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants