Skip to content

Commit 827c746

Browse files
authored
Add chunked log-prob computation for TRL GRPO integration (#126)
* Add chunked log-prob computation for TRL GRPO integration * Addressed MR comments * Reformmated using updated black package
1 parent 71791b3 commit 827c746

6 files changed

Lines changed: 592 additions & 0 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ logs/
1616
out/
1717
checkpoints/
1818
test_tokenizer/
19+
keyval_venv/

ai_dev/trl_integration.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# TRL Integration - AI Development Notes
2+
3+
## AI Usage
4+
5+
AI was used to generate docstrings for the functions and classes in
6+
`keys_values/logprobs.py` and `keys_values/finetune/grpo.py`.

keys_values/finetune/grpo.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
GRPO fine-tuning with KeysAndValues long-context KV cache support.
16+
17+
Provides ``GRPOLongContextTrainer`` — a subclass of TRL's ``GRPOTrainer``
18+
whose per-token log-probability computation uses KeysAndValues' chunked
19+
KV-cache forward pass, bounding GPU memory for arbitrarily long sequences.
20+
21+
Usage::
22+
23+
from keys_values.finetune.grpo import GRPOLongContextTrainer
24+
25+
trainer = GRPOLongContextTrainer(
26+
model="Qwen/Qwen2.5-0.5B-Instruct",
27+
reward_funcs=my_reward_func,
28+
train_dataset=dataset,
29+
kv_cache_name="h2o-torch-quantized8",
30+
kv_cache_length=16384,
31+
kv_chunk_size=1024,
32+
)
33+
trainer.train()
34+
"""
35+
36+
from __future__ import annotations
37+
38+
import torch
39+
from trl.trainer.grpo_trainer import GRPOTrainer
40+
41+
from keys_values.logprobs import chunked_per_token_logps
42+
from keys_values.model import GPT
43+
44+
_UNWRAP_ATTRS = ("gpt_model", "model", "base_model", "module")
45+
46+
47+
class GRPOLongContextTrainer(GRPOTrainer):
48+
"""``GRPOTrainer`` with KV-cache chunked log-prob computation.
49+
50+
Overrides TRL's full-sequence forward with KeysAndValues' bounded-memory
51+
chunked path for sequences exceeding ``kv_cache_length``. Short sequences
52+
fall through to TRL's default — zero overhead.
53+
54+
Parameters
55+
----------
56+
kv_cache_name : str
57+
Cache policy, e.g. ``"h2o-torch-quantized8"``.
58+
kv_cache_length : int
59+
Slot count. Sequences longer than this trigger the chunked path.
60+
kv_chunk_size : int
61+
Chunk size for post-prefill processing.
62+
kv_cache_kwargs : dict | None
63+
Extra kwargs forwarded to ``KVCacheFactory.create``.
64+
"""
65+
66+
def __init__(
67+
self,
68+
*args,
69+
kv_cache_name: str = "h2o-torch-quantized8",
70+
kv_cache_length: int = 16384,
71+
kv_chunk_size: int = 1024,
72+
kv_cache_kwargs: dict | None = None,
73+
**kwargs,
74+
):
75+
self.kv_cache_name = kv_cache_name
76+
self.kv_cache_length = kv_cache_length
77+
self.kv_chunk_size = kv_chunk_size
78+
self.kv_cache_kwargs = kv_cache_kwargs
79+
super().__init__(*args, **kwargs)
80+
81+
def _get_per_token_logps_and_entropies(
82+
self,
83+
model,
84+
input_ids,
85+
attention_mask,
86+
logits_to_keep,
87+
*,
88+
batch_size=None,
89+
compute_entropy=False,
90+
**kwargs,
91+
):
92+
"""Route long sequences through the chunked KV-cache path."""
93+
seq_len = input_ids.shape[1]
94+
95+
if seq_len <= self.kv_cache_length:
96+
return super()._get_per_token_logps_and_entropies(
97+
model,
98+
input_ids,
99+
attention_mask,
100+
logits_to_keep,
101+
batch_size=batch_size,
102+
compute_entropy=compute_entropy,
103+
**kwargs,
104+
)
105+
106+
gpt = _unwrap(model)
107+
bs = batch_size or input_ids.size(0)
108+
109+
results = [
110+
chunked_per_token_logps(
111+
gpt_model=gpt,
112+
input_ids=input_ids[i : i + bs],
113+
logits_to_keep=logits_to_keep,
114+
cache_name=self.kv_cache_name,
115+
cache_length=self.kv_cache_length,
116+
chunk_size=self.kv_chunk_size,
117+
cache_kwargs=self.kv_cache_kwargs,
118+
temperature=self.temperature,
119+
compute_entropy=compute_entropy,
120+
)
121+
for i in range(0, input_ids.size(0), bs)
122+
]
123+
124+
logps = torch.cat([r[0] for r in results])
125+
entropies = (
126+
torch.cat([r[1] for r in results if r[1] is not None])
127+
if results[0][1] is not None
128+
else None
129+
)
130+
return logps, entropies
131+
132+
133+
def _unwrap(model) -> GPT:
134+
"""Peel wrappers (DDP, PEFT, HF, ...) until we hit a ``GPT`` instance."""
135+
seen: set[int] = set()
136+
cur = model
137+
while not isinstance(cur, GPT):
138+
if id(cur) in seen:
139+
break
140+
seen.add(id(cur))
141+
nxt = next(
142+
(getattr(cur, a) for a in _UNWRAP_ATTRS if hasattr(cur, a)),
143+
None,
144+
)
145+
if nxt is None:
146+
break
147+
cur = nxt
148+
149+
if not isinstance(cur, GPT):
150+
raise TypeError(
151+
f"Cannot locate a keys_values.model.GPT inside {type(model).__name__}. "
152+
"Ensure your model was loaded through keys_values."
153+
)
154+
return cur

keys_values/logprobs.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Memory-efficient per-token log-probability computation using KV cache.
16+
17+
Implements a :class:`HeadModel` that accumulates per-token log-probs
18+
(and optionally entropies) instead of a scalar loss. Plug it into
19+
:class:`LongContextInferenceModel` and the existing chunked forward
20+
infrastructure handles everything — no new forward loop needed.
21+
22+
Usage::
23+
24+
from keys_values.logprobs import compute_logprobs
25+
26+
logps, entropies = compute_logprobs(
27+
gpt_model=model,
28+
input_ids=input_ids,
29+
targets=completion_ids,
30+
cache_name="h2o-torch-quantized8",
31+
cache_length=16384,
32+
chunk_size=1024,
33+
)
34+
"""
35+
36+
from typing import Optional, Tuple
37+
38+
import torch
39+
import torch.nn.functional as F
40+
41+
from keys_values.config import Config
42+
from keys_values.head_model import HeadModel
43+
from keys_values.kvcache.factory import KVCacheFactory
44+
from keys_values.long_context import LongContextInferenceModel
45+
from keys_values.model import GPT
46+
47+
48+
class LogProbsHeadModel(HeadModel):
49+
"""HeadModel that accumulates per-token log-probs instead of a loss.
50+
51+
Wraps the same logic as :class:`CrossEntropyOnLogits` but instead of
52+
reducing to a scalar loss, it gathers the log-probability of each target
53+
token and stores it. After the full chunked forward pass completes,
54+
call :meth:`get_results` to retrieve the accumulated tensors.
55+
56+
This is meant to be used with :class:`LongContextInferenceModel` — the
57+
existing chunk/cell/layer loop calls ``forward()`` chunk by chunk, and
58+
this class collects log-probs as they come.
59+
"""
60+
61+
NAME = "log_probs"
62+
63+
def __init__(
64+
self, config: Config, temperature: float = 1.0, compute_entropy: bool = False
65+
):
66+
super().__init__()
67+
self._vocab_size = config.padded_vocab_size
68+
self._temperature = temperature
69+
self._compute_entropy = compute_entropy
70+
self._logps_chunks: list[torch.Tensor] = []
71+
self._entropy_chunks: list[torch.Tensor] = []
72+
73+
def needs_logits(self) -> bool:
74+
return True
75+
76+
def forward(
77+
self,
78+
model_outputs: torch.Tensor,
79+
targets: Optional[torch.Tensor],
80+
input_pos: int,
81+
) -> torch.Tensor:
82+
"""Accumulate log-probs for target tokens in this chunk.
83+
84+
Called by LongContextInferenceModel for each chunk. When targets
85+
is None (prompt-only chunk), we skip. When targets are present,
86+
we gather log-probs and optionally entropy.
87+
"""
88+
if input_pos == 0:
89+
self._logps_chunks.clear()
90+
self._entropy_chunks.clear()
91+
92+
diff = self._check_model_outputs_targets(
93+
model_outputs, targets, final_dim=self._vocab_size
94+
)
95+
96+
if diff is not None:
97+
logits = model_outputs[:, diff:, :]
98+
if self._temperature != 1.0:
99+
logits = logits / self._temperature
100+
101+
# Per-token log-probs
102+
log_probs = F.log_softmax(logits, dim=-1)
103+
token_logps = torch.gather(
104+
log_probs, dim=-1, index=targets.unsqueeze(-1)
105+
).squeeze(-1)
106+
self._logps_chunks.append(token_logps)
107+
108+
if self._compute_entropy:
109+
ent = -(log_probs.exp() * log_probs).sum(dim=-1)
110+
self._entropy_chunks.append(ent)
111+
112+
# Return zeros
113+
return torch.zeros(
114+
model_outputs.shape[0],
115+
device=model_outputs.device,
116+
dtype=model_outputs.dtype,
117+
)
118+
119+
def num_target_entries(self, targets: torch.Tensor) -> Optional[torch.Tensor]:
120+
return None
121+
122+
def get_results(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
123+
"""Retrieve accumulated log-probs and entropies after forward pass.
124+
125+
Returns
126+
-------
127+
logps : torch.Tensor
128+
Shape ``(batch_size, num_target_tokens)``.
129+
entropies : torch.Tensor | None
130+
Shape ``(batch_size, num_target_tokens)`` or None.
131+
"""
132+
logps = torch.cat(self._logps_chunks, dim=1)
133+
entropies = (
134+
torch.cat(self._entropy_chunks, dim=1) if self._entropy_chunks else None
135+
)
136+
return logps, entropies
137+
138+
def _empty_clone(self, device: Optional[torch.device] = None) -> "HeadModel":
139+
config = Config()
140+
config.padded_vocab_size = self._vocab_size
141+
return LogProbsHeadModel(
142+
config,
143+
temperature=self._temperature,
144+
compute_entropy=self._compute_entropy,
145+
)
146+
147+
148+
def compute_logprobs(
149+
gpt_model: GPT,
150+
input_ids: torch.Tensor,
151+
targets: torch.Tensor,
152+
cache_name: str = "h2o-torch-quantized8",
153+
cache_length: int = 16384,
154+
chunk_size: int = 1024,
155+
cache_kwargs: Optional[dict] = None,
156+
temperature: float = 1.0,
157+
compute_entropy: bool = False,
158+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
159+
"""Compute per-token log-probs via LongContextInferenceModel.
160+
161+
This is the primary entry point. It creates a :class:`LogProbsHeadModel`,
162+
plugs it into :class:`LongContextInferenceModel`, and runs the existing
163+
chunked forward pass. All KV cache management, chunk/cell grouping, and
164+
layer processing is handled by the existing infrastructure.
165+
166+
Args:
167+
gpt_model: KeysAndValues GPT model.
168+
input_ids: Full sequence (prompt + completion), shape
169+
``(batch_size, seq_length)``.
170+
targets: Target tokens (right-aligned with input_ids), shape
171+
``(batch_size, num_completion_tokens)``.
172+
cache_name: KV cache policy name.
173+
cache_length: Number of slots in the KV cache.
174+
chunk_size: Chunk size for post-prefill processing.
175+
cache_kwargs: Extra args for KV cache construction.
176+
temperature: Scales logits before softmax.
177+
compute_entropy: Whether to also return per-token entropy.
178+
179+
Returns:
180+
Tuple of (log_probs, entropies).
181+
"""
182+
batch_size = input_ids.shape[0]
183+
config = gpt_model.config
184+
dtype = next(gpt_model.parameters()).dtype
185+
186+
head = LogProbsHeadModel(
187+
config, temperature=temperature, compute_entropy=compute_entropy
188+
)
189+
190+
caches_created = False
191+
if gpt_model.get_kv_caches()[0] is None:
192+
gpt_model.assign_kv_caches(
193+
KVCacheFactory.create(
194+
gpt_model=gpt_model,
195+
name=cache_name,
196+
max_batch_size=batch_size,
197+
cache_length=cache_length,
198+
dtype=dtype,
199+
cache_kwargs=cache_kwargs or {},
200+
)
201+
)
202+
caches_created = True
203+
204+
inference_model = LongContextInferenceModel(
205+
gpt_model=gpt_model,
206+
head_model=head,
207+
chunk_size=chunk_size,
208+
)
209+
210+
# Run the forward pass
211+
inference_model(input_ids=input_ids, targets=targets)
212+
213+
logps, entropies = head.get_results()
214+
215+
if caches_created:
216+
from keys_values.kvcache.factory import deallocate_kv_cache_buffers_of_model
217+
218+
deallocate_kv_cache_buffers_of_model(gpt_model)
219+
gpt_model.assign_kv_caches([None] * config.n_layer)
220+
221+
return logps, entropies

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929

3030
[project.optional-dependencies]
3131
cuda = ["flashinfer-python"]
32+
trl = ["trl>=1.0.0"]
3233

3334
[tool.setuptools.packages.find]
3435
include = [

0 commit comments

Comments
 (0)