Description
🚀 Feature
Currently if we wrap with model with torch_xla2.compile
and want to train the model using the traditional torch training loop similar to https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/examples/basic_training.py
You would notice that it doesn't work.
The reason is because the compile wrapper JittableModule
will eventuall call a jax.jit
d callable, and torch doesn't know how to compute gradient of that callable.
The solution is to create a torch.autograd.Function
subclass on the fly, with backward defined to call jax.vjp
similar to this tutorial: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
The result would be that wrapping a model with torch_xla2.compile
it is still trainable.
Motivation
Having the forward and backward compiled with jax jit is faster to run.
Activity