Skip to content

[torchax] RNG handling in a jitted graph is unsound #8636

Open
@tengyifei

Description

🐛 Bug

If I jit compile some model code that uses the RNG (e.g. dropout layers), then all future invocation of that jitted function will use the same RNG value. The RNG output is burned into the compiled StableHLO.

To Reproduce

See this notebook:

https://github.com/tengyifei/playground/blob/master/torchax/rng-test.ipynb

The jitted function gets the same RNG on every call.

Expected behavior

I'd expect each iteration in the loop to output a different random number.

Environment

  • Reproducible on XLA backend: CPU/TPU
  • torchax version: 8e6ca6000e83ccbc4365a9d9358e510504b71dea

Additional context

If we don't fix this, then the jitted behavior of any model with dropout layers or random masking or any random operation is wrong. This may e.g. causing training to not converge.

cc @qihqi

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions