Skip to content

XLA 2.7 releases don't work properly with the upstream torch 2.7 #8626

Open
@hosseinsarshar

Description

🐛 Bug

I tested the xla 2.7 nightly builds with the upstream torch 2.7 with many date combinations and all resulted in faulty execution - for example:

for this install:

pip install -U --pre torch==2.7.0.dev20250124+cpu --index-url https://download.pytorch.org/whl/nightly/cpu

pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

pip install -U torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

I get this behaviour:

$python
Python 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch_xla
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/hosseins/miniconda3/envs/test-new-nightly/lib/python3.10/site-packages/torch_xla/__init__.py", line 20, in <module>
    import _XLAC
ImportError: /home/hosseins/miniconda3/envs/test-new-nightly/lib/python3.10/site-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN5torch4lazy13MetricFnValueEd

But once, I downgrade the upstream torch to 2.6.0 dev20241216 (it's the highest that worked) - xla works as expected:

pip install -U --pre torch==2.6.0.dev20241216+cpu --index-url https://download.pytorch.org/whl/nightly/cpu

pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250119-cp310-cp310-linux_x86_64.whl' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html


pip install -U torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

I get it fixed:

$ python
Python 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import xla
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ModuleNotFoundError: No module named 'xla'
>>> import torch_xla
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3), device(type='xla', index=4), device(type='xla', index=5), device(type='xla', index=6), device(type='xla', index=7)]

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions