I just import kfac_jax. Then it report one error,
AttributeError: jax.core.Primitive was removed in JAX v0.6.0. Use jax.extend.core.Primitive instead, and see https://docs.jax.dev/en/latest/jax.extend.html for details.. Did you mean: 'CallPrimitive'?