Skip to content

[RFC] PyTorch/XLA Auto-Sharding API #6322

Open
@yeounoh

Description

🚀 Feature & Motivation

PyTorch/XLA recently launched PyTorch/XLA SPMD (RFC, blog, docs/spmd.md), as a first-step to automate ML workloads parallelization using GSPMD. In contrast to previous PyTorch user experiences with implementing tensor & model parallelism for their ML workloads, PyTorch/XLA SPMD allows the users to provide a handful of “sharding hints” via the PyTorch/XLA SPMD sharding annotation API, and keep the original model implementation as-is. We presented some exciting CloudTPU PyTorch LLaMA2 training results using SPMD at Google Next (blog). Another highlight of PyTorch/XLA SPMD was that it enables more advanced and hybrid types of parallelism strategies for the PyTorch users, combining data and tensor parallelism, as well as pipelining. While this is all great and we are happy to release the new feature to the PyTorch community, one challenge still remains as to providing the optimal sharding hints. It turns out that the performance largely depends on the quality of sharding hints provided by the user – and it requires a correct and deep understanding of model architectures and much expertise to come up with optimal sharding hints.

To address this problem, we propose to integrate PyTorch/XLA SPMD with XLA's auto sharding service that allows the XLA compiler to shard and optimize the whole model without any user input. XLA auto-sharding service is based on a published research work, Alpa (blog). While this sounds like a leap of faith, it is already being tested and showing some promising results on the Google internal workloads (also see our mini-benchmark results).

API and Usage Example

To enable auto-sharding, simply call use_spmd with the auto=True flag:

import torch_xla.runtime as xr 

# Enable XLA SPMD execution mode with auto-sharding.
xr.use_spmd(auto=True)

# Write a PyTorch program without any sharding annotations.
# User can still provide sharding hints, but optionally. 
...

This should be it, the PyTorch program should be automatically sharded and executed.

There are optional configuration knobs, though:

# The optional auto_sharding config can be passed to control 
# the auto-sharding behavior.
config = {"auto_sharding" : {"partitioner" : "alpa", "keep_user_sharding" : True }}
xr.use_spmd(auto=True, spmd_config = config)

The auto-sharding uses auto_sharder=”alpa” auto-partitioner, and it’s the only option available.

The auto-sharding runs with SPMD mode and should work with zero sharding hints. to work, and the execution is sharded and optimized by the XLA compiler. By default, the auto-sharding pass will respect pre-existing sharding annotations on the inputs and outputs; the user can choose to provide more hints using PyTorch/XLA SPMD mark_sharding API and setting keep_user_sharding option.

auto-sharding configuration

Here is the list of supported auto-sharding configuration options:

  • partitioner (str): auto-sharding or partitioning algorithm to be used. Set to “alpa” by default.
  • keep_user_sharding (bool): if set, respect all user-provided sharding annotations. Otherwise, only input and output shardings are kept.
  • memory_budget (int): Per TPU chip memory (GB) budget for auto-sharding partitioning to consider in optimization.
  • mesh_shape (str): if provided, allow auto-sharding pass to generate sharding strategies for the given mesh shape. If unset, 1D (num_devices, 1) mesh is used.
  • auto_mesh_selection (bool): if set, allow the compiler to explore different mesh shapes and use an optimal one. This extends the compile time.

PyTorch DTensor integration

It is important to be plugged into the PyTorch distributed API for unified UX for PyTorch distributed [RFC]. DTensor is PyTorch's SPMD-style distributed tensor computation API, where PyTorch/XLA SPMD is integrated with.
We propose to introduce a new auto partition function to the DTensor distribute_module API with partition_fcn="AUTO":

import torch
import torch_xla.core.xla_model as xm
from torch.distributed import DeviceMesh, distirbute_module

# Define a PyTorch module
...
my_module = MyModule().to(xm.xla_device())

# Automatically sharded (annotated) module for XLA distributed execution
mesh = DeviceMesh("xla", list(range(world_size)))
my_sharded_module = distribute_module(my_module, mesh, partition_fcn="AUTO")

Mini-Benchmark Results

Here we present preliminary benchmark results usign GPT-2, LLaMA2 and GPT-Neo from HuggingFace. Auto-sharding works to parallelize and distribute any transformer-based language models without user sharding annotations on the models. We used PyTorch/XLA’s MpDeviceDataLoader for background data loading with batch dimension sharding.

gpt2_v4_8_mfu_batch

gpt2_2b_step_time_vs_batch

A preliminary benchmark result based on GPT-2 (2B parameters) model on TPUv4-8 shows that the auto-sharding pass generates comparable results with the human-curated 2D sharding strategies:

  • Resulted in more memory-efficient shardings, enabling larger batch sizes
  • auto-sharding pass heuristics do not seem to produce the best results for different (smaller) batch sizes. It may be producing the optimal shardings for the maximal available batch size.
  • Comparable, but slightly worse MFU at the maximal batch size, possibly due to resharding (room for improvement) & sharding differences.

llama2_2b_bsz128

perf_auto_vs_manual

The above figures show that (left) auto-sharding doesn’t always, in case of LLaMA2, generate shardings for the best performance (MFU) while still producing performant ones. It is important to note that (right) it did work with three popular models from HuggingFace without customizing or manual annotations.

Alternatives

This work is to automate model parallelization using PyTorch/XLA SPMD, allowing the XLA compiler to come up with the optimal sharding strategies on behalf of the user. Alternatively, we will introduce a high-level API (e.g., FSDP) that iteratively calls PyTorch/XLA SPMD for a given policy RFC. Our goal is to provide useful tools for the PyTorch users for good optionalities.

Additional Context

Alpa is still an experimental feature, and it works for XLA supported HW types, like TPU and GPU -- We hope to provide a singular approach for any PyTorch/XLA backend types. In the near future, we will also expand the choice of auto-sharding algorithms outside Alpa as well.

cc @JackCaoG @miladm @shauheen @alanwaketan @baoleai @anw90 @yitongh @Seventeen17 @wconstab @wanchaol

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions