Skip to content

Commit 2ad6285

Browse files
jburnimMctxDev
authored andcommitted
Configure tests to explicitly use jax_threefry_partitionable=False.
See jax-ml/jax#18480 PiperOrigin-RevId: 746182723
1 parent 2a6919d commit 2ad6285

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

mctx/_src/tests/policies_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from mctx._src import policies
2323
import numpy as np
2424

25+
jax.config.update("jax_threefry_partitionable", False)
26+
2527

2628
def _make_bandit_recurrent_fn(rewards, dummy_embedding=()):
2729
"""Returns a recurrent_fn with discount=0."""

mctx/_src/tests/tree_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import mctx
2727
import numpy as np
2828

29+
jax.config.update("jax_threefry_partitionable", False)
30+
2931

3032
def _prepare_root(batch_size, num_actions):
3133
"""Returns a root consistent with the stored expected trees."""

0 commit comments

Comments
 (0)