Skip to content

Bug: Jax 0.4.13 support #279

@arnon-1

Description

@arnon-1

Hello,

While trying to use kfac_jax in jax 0.4.13 (which is supported if I am not mistaken), I had to fix some errors.
I installed commit a4531e9 which is fairly recent.

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/utils/types.py", line 27, in <module>
    DType = jax.typing.DTypeLike
AttributeError: module 'jax.typing' has no attribute 'DTypeLike'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/curvature_blocks/curvature_block.py", line 123, in parameters_shapes
    return tuple(jax.tree.map(
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tracer.py", line 790, in forward
    write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tag_graph_matcher.py", line 68, in eval_jaxpr_eqn
    user_context = jax_extend.source_info_util.user_context
AttributeError: module 'jax.extend' has no attribute 'source_info_util'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions