[Feat][GLA] Add GLA (Gated Linear Attention) Forward Operator (L2)#214
[Feat][GLA] Add GLA (Gated Linear Attention) Forward Operator (L2)#214superAngGao wants to merge 8 commits intotile-ai:mainfrom
Conversation
Summary of ChangesHello @superAngGao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the Gated Linear Attention (GLA) forward operator, a novel and efficient attention mechanism. By integrating a 4-stage chunked algorithm that strategically utilizes both PyTorch for sequential computations and TileLang for high-performance matrix operations, this change significantly enhances the library's capability for advanced sequence modeling. The new operator is thoroughly tested for correctness, providing a robust foundation for future developments in linear attention models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new Gated Linear Attention (GLA) forward operator, including the TileLang kernel, operator wrapper, and unit tests. A critical security vulnerability has been identified where the GPU kernels perform out-of-bounds writes when the sequence length is not a multiple of the chunk size. This stems from a lack of boundary checks in the TileLang kernels and insufficient input validation in the Op wrapper, which also contributes to general implementation correctness issues for non-divisible sequence lengths. To address this, an assertion should be added to the Op wrapper to enforce the documented constraint. Additionally, the review highlights areas for improving numerical stability, memory efficiency, and overall robustness, as well as resolving inconsistencies between the test reference and the actual implementation.
2055eb1 to
0f24a82
Compare
fc3c7ab to
5f1e0c7
Compare
Implements chunked GLA forward pass with: - Stage 1+2 (PyTorch): within-chunk gate cumsum + inter-chunk hidden state recurrence - Stage 3 (TileLang): intra-chunk causal attention matrix A [B, T, H, BT] - Stage 4 (TileLang): output combining inter-chunk and intra-chunk contributions Files added: - tileops/kernels/gla/gla_fwd.py -- GLAFwdKernel (sm90a) - tileops/kernels/gla/__init__.py - tileops/ops/gla.py -- GLAFwdOp - tests/ops/test_gla.py -- 7 test cases (fp16 + bf16, with/without initial_state) Closes tile-ai#213 Reference: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add seq_len % chunk_size == 0 assertion in GLAFwdOp to prevent OOB writes in TileLang kernels on non-divisible sequence lengths - Cast k/v to float32 per-chunk in GLAFwdKernel.forward to reduce peak memory usage - Fix k_adj formula in ref_gla_fwd to use log-space subtraction (matching GLAFwdKernel) instead of division with clamp - Add test_gla_fwd_non_divisible_seq_len to verify the assertion fires - Add skill.md files for create-new-kernel, create-new-op, create-new-op-test, creating-pull-request, migrating-new-op Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…on skill, add YAML frontmatter and auto-invoke to all skills Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…entions - Single @T.prim_func with 4 @T.macro stages in one T.Serial(num_chunks) loop - Stages run in order 1→3→4→2 so stage4 reads pre-decay h_s before stage2 updates it - Hoist all shared buffers into _main and pass as parameters to eliminate duplicate allocations (stays within 232448 byte optin limit) - Move shape lists inside _gla_fwd_func so outer closure only captures serializable scalars (fixes autotuner assertion) - Add self.kernel assignment in __init__ to support autotune - Fix custom_op namespace to top:: and add autotune_configs - forward() only allocates buffers and calls wrapper; no PyTorch compute Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
5f1e0c7 to
b2783e1
Compare
Summary
Implements Gated Linear Attention (GLA) forward pass as a new L2 operator (Kernel + Op).
Closes #213
Closes #212
Algorithm
Chunked GLA forward in 4 stages:
h [B,NT,H,K,V]with gated decayA [B,T,H,BT]with gated QKo = scale*(q*exp(g_cs))@h + A@vFiles Changed
tileops/kernels/gla/gla_fwd.pyGLAFwdKernel— TileLang stages 3 & 4, sm90atileops/kernels/gla/__init__.pytileops/ops/gla.pyGLAFwdOp— Op wrappertileops/ops/__init__.pyGLAFwdOptests/ops/test_gla.py.claude/create-new-kernel/skill.md.claude/create-new-op/skill.md.claude/create-new-op-test/skill.md.claude/creating-pull-request/SKILL.md.claude/migrating-new-op/SKILL.mdTest Results
Reference
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py
Checklist
__init__.pyexports synchronizedseq_len % chunk_size == 0enforced with clear error message