Skip to content

[RFC] A high-level GSPMD API in PT/XLA (based on xs.mark_sharding) #3755

Open
@ronghanghu

Description

🚀 [RFC] A high-level GSPMD API in PT/XLA (based on xs.mark_sharding)

This RFC proposes a high-level API for GSPMD through a wrapper class and a partitioning rule function, based on xs.mark_sharding.

Motivation

GSPMD is a powerful approach for model sharding and parallelization, and is enabled in PyTorch/XLA by the xs.mark_sharding API in #3476 and #3684. However, there are a few limitations when directly using this API to build SPMD programs:

  • To build SPMD (e.g. on a transformer model) directly using this API, one would have to edit the PyTorch model and insert xs.mark_sharding into the initialization and forward code of each layer in the model. This forces the model implementation to be aware of the SPMD sharding strategies and makes the model code device-dependent and sharding-dependent, incompatible with the native PyTorch, and hard to switch to other sharding strategies (e.g. from Megatron to Optimus) or other parallelism APIs (e.g. the FSDP wrapper).
  • The requirement of adding xs.mark_sharding to a model implementation (e.g. its forward code) makes it hard/impossible to run SPMD in those cases where the users cannot easily edit the model implementation code (e.g. when a model comes from external libraries) and largely limits its application. For example, suppose one wants to apply SPMD to a torchvision model or a Hugging Face transformer, it is infeasible to submit a pull request to these libraries to add xs.mark_sharding into their model implementations.
  • SPMD applications often require annotating a lot of tensors, sometimes on those submodules several levels below the base model (e.g. every transformer MLP nn.Linear layer in a T5 encoder-decoder model for Megatron-style sharding), so it is inconvenient and error-prone to directly build SPMD programs on top of xs.mark_sharding.

Ideally, one would like to be able to take any existing PyTorch models (e.g. BERT from Hugging Face), and apply a specific sharding strategy to it (e.g. Megatron), without changing the model code (in Hugging Face library code in this case).

In this RFC, we advocate for a high-level SPMD API in PyTorch/XLA (built upon xs.mark_sharding) without requesting the user to rewrite their entire model code to be aware of the SPMD sharding partitions. Specifically, we propose to parallelize a model with SPMD by first building the base model and then wrapping it with a wrapper class GSPMDParallel, similar to how DDP or FSDP is applied to an existing model.

The proposed GSPMDParallel class takes as input

  1. a given PyTorch module (the base model, an nn.Module instance), and
  2. a user-specified function/callable sharding_rule_func to define the sharding rules, and
  3. a mesh_shape tuple to define the TPU mesh.

Since those tensors we want to shard in SPMD are usually either the parameters or the input/output tensors of a submodule (e.g. an nn.Linear) in the whole network, the GSPMDParallel class will recursively go down to all the submodules in the wrapped model and apply the sharding rules to the parameters, inputs, and outputs of each submodule. The sharded HLO graph will be built when forward (and backward) is called on the GSPMDParallel class.

In this way, the base model code doesn't need to be changed when using SPMD -- it will just be wrapped and partitioned in a post-hoc manner, so all the current model implementations from the abundant existing PyTorch libraries (e.g. timm, torchvision, Hugging Face) can be directly used for SPMD parallelism without rewriting these libraries.

Proposed Implementation

Building upon the xs.mark_sharding API, the proposed implementation (prototype) of the high-level SPMD API consists of a GSPMDParallel class that recursively applies a user-defined function/callable sharding_rule_func to the submodules of an input module, which can be based on their names and the submodule instances themselves.

A prototype is as follows:

class GSPMDParallel(nn.Module):
  """recursively apply a `sharding_rule_func` to submodules of an input `module`."""

  def __init__(self, module: nn.Module, sharding_rule_func: Callable, mesh_shape: Tuple[Union[int, None]]):
      super().__init__()
  
      # apply SPMD sharding rule to the base model
      sharded_module = self.apply_sharding(module, sharding_rule_func, mesh_shape)
      self.module = sharded_module
      self.sharding_rule_func = sharding_rule_func
      self.mesh_shape = mesh_shape

  def apply_sharding(self, module, sharding_rule_func, mesh_shape):
      sharded_module = deepcopy(module)  # maybe make a copy if we want to keep the original `module`

      # recursively apply the sharding rule to all the submodules
      for name, m in module.named_modules():
          sharding_rule_func(name, m, mesh_shape)

      return sharded_module

  def forward(self, *args, **kwargs):
      return self.module(*args, **kwargs)

