Skip to content

MrVI slowdown due to JAX compilation update #3179

@justjhong

Description

@justjhong

With recent updates to JAX, MrVI trains significantly slower than before. We suspect it is due to the new AOT compilation strategy (https://jax.readthedocs.io/en/latest/aot.html).

Any basic training with MrVI with a fresh install. Reproduced by @PierreBoyeau and myself.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions