On this line it looks like dmgas_dt is being computed via finite differences. But @alexalar: since mgas is computed differentiably, shouldn't it be possible to define dmgas_dt using jax.grad instead?
This is currently the only use of the _jax_get_dt_array function in the repo.