Add cu_seqlens support to OlmoHybridGatedDeltaNet for packed sequences#44836
Add cu_seqlens support to OlmoHybridGatedDeltaNet for packed sequences#44836tyler-romero wants to merge 4 commits intohuggingface:mainfrom
Conversation
Pass cu_seqlens derived from packed attention masks to FLA's ShortConvolution and chunk_gated_delta_rule kernels, preventing recurrent state from leaking across sequence boundaries during packed-sequence training.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: olmo_hybrid |
vasqu
left a comment
There was a problem hiding this comment.
Commented on the modeling file (ofc would need to be modified in the modular file sry)
I just wanted to intervene because our approach is to do this before the model forward, not within. It always causes issues otherwise - be it compile, export, or just readability. I went into more detail with my first comment
| ) | ||
|
|
||
|
|
||
| def _cu_seqlens_from_packed_mask(attention_mask: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
I have to intervene a bit here. We should not manually unpad and pad - this is something that I explicitly disallowed in the model addition PR
We can still support padding-free training/forward tho by preparing the input properly before passing it the model. We are kinda stuck with FA supporting it with position ids only as well which was a mistake imo.
Tl;dr: Use our data collator
And pass the prepared input with its metadata (cu seq lens, max seq lens etc)
| # Requires LEFT padding to work correctly | ||
| hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) |
| q, cache=conv_state_q, use_precomputed=use_precomputed, output_final_state=use_cache, cu_seqlens=cu_seqlens | ||
| ) | ||
| k, new_conv_state_k = self.k_conv1d( | ||
| k, cache=conv_state_k, use_precomputed=use_precomputed, output_final_state=use_cache | ||
| k, cache=conv_state_k, use_precomputed=use_precomputed, output_final_state=use_cache, cu_seqlens=cu_seqlens | ||
| ) | ||
| v, new_conv_state_v = self.v_conv1d( | ||
| v, cache=conv_state_v, use_precomputed=use_precomputed, output_final_state=use_cache | ||
| v, cache=conv_state_v, use_precomputed=use_precomputed, output_final_state=use_cache, cu_seqlens=cu_seqlens |
There was a problem hiding this comment.
Should be hidden behind kwargs if possible (from the signature)
I hope FLA has **kwargs in its signature and it doesnt destroy anything
What does this PR do?
Pass cu_seqlens derived from packed attention masks to FLA's ShortConvolution and chunk_gated_delta_rule kernels, preventing recurrent state from leaking across sequence boundaries during packed-sequence training.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.