|
| 1 | +# GRPO Fine-tuning with KeysAndValues |
| 2 | + |
| 3 | +This document explains how to run GRPO (Group Relative Policy Optimization) |
| 4 | +fine-tuning on top of KeysAndValues, what you need to install, and exactly |
| 5 | +where the KeysAndValues KV cache is used in the pipeline. |
| 6 | + |
| 7 | +The code lives under [`keys_values/rl/`](../keys_values/rl): |
| 8 | + |
| 9 | +``` |
| 10 | +keys_values/rl/ |
| 11 | + logprobs.py # memory-efficient per-token log-prob computation |
| 12 | + grpo/ |
| 13 | + trainer.py # GRPOLongContextTrainer (TRL GRPOTrainer subclass) |
| 14 | + loop.py # grpo_step: standalone GRPO loop (no TRL needed) |
| 15 | + loss.py # GRPOLossHeadModel: GRPO loss as a HeadModel |
| 16 | + rollout.py # generate_completions: KV-cache generation adapter |
| 17 | +``` |
| 18 | + |
| 19 | +## Why a KV cache for RL? |
| 20 | + |
| 21 | +GRPO repeatedly (1) generates completions, (2) scores them under the sampling |
| 22 | +policy, and (3) computes a policy gradient. For long prompts/completions, each |
| 23 | +of these steps would normally hold the activations or attention state for the |
| 24 | +entire sequence in GPU memory. |
| 25 | + |
| 26 | +KeysAndValues processes sequences in **chunks** through a bounded KV cache, so |
| 27 | +peak memory stays flat as sequence length grows. The GRPO integration routes |
| 28 | +the memory-heavy steps through that infrastructure. |
| 29 | + |
| 30 | +## Where the KV cache is used |
| 31 | + |
| 32 | +This is the key question, and the answer depends on which entry point you use. |
| 33 | + |
| 34 | +### Standalone loop (`keys_values.rl.grpo.loop.grpo_step`) |
| 35 | + |
| 36 | +The KV cache is used at **every** memory-heavy stage — both inference and the |
| 37 | +gradient update: |
| 38 | + |
| 39 | +| Stage | Component | KV cache role | |
| 40 | +|-------|-----------|---------------| |
| 41 | +| 1. Generation | `rollout.generate_completions` → `LongContextInferenceModel` | inference (chunked prefill + decode) | |
| 42 | +| 2. Old (sampling) log-probs | `logprobs.compute_logprobs` (under `no_grad`) | inference (scoring) | |
| 43 | +| 3. Policy gradient | `loss.GRPOLossHeadModel` + `LongContextGradientModel` | gradient updates (memory-bounded backward) | |
| 44 | +| 4. Optimizer step | `torch.optim.Optimizer` | — | |
| 45 | + |
| 46 | +So for the standalone loop the answer is **both**: the cache backs generation |
| 47 | +and old-log-prob scoring (inference) *and* the gradient/backward pass. |
| 48 | + |
| 49 | +### TRL trainer (`keys_values.rl.grpo.trainer.GRPOLongContextTrainer`) |
| 50 | + |
| 51 | +This subclass plugs into TRL's `GRPOTrainer` and overrides only the per-token |
| 52 | +log-probability computation (`_get_per_token_logps_and_entropies`), routing it |
| 53 | +through `compute_logprobs` when a sequence is longer than `kv_cache_length`. |
| 54 | +TRL still owns generation and the optimizer. |
| 55 | + |
| 56 | +| Stage | Owner | KV cache role | |
| 57 | +|-------|-------|---------------| |
| 58 | +| Generation | TRL (transformers / vLLM) | not used | |
| 59 | +| Per-token log-probs (reference and policy) | `compute_logprobs` | inference scoring for long sequences; short sequences fall through to TRL's default with zero overhead | |
| 60 | +| Optimizer step | TRL | — | |
| 61 | + |
| 62 | +So for the TRL path the cache is used specifically for the **log-prob |
| 63 | +computation** TRL depends on, on long sequences only. |
| 64 | + |
| 65 | +## Installation |
| 66 | + |
| 67 | +Start from a working KeysAndValues install (see the top-level |
| 68 | +[README](../README.md) for the base setup), then: |
| 69 | + |
| 70 | +```bash |
| 71 | +# Base package (editable install from the repo root) |
| 72 | +pip install -e . |
| 73 | + |
| 74 | +# The standalone loop (keys_values.rl.grpo.loop) and logprobs need nothing |
| 75 | +# beyond the base install — they run anywhere KeysAndValues runs, incl. CPU. |
| 76 | + |
| 77 | +# The TRL trainer additionally needs TRL: |
| 78 | +pip install -e .[trl] # installs trl>=1.0.0 |
| 79 | +``` |
| 80 | + |
| 81 | +GPU is recommended for real runs but not required for the standalone loop on a |
| 82 | +small model (the unit tests run on CPU). |
| 83 | + |
| 84 | +## Running |
| 85 | + |
| 86 | +### Option A — standalone loop (no TRL) |
| 87 | + |
| 88 | +`grpo_step` runs one full GRPO optimization step on a `keys_values.model.GPT` |
| 89 | +that has non-dense KV caches assigned. You supply prompts and a reward |
| 90 | +function. |
| 91 | + |
| 92 | +```python |
| 93 | +import torch |
| 94 | +from keys_values.model import GPT |
| 95 | +from keys_values.kvcache.factory import KVCacheFactory |
| 96 | +from keys_values.rl.grpo.loop import grpo_step |
| 97 | + |
| 98 | +# model must have (non-dense) KV caches assigned |
| 99 | +gpt_model.assign_kv_caches( |
| 100 | + KVCacheFactory.create( |
| 101 | + gpt_model=gpt_model, |
| 102 | + name="lastrec-default", |
| 103 | + max_batch_size=num_prompts * group_size, |
| 104 | + cache_length=cache_length, |
| 105 | + dtype=torch.float32, |
| 106 | + ) |
| 107 | +) |
| 108 | +optimizer = torch.optim.SGD(gpt_model.parameters(), lr=1e-2) |
| 109 | + |
| 110 | +def reward_fn(prompt_ids, completion_ids): |
| 111 | + # return a reward tensor of shape (num_prompts * group_size,) |
| 112 | + ... |
| 113 | + |
| 114 | +metrics = grpo_step( |
| 115 | + gpt_model=gpt_model, |
| 116 | + prompt_ids=prompt_ids, # (num_prompts, prompt_len), left-padded |
| 117 | + reward_fn=reward_fn, |
| 118 | + optimizer=optimizer, |
| 119 | + group_size=group_size, # completions sampled per prompt |
| 120 | + max_new_tokens=64, |
| 121 | + chunk_size=16, |
| 122 | +) |
| 123 | +print(metrics) # loss, mean_reward, mean_advantage, ... |
| 124 | +``` |
| 125 | + |
| 126 | +A runnable, annotated walkthrough is in |
| 127 | +[`examples/trl_grpo_demo.ipynb`](../examples/trl_grpo_demo.ipynb). |
| 128 | + |
| 129 | +### Option B — TRL trainer |
| 130 | + |
| 131 | +A drop-in `GRPOTrainer` for users already on TRL who have a KeysAndValues |
| 132 | +model. Long sequences get the bounded-memory log-prob path; short ones use |
| 133 | +TRL's default. |
| 134 | + |
| 135 | +```python |
| 136 | +from keys_values.rl.grpo.trainer import GRPOLongContextTrainer |
| 137 | + |
| 138 | +trainer = GRPOLongContextTrainer( |
| 139 | + model="Qwen/Qwen2.5-0.5B-Instruct", |
| 140 | + reward_funcs=my_reward_func, |
| 141 | + train_dataset=dataset, |
| 142 | + kv_cache_name="h2o-torch-quantized8", |
| 143 | + kv_cache_length=16384, # sequences longer than this use the chunked path |
| 144 | + kv_chunk_size=1024, |
| 145 | +) |
| 146 | +trainer.train() |
| 147 | +``` |
| 148 | + |
| 149 | +## Tests |
| 150 | + |
| 151 | +```bash |
| 152 | +pytest test/rl/ |
| 153 | +``` |
| 154 | + |
| 155 | +The suite runs on CPU with a tiny model and exercises the full pipeline: |
| 156 | +generation → reward → advantages → old log-probs → policy gradient → |
| 157 | +optimizer step. |
0 commit comments