-
Notifications
You must be signed in to change notification settings - Fork 138
Fix RoPE positions to reset at document boundaries when using doc_lens #591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This script demonstrates that when using doc_lens for intra-document masking, RoPE positions are NOT reset per document. When packing two identical sequences [seq | seq] with doc_lens=[10, 10]: - Doc1 gets RoPE positions [0, 1, ..., 9] (correct) - Doc2 gets RoPE positions [10, 11, ..., 19] (incorrect, should be [0, 1, ..., 9]) This causes the second document's logits to differ from what they would be if processed separately, which affects use cases like DPO training where chosen and rejected sequences are packed together. To run: uv run python src/scripts/doc_lens_rope_issue.py Requires: CUDA and flash attention
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The TransformerBlockConfig was migrated from block.attention to block.sequence_mixer. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When packing multiple documents with cu_doc_lens, RoPE was using global positions instead of per-document positions. This caused the second document to get positions [20, 21, ...] instead of [0, 1, ...], breaking the expected behavior where identical documents should produce identical outputs. The fix adds cu_doc_lens parameter to RotaryEmbedding.forward() which computes per-document positions using searchsorted to find document boundaries. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…cript - Add cu_doc_lens parameter to ComplexRotaryEmbedding.forward() for per-document position computation - Remove src/scripts/doc_lens_rope_issue.py which had lint/type errors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: df72f4eeb7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
Nice! Related: #503 |
- Add test_rope_cu_doc_lens_resets_positions: verifies positions reset at document boundaries for B=1 - Add test_rope_cu_doc_lens_batch_gt_1: verifies B>1 with same doc structure per batch element - Add test_rope_cu_doc_lens_uneven_docs: verifies B>1 with different doc structures per batch element (e.g., batch0 has docs of len 3+5, batch1 has docs of len 2+6) - Tests cover both RotaryEmbedding and ComplexRotaryEmbedding classes - Tests cover both head_first=True and head_first=False modes - Fix B>1 support in RotaryEmbedding and ComplexRotaryEmbedding by computing global positions across the flattened [B*T] space before searchsorted Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
6aae7d5 to
3cb1939
Compare
- Extract compute_local_positions() helper function to compute per-document positions from cu_doc_lens, used by both RotaryEmbedding and ComplexRotaryEmbedding - Keep original bracket notation for tensor unsqueezing (e.g., [:, None, :, :]) for clarity Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
3cb1939 to
8650c5a
Compare
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 50ab95eca1
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| batch_offsets = torch.arange(batch_size, device=device) * k_len | ||
| positions = torch.arange(k_len, device=device) | ||
| global_positions = batch_offsets[:, None] + positions[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Account for start_pos when building local positions
When cu_doc_lens is provided, local positions are derived from positions = arange(k_len) and never offset by start_pos. In KV-cache decoding (start_pos > 0), this makes the newest token always appear at local position 0 (or near the doc start) instead of its true absolute position within the document, so RoPE repeats earlier positions and yields incorrect attention for continued generation. This only affects runs that combine cu_doc_lens with decoding/start_pos (e.g., cached autoregressive generation) and does not occur for full-sequence training where start_pos is None.
Useful? React with 👍 / 👎.
When packing multiple documents with
cu_doc_lens, RoPE was using global positions instead of per-document positions. This caused the second document to get positions[20, 21, ...]instead of[0, 1, ...], breaking the expected behavior where identical documents should produce identical outputs.To fix, we add a
cu_doc_lensparameter toRotaryEmbedding.forward()andComplexRotaryEmbedding.forward()which computes per-document positions usingsearchsortedto find document boundaries. We have a minimal repro which we ran on Beaker.Before this fix, identical documents packed with
doc_lensproduced different outputs because RoPE applied global positions (seeminimal_doc_lens_repro.py):After our fix, identical documents packed with
doc_lensnow produce identical outputs: