Closed
Description
🐛 Bug
When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.
To Reproduce
Run any train using flash attention with segment_ids.
Expected behavior
Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]:
- torch_xla version: 2.4 / 2.6
Metadata
Assignees
Labels
No labels