Skip to content

Bug - Using Sharding in Flash Attention with segment ids. #8334

Closed
@dudulightricks

Description

🐛 Bug

When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.

To Reproduce

Run any train using flash attention with segment_ids.

Expected behavior

Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]:
  • torch_xla version: 2.4 / 2.6

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions