Skip to content

use of core.unsafe_get_axis_names_DO_NOT_USE which no longer exists #285

@svandenhaute

Description

@svandenhaute

Using:

jax                    0.4.31.dev20241012+eafb38ca8
jax-rocm60-pjrt        0.4.31
jax-rocm60-plugin      0.4.31
jaxlib                 0.4.31
jaxtyping              0.2.34
kfac-jax               0.0.6

and trying to use ferminet but got the following error:

  File "/*/lib/python3.11/site-packages/ferminet/pretrain.py", line 245, in loss_fn
    return constants.pmean(result)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 58, in pmean_if_pmap
    return lax.pmean(obj, axis_name) if in_pmap(axis_name) else obj
                                        ^^^^^^^^^^^^^^^^^^
  File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 38, in in_pmap
    return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/*/lib/python3.11/site-packages/jax/_src/deprecations.py", line 55, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'unsafe_get_axis_names_DO_NOT_USE'

It appears that this is a deprecated method? Is kfac-jax only compatible with JAX up to a certain version? If so, this should probably be specified in the package config.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions