Thank you for your interest in contributing to Flash Linear Attention! All pull requests are super welcomed and greatly appreciated.
- Table of Contents
- Report Bugs
- Ask Questions
- Submit Pull Requests
- Setup Development Environment
- Project Structure
- Code Style
- Adding a New Operator
- Adding a New Model
- Testing
- Environment Variables
- License
If you run into any weird behavior while using fla, feel free to open a new issue! Please run a search before opening a new issue, to make sure that someone else hasn't already reported or solved the bug you've found.
Any issue you open should include:
- A minimal code snippet that reproduces the bug.
- A clear explanation of what the issue is.
Please ask questions in issues or on Discord. Check FAQs.md first for common questions.
Note
Please include tests with every pull request if applicable!
- Keep the scope focused: one PR should do one thing. If you have multiple unrelated changes, please split them into separate PRs.
- Use Draft PRs: feel free to open a draft early for design feedback or work-in-progress discussion.
Use a prefix tag in square brackets to categorize your change. Here are some common examples:
| Tag | Usage | Example |
|---|---|---|
[Fix] |
Bug fixes | [Fix] Guard checkpoint weight re-initialization |
[Misc] |
Miscellaneous | [Misc] Upgrade minimum PyTorch requirement |
[Docs] |
Documentation | [Docs] Update CP README |
[CI] |
CI/CD changes | [CI] Fix skip-test check failing on fork PRs |
[Test] |
Test additions or fixes | [Test] Add varlen backward gradient checks |
[Perf] |
Performance optimizations | [Perf] Fuse gate multiplication in delta rule |
[Refactor] |
Code refactoring | [Refactor] Unify chunk kernel entry points |
[Ops] |
General operator changes | [Ops] Refactor common chunk reduction utilities |
[Model] |
Model architecture changes | [Model] Add RoPE scaling to GLA config |
[Layer] |
Layer-level changes | [Layer] Normalize initial state initialization |
[Attn] |
Attention-related changes | [Attn] Add sliding window attention support |
[GDN] |
Gated Delta Net | [GDN] Add fused gate kernel |
[KDA] |
Kimi Delta Attention | [KDA] Fix illegal memory access in backward |
[CP] |
Context Parallel | [CP] Enable KCP for DPLR |
[Conv] |
Convolution | [Conv] Fix int32 overflow in varlen conv kernel |
[CE] |
Cross Entropy | [CE] Add logit softcapping support |
If your change doesn't fit any of the above, [Misc]/[chore] is the safe default.
Include a clear description with:
- Summary: What the PR does and why (bullet points preferred).
- Test plan: How the change is tested.
- Breaking changes (if any): List any API changes that are not backward compatible and describe the migration path.
See recent PRs for examples.
When you submit a PR, the following checks run automatically:
- Linting — Ruff + autopep8 via pre-commit
- License header check — Ensures copyright headers are present
- GPU tests — On NVIDIA H100/A100/4090 and Intel B580 (when available)
- Benchmarks — Performance regression checks
Add [skip test] to your commit message to skip GPU tests for documentation-only changes.
Before submitting, please go through the following checklist:
- Code follows the project's style conventions.
- Copyright header is present on all new files.
- Tests pass locally (
pytest tests/ops/test_<your_op>.py). - New operators include a naive reference implementation.
- Both forward and backward passes are tested.
- Gradient correctness is verified against a reference implementation.
- Pre-commit hooks pass (
pre-commit run --files <your_files>).
- Python >= 3.10
- PyTorch >= 2.7.0
- A GPU with Triton support (NVIDIA, AMD, or Intel)
-
Fork flash-linear-attention (fork) on GitHub and clone the repository.
git clone git@github.com:<your username>/flash-linear-attention.git cd flash-linear-attention git remote add upstream git@github.com:fla-org/flash-linear-attention.git
-
Install in development mode:
pip install -e '.[test]'[!TIP] If the install fails, double-check that your PyTorch version matches your local CUDA toolkit and that
nvccis available in yourPATH. -
Setup the
pre-commithooks:pip install pre-commit pre-commit install
To check the linting, run:
pre-commit run --all-filespytest tests/fla/
├── layers/ # PyTorch attention layer implementations
├── ops/ # Triton kernel operators (the core of the project)
│ ├── common/ # Shared kernels reused across operators
│ └── <op_name>/ # Each operator in its own directory
│ ├── __init__.py
│ ├── naive.py # Reference implementation in pure PyTorch
│ ├── chunk.py # Chunk-based implementation
│ ├── parallel.py # Parallel Triton kernel implementation
│ ├── fused_recurrent.py # Fused recurrent implementation
│ └── README.md # (optional) Mathematical derivations
├── models/ # Full language model definitions (config + modeling)
├── modules/ # Utility modules (norms, feature maps, rotary, etc.)
└── utils.py # Global utilities and decorators
tests/
├── conftest.py # Pytest config with NaN memory poisoning
├── ops/ # Operator tests
├── layers/ # Layer tests
├── models/ # Model tests
└── modules/ # Module tests
Every source file should begin with the following header:
# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# For a list of all contributors, visit:
# https://github.com/fla-org/flash-linear-attention/graphs/contributorsA CI workflow (check-header.yml) enforces this automatically.
We use Ruff for linting and autopep8 for formatting. Pre-commit hooks run both automatically.
Key rules:
- Max line length: 127 characters
- Target Python version: 3.10+
- Import sorting:
isort-compatible via Ruff (flaas first-party) - Type hints: Use modern syntax (
X | Noneinstead ofOptional[X],list[str]instead ofList[str]) - Use
TYPE_CHECKINGfor imports only needed at type-check time
| Entity | Convention | Example |
|---|---|---|
| Classes | PascalCase | GatedDeltaNet, LinearAttention |
| Functions | snake_case | chunk_delta_rule, fused_recurrent_gla |
| Constants | UPPER_SNAKE_CASE | FLA_CI_ENV, SUPPORTS_AUTOTUNE_CACHE |
| Private helpers | Leading underscore | _guarded_empty, _is_called_from_fla |
- Kernel functions use
@triton.jitwithdo_not_specialize=['T']for the sequence-length argument. - Use
tl.constexprfor compile-time constants (block sizes, flags likeUSE_INITIAL_STATE). - Use
tl.make_block_ptrfor coalesced memory access. - Gate autotune configs with
autotune_cache_kwargsfor cache support. - Kernel naming:
<op>_fwd_kernel_<suffix>/<op>_bwd_kernel_<suffix>.
- Wrap public-facing ops with the
@input_guarddecorator to ensure tensor contiguity. - Use
@autocast_custom_fwd/@autocast_custom_bwdfor mixed-precision support. - Provide a reference (naive) implementation in
naive.pyfor testing.
When adding a new operator under fla/ops/<op_name>/:
- Create the directory with an
__init__.pythat exports the public API. - Write a naive implementation (
naive.py) in pure PyTorch. This serves as the ground-truth reference for testing. - Implement the optimized kernel(s) in
chunk.py,parallel.py, and/orfused_recurrent.py. - Reuse shared kernels from
fla/ops/common/where possible (e.g.,chunk_fwd_o,chunk_gated_delta_rule_fwd_h). - Add tests in
tests/ops/test_<op_name>.py(see Testing below). - (Optional) Add a
README.mdwith mathematical derivations.
Each model lives under fla/models/<model_name>/ with:
configuration_<model_name>.py— Config class extendingPretrainedConfigmodeling_<model_name>.py— Model, PreTrainedModel, and ForCausalLM classes__init__.py— Auto-registration withtransformers
Register your model in fla/models/__init__.py for auto-discovery.
# Run all tests
pytest tests/
# Run a specific test file
pytest tests/ops/test_delta.py
# Run a specific test
pytest tests/ops/test_delta.py::test_chunk -vTests compare optimized (Triton) implementations against reference (naive/recurrent) implementations. Follow this pattern:
import pytest
import torch
from fla.ops.your_op import chunk_your_op, fused_recurrent_your_op
from fla.utils import assert_close, device, device_platform
@pytest.mark.parametrize(
('B', 'T', 'H', 'D', 'dtype'),
[
pytest.param(*test, id="B{}-T{}-H{}-D{}".format(*test))
for test in [
(1, 63, 1, 64, torch.float16),
(2, 1000, 4, 128, torch.float16),
]
],
)
def test_chunk(B: int, T: int, H: int, D: int, dtype: torch.dtype):
torch.manual_seed(42)
q = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
k = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
v = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
do = torch.rand_like(v)
# Triton implementation
tri = chunk_your_op(q.clone(), k.clone(), v.clone())
(tri * do).sum().backward()
tri_dq, tri_dk, tri_dv = q.grad, k.grad, v.grad
q.grad = k.grad = v.grad = None
# Reference implementation
ref = fused_recurrent_your_op(q.clone(), k.clone(), v.clone())
(ref * do).sum().backward()
ref_dq, ref_dk, ref_dv = q.grad, k.grad, v.grad
assert_close('o', ref, tri, 0.006)
assert_close('dq', ref_dq, tri_dq, 0.006)
assert_close('dk', ref_dk, tri_dk, 0.006)
assert_close('dv', ref_dv, tri_dv, 0.006)Key guidelines:
- Always use
torch.manual_seed(42)for reproducibility. - Use
assert_closefromfla.utilsfor numerical comparison with relative tolerance. - Test both forward and backward passes by computing gradients.
- Use
devicefromfla.utilsfor device-agnostic tests. - Parametrize with diverse shapes including non-power-of-2 sequence lengths (e.g., 63, 100, 2000).
- Skip unsupported platforms with
@pytest.mark.skipif(device_platform == 'intel', ...)when needed. - Include test IDs in parametrize for readable output.
The test suite (conftest.py) automatically replaces torch.empty with NaN-filled tensors for tests/ops/ and tests/modules/. This catches bugs where uninitialized memory is accidentally used. You don't need to do anything special — just be aware that your kernels must fully initialize all output tensors.
See ENVs.md for a full list.
By contributing, you agree that your contributions will be licensed under the MIT License.