-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Hello,
While trying to use kfac_jax in jax 0.4.13 (which is supported if I am not mistaken), I had to fix some errors.
I installed commit a4531e9 which is fairly recent.
File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/utils/types.py", line 27, in <module>
DType = jax.typing.DTypeLike
AttributeError: module 'jax.typing' has no attribute 'DTypeLike'
File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/curvature_blocks/curvature_block.py", line 123, in parameters_shapes
return tuple(jax.tree.map(
File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'
File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tracer.py", line 790, in forward
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tag_graph_matcher.py", line 68, in eval_jaxpr_eqn
user_context = jax_extend.source_info_util.user_context
AttributeError: module 'jax.extend' has no attribute 'source_info_util'
Metadata
Metadata
Assignees
Labels
No labels