Skip to content

marijaEf/xla-gpu-tests

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch/XLA SPMD Weak-Scaling (Sanity Check) on NVIDIA A100 (80GB)

A minimal, single-run sanity check of PyTorch/XLA SPMD weak-scaling behavior using the official ResNet-50 example.

This experiment evaluates SPMD configured strictly for data-parallel training (batch sharding only).
It is intended for reproducibility and quick validation, not as a rigorous benchmark.

Scope and Clarification

This test uses PyTorch/XLA SPMD only in a data-parallel configuration:

  • Sharding is applied exclusively on the batch dimension
  • No model parallelism, tensor parallelism, or hybrid sharding strategies are used
  • Behavior observed here reflects data-parallel SPMD performance, not general SPMD behavior

Methodology

We used the official ResNet-50 SPMD training example using PyTorch/XLA 2.7 (latest CUDA-supported release; see note below), forked at:
https://github.com/marijaEf/xla/blob/ddp-benchmark/test/spmd/test_train_spmd_imagenet.py

Key characteristics:

  • Model: ResNet-50 (default in script)
  • Mode: SPMD
  • Parallelism strategy: Data parallel (via batch sharding only)
  • Data: Synthetic (--fake_data) using zero tensors

Note: CUDA deprecation started with PyTroch/XLA release 2.8.

Parameters

  • --fake_data
    Generates synthetic inputs:
    torch.zeros(batch_size, 3, 224, 224)
    torch.zeros(batch_size, dtype=torch.int64)
  • --sharding batch
    Enforces data-parallel SPMD (batch dimension only)
  • Local batch size per GPU: 32
  • Global batch size scales with number of devices (weak scaling)
  • Training:
    • --epochs 2
    • --log_steps 10

Environment

  • Hardware: NVIDIA A100-SXM4-80GB
  • Container: herefortheimage/pytorch-xla-2.7.1-mlperf-resnet50:latest
  • Execution: Kubernetes
  • Runs per configuration: 1 (no repetition)

An example pod specification is provided in test_pytorch_xla_spmd.yaml.

CUDA Version and Build Instructions

This setup was tested with CUDA 12.6.2.

Build with a Different CUDA Version

You can change the CUDA version by:

  1. Update the Dockerfile
    FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
    and update the PyTorch/XLA GPU wheel to a compatible version.

    Note: You can find list of GPU release builds here. CUDA version compatibility must match the PyTorch and PyTorch/XLA versions used.

  2. Build a Docker image with the updated CUDA version
  3. Use the new Docker image for creating the containers

Configurations

GPUs Global Batch Size Local Batch Size (per GPU)
1 128 32
2 256 32
4 512 32

Global batch size is scaled linearly to maintain weak-scaling conditions.

Logs are available in the logs/ directory.

Findings

From the execution logs:

  • Total execution time increases as more GPUs are added
  • Throughput decreases (reported as Rate and GlobalRate)

This deviates from ideal weak-scaling expectations, where throughput should remain approximately constant as resources scale.

Interpretation

Given that:

  • the test uses the official PyTorch/XLA SPMD example
  • the setup is restricted to data-parallel batch sharding
  • no custom implementation changes were introduced

the observed behavior suggests that:

  • weak-scaling degradation is not specific to a custom training setup
  • it may be characteristic of data-parallel SPMD in PyTorch/XLA v2.7 under this configuration

Limitations

  • Single-run experiment (no statistical significance)
  • Synthetic data (eliminates data loading or I/O effects)
  • Limited scale (up to 4 GPUs)
  • No profiling (e.g., communication vs compute breakdown)

Conclusion

This sanity check indicates that data-parallel SPMD (batch sharding only) in PyTorch/XLA v2.7 does not exhibit ideal weak-scaling under these conditions.

Further investigation with:

  • repeated runs
  • real datasets
  • profiling tools (e.g., torch_xla.debug.profiler)
  • comparison with standard PyTorch DDP

is required to draw definitive conclusions.

Suggested Next Steps

If extending this work:

  • Run multiple trials and report averages and variance
  • Compare against standard PyTorch DDP baseline
  • Profile execution to isolate communication overhead
  • Separate time-to-first-step from steady-state throughput

Repository Structure

Disclaimer

This repository is intended as a sanity check and reproducibility artifact, not as a formal benchmark or performance study.

About

A reproducible sanity check showing unexpected weak-scaling behavior under data-parallel SPMD on NVIDIA A100 GPUs

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors