Skip to content

[Feat][GLA] Add GLA (Gated Linear Attention) Forward Operator (L2)#214

Open
superAngGao wants to merge 8 commits intotile-ai:mainfrom
superAngGao:feat/gla-fwd
Open

[Feat][GLA] Add GLA (Gated Linear Attention) Forward Operator (L2)#214
superAngGao wants to merge 8 commits intotile-ai:mainfrom
superAngGao:feat/gla-fwd

Conversation

@superAngGao
Copy link
Collaborator

@superAngGao superAngGao commented Feb 26, 2026

Summary

Implements Gated Linear Attention (GLA) forward pass as a new L2 operator (Kernel + Op).

Closes #213
Closes #212

Algorithm

Chunked GLA forward in 4 stages:

  1. Gate cumsum (PyTorch): within-chunk prefix sum of log-space gates
  2. Hidden state recurrence (PyTorch): inter-chunk h [B,NT,H,K,V] with gated decay
  3. Intra-chunk attention (TileLang): causal A [B,T,H,BT] with gated QK
  4. Output (TileLang): o = scale*(q*exp(g_cs))@h + A@v

Files Changed

File Description
tileops/kernels/gla/gla_fwd.py GLAFwdKernel — TileLang stages 3 & 4, sm90a
tileops/kernels/gla/__init__.py Kernel package export
tileops/ops/gla.py GLAFwdOp — Op wrapper
tileops/ops/__init__.py Register GLAFwdOp
tests/ops/test_gla.py 8 test cases (fp16 + bf16, ±initial_state, non-divisible seq_len assertion)
.claude/create-new-kernel/skill.md Skill guide for creating new kernels
.claude/create-new-op/skill.md Skill guide for creating new ops
.claude/create-new-op-test/skill.md Skill guide for creating op tests
.claude/creating-pull-request/SKILL.md Updated PR creation skill
.claude/migrating-new-op/SKILL.md Updated op migration skill

Test Results

8/8 passed — fp16, bf16, dim_k=64/128, dim_v=64/128, chunk_size=32/64, with/without initial_state, non-divisible seq_len assertion

Reference

https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py

Checklist

  • Kernel implemented (TileLang, sm90a)
  • Op wrapper with stable API
  • 8 correctness tests passing (incl. assertion test)
  • __init__.py exports synchronized
  • seq_len % chunk_size == 0 enforced with clear error message
  • Benchmark (out of scope for L2 — to be added separately)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @superAngGao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the Gated Linear Attention (GLA) forward operator, a novel and efficient attention mechanism. By integrating a 4-stage chunked algorithm that strategically utilizes both PyTorch for sequential computations and TileLang for high-performance matrix operations, this change significantly enhances the library's capability for advanced sequence modeling. The new operator is thoroughly tested for correctness, providing a robust foundation for future developments in linear attention models.

Highlights

  • New Operator Implementation: Implemented the Gated Linear Attention (GLA) forward pass as a new L2 operator, leveraging a 4-stage chunked algorithm for efficiency.
  • Hybrid Computation Strategy: The GLA forward pass combines PyTorch for sequential operations (gate cumsum and hidden state recurrence) and TileLang for performance-critical parallel computations (intra-chunk attention and output calculation).
  • Kernel and Op Wrapper: Introduced GLAFwdKernel for the low-level TileLang implementation and GLAFwdOp as a user-friendly Python wrapper, ensuring a stable API.
  • Comprehensive Testing: Added extensive unit tests (test_gla.py) with a PyTorch reference implementation to validate correctness across various data types (fp16, bf16), dimensions, chunk sizes, and initial state configurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/ops/test_gla.py
    • Added a new test file containing a PyTorch reference implementation for GLA forward and parameterized unit tests for the GLAFwdOp, covering various configurations and data types.
  • tileops/kernels/gla/init.py
    • Added an __init__.py file to the gla kernel directory, exposing the GLAFwdKernel.
  • tileops/kernels/gla/gla_fwd.py
    • Added the GLAFwdKernel class, which defines the TileLang kernels for intra-chunk attention and output computation, and orchestrates the overall 4-stage GLA forward pass, including PyTorch-based gate cumsum and hidden state recurrence.
  • tileops/ops/init.py
    • Updated the __init__.py file to import and include GLAFwdOp in the module's __all__ export list.
  • tileops/ops/gla.py
    • Added the GLAFwdOp class, providing a high-level Python wrapper for the GLAFwdKernel, handling parameter initialization and dispatching the forward pass.
