File tree 3 files changed +12
-12
lines changed
3 files changed +12
-12
lines changed Original file line number Diff line number Diff line change @@ -44,13 +44,9 @@ new_local_repository(
44
44
45
45
############################# OpenXLA Setup ###############################
46
46
47
- # To update OpenXLA to a new revision,
48
- # a) update URL and strip_prefix to the new git commit hash
49
- # b) get the sha256 hash of the commit by running:
50
- # curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
51
- # and update the sha256 with the result.
52
-
53
- xla_hash = '8d06f3680ad046ea44f8e7159f52c728bb66c069'
47
+ # To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
48
+ # the openxla git commit hash.
49
+ xla_hash = '6e91ff19dad528ab7d2025a9bb46150618a3bc7d'
54
50
55
51
http_archive (
56
52
name = "xla" ,
Original file line number Diff line number Diff line change 65
65
base_dir = os .path .dirname (os .path .abspath (__file__ ))
66
66
67
67
USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
68
- _date = '20250113'
69
- _libtpu_version = f'0.0.8'
70
- _jax_version = f'0.4.39'
71
- _jaxlib_version = f'0.4.39'
68
+
69
+ _date = '20250131'
70
+
71
+ # Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
72
+ _libtpu_version = '0.0.9'
73
+ _jax_version = '0.5.1'
74
+ _jaxlib_version = '0.5.1'
75
+
72
76
_libtpu_wheel_name = f'libtpu-{ _libtpu_version } '
73
77
_libtpu_storage_directory = 'libtpu-lts-releases'
74
78
Original file line number Diff line number Diff line change @@ -185,7 +185,7 @@ def test_paged_attention_without_query_padding(
185
185
self .assertEqual (actual_output .shape , expected_output .shape )
186
186
187
187
if dtype == jnp .float32 :
188
- atol = 1e-2
188
+ atol = 1e-1
189
189
rtol = 1e-2
190
190
elif dtype == jnp .bfloat16 :
191
191
atol = 6e-1
You can’t perform that action at this time.
0 commit comments