Skip to content

[torchax][RFC] Anchor on the device API everywhere #8638

Open
@tengyifei

Description

🚀 Feature

Today we have two related concepts in torchax:

  • Environment
  • The "jax" device

In particular, the user need to know both to access JAX (e.g. TPU) features.

This RFC proposes that we should refactor the API such that the user only need to know the "jax" device in order to operate on the TPUs.

Motivation

The device is a well known concept in PyTorch: tutorials talk about tensor.cpu() and tensor.cuda(). It's commonly understood that if I create a tensor without the device argument, then that tensor lives on the default device (usually CPU). If I call .cuda(), then that tensor is moved to the GPU.

Since people primarily choose torchax to be able to use the TPU (or access the XLA GPU backend), it makes sense to present this functionality as a PyTorch device. By the principle of symmetry, it's natural to introduce jax counterparts for various cuda APIs, where applicable. This gives people a clear mental model of when they are or aren't using JAX.

We look at a few examples (all these assume we import torchax):

  • I can call torch.cuda.current_device() to get the index of the current CUDA device.

    • I can also call torch.jax.current_device() to get the index of the current JAX (XLA) device.
  • I can call torch.cuda.is_available() to check if CUDA support is available.

    • I can call torch.jax.is_available() to check if the JAX backend is available.
  • I can run torch.randn(1, 2, device='cuda') to generate a random number using the CUDA device.

    • ❌ If I ran torch.randn(1, 2, device='jax'), that fails with a confusing dispatcher error: 1
  • I can run torch.set_default_device('cuda') to make all subsequent tensor live on the CUDA device.

    • ❌ If I ran torch.set_default_device('jax') and then creates some tensor, that fails with another confusing error: 2

This RFC proposes that we should change torchax to close the behavior divergence such as the two above bullet points. In the limit, using eager torchax should feel identical to using some other backend of PyTorch.

Pitch

Always call enable_globally()

We're pretty close to closing the gaps above. If I run torchax.enable_globally() after importing torchax, then torch.randn(1, 2, device='jax') works, and the error after torch.set_default_device('jax') seems like a fixable bug. I propose we go one step further and just automatically call enable_globally() and we should also fix the default device behavior.

Always keep the torchax modes activated

Today the environment object is what activates the XLAFunctionMode and XLADispatchMode that intercept PyTorch operations. However, these modes are an implementation detail of how torchax supports the JAX device. It should be possible to always keep the XLAFunctionMode and XLADispatchMode activated in the mode stack, without changing the behavior of non-JAX tensors. This is akin to how PyTorch already keeps a few modes such as FuncTorchVmapMode and FuncTorchDynamicLayerFrontMode in the stack most of the time. For testing purposes, it could be useful to temporarily disable the XLAFunctionMode and XLADispatchMode, but that should be an internal API that users don't know about.

As a pressure test, we could probably try running some subset of PyTorch tests with XLA{Function,Dispatch}Mode in the mode stack, and make sure those don't fail. That's to ensure that even if the user import torchax, their CPU tensor behaviors don't change.

This suggests we need to decouple the XLAFunctionMode and XLADispatchMode from the environment. For example, perhaps those could be relocated to a torchax._internal.XLAModes context manager.

Configuration context managers

The environment object also holds certain configuration (e.g. optimize for performance or accuracy). As a user it's useful to change these settings sometimes. We can keep them in the environment and always provide a sensible default in the default environment. We could also support a stack of environments via context managers, where configurations at the top of the stack takes precedence. That's a useful way to locally change some config and have them revert to previous values when leaving the scope.

RNG seed

The environment object also holds a seed for the pseudo random number generator. That should probably change as part of solving #8636.

Alternatives

An alternative is to do nothing and stick to the status quo.

Another follow up is to see what torch changes do we need to remove frictions of using the JAX device. For example, today if I write tensor.jax() with the hope of moving the tensor to the JAX device, the Python type checker complains that jax() is not a known function on Tensor, unlike cuda.

Additional context

Anecdotally, some people had questions why they have to create an environment to use a torchax tensor and didn't understand the error message when the environment was missing.

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions