Skip to content

Commit 0fb6bcf

Browse files
IvyZXFlax Authors
authored andcommitted
Migrate away from private jax._src.tree_util._registry
PiperOrigin-RevId: 918133843
1 parent 8485ec3 commit 0fb6bcf

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

flax/nnx/graphlib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3627,7 +3627,7 @@ class Static(tp.Generic[A]):
36273627
class GenericPytree: ...
36283628

36293629

3630-
from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY
3630+
36313631

36323632

36333633
def is_pytree_node(
@@ -3637,7 +3637,7 @@ def is_pytree_node(
36373637
return False
36383638
elif isinstance(x, Variable):
36393639
return False
3640-
elif type(x) in JAX_PYTREE_REGISTRY:
3640+
elif jax.tree_util.is_tree_node(type(x)):
36413641
return True
36423642
elif isinstance(x, tuple):
36433643
return True

0 commit comments

Comments
 (0)