[torchax][RFC] Anchor on the device API everywhere #8638
Description
🚀 Feature
Today we have two related concepts in torchax
:
- Environment
- The "jax" device
In particular, the user need to know both to access JAX (e.g. TPU) features.
This RFC proposes that we should refactor the API such that the user only need to know the "jax" device in order to operate on the TPUs.
Motivation
The device is a well known concept in PyTorch: tutorials talk about tensor.cpu()
and tensor.cuda()
. It's commonly understood that if I create a tensor without the device argument, then that tensor lives on the default device (usually CPU). If I call .cuda()
, then that tensor is moved to the GPU.
Since people primarily choose torchax
to be able to use the TPU (or access the XLA GPU backend), it makes sense to present this functionality as a PyTorch device. By the principle of symmetry, it's natural to introduce jax
counterparts for various cuda
APIs, where applicable. This gives people a clear mental model of when they are or aren't using JAX.
We look at a few examples (all these assume we import torchax
):
-
I can call
torch.cuda.current_device()
to get the index of the current CUDA device.- I can also call
torch.jax.current_device()
to get the index of the current JAX (XLA) device.
- I can also call
-
I can call
torch.cuda.is_available()
to check if CUDA support is available.- I can call
torch.jax.is_available()
to check if the JAX backend is available.
- I can call
-
I can run
torch.randn(1, 2, device='cuda')
to generate a random number using the CUDA device.- ❌ If I ran
torch.randn(1, 2, device='jax')
, that fails with a confusing dispatcher error: 1
- ❌ If I ran
-
I can run
torch.set_default_device('cuda')
to make all subsequent tensor live on the CUDA device.- ❌ If I ran
torch.set_default_device('jax')
and then creates some tensor, that fails with another confusing error: 2
- ❌ If I ran
This RFC proposes that we should change torchax
to close the behavior divergence such as the two above bullet points. In the limit, using eager torchax should feel identical to using some other backend of PyTorch.
Pitch
Always call enable_globally()
We're pretty close to closing the gaps above. If I run torchax.enable_globally()
after importing torchax
, then torch.randn(1, 2, device='jax')
works, and the error after torch.set_default_device('jax')
seems like a fixable bug. I propose we go one step further and just automatically call enable_globally()
and we should also fix the default device behavior.
Always keep the torchax modes activated
Today the environment object is what activates the XLAFunctionMode
and XLADispatchMode
that intercept PyTorch operations. However, these modes are an implementation detail of how torchax
supports the JAX device. It should be possible to always keep the XLAFunctionMode
and XLADispatchMode
activated in the mode stack, without changing the behavior of non-JAX tensors. This is akin to how PyTorch already keeps a few modes such as FuncTorchVmapMode
and FuncTorchDynamicLayerFrontMode
in the stack most of the time. For testing purposes, it could be useful to temporarily disable the XLAFunctionMode
and XLADispatchMode
, but that should be an internal API that users don't know about.
As a pressure test, we could probably try running some subset of PyTorch tests with XLA{Function,Dispatch}Mode
in the mode stack, and make sure those don't fail. That's to ensure that even if the user import torchax
, their CPU tensor behaviors don't change.
This suggests we need to decouple the XLAFunctionMode
and XLADispatchMode
from the environment. For example, perhaps those could be relocated to a torchax._internal.XLAModes
context manager.
Configuration context managers
The environment object also holds certain configuration (e.g. optimize for performance or accuracy). As a user it's useful to change these settings sometimes. We can keep them in the environment and always provide a sensible default in the default environment. We could also support a stack of environments via context managers, where configurations at the top of the stack takes precedence. That's a useful way to locally change some config and have them revert to previous values when leaving the scope.
RNG seed
The environment object also holds a seed for the pseudo random number generator. That should probably change as part of solving #8636.
Alternatives
An alternative is to do nothing and stick to the status quo.
Another follow up is to see what torch
changes do we need to remove frictions of using the JAX device. For example, today if I write tensor.jax()
with the hope of moving the tensor to the JAX device, the Python type checker complains that jax()
is not a known function on Tensor
, unlike cuda
.
Additional context
Anecdotally, some people had questions why they have to create an environment to use a torchax tensor and didn't understand the error message when the environment was missing.