Skip to content

Commit

Permalink
openxla pin update to 20250131 (#8621)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore authored Jan 31, 2025
1 parent 8572e75 commit 82d3504
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
10 changes: 3 additions & 7 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ new_local_repository(

############################# OpenXLA Setup ###############################

# To update OpenXLA to a new revision,
# a) update URL and strip_prefix to the new git commit hash
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = '8d06f3680ad046ea44f8e7159f52c728bb66c069'
# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
# the openxla git commit hash.
xla_hash = '6e91ff19dad528ab7d2025a9bb46150618a3bc7d'

http_archive(
name = "xla",
Expand Down
12 changes: 8 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,14 @@
base_dir = os.path.dirname(os.path.abspath(__file__))

USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
_date = '20250113'
_libtpu_version = f'0.0.8'
_jax_version = f'0.4.39'
_jaxlib_version = f'0.4.39'

_date = '20250131'

# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
_libtpu_version = '0.0.9'
_jax_version = '0.5.1'
_jaxlib_version = '0.5.1'

_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
_libtpu_storage_directory = 'libtpu-lts-releases'

Expand Down
2 changes: 1 addition & 1 deletion test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_paged_attention_without_query_padding(
self.assertEqual(actual_output.shape, expected_output.shape)

if dtype == jnp.float32:
atol = 1e-2
atol = 1e-1
rtol = 1e-2
elif dtype == jnp.bfloat16:
atol = 6e-1
Expand Down

0 comments on commit 82d3504

Please sign in to comment.