Activity
  • The core GLAFwdKernel has been implemented using TileLang for stages 3 and 4 of the GLA forward pass.
  • A Python GLAFwdOp wrapper has been created, offering a stable API for the new operator.
  • Seven correctness tests have been added and are passing for various configurations, including fp16 and bf16 data types, different dimension sizes, chunk sizes, and scenarios with/without an initial state.
  • The __init__.py exports for the new kernel and operator have been synchronized.
  • Benchmarking for this operator is noted as out of scope for this L2 implementation and is planned to be added separately.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Gated Linear Attention (GLA) forward operator, including the TileLang kernel, operator wrapper, and unit tests. A critical security vulnerability has been identified where the GPU kernels perform out-of-bounds writes when the sequence length is not a multiple of the chunk size. This stems from a lack of boundary checks in the TileLang kernels and insufficient input validation in the Op wrapper, which also contributes to general implementation correctness issues for non-divisible sequence lengths. To address this, an assertion should be added to the Op wrapper to enforce the documented constraint. Additionally, the review highlights areas for improving numerical stability, memory efficiency, and overall robustness, as well as resolving inconsistencies between the test reference and the actual implementation.

@superAngGao superAngGao changed the title feat: Add GLA (Gated Linear Attention) Forward Operator (L2) [Feat][GLA] Add GLA (Gated Linear Attention) Forward Operator (L2) Feb 26, 2026
@superAngGao superAngGao force-pushed the feat/gla-fwd branch 2 times, most recently from 2055eb1 to 0f24a82 Compare February 26, 2026 11:26
@superAngGao superAngGao reopened this Feb 27, 2026
@superAngGao superAngGao marked this pull request as draft February 27, 2026 03:38
@superAngGao superAngGao marked this pull request as ready for review February 27, 2026 05:16
@superAngGao superAngGao force-pushed the feat/gla-fwd branch 2 times, most recently from fc3c7ab to 5f1e0c7 Compare February 27, 2026 06:53
superAngGao and others added 8 commits February 27, 2026 16:54
Implements chunked GLA forward pass with:
- Stage 1+2 (PyTorch): within-chunk gate cumsum + inter-chunk hidden state recurrence
- Stage 3 (TileLang): intra-chunk causal attention matrix A [B, T, H, BT]
- Stage 4 (TileLang): output combining inter-chunk and intra-chunk contributions

Files added:
- tileops/kernels/gla/gla_fwd.py  -- GLAFwdKernel (sm90a)
- tileops/kernels/gla/__init__.py
- tileops/ops/gla.py              -- GLAFwdOp
- tests/ops/test_gla.py           -- 7 test cases (fp16 + bf16, with/without initial_state)

Closes tile-ai#213
Reference: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add seq_len % chunk_size == 0 assertion in GLAFwdOp to prevent OOB
  writes in TileLang kernels on non-divisible sequence lengths
- Cast k/v to float32 per-chunk in GLAFwdKernel.forward to reduce peak
  memory usage
- Fix k_adj formula in ref_gla_fwd to use log-space subtraction
  (matching GLAFwdKernel) instead of division with clamp
- Add test_gla_fwd_non_divisible_seq_len to verify the assertion fires
- Add skill.md files for create-new-kernel, create-new-op,
  create-new-op-test, creating-pull-request, migrating-new-op

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…on skill, add YAML frontmatter and auto-invoke to all skills

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…entions

- Single @T.prim_func with 4 @T.macro stages in one T.Serial(num_chunks) loop
- Stages run in order 1→3→4→2 so stage4 reads pre-decay h_s before stage2 updates it
- Hoist all shared buffers into _main and pass as parameters to eliminate duplicate allocations (stays within 232448 byte optin limit)
- Move shape lists inside _gla_fwd_func so outer closure only captures serializable scalars (fixes autotuner assertion)
- Add self.kernel assignment in __init__ to support autotune
- Fix custom_op namespace to top:: and add autotune_configs
- forward() only allocates buffers and calls wrapper; no PyTorch compute

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant