Code release for "DualKV: Shared-Prompt Flash-Attention Kernels for Efficient Policy Updates in RL Training".
DualKV deduplicates shared prompts in GRPO/DAPO training — instead of computing attention over N*(P+R) tokens, it computes over P + N*R, yielding up to 6x kernel speedup and 2x end-to-end throughput on long-context RL workloads. This release includes the custom flash-attention kernels, veRL integration (with Ulysses Sequence Parallelism support), and scripts to reproduce all paper experiments.
├── flash-attention/ # FlashAttention-2 (commit 41b2ef6) with DualKV kernels applied
├── verl/ # veRL v0.7.0 with DualKV integration applied
├── experiments/ # Benchmarks, training scripts, reward functions
├── LICENSE # CC-BY-NC-4.0
└── THIRD_PARTY_LICENSES
Key implementation files:
- Forward kernel:
flash-attention/csrc/flash_attn/src/flash_fwd_kernel_dualkv_training.h - Backward kernel:
flash-attention/csrc/flash_attn/src/flash_bwd_kernel_dualkv_training.h - Python interface:
flash-attention/flash_attn/flash_attn_interface.py(search fordualkv) - veRL actor integration:
verl/verl/workers/actor/dp_actor.py(search for_dualkv) - Attention monkey-patch + SP:
verl/verl/models/transformers/monkey_patch.py(DualKV + Ulysses all-to-all) - SP correctness test:
experiments/test_dualkv_sp_correctness.py
| Experiment | GPUs |
|---|---|
| Kernel benchmarks (Table 1, Table 2) | 1x H100-80GB |
| Qwen3-8B end-to-end (Table 5, Table 8) | 8x H100-80GB |
| Qwen3-14B end-to-end | 8x H100-80GB |
| DAPO end-to-end (Table 7) | 8x H100-80GB |
| Qwen3-30B-A3B multi-node (Table 3) | 16x H100-80GB (2 nodes) |
| Memory scaling sweep | 1x H100-80GB |
| Package | Version |
|---|---|
| Python | 3.12 |
| PyTorch | 2.9.0+cu128 |
| CUDA | 12.8 |
| flash-attn | 2.8.4 (included, with DualKV) |
| veRL | 0.7.0 (included, with DualKV) |
| vLLM | 0.12.0 |
| Ray | 2.55.0 |
| Transformers | 4.57.6 |
git clone <this-repo> dualkv && cd dualkv
python3 -m venv .venv && source .venv/bin/activate
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu128cd flash-attention
pip install ninja numpy packaging
git clone --depth 1 https://github.com/NVIDIA/cutlass.git csrc/cutlass
pip install -e . --no-build-isolation
cd ..Verify: python -c "from flash_attn import flash_attn_dualkv_varlen_func; print('OK')"
cd verl
pip install -e .
cd ..Only needed to reproduce FA3 baseline rows in Table 5 and Table 7:
git clone https://github.com/Dao-AILab/flash-attention.git /tmp/flash-attention-3
cd /tmp/flash-attention-3 && git checkout v3.0.0 && cd hopper && pip install -e .Verify: python -c "from flash_attn_interface import flash_attn_func; print('FA3 OK')"
Only needed to reproduce the Prefix Grouper baseline in Table 2:
pip install git+https://github.com/CASIA-IVA-Lab/PrefixGrouper.gitpip install vllm==0.12.0 ray==2.55.0 wandb pandas pyarrowWORKDIR=/path/to/your/workdir
# Models
huggingface-cli download Qwen/Qwen3-8B --local-dir ${WORKDIR}/models/Qwen3-8B
huggingface-cli download Qwen/Qwen3-14B --local-dir ${WORKDIR}/models/Qwen3-14B
huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir ${WORKDIR}/models/Qwen3-30B-A3B
# Data
python experiments/preprocess_longreason.py --local_save_dir ${WORKDIR}/data/longreason
python experiments/preprocess_quality.py --local_save_dir ${WORKDIR}/data/qualitySet environment before running any script:
export WORKDIR=/path/to/your/workdir
export WANDB_API_KEY=your_key # optional, scripts fall back to console loggingNotation: mb = micro-batch size (prompt groups per training step), P = prompt length, N = number of responses per prompt, R = response length, SP = Ulysses sequence parallelism degree, DP = data parallelism degree, FA2/FA3 = FlashAttention-2/3.
Isolated DualKV vs FA2 attention kernel timing (fwd + bwd), fp16.
CUDA_VISIBLE_DEVICES=0 python experiments/reproduce_table1.pyExpected output (H100-80GB):
N P | FA2 fwd FA2 bwd FA2 f+b | DK fwd DK bwd DK f+b | fwd bwd f+b
28 4096 | 49.4 165.8 215.3 | 34.4 98.7 133.1 | 1.44x 1.68x 1.62x
28 16384 | 425.0 1325.8 1750.8 | 120.1 347.6 467.7 | 3.54x 3.81x 3.74x
16 32768 | 857.7 2645.8 3503.4 | 174.5 504.9 679.4 | 4.91x 5.24x 5.16x
28 32768 | 1500.9 4609.0 6109.9 | 259.8 758.4 1018.2 | 5.78x 6.08x 6.00x
16 65536 | OOM OOM OOM | 454.2 1277.7 1731.8 | inf inf inf
Single Qwen3-8B decoder layer fwd+bwd with realistic response lengths. Prefix Grouper is self-implemented (no external package needed).
CUDA_VISIBLE_DEVICES=0 python experiments/reproduce_table2.pyPaper Table 2 reports configs: (P=5K, mb=32), (8K, 16), (16K, 8), (32K, 4).
The script sweeps the full P x mb grid and marks paper configs with *.
torchrun --standalone --nproc-per-node 8 experiments/benchmark_qwen3_single_step.py \
--model ${WORKDIR}/models/Qwen3-8B --path both| Config | Script |
|---|---|
| FA2 mb=4 (baseline) | bash experiments/run_qwen3_8b_longreason_fa2.sh |
| FA3 mb=4 | bash experiments/run_qwen3_8b_longreason_fa3.sh |
| DualKV mb=4 | bash experiments/run_qwen3_8b_longreason_dualkv_mb4.sh |
| DualKV mb=8 | bash experiments/run_qwen3_8b_longreason_dualkv_mb8.sh |
| Config | Script |
|---|---|
| FA2 mb=4 | bash experiments/run_dapo_qwen3_8b_longreason_fa2_mb4.sh |
| FA3 mb=4 | bash experiments/run_dapo_qwen3_8b_longreason_fa3_mb4.sh |
| DualKV mb=4 | bash experiments/run_dapo_qwen3_8b_longreason_dualkv_mb4.sh |
| DualKV mb=8 | bash experiments/run_dapo_qwen3_8b_longreason_dualkv_mb8.sh |
Start a 2-node Ray cluster first:
# On node 0 (head):
ray start --head --port=6379 --num-gpus=8
# On node 1 (worker):
ray start --address='<head_ip>:6379' --num-gpus=8| Config | Script |
|---|---|
| FA2 mb=8 SP=4 | bash experiments/run_qwen3_30b_a3b_longreason_fa2_mb8_sp4.sh |
| FA2 mb=8 SP=2 | bash experiments/run_qwen3_30b_a3b_longreason_fa2_mb8_sp2.sh |
| DualKV mb=8 | bash experiments/run_qwen3_30b_a3b_longreason_dualkv_mb8.sh |
| Config | Script |
|---|---|
| FA2 LongReason mb=4 | bash experiments/run_qwen3_14b_longreason_fa2_mb4.sh |
| DualKV LongReason mb=8 | bash experiments/run_qwen3_14b_longreason_dualkv_mb8.sh |
| FA2 QuALITY mb=8 | bash experiments/run_qwen3_14b_quality_fa2_mb8.sh |
| DualKV QuALITY mb=8 | bash experiments/run_qwen3_14b_quality_dualkv_mb8.sh |
| Config | Script |
|---|---|
| FA2 | bash experiments/run_qwen3_8b_quality_fa2.sh |
| DualKV mb=4 | bash experiments/run_qwen3_8b_quality_dualkv_mb4.sh |
for P in 8192 16384 32768 65536 131072; do
python experiments/generate_padded_data.py \
--src_dir ${WORKDIR}/data/longreason \
--out_dir ${WORKDIR}/data/longreason_padded/P${P} \
--target_tokens $P \
--model_path ${WORKDIR}/models/Qwen3-8B
done
python experiments/bench_long_context_sweep.pypython experiments/predict_memory.pyDualKV composes with Ulysses SP for additional memory savings. The DualKV repack happens before the Ulysses sequence slice, and all-to-all communication is performed inside the attention kernel wrapper to reconstruct the full token sequence per rank.
# DualKV + SP=2 (DP=4, SP=2): combines prompt deduplication with head splitting
bash experiments/run_qwen3_8b_longreason_dualkv_mb8_sp2.shCorrectness test (requires 2 GPUs):
torchrun --nproc-per-node=2 experiments/test_dualkv_sp_correctness.py@article{dualkv2026,
title={DualKV: Shared-Prompt Flash-Attention Kernels for Efficient Policy Updates in RL Training},
author={Gai, Jiading* and Zhang, Shuai* and Song, Xiang and Wang, Bernie and Karypis, George},
journal={arXiv preprint arXiv:2605.15422},
year={2026}
}This project is licensed under CC-BY-NC-4.0. See LICENSE.
This project includes code derived from FlashAttention-2 (BSD-3-Clause) and veRL (Apache-2.0). See THIRD_PARTY_LICENSES.