jaxqkv is a transformer framework in jax
, with small amounts of flax
(otimized layers) & optax
(advance grad optimizers). The general goal is for it to be used both as a reference implementation, learning resource & general playground.
A sample loader & tokenizer of the TinyStories dataset is provided, and training on different datasets should be pretty low overhead.
git clone https://github.com/friedhar/jaxqkv.git
cd jaxqkv
chmod +x setup_env.sh
./setup_env.sh
uv run data_prep/samplegamma.py
uv run train.py samplegamma
- Flash Attention Support
- RoPe Token Embedding
- Non-Naive KV Cache
- Support for GRPO
- Full Training Run With Fineweb