Skip to content

Unsolvable jax dependency #321

@adiaconu11

Description

@adiaconu11

Following the most recent commit, the jax.tree module is now being used instead of jax.tree_*. This is change requires jax >= 0.4.25. However, the are still many parts of the repository that are still old/deprecated. For instance, if you install 0.4.25 you might get something like:

AttributeError: module 'jax.random' has no attribute 'KeyArray'

This is because this module has been removed in jax 0.4.24, meaning that in order to not run into this problem you need jax <=0.4.23. Obviously this goes against the requirement above.

Lastly, there is still the issue with DeviceArray and ShardedDeviceArray. They have all been changed to somply jax.Array back in jax=0.4.0! At the current state of the repo you basically need to add lines like:

jax.interpreters.xla.DeviceArray = jax.Array in order to be able to even import acme...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions