Closed
Description
We currently have a demo for transpiling torch models to jax followed by inference on a gpu, but there was a recent request for being able to use the transpiled jax model on a tpu as well.
This task involves creating a colab notebook that demonstrates transpilation of torch models to jax and then inference on a tpu. This demo will then be added to our docs for the wider community.
Feel free to reach out in case you face any issues, thanks!