NSA-Torch is a pure PyTorch implementation of Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention.
-
Clone the Repository:
git clone https://github.com/not-heavychevy/NSA-Torch.git cd NSA-Torch -
Install Dependencies:
It is recommended to use a virtual environment. Then install the required packages:
pip install -r requirements.txt
You can import and use the NSA model in your PyTorch projects. Below is a sample code snippet:
import torch
from nsa import NSA
# Initialize the NSA model with desired parameters
model = NSA(block_size=16, window_size=8, topk_blocks=4)
# Create random input tensors
B, T, D = 2, 128, 64 # Batch size, sequence length, feature dimension
q = torch.randn(B, T, D)
k = torch.randn(B, T, D)
v = torch.randn(B, T, D)
# Define gate values for each branch (can be learned or fixed)
gate_cmp = torch.ones(B, T)
gate_slc = torch.ones(B, T)
gate_swa = torch.ones(B, T)
# Run the NSA model
output = model(q, k, v, gate_cmp, gate_slc, gate_swa)
print("Output shape:", output.shape)To run the unit tests and verify the correctness of the NSA implementation:
- Ensure you are in the repository root.
- Run the following command:
pytest tests/test_nsa.py > test_results.txt 2>&1
This command will execute all tests in the tests/test_nsa.py file and save the output in test_results.txt.
This project is licensed under the MIT License.
Contributions, issues, and feature requests are welcome! Feel free to check the issues page.
This repository is inspired by recent research in sparse attention mechanisms for efficient long-context modeling in language models.
Feel free to reach out if you have any questions or suggestions. Happy coding!