Skip to content

Commit 82d3504

Browse files
authored
openxla pin update to 20250131 (#8621)
1 parent 8572e75 commit 82d3504

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

WORKSPACE

+3-7
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,9 @@ new_local_repository(
4444

4545
############################# OpenXLA Setup ###############################
4646

47-
# To update OpenXLA to a new revision,
48-
# a) update URL and strip_prefix to the new git commit hash
49-
# b) get the sha256 hash of the commit by running:
50-
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
51-
# and update the sha256 with the result.
52-
53-
xla_hash = '8d06f3680ad046ea44f8e7159f52c728bb66c069'
47+
# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
48+
# the openxla git commit hash.
49+
xla_hash = '6e91ff19dad528ab7d2025a9bb46150618a3bc7d'
5450

5551
http_archive(
5652
name = "xla",

setup.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,14 @@
6565
base_dir = os.path.dirname(os.path.abspath(__file__))
6666

6767
USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
68-
_date = '20250113'
69-
_libtpu_version = f'0.0.8'
70-
_jax_version = f'0.4.39'
71-
_jaxlib_version = f'0.4.39'
68+
69+
_date = '20250131'
70+
71+
# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
72+
_libtpu_version = '0.0.9'
73+
_jax_version = '0.5.1'
74+
_jaxlib_version = '0.5.1'
75+
7276
_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
7377
_libtpu_storage_directory = 'libtpu-lts-releases'
7478

test/test_tpu_paged_attention_kernel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_paged_attention_without_query_padding(
185185
self.assertEqual(actual_output.shape, expected_output.shape)
186186

187187
if dtype == jnp.float32:
188-
atol = 1e-2
188+
atol = 1e-1
189189
rtol = 1e-2
190190
elif dtype == jnp.bfloat16:
191191
atol = 6e-1

0 commit comments

Comments
 (0)