And a user-defined sharding rule function/callable will decide whether and how to apply SPMD sharding annotations to an nn.Module's parameters, inputs, and outputs with xs.mark_sharding. For example, the following sharding rule function can be used to apply the Megatron sharding to those MLP layers in a timm Vision Transformer (ViT).

def example_sharding_rule_func(name: str, submodule: nn.Module, mesh_shape: Tuple[Union[int, None]]):
    """apply Megatron to MLP layers in timm ViT. Assuming (data, model) `mesh_shape` like T5X."""

    if name.endswith("blocks.mlp.fc1"):
      assert isinstance(submodule, nn.Linear)
      # shard the 1st MLP layer's weight param (mlp_dim, hidden_size)
      submodule.weight = xs.mark_sharding(submodule.weight, mesh_shape, (1, None))
      # shard the 1st MLP layer's bias param (mlp_dim,)
      submodule.bias = xs.mark_sharding(submodule.bias, mesh_shape, (1,))
      
      # shard the 1st MLP layer's output (batch_size, seq_length, mlp_dim) by patching its forward
      # TODO (change to decorators on `forward` to get cleaner code)
      submodule._orig_forward = submodule.forward
      def _new_forward(m, x):
          return xs.mark_sharding(m._orig_forward(x), mesh_shape, (0, None, 1))
      submodule.forward = MethodType(_new_forward, submodule)

    elif name.endswith("blocks.mlp.fc2"):
      assert isinstance(submodule, nn.Linear)

      # shard the weight (hidden_size, mlp_size) in the 2nd MLP layer
      submodule.weight = xs.mark_sharding(submodule.weight, mesh_shape, (None, 1))

In the example above, it is straightforward to shard a submodule's parameter tensors, but rather hacky to shard its input and output tensors (by manually patching the forward method). We can switch to decorators to get a cleaner code.

We could also provide a few easier ways to build sharding_rule_func or to use it in GSPMDParallel. For example, we can make sharding_rule_func a callable class instead of a function, and provide a good class structure to build the sharding rule callables.

Alternatives

The implementation above requires a sharding_rule_func (that takes a submodule object and its name) to decide how to decorate its parameters, inputs, or return values using the xs.mark_sharding API. While this should be sufficient to implement nearly all SPMD use cases, it is hard to later inspect the sharding annotations in a GSPMDParallel instance, such as printing a list of sharded tensors annotated by xs.mark_sharding and their partitioning details. It relies on the user to keep track of what is sharded by sharding_rule_func.

An alternative way to implement this GSPMDParallel class is to enforce a more principled way to define the sharding rule. Rather than having an arbitrary function to do anything, one can have a name-based (string-based) sharding rule similar to the logical axis names in T5X. Under this name-based sharding rule definition, a sharding rule consists of the following:

  1. a function/callable to extract those tensors to be sharded (param, input tensor, return values) from a submodule and map their tensor axes to logical axis names (such as (batch, mlp_size)).

  2. a user-specified mapping rule (e.g. a list) to map a logical axis to a TPU mesh axis, similar to those in T5X.

An orthogonal and complementary way to simplify the sharding rule is to use Named Tensors API to give each tensor axis a name (a string). If the users need to implement both their base model (to be wrapped by GSPMDParallel) as well as their sharding_rule_func, then it would be easier for them to use named tensors in their base model implementation and refer to those axis names in their sharding_rule_func implementation. However, this approach cannot be applied to existing models (e.g. those in torchvision or Hugging Face) that don't use the named tensors, so named tensors should not be a requirement in the GSPMDParallel class.

Additional context

A related problem is how to save and load an SPMD partitioned model's parameters and optimizer state dicts, especially in those cases where the full model cannot fit into a single TPU VM's host memory. This part requires a mechanism to save and load checkpoints in a distributed manner without consolidating them onto a single host.

  • We can enable distributed handling of an SPMD model's state_dict() by building an API to allow saving and loading of a sharded partition (rather than the consolidated version) of a XLAShardedTensor instance.
  • We can also make it compatible with the tensorstore package to use it in PT/XLA.
  • We can consider making the sequential serialization API xser.save work with XLAShardedTensor.

On our end (FAIR), we are happy to work on a prototype implementation of the GSPMDParallel class above and first try it out in our internal use cases. We can submit a PR once we have a mature implementation.

cc: @yeounoh @JackCaoG @miladm @ultrons

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions