Skip to content

Add cu_seqlens support to OlmoHybridGatedDeltaNet for packed sequences#44836

Draft
tyler-romero wants to merge 4 commits intohuggingface:mainfrom
tyler-romero:olmo-hybrid-cu-seqlens
Draft

Add cu_seqlens support to OlmoHybridGatedDeltaNet for packed sequences#44836
tyler-romero wants to merge 4 commits intohuggingface:mainfrom
tyler-romero:olmo-hybrid-cu-seqlens

Conversation

@tyler-romero
Copy link
Contributor

@tyler-romero tyler-romero commented Mar 18, 2026

What does this PR do?

Pass cu_seqlens derived from packed attention masks to FLA's ShortConvolution and chunk_gated_delta_rule kernels, preventing recurrent state from leaking across sequence boundaries during packed-sequence training.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Pass cu_seqlens derived from packed attention masks to FLA's
ShortConvolution and chunk_gated_delta_rule kernels, preventing
recurrent state from leaking across sequence boundaries during
packed-sequence training.
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: olmo_hybrid

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented on the modeling file (ofc would need to be modified in the modular file sry)

I just wanted to intervene because our approach is to do this before the model forward, not within. It always causes issues otherwise - be it compile, export, or just readability. I went into more detail with my first comment

)


def _cu_seqlens_from_packed_mask(attention_mask: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to intervene a bit here. We should not manually unpad and pad - this is something that I explicitly disallowed in the model addition PR

We can still support padding-free training/forward tho by preparing the input properly before passing it the model. We are kinda stuck with FA supporting it with position ids only as well which was a mistake imo.

Tl;dr: Use our data collator

class DataCollatorWithFlattening(DefaultDataCollator):

And pass the prepared input with its metadata (cu seq lens, max seq lens etc)

Comment on lines -722 to -723
# Requires LEFT padding to work correctly
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will stay then

Comment on lines +770 to +776
q, cache=conv_state_q, use_precomputed=use_precomputed, output_final_state=use_cache, cu_seqlens=cu_seqlens
)
k, new_conv_state_k = self.k_conv1d(
k, cache=conv_state_k, use_precomputed=use_precomputed, output_final_state=use_cache
k, cache=conv_state_k, use_precomputed=use_precomputed, output_final_state=use_cache, cu_seqlens=cu_seqlens
)
v, new_conv_state_v = self.v_conv1d(
v, cache=conv_state_v, use_precomputed=use_precomputed, output_final_state=use_cache
v, cache=conv_state_v, use_precomputed=use_precomputed, output_final_state=use_cache, cu_seqlens=cu_seqlens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be hidden behind kwargs if possible (from the signature)

I hope FLA has **kwargs in its signature and it doesnt destroy anything

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants