potentially increase speed up of the models using Jax (same as scvi-tools for example) tutorial: https://github.com/ludwigwinkler/JaxLightning