Description
🚀 Feature
It's worth providing some function to capture the model constructor (i.e., the torch ops that generates the random weights that make up the parameters of a model) as one StableHLO graph, and run that on accelerator devices.
Motivation
The primary motivation is to more closely match PyTorch eager UX during SPMD training.
Today, in order to initialize a large model on e.g. 256 TPUs, we randomly initialize every layer, and then send that layer to TPUs following a sharding spec:
xla/torchax/examples/train_llama_torchtitan/train_llama.py
Lines 193 to 211 in 8e6ca60
This works but has some drawbacks:
- We're initializing the weights with
torch.randn
but eager PyTorch initializes the weights with a variety of different distributions. When I tested training a Llama model withrandn
(Gaussian distributed) weights, the loss at step 0 is 10x larger than what eager PyTorch gives us to start with. - We could probably workaround this in the near term by having the user specify a dictionary of
module: initializer_fn
mappings. But that's more code over eager PyTorch, and is a cost that users are paying without corresponding gains. In comparison, SPMD sharding annotations is a cost that users pay to get automatic collectives/sharding propagation.
Pitch
In PyTorch/XLA, I could do this:
with torch_xla.device():
model = Model()
xs.mark_sharding(model.some_weight, ...)
torch_xla.sync()
The above will compile down to a graph that outputs a bunch of model weights, and the outputs (weights) have sharding annotations in them. When this graph is executed, each TPU will initialize the shard of the model weight that they are responsible for. Note that in this case the only extra cost (besides the torch_xla.sync()
) is the mark_sharding
for SPMD sharding annotations, reflecting a principle of "only pay for what you use".
We propose looking into the feasibility of this sort of feature in torchax
. For example, could we run the model constructor under some sort of compile
function or jit
context manager, where all the model weights are tracers, and we lower them into StableHLO?
Alternatives
An alternative is to maintain a comprehensive module: initializer_fn
mapping inside torchax
itself. However, that still won't cover cases where the user added custom initialization logic in their model constructor.
Additional context
We found this while bringing up Llama 3.1 405B training, c.f. https://github.com/AI-Hypercomputer/torchprime/pull/25/files
cc @qihqi