Description
Context
Objective
In this RFC I will talk about the roadmap to enable eager mode as the default computation mode for PyTorch/XLA users and how to enable graph compilation in this mode.
Background
PyTorch/XLA has been using tracing mode as the default mode since the project started. All of the torch operation users issued will be accumulated in the background and sent to the XLA for compilation and execution upon a mark_step
call.
The upside of this approach is that users don’t need to change their model code too much. As long as the user adds a mark_step
at the right place everything should just work. However from the user feedback in the last couple years this approach creates too much confusion and frustration for the user. Both PyTorch and JAX took the approach of using eager mode as default and asking users to specify the function that they want to compile. PyTorch/XLA should take the same approach.
Design
Eager mode
There is no real eager mode in TPU. However we can fake the eager mode by compiling and executing each torch operation. Such mode already exist as a debug only mode today, it was contributed by @aws-rhsoln 2 year ago in #3306. The work here is to do a better API level wrapping and make sure this mode work with other features(debug output, SPMD, multiprocess etc). This approach was way too slow a couple years ago due to XRT not being able to execute small executions very efficiently but with PJRT the performance is much better.
The whole eager mode still builds on top of the existing Lazy tensor framework, but becomes invisible to the user. A couple things we need to do to accommodate the eager mode are
- Increase the compilation cache from 1024 to 2048 since each torch op will also reside in the compilation cache. We also need to recompile every torch op for different input shapes.
- Increase the max execution we can queue in the PJRT level since now we will execute a lot more small computations.
Compile
For the compile part we currently have 2 options, lazy tensor and torch dynamo(torch.compile).
For lazy tensor based compile I will add a new API_
torch_xla.experimental.compile(fn) -> compiled_fn
Which under the hood just enables the tracing mode upon running the function and executes the traced graph before returning. Here is the implementation. For torch.compile
we can just use the existing API.
Example UX
import torch_xla
torch_xla.experimental.eager_mode(True)
Class TrainDecoderOnlyBase():
def __init__():
train_loader = MyLoader()
self.model = DecoderOnlyModel(self.config).to(torch_xla.device())
# if run with dynamo, use
# self.step_fn = torch.compile(self.step_fn, backend="openxla")
self.step_fn = torch_xla.experimental.compile(self.step_fn)
def step_fn(self, data, target):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(
logits.view(-1, self.config.vocab_size), target.view(-1))
loss.backward()
self.run_optimizer()
return loss
def start_training(self):
for step, (data, target) in enumerate(loader):
loss = self.step_fn(data, target)
if __name__ == '__main__':
base = TrainDecoderOnlyBase()
base.start_training()
Note that two changes user need to make is to enable the eager mode by torch_xla.experimental.eager_mode(True)
and then compile the step function with torch_xla.experimental.compile
or torch.compile
.
Users can also choose to run the whole model in eager mode.
Why
IMO using tracing mode as the default has a couple very significant drawback
- Users are often confused about when the framework is tracing and when the framework is executing.
- Users don’t know where to add the
mark_step
. - Random python code(data preprocessing for example) often generates some small pending execution that gets leaked into the main graph(step function) and causes recompilation. The recompilation of the whole graph is usually very expensive.
- It is hard to debug when/why recompilation happens.
Both JAX and PyTorch took the approach of asking users to explicitly mark the region/function for compilation. This methodology seems well received for users that want compilation mode. I think this proposal will make a much better usability story by
- Allow users to use eager mode to do the initial model development and use compile mode to scale up. This also significantly lowers the bar for a normal pytorch user to onboard PyTorch/XLA.
- Reduce the number of recompilation generated by non-core model codes, since those will get executed eagerly.
- Make graph recompilation easier to debug since only the
compiled_fn
should generate graphs.
Benchmark
I am running a 2 layer decoder only model training(it is pretty much just a llama2) with fake data on a single chip of v4-8 for 300 steps. This is not a very scientific benchmark so take it with a grain of salt.
token/s | |
Tracing mode(base line) | 147 |
Eager mode | 65 |
Eager + torch_xla compile | 147 |
Eager mode can achieve ~45% performance of the fully compiled model for the decoder only model. The trainer I used to test can be found here and here.
Work Breakdown
- Enable eager mode (done)
- Enable
torch_xla.experimental.compile
(done) - Support eager mode with
torch.compile
(pr) - Test eager mode with SPMD (1 day)
- Test eager mode with multi-process distributed (2 days)
- Test eager mode with palla kernel(1 day)
- Test eager mode with rest of the pytorch/xla features (1 week)
- Enable more tests with eager mode (1 week)
- Enable more tests with eager mode + torch_xla.compile (1 week)
- Update examples and README to use eager + troch_xla.compile(1 week)
- Integrate eager mode with HF(2 weeks to 1 months)
- Integrate eager mode with Torch Lighting(2 weeks to 1 months)
Timeline
2.4 release -> experimental
2.5 release -> beta
2.6 release -> enable by default
cc @ezyang @bdhirsh @wconstab @baoleai @amithrm @jeffhataws @albanD @gkroiz @Liyang90