Open
Description
🐛 Bug
If I jit
compile some model code that uses the RNG (e.g. dropout layers), then all future invocation of that jitted function will use the same RNG value. The RNG output is burned into the compiled StableHLO.
To Reproduce
See this notebook:
https://github.com/tengyifei/playground/blob/master/torchax/rng-test.ipynb
The jitted function gets the same RNG on every call.
Expected behavior
I'd expect each iteration in the loop to output a different random number.
Environment
- Reproducible on XLA backend: CPU/TPU
- torchax version:
8e6ca6000e83ccbc4365a9d9358e510504b71dea
Additional context
If we don't fix this, then the jitted behavior of any model with dropout layers or random masking or any random operation is wrong. This may e.g. causing training to not converge.
cc @qihqi