Skip to content

Commit 18f9727

Browse files
authored
End-to-end GRPO with KeysAndValues KV cache (#128)
* Add end-to-end GRPO pipeline with KV-cache generation and gradient Phase 2 of the TRL integration. Adds the generation and policy-gradient pieces so the full GRPO loop runs through KeysAndValues' KV cache: - generate/trl_rollout.py: generate_completions, chunked KV-cache decode - finetune/grpo_loss.py: GRPOLossHeadModel, GRPO loss as a HeadModel so the policy gradient flows through LongContextGradientModel (memory-bounded backward) - finetune/grpo_loop.py: grpo_step + compute_group_advantages, the standalone end-to-end loop on a keys_values GPT - finetune/grpo.py: fix stale import (chunked_per_token_logps -> compute_logprobs) - logprobs.py: add backward-compatible verbose parameter - examples/trl_grpo_demo.ipynb: full pipeline demo, runs on CPU - tests: end-to-end grpo_step + generation adapter (19 passing on CPU) * Update ai_dev notes to cover phase 2 TRL integration files * Reorganize GRPO/RL code into keys_values/rl module; add docs Addresses review feedback on PR #128: - Move GRPO files into a dedicated keys_values/rl/grpo package and the shared log-prob utility into keys_values/rl, so RL support is discoverable. - Add docs/grpo_integration.md covering installation, how to run (standalone loop and TRL trainer), and where the KV cache is used (inference vs. gradient updates). - Update all imports, tests, notebook, and AI-dev notes for the new paths.
1 parent e296d94 commit 18f9727

15 files changed

Lines changed: 1492 additions & 9 deletions

File tree

ai_dev/trl_integration.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22

33
## AI Usage
44

5-
AI was used to generate docstrings for the functions and classes in
6-
`keys_values/logprobs.py` and `keys_values/finetune/grpo.py`.
5+
AI was used to generate docstrings for the functions and classes in the TRL
6+
integration code, prior to merging:
7+
8+
- `keys_values/rl/logprobs.py`
9+
- `keys_values/rl/grpo/trainer.py`
10+
- `keys_values/rl/grpo/loss.py`
11+
- `keys_values/rl/grpo/loop.py`
12+
- `keys_values/rl/grpo/rollout.py`

docs/grpo_integration.md

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# GRPO Fine-tuning with KeysAndValues
2+
3+
This document explains how to run GRPO (Group Relative Policy Optimization)
4+
fine-tuning on top of KeysAndValues, what you need to install, and exactly
5+
where the KeysAndValues KV cache is used in the pipeline.
6+
7+
The code lives under [`keys_values/rl/`](../keys_values/rl):
8+
9+
```
10+
keys_values/rl/
11+
logprobs.py # memory-efficient per-token log-prob computation
12+
grpo/
13+
trainer.py # GRPOLongContextTrainer (TRL GRPOTrainer subclass)
14+
loop.py # grpo_step: standalone GRPO loop (no TRL needed)
15+
loss.py # GRPOLossHeadModel: GRPO loss as a HeadModel
16+
rollout.py # generate_completions: KV-cache generation adapter
17+
```
18+
19+
## Why a KV cache for RL?
20+
21+
GRPO repeatedly (1) generates completions, (2) scores them under the sampling
22+
policy, and (3) computes a policy gradient. For long prompts/completions, each
23+
of these steps would normally hold the activations or attention state for the
24+
entire sequence in GPU memory.
25+
26+
KeysAndValues processes sequences in **chunks** through a bounded KV cache, so
27+
peak memory stays flat as sequence length grows. The GRPO integration routes
28+
the memory-heavy steps through that infrastructure.
29+
30+
## Where the KV cache is used
31+
32+
This is the key question, and the answer depends on which entry point you use.
33+
34+
### Standalone loop (`keys_values.rl.grpo.loop.grpo_step`)
35+
36+
The KV cache is used at **every** memory-heavy stage — both inference and the
37+
gradient update:
38+
39+
| Stage | Component | KV cache role |
40+
|-------|-----------|---------------|
41+
| 1. Generation | `rollout.generate_completions``LongContextInferenceModel` | inference (chunked prefill + decode) |
42+
| 2. Old (sampling) log-probs | `logprobs.compute_logprobs` (under `no_grad`) | inference (scoring) |
43+
| 3. Policy gradient | `loss.GRPOLossHeadModel` + `LongContextGradientModel` | gradient updates (memory-bounded backward) |
44+
| 4. Optimizer step | `torch.optim.Optimizer` ||
45+
46+
So for the standalone loop the answer is **both**: the cache backs generation
47+
and old-log-prob scoring (inference) *and* the gradient/backward pass.
48+
49+
### TRL trainer (`keys_values.rl.grpo.trainer.GRPOLongContextTrainer`)
50+
51+
This subclass plugs into TRL's `GRPOTrainer` and overrides only the per-token
52+
log-probability computation (`_get_per_token_logps_and_entropies`), routing it
53+
through `compute_logprobs` when a sequence is longer than `kv_cache_length`.
54+
TRL still owns generation and the optimizer.
55+
56+
| Stage | Owner | KV cache role |
57+
|-------|-------|---------------|
58+
| Generation | TRL (transformers / vLLM) | not used |
59+
| Per-token log-probs (reference and policy) | `compute_logprobs` | inference scoring for long sequences; short sequences fall through to TRL's default with zero overhead |
60+
| Optimizer step | TRL ||
61+
62+
So for the TRL path the cache is used specifically for the **log-prob
63+
computation** TRL depends on, on long sequences only.
64+
65+
## Installation
66+
67+
Start from a working KeysAndValues install (see the top-level
68+
[README](../README.md) for the base setup), then:
69+
70+
```bash
71+
# Base package (editable install from the repo root)
72+
pip install -e .
73+
74+
# The standalone loop (keys_values.rl.grpo.loop) and logprobs need nothing
75+
# beyond the base install — they run anywhere KeysAndValues runs, incl. CPU.
76+
77+
# The TRL trainer additionally needs TRL:
78+
pip install -e .[trl] # installs trl>=1.0.0
79+
```
80+
81+
GPU is recommended for real runs but not required for the standalone loop on a
82+
small model (the unit tests run on CPU).
83+
84+
## Running
85+
86+
### Option A — standalone loop (no TRL)
87+
88+
`grpo_step` runs one full GRPO optimization step on a `keys_values.model.GPT`
89+
that has non-dense KV caches assigned. You supply prompts and a reward
90+
function.
91+
92+
```python
93+
import torch
94+
from keys_values.model import GPT
95+
from keys_values.kvcache.factory import KVCacheFactory
96+
from keys_values.rl.grpo.loop import grpo_step
97+
98+
# model must have (non-dense) KV caches assigned
99+
gpt_model.assign_kv_caches(
100+
KVCacheFactory.create(
101+
gpt_model=gpt_model,
102+
name="lastrec-default",
103+
max_batch_size=num_prompts * group_size,
104+
cache_length=cache_length,
105+
dtype=torch.float32,
106+
)
107+
)
108+
optimizer = torch.optim.SGD(gpt_model.parameters(), lr=1e-2)
109+
110+
def reward_fn(prompt_ids, completion_ids):
111+
# return a reward tensor of shape (num_prompts * group_size,)
112+
...
113+
114+
metrics = grpo_step(
115+
gpt_model=gpt_model,
116+
prompt_ids=prompt_ids, # (num_prompts, prompt_len), left-padded
117+
reward_fn=reward_fn,
118+
optimizer=optimizer,
119+
group_size=group_size, # completions sampled per prompt
120+
max_new_tokens=64,
121+
chunk_size=16,
122+
)
123+
print(metrics) # loss, mean_reward, mean_advantage, ...
124+
```
125+
126+
A runnable, annotated walkthrough is in
127+
[`examples/trl_grpo_demo.ipynb`](../examples/trl_grpo_demo.ipynb).
128+
129+
### Option B — TRL trainer
130+
131+
A drop-in `GRPOTrainer` for users already on TRL who have a KeysAndValues
132+
model. Long sequences get the bounded-memory log-prob path; short ones use
133+
TRL's default.
134+
135+
```python
136+
from keys_values.rl.grpo.trainer import GRPOLongContextTrainer
137+
138+
trainer = GRPOLongContextTrainer(
139+
model="Qwen/Qwen2.5-0.5B-Instruct",
140+
reward_funcs=my_reward_func,
141+
train_dataset=dataset,
142+
kv_cache_name="h2o-torch-quantized8",
143+
kv_cache_length=16384, # sequences longer than this use the chunked path
144+
kv_chunk_size=1024,
145+
)
146+
trainer.train()
147+
```
148+
149+
## Tests
150+
151+
```bash
152+
pytest test/rl/
153+
```
154+
155+
The suite runs on CPU with a tiny model and exercises the full pipeline:
156+
generation → reward → advantages → old log-probs → policy gradient →
157+
optimizer step.

0 commit comments

Comments
 (0)