The jax2torch tests fail occasionally. TODO: - [ ] set a seed - [ ] remove print statement from tests