From 82d35047258bf9bcfb78c499dda82f9510cedb13 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Fri, 31 Jan 2025 12:35:08 -0800 Subject: [PATCH] openxla pin update to 20250131 (#8621) --- WORKSPACE | 10 +++------- setup.py | 12 ++++++++---- test/test_tpu_paged_attention_kernel.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 44c02a0785cb..dfd9f4b3221d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -44,13 +44,9 @@ new_local_repository( ############################# OpenXLA Setup ############################### -# To update OpenXLA to a new revision, -# a) update URL and strip_prefix to the new git commit hash -# b) get the sha256 hash of the commit by running: -# curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum -# and update the sha256 with the result. - -xla_hash = '8d06f3680ad046ea44f8e7159f52c728bb66c069' +# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to +# the openxla git commit hash. +xla_hash = '6e91ff19dad528ab7d2025a9bb46150618a3bc7d' http_archive( name = "xla", diff --git a/setup.py b/setup.py index 453a73172cf9..66e7f1511831 100644 --- a/setup.py +++ b/setup.py @@ -65,10 +65,14 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax -_date = '20250113' -_libtpu_version = f'0.0.8' -_jax_version = f'0.4.39' -_jaxlib_version = f'0.4.39' + +_date = '20250131' + +# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details. +_libtpu_version = '0.0.9' +_jax_version = '0.5.1' +_jaxlib_version = '0.5.1' + _libtpu_wheel_name = f'libtpu-{_libtpu_version}' _libtpu_storage_directory = 'libtpu-lts-releases' diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index 8749334b7112..21b799ecf405 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -185,7 +185,7 @@ def test_paged_attention_without_query_padding( self.assertEqual(actual_output.shape, expected_output.shape) if dtype == jnp.float32: - atol = 1e-2 + atol = 1e-1 rtol = 1e-2 elif dtype == jnp.bfloat16: atol = 6e-1