Skip to content

[torchax] jit compile the model constructor #8635

Open
@tengyifei

Description

🚀 Feature

It's worth providing some function to capture the model constructor (i.e., the torch ops that generates the random weights that make up the parameters of a model) as one StableHLO graph, and run that on accelerator devices.

Motivation

The primary motivation is to more closely match PyTorch eager UX during SPMD training.

Today, in order to initialize a large model on e.g. 256 TPUs, we randomly initialize every layer, and then send that layer to TPUs following a sharding spec:

def create_sharded_weights(model, mesh, sharding_map):
res = {}
env = torchax.default_env()
for name, weight_meta in model.state_dict().items():
sharding_spec = sharding_map.get(_process_sharding_name(name))
if sharding_spec is None:
print('Skipping weight:', name)
continue
sharding = NamedSharding(mesh, P(*sharding_spec))
with jax.default_device(jax.devices('cpu')[0]):
weight_torch = torch.randn(
weight_meta.shape,
dtype=weight_meta.dtype)
weight_jax = torchax.default_env().to_xla(weight_torch).jax()
#print(name, weight.shape, weight.dtype)
res[name] = env.j2t_iso(jax.make_array_from_callback(
weight_jax.shape, sharding, lambda a: weight_jax[a]
))
return res

This works but has some drawbacks:

  • We're initializing the weights with torch.randn but eager PyTorch initializes the weights with a variety of different distributions. When I tested training a Llama model with randn (Gaussian distributed) weights, the loss at step 0 is 10x larger than what eager PyTorch gives us to start with.
  • We could probably workaround this in the near term by having the user specify a dictionary of module: initializer_fn mappings. But that's more code over eager PyTorch, and is a cost that users are paying without corresponding gains. In comparison, SPMD sharding annotations is a cost that users pay to get automatic collectives/sharding propagation.

Pitch

In PyTorch/XLA, I could do this:

with torch_xla.device():
  model = Model()
  xs.mark_sharding(model.some_weight, ...)
torch_xla.sync()

The above will compile down to a graph that outputs a bunch of model weights, and the outputs (weights) have sharding annotations in them. When this graph is executed, each TPU will initialize the shard of the model weight that they are responsible for. Note that in this case the only extra cost (besides the torch_xla.sync()) is the mark_sharding for SPMD sharding annotations, reflecting a principle of "only pay for what you use".

We propose looking into the feasibility of this sort of feature in torchax. For example, could we run the model constructor under some sort of compile function or jit context manager, where all the model weights are tracers, and we lower them into StableHLO?

Alternatives

An alternative is to maintain a comprehensive module: initializer_fn mapping inside torchax itself. However, that still won't cover cases where the user added custom initialization logic in their model constructor.

Additional context

We found this while bringing up Llama 3.1 405B training, c.f. https://github.com/AI-Hypercomputer/torchprime/pull/25/files

cc @qihqi

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions