A minimal, single-run sanity check of PyTorch/XLA SPMD weak-scaling behavior using the official ResNet-50 example.
This experiment evaluates SPMD configured strictly for data-parallel training (batch sharding only).
It is intended for reproducibility and quick validation, not as a rigorous benchmark.
This test uses PyTorch/XLA SPMD only in a data-parallel configuration:
- Sharding is applied exclusively on the batch dimension
- No model parallelism, tensor parallelism, or hybrid sharding strategies are used
- Behavior observed here reflects data-parallel SPMD performance, not general SPMD behavior
We used the official ResNet-50 SPMD training example using PyTorch/XLA 2.7 (latest CUDA-supported release; see note below), forked at:
https://github.com/marijaEf/xla/blob/ddp-benchmark/test/spmd/test_train_spmd_imagenet.py
Key characteristics:
- Model: ResNet-50 (default in script)
- Mode: SPMD
- Parallelism strategy: Data parallel (via batch sharding only)
- Data: Synthetic (
--fake_data) using zero tensors
Note: CUDA deprecation started with PyTroch/XLA release 2.8.
--fake_data
Generates synthetic inputs:torch.zeros(batch_size, 3, 224, 224) torch.zeros(batch_size, dtype=torch.int64)
--sharding batch
Enforces data-parallel SPMD (batch dimension only)- Local batch size per GPU: 32
- Global batch size scales with number of devices (weak scaling)
- Training:
--epochs 2--log_steps 10
- Hardware: NVIDIA A100-SXM4-80GB
- Container:
herefortheimage/pytorch-xla-2.7.1-mlperf-resnet50:latest - Execution: Kubernetes
- Runs per configuration: 1 (no repetition)
An example pod specification is provided in test_pytorch_xla_spmd.yaml.
This setup was tested with CUDA 12.6.2.
You can change the CUDA version by:
- Update the Dockerfile
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
and update the PyTorch/XLA GPU wheel to a compatible version.Note: You can find list of GPU release builds here. CUDA version compatibility must match the PyTorch and PyTorch/XLA versions used.
- Build a Docker image with the updated CUDA version
- Use the new Docker image for creating the containers
| GPUs | Global Batch Size | Local Batch Size (per GPU) |
|---|---|---|
| 1 | 128 | 32 |
| 2 | 256 | 32 |
| 4 | 512 | 32 |
Global batch size is scaled linearly to maintain weak-scaling conditions.
Logs are available in the logs/ directory.
From the execution logs:
- Total execution time increases as more GPUs are added
- Throughput decreases (reported as
RateandGlobalRate)
This deviates from ideal weak-scaling expectations, where throughput should remain approximately constant as resources scale.
Given that:
- the test uses the official PyTorch/XLA SPMD example
- the setup is restricted to data-parallel batch sharding
- no custom implementation changes were introduced
the observed behavior suggests that:
- weak-scaling degradation is not specific to a custom training setup
- it may be characteristic of data-parallel SPMD in PyTorch/XLA v2.7 under this configuration
- Single-run experiment (no statistical significance)
- Synthetic data (eliminates data loading or I/O effects)
- Limited scale (up to 4 GPUs)
- No profiling (e.g., communication vs compute breakdown)
This sanity check indicates that data-parallel SPMD (batch sharding only) in PyTorch/XLA v2.7 does not exhibit ideal weak-scaling under these conditions.
Further investigation with:
- repeated runs
- real datasets
- profiling tools (e.g., torch_xla.debug.profiler)
- comparison with standard PyTorch DDP
is required to draw definitive conclusions.
If extending this work:
- Run multiple trials and report averages and variance
- Compare against standard PyTorch DDP baseline
- Profile execution to isolate communication overhead
- Separate time-to-first-step from steady-state throughput
logs/— execution logs for each configurationtest_pytorch_xla_spmd.yaml— example Kubernetes pod specificationDockerfile- environment definition for reproducibilityrequirements.txt- Python dependencies
This repository is intended as a sanity check and reproducibility artifact, not as a formal benchmark or performance study.