With recent updates to JAX, MrVI trains significantly slower than before. We suspect it is due to the new AOT compilation strategy (https://jax.readthedocs.io/en/latest/aot.html).
Any basic training with MrVI with a fresh install. Reproduced by @PierreBoyeau and myself.