Open
Description
🐛 Bug
I am attempting to implement custom Pallas kernels locally on a CPU for use with a TPU. I'm attempting to follow the official example here, with the minor modification being that I run the script on a CPU using interpret mode. After investigating, it appears that the main branch's latest code for a custom kernel should fix any issues with this error.
To Reproduce
Please use the colab here:
Steps to reproduce the behavior:
- Run the colab
- Observe errors in the last two cells
Expected behavior
It should execute the code without any errors
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
- torch_xla version: ~=2.3.0
Additional context
N/A