Open
Description
🐛 Bug
I'm working with nightly versions of torch/xla on TPU. When moving from torch==2.6.0.dev20241106+cpu to torch==2.6.0.dev20241107, I see significantly increased use of the TPU memory for SPMD training (x 2.5), and in some settings, it also crashes due to OOM. The newest nightly still hasn't solved this problem. I suspect it might be some change in torch that affects SPMD training in torch_xla.
Environment
- Reproducible on XLA backend - TPU
- torch_xla version:2.6.0.dev20241107
- torch/torchvision versions: 2.6.0.dev20241107
Activity