Description
🚀 [RFC] A high-level GSPMD API in PT/XLA (based on xs.mark_sharding
)
This RFC proposes a high-level API for GSPMD through a wrapper class and a partitioning rule function, based on xs.mark_sharding
.
Motivation
GSPMD is a powerful approach for model sharding and parallelization, and is enabled in PyTorch/XLA by the xs.mark_sharding
API in #3476 and #3684. However, there are a few limitations when directly using this API to build SPMD programs:
- To build SPMD (e.g. on a transformer model) directly using this API, one would have to edit the PyTorch model and insert
xs.mark_sharding
into the initialization and forward code of each layer in the model. This forces the model implementation to be aware of the SPMD sharding strategies and makes the model code device-dependent and sharding-dependent, incompatible with the native PyTorch, and hard to switch to other sharding strategies (e.g. from Megatron to Optimus) or other parallelism APIs (e.g. the FSDP wrapper). - The requirement of adding
xs.mark_sharding
to a model implementation (e.g. its forward code) makes it hard/impossible to run SPMD in those cases where the users cannot easily edit the model implementation code (e.g. when a model comes from external libraries) and largely limits its application. For example, suppose one wants to apply SPMD to a torchvision model or a Hugging Face transformer, it is infeasible to submit a pull request to these libraries to addxs.mark_sharding
into their model implementations. - SPMD applications often require annotating a lot of tensors, sometimes on those submodules several levels below the base model (e.g. every transformer MLP
nn.Linear
layer in a T5 encoder-decoder model for Megatron-style sharding), so it is inconvenient and error-prone to directly build SPMD programs on top ofxs.mark_sharding
.
Ideally, one would like to be able to take any existing PyTorch models (e.g. BERT from Hugging Face), and apply a specific sharding strategy to it (e.g. Megatron), without changing the model code (in Hugging Face library code in this case).
In this RFC, we advocate for a high-level SPMD API in PyTorch/XLA (built upon xs.mark_sharding
) without requesting the user to rewrite their entire model code to be aware of the SPMD sharding partitions. Specifically, we propose to parallelize a model with SPMD by first building the base model and then wrapping it with a wrapper class GSPMDParallel
, similar to how DDP or FSDP is applied to an existing model.
The proposed GSPMDParallel
class takes as input
- a given PyTorch
module
(the base model, annn.Module
instance), and - a user-specified function/callable
sharding_rule_func
to define the sharding rules, and - a
mesh_shape
tuple to define the TPU mesh.
Since those tensors we want to shard in SPMD are usually either the parameters or the input/output tensors of a submodule (e.g. an nn.Linear
) in the whole network, the GSPMDParallel
class will recursively go down to all the submodules in the wrapped model and apply the sharding rules to the parameters, inputs, and outputs of each submodule. The sharded HLO graph will be built when forward (and backward) is called on the GSPMDParallel
class.
In this way, the base model code doesn't need to be changed when using SPMD -- it will just be wrapped and partitioned in a post-hoc manner, so all the current model implementations from the abundant existing PyTorch libraries (e.g. timm, torchvision, Hugging Face) can be directly used for SPMD parallelism without rewriting these libraries.
Proposed Implementation
Building upon the xs.mark_sharding
API, the proposed implementation (prototype) of the high-level SPMD API consists of a GSPMDParallel
class that recursively applies a user-defined function/callable sharding_rule_func
to the submodules of an input module
, which can be based on their names and the submodule instances themselves.
A prototype is as follows:
class GSPMDParallel(nn.Module):
"""recursively apply a `sharding_rule_func` to submodules of an input `module`."""
def __init__(self, module: nn.Module, sharding_rule_func: Callable, mesh_shape: Tuple[Union[int, None]]):
super().__init__()
# apply SPMD sharding rule to the base model
sharded_module = self.apply_sharding(module, sharding_rule_func, mesh_shape)
self.module = sharded_module
self.sharding_rule_func = sharding_rule_func
self.mesh_shape = mesh_shape
def apply_sharding(self, module, sharding_rule_func, mesh_shape):
sharded_module = deepcopy(module) # maybe make a copy if we want to keep the original `module`
# recursively apply the sharding rule to all the submodules
for name, m in module.named_modules():
sharding_rule_func(name, m, mesh_shape)
return sharded_module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
And a user-defined sharding rule function/callable will decide whether and how to apply SPMD sharding annotations to an nn.Module
's parameters, inputs, and outputs with xs.mark_sharding
. For example, the following sharding rule function can be used to apply the Megatron sharding to those MLP layers in a timm Vision Transformer (ViT).
def example_sharding_rule_func(name: str, submodule: nn.Module, mesh_shape: Tuple[Union[int, None]]):
"""apply Megatron to MLP layers in timm ViT. Assuming (data, model) `mesh_shape` like T5X."""
if name.endswith("blocks.mlp.fc1"):
assert isinstance(submodule, nn.Linear)
# shard the 1st MLP layer's weight param (mlp_dim, hidden_size)
submodule.weight = xs.mark_sharding(submodule.weight, mesh_shape, (1, None))
# shard the 1st MLP layer's bias param (mlp_dim,)
submodule.bias = xs.mark_sharding(submodule.bias, mesh_shape, (1,))
# shard the 1st MLP layer's output (batch_size, seq_length, mlp_dim) by patching its forward
# TODO (change to decorators on `forward` to get cleaner code)
submodule._orig_forward = submodule.forward
def _new_forward(m, x):
return xs.mark_sharding(m._orig_forward(x), mesh_shape, (0, None, 1))
submodule.forward = MethodType(_new_forward, submodule)
elif name.endswith("blocks.mlp.fc2"):
assert isinstance(submodule, nn.Linear)
# shard the weight (hidden_size, mlp_size) in the 2nd MLP layer
submodule.weight = xs.mark_sharding(submodule.weight, mesh_shape, (None, 1))
In the example above, it is straightforward to shard a submodule's parameter tensors, but rather hacky to shard its input and output tensors (by manually patching the forward method). We can switch to decorators to get a cleaner code.
We could also provide a few easier ways to build sharding_rule_func
or to use it in GSPMDParallel
. For example, we can make sharding_rule_func
a callable class instead of a function, and provide a good class structure to build the sharding rule callables.
Alternatives
The implementation above requires a sharding_rule_func
(that takes a submodule object and its name) to decide how to decorate its parameters, inputs, or return values using the xs.mark_sharding
API. While this should be sufficient to implement nearly all SPMD use cases, it is hard to later inspect the sharding annotations in a GSPMDParallel
instance, such as printing a list of sharded tensors annotated by xs.mark_sharding
and their partitioning details. It relies on the user to keep track of what is sharded by sharding_rule_func
.
An alternative way to implement this GSPMDParallel
class is to enforce a more principled way to define the sharding rule. Rather than having an arbitrary function to do anything, one can have a name-based (string-based) sharding rule similar to the logical axis names in T5X. Under this name-based sharding rule definition, a sharding rule consists of the following:
-
a function/callable to extract those tensors to be sharded (param, input tensor, return values) from a submodule and map their tensor axes to logical axis names (such as
(batch, mlp_size)
). -
a user-specified mapping rule (e.g. a list) to map a logical axis to a TPU mesh axis, similar to those in T5X.
An orthogonal and complementary way to simplify the sharding rule is to use Named Tensors API to give each tensor axis a name (a string). If the users need to implement both their base model (to be wrapped by GSPMDParallel
) as well as their sharding_rule_func
, then it would be easier for them to use named tensors in their base model implementation and refer to those axis names in their sharding_rule_func
implementation. However, this approach cannot be applied to existing models (e.g. those in torchvision or Hugging Face) that don't use the named tensors, so named tensors should not be a requirement in the GSPMDParallel
class.
Additional context
A related problem is how to save and load an SPMD partitioned model's parameters and optimizer state dicts, especially in those cases where the full model cannot fit into a single TPU VM's host memory. This part requires a mechanism to save and load checkpoints in a distributed manner without consolidating them onto a single host.
- We can enable distributed handling of an SPMD model's
state_dict()
by building an API to allow saving and loading of a sharded partition (rather than the consolidated version) of aXLAShardedTensor
instance. - We can also make it compatible with the tensorstore package to use it in PT/XLA.
- We can consider making the sequential serialization API
xser.save
work withXLAShardedTensor
.
On our end (FAIR), we are happy to work on a prototype implementation of the GSPMDParallel
class above and first try it out in our internal use cases. We can submit a PR once we have a mature implementation.