Skip to content

Commit b500fff

Browse files
jburnimtensorflower-gardener
authored andcommitted
Configure tests to explicitly use jax_threefry_partitionable=False.
See jax-ml/jax#18480 PiperOrigin-RevId: 746378505
1 parent 6781129 commit b500fff

File tree

7 files changed

+11
-0
lines changed

7 files changed

+11
-0
lines changed

discussion/pathfinder/pathfinder_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from discussion.pathfinder import pathfinder
2020
import tensorflow_probability.substrates.jax as tfp
2121

22+
jax.config.update("jax_threefry_partitionable", False)
23+
2224
tfd = tfp.distributions
2325

2426

spinoffs/autobnn/autobnn/bnn_tree_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from autobnn import kernels
2323
from absl.testing import absltest
2424

25+
jax.config.update('jax_threefry_partitionable', False)
26+
2527

2628
class TreeTest(parameterized.TestCase):
2729

spinoffs/autobnn/autobnn/kernels_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
from absl.testing import absltest
2626

27+
jax.config.update('jax_threefry_partitionable', False)
28+
2729

2830
KERNELS = [
2931
kernels.IdentityBNN,

tensorflow_probability/python/experimental/fastgp/fast_gp_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -821,4 +821,5 @@ class FastGpTestFloat64(_FastGpTest):
821821

822822
if __name__ == "__main__":
823823
config.update("jax_enable_x64", True)
824+
config.update("jax_threefry_partitionable", False)
824825
absltest.main()

tensorflow_probability/python/experimental/fastgp/fast_log_det_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from tensorflow_probability.substrates import jax as tfp
2626
from tensorflow_probability.substrates.jax.internal import test_util
2727

28+
jax.config.update('jax_threefry_partitionable', False)
29+
2830
# pylint: disable=invalid-name
2931

3032

tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -458,4 +458,5 @@ class FastMultiTaskGpTestFloat64(_FastMultiTaskGpTest):
458458

459459
if __name__ == "__main__":
460460
config.update("jax_enable_x64", True)
461+
config.update("jax_threefry_partitionable", False)
461462
absltest.main()

tensorflow_probability/python/experimental/fastgp/preconditioners_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -696,4 +696,5 @@ class PreconditionersTestFloat64(_PreconditionersTest):
696696

697697
if __name__ == "__main__":
698698
config.update("jax_enable_x64", True)
699+
config.update("jax_threefry_partitionable", False)
699700
absltest.main()

0 commit comments

Comments
 (0)