-
I am trying to convert a jax checkpoint to pytorch. However, if there is a utility to convert jax checkpoint to pytorch please send my way so I dont need to attempt this. So first I am trying to read a checkpoint that is saved with jax/flax. I need to reconstruct the list so I could map to torch. But the pytreedef state is in string so I could not use the jax.tree_util module to unflatten it. How to convert str_pytree_state to PyTreeDef, or unflatten the: flattened_state = ckpt["flattened_state"]
pytree_structure = ckpt["str_pytree_state"] Here is a minimal example from jax.tree_util import tree_unflatten
transformed_flat=[2.0, 4.0, 6.0]
value_tree_str='PyTreeDef([*, (*, *)])'
tree_unflatten (value_tree_str, transformed_flat) 106 def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
107 """Reconstructs a pytree from the treedef and the leaves.
108
109 The inverse of :func:`tree_flatten`.
(...)
130 - :func:`jax.tree.structure`
131 """
--> 132 return treedef.unflatten(leaves)
AttributeError: 'str' object has no attribute 'unflatten' |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hello – I'm having a hard time understanding your question. Could you put together a minimal reproducible example showing the context of what you're trying to do? |
Beta Was this translation helpful? Give feedback.
-
There is no general way to convert a string repr of a pytree object into a pytree object. You'll need to serialize and store the original object (for example, using pickle) rather than the string representation of it. |
Beta Was this translation helpful? Give feedback.
There is no general way to convert a string repr of a pytree object into a pytree object. You'll need to serialize and store the original object (for example, using pickle) rather than the string representation of it.