-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Using:
jax 0.4.31.dev20241012+eafb38ca8
jax-rocm60-pjrt 0.4.31
jax-rocm60-plugin 0.4.31
jaxlib 0.4.31
jaxtyping 0.2.34
kfac-jax 0.0.6
and trying to use ferminet but got the following error:
File "/*/lib/python3.11/site-packages/ferminet/pretrain.py", line 245, in loss_fn
return constants.pmean(result)
^^^^^^^^^^^^^^^^^^^^^^^
File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 58, in pmean_if_pmap
return lax.pmean(obj, axis_name) if in_pmap(axis_name) else obj
^^^^^^^^^^^^^^^^^^
File "/*/lib/python3.11/site-packages/kfac_jax/_src/utils/parallel.py", line 38, in in_pmap
return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/*/lib/python3.11/site-packages/jax/_src/deprecations.py", line 55, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'unsafe_get_axis_names_DO_NOT_USE'
It appears that this is a deprecated method? Is kfac-jax only compatible with JAX up to a certain version? If so, this should probably be specified in the package config.
Metadata
Metadata
Assignees
Labels
No labels