Skip to content
Open
Show file tree
Hide file tree
Changes from 97 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
f10285e
support prompts or token IDs in VLLMClient and update API request han…
qgallouedec Mar 5, 2026
7d2bb67
test
qgallouedec Mar 5, 2026
3b356ac
consistency
qgallouedec Mar 5, 2026
82c4508
fix
qgallouedec Mar 5, 2026
3ea2fcf
another fix
qgallouedec Mar 5, 2026
445f4ba
fix docstring
qgallouedec Mar 5, 2026
8c6c88d
Add support for multi-modal inputs in VLLMClient and vllm_serve
qgallouedec Mar 5, 2026
f617b2d
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
eaffd67
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
f3f6a5d
Move `rollout_func from `_generate_single_turn` to `_generate`
qgallouedec Mar 6, 2026
d417543
fix style
qgallouedec Mar 6, 2026
4b927d6
support multi-image
qgallouedec Mar 6, 2026
029fc1f
style
qgallouedec Mar 6, 2026
20b4039
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 6, 2026
b8e3912
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 6, 2026
07181cb
Fix handling of images in OnlineDPOTrainer to ensure proper structure…
qgallouedec Mar 7, 2026
6ff1e56
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 7, 2026
9f340e4
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 7, 2026
d138be7
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 7, 2026
09128d6
Move tokenization before vLLM generation call
qgallouedec Mar 7, 2026
7fd1711
Fix deadlock issue by ensuring images are always gathered in VLLMGene…
qgallouedec Mar 7, 2026
3ab04b0
Unify tokenization across all generation backends in _generate_single…
qgallouedec Mar 7, 2026
5d6d067
Extract tokenization out of _generate_single_turn into _tokenize_prompts
qgallouedec Mar 7, 2026
b4d2c34
Enhance multimodal input handling in GRPO and RLOO trainers by adding…
qgallouedec Mar 7, 2026
4922362
style
qgallouedec Mar 7, 2026
37c48b3
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 7, 2026
3375aea
Fix re-tokenization bug in tool-calling loop by concatenating token IDs
qgallouedec Mar 7, 2026
638f88a
Enhance _tool_call_loop to support multimodal inputs by adding images…
qgallouedec Mar 7, 2026
9825358
Refactor generation methods in GRPO and RLOO trainers to remove unuse…
qgallouedec Mar 7, 2026
65d62db
Refactor GRPOTrainer generation methods to remove unused extra_fields…
qgallouedec Mar 7, 2026
d1685b1
multimodal
qgallouedec Mar 7, 2026
71de8c0
fix
qgallouedec Mar 7, 2026
0a264a2
Fix tokenization padding issue in GRPOTrainer to handle unpadded inpu…
qgallouedec Mar 7, 2026
0aa0e30
style
qgallouedec Mar 7, 2026
b490357
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 7, 2026
6fd47dc
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 7, 2026
8fecba1
align rloo
qgallouedec Mar 7, 2026
6c093dd
style
qgallouedec Mar 7, 2026
a9a91c7
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 7, 2026
934aae7
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 7, 2026
7e863e1
fix
qgallouedec Mar 7, 2026
f033e63
revert doc modif
qgallouedec Mar 9, 2026
5a1f609
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 9, 2026
1eb3540
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
498a564
Merge branch 'move-rollout-func' into vllm-generate-with-token-ids
qgallouedec Mar 9, 2026
be2ff99
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 9, 2026
5df2069
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 9, 2026
ae8767f
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 9, 2026
d3f7971
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
319d52a
simplify multimodal
qgallouedec Mar 9, 2026
d5e1906
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
4ccadcf
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
2a80df9
Merge branch 'move-rollout-func' into vllm-generate-with-token-ids
qgallouedec Mar 9, 2026
a0df552
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 9, 2026
3350588
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 9, 2026
19ffe9e
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 9, 2026
0558dc9
Merge branch 'main' into move-rollout-func
qgallouedec Mar 9, 2026
6ebb681
Merge branch 'move-rollout-func' into vllm-generate-with-token-ids
qgallouedec Mar 9, 2026
93640e4
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 9, 2026
1c009b0
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 9, 2026
0c1fe0f
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 9, 2026
97a813b
Merge branch 'main' into vllm-generate-with-token-ids
qgallouedec Mar 10, 2026
83ab9bd
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 10, 2026
408fb2e
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
087b5e9
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 10, 2026
ade2831
Merge branch 'main' into vllm-generate-with-token-ids
qgallouedec Mar 10, 2026
258e0a8
Update trl/trainer/grpo_trainer.py
qgallouedec Mar 10, 2026
ef96048
Update trl/trainer/rloo_trainer.py
qgallouedec Mar 10, 2026
0ee6495
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 10, 2026
bb6dc69
Update trl/trainer/grpo_trainer.py
qgallouedec Mar 10, 2026
0effa0d
Update trl/trainer/rloo_trainer.py
qgallouedec Mar 10, 2026
fad1fdd
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
f2d1e01
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 10, 2026
b35f250
Remove unused chat/tool configuration parameters from VLLM and RLOO t…
qgallouedec Mar 10, 2026
040e392
Update trl/generation/vllm_generation.py
qgallouedec Mar 10, 2026
ca2cae3
Update trl/trainer/rloo_trainer.py
qgallouedec Mar 10, 2026
fee553d
Merge branch 'main' into vllm-generate-with-token-ids
qgallouedec Mar 10, 2026
90df2de
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 10, 2026
f36c0ea
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
8678382
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 10, 2026
fdaa90a
fix
qgallouedec Mar 10, 2026
6f10cd2
style
qgallouedec Mar 10, 2026
533c337
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
50418e0
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 10, 2026
7e7e3b3
Merge branch 'main' into unify-tokenization-generate
qgallouedec Mar 10, 2026
31d8a0c
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
e88987f
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 10, 2026
8b4f6af
Merge branch 'main' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
a704d89
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 10, 2026
81cf273
Merge branch 'main' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
918686b
Remove dead code: eliminate prompt tokenization logic from GRPOTraine…
qgallouedec Mar 10, 2026
9b8de83
remove unused extra_fields from _generate_single_turn return value
qgallouedec Mar 10, 2026
6c8f55c
style
qgallouedec Mar 10, 2026
130d974
Merge branch 'extract-tokenize-prompts' into fix-retokenization-tool-…
qgallouedec Mar 10, 2026
8b27397
properly merge upstream
qgallouedec Mar 10, 2026
6c9db28
fix
qgallouedec Mar 10, 2026
441725b
Merge branch 'main' into fix-retokenization-tool-loop
qgallouedec Mar 13, 2026
367a79e
align with main
qgallouedec Mar 13, 2026
f3f0f8d
fix
qgallouedec Mar 14, 2026
5147625
Merge branch 'main' into fix-retokenization-tool-loop
qgallouedec Mar 14, 2026
10708ca
Merge branch 'main' into fix-retokenization-tool-loop
qgallouedec Mar 16, 2026
f81f6a9
Merge branch 'main' into fix-retokenization-tool-loop
qgallouedec Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 63 additions & 39 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):

# Generate using vLLM with raw token IDs
num_generations = self.num_generations if mode == "train" else self.num_generations_eval
prompt_ids, completion_ids, logprobs, _ = self.vllm_generation.generate(
_, completion_ids, logprobs, _ = self.vllm_generation.generate(
prompts=prompt_ids,
images=images,
num_generations=num_generations,
Expand Down Expand Up @@ -1309,8 +1309,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
**generate_inputs, generation_config=self.generation_config, disable_compile=True
)
# Compute prompt length and extract completion ids
prompt_ids_tensor, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
prompt_length = prompt_ids_tensor.size(1)
prompt_length = generate_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]

# Mask everything after the first EOS token
Expand All @@ -1319,18 +1318,34 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Move tensors to CPU before per-sample to avoid many CUDA syncs/copies (costly at scale/contention).
prompt_ids = [
p[m].tolist() for p, m in zip(prompt_ids_tensor.cpu(), prompt_mask.bool().cpu(), strict=True)
]
completion_ids = [
c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True)
]
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)]
logprobs = None # not used in this case

return prompt_ids, completion_ids, logprobs
return completion_ids, logprobs

def _get_tool_suffix_ids(self, tool_messages):
"""
Get token IDs for tool result formatting by using a minimal dummy conversation."""
dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}]
prefix_ids = self.processing_class.apply_chat_template(
dummy_messages,
add_generation_prompt=False,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
full_ids = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
if not full_ids[: len(prefix_ids)] == prefix_ids:
raise ValueError("Unexpected tokenization: the prefix IDs are not a prefix of the full IDs.")
return full_ids[len(prefix_ids) :]

def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs):
def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields):
# Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt
tool_calls = [completion[0].get("tool_calls") for completion in completions]
idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call]
Expand Down Expand Up @@ -1397,17 +1412,24 @@ async def _run_async_tools(async_coros):
prompt_completion_tool.append(tool_message)
completions[idx_with_tool].append(tool_message)

# Tokenize and filter samples whose length exceeds max allowed length. This is important, because both
# Build token IDs by concatenation: prompt + completion + tool_suffix.
prompt_completion_tool_ids = []
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
# Extract trailing tool messages from completions
tool_messages = []
for message in reversed(completions[idx_with_tool]):
if message["role"] == "tool":
tool_messages.insert(0, message)
else:
break
suffix_ids = self._get_tool_suffix_ids(tool_messages)
prompt_completion_tool_ids.append(
prompt_ids[idx_with_tool] + completion_ids[idx_with_tool] + suffix_ids
)

# Filter samples whose length exceeds max allowed length. This is important, because both
# vLLM and transformers will error out if the input is longer than the model's max length.
pct_ids = self.processing_class.apply_chat_template(
prompt_completion_tools,
tools=self.tools,
chat_template=self.chat_template,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
**self.chat_template_kwargs,
)
if self.use_vllm and self.vllm_mode == "colocate":
max_model_len = self.llm.llm_engine.model_config.max_model_len
elif not self.use_vllm:
Expand All @@ -1416,37 +1438,37 @@ async def _run_async_tools(async_coros):
raise NotImplementedError(
f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}"
)
overlong = [len(pct) >= max_model_len for pct in pct_ids]
overlong = [len(pct) >= max_model_len for pct in prompt_completion_tool_ids]
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
if overlong[idx]:
prompt_length = len(prompt_ids[idx_with_tool])
ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length]
ct = prompt_completion_tool_ids[idx][prompt_length : prompt_length + self.max_completion_length]
completion_ids[idx_with_tool] = ct
tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool]))
if logprobs is not None:
logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool]))
# Keep only non-overlong items for further processing
idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o]
prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o]
prompt_completion_tool_ids = [
pct for pct, o in zip(prompt_completion_tool_ids, overlong, strict=True) if not o
]
if not idxs_with_tool:
break # all overlong, exit tool loop

# Generate new completions after tool execution
pct_prompt_ids, pct_images, pct_multimodal_fields = self._tokenize_prompts(prompt_completion_tools)
prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs = self._generate_single_turn(
pct_prompt_ids, pct_images, pct_multimodal_fields
# Filter images and multimodal fields to match the current subset (index into full batch)
loop_images = [images[i] for i in idxs_with_tool] if images else None
loop_multimodal_fields = (
{k: [v[i] for i in idxs_with_tool] for k, v in multimodal_fields.items()}
if multimodal_fields
else None
)

# Sanity check: from experience, this is useful to catch bugs in the chat template
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool
if prompt_ids[idx_with_tool] != pct[: len(prompt_ids[idx_with_tool])]:
raise ValueError(
"The chat template is not prefix-preserving. Please update it to use a prefix-preserving "
"format."
)
# Generate new completions after tool execution (using concatenated IDs, no re-tokenization)
post_tool_ids, post_tool_logprobs = self._generate_single_turn(
prompt_completion_tool_ids, loop_images, loop_multimodal_fields
)

# Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length
for idx in range(len(idxs_with_tool)):
Expand Down Expand Up @@ -1528,7 +1550,7 @@ def _generate(self, prompts: list):
prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"]
else:
prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts)
prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields)
completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields)
extra_fields = {}

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
Expand All @@ -1555,7 +1577,9 @@ def _generate(self, prompts: list):
logprobs,
tool_call_count,
tool_failure_count,
) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs)
) = self._tool_call_loop(
prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields
)
else:
# Support custom env_mask from rollout_func (e.g., for environment feedback masking)
# Internally treated as tool_mask - marks model tokens (1) vs external tokens (0)
Expand Down
19 changes: 6 additions & 13 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def _tokenize_prompts(self, prompts: list):
# Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists
prompt_ids = [
[tok for tok, m in zip(ids, mask, strict=True) if m]
for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True)
for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=False)
]
# For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.)
multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")}
Expand All @@ -915,7 +915,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):

# Generate using vLLM (note: RLOO doesn't use logprobs from generation, so we ignore them)
num_generations = self.num_generations if mode == "train" else self.num_generations_eval
prompt_ids, completion_ids, _, _ = self.vllm_generation.generate(
_, completion_ids, _, _ = self.vllm_generation.generate(
prompts=prompt_ids,
images=images,
num_generations=num_generations,
Expand Down Expand Up @@ -976,8 +976,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
**generate_inputs, generation_config=self.generation_config, disable_compile=True
)
# Compute prompt length and extract completion ids
prompt_ids_tensor, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
prompt_length = prompt_ids_tensor.size(1)
prompt_length = generate_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]

# Mask everything after the first EOS token
Expand All @@ -986,15 +985,9 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Move tensors to CPU before per-sample to avoid many CUDA syncs/copies (costly at scale/contention).
prompt_ids = [
p[m].tolist() for p, m in zip(prompt_ids_tensor.cpu(), prompt_mask.bool().cpu(), strict=True)
]
completion_ids = [
c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True)
]
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)]

return prompt_ids, completion_ids
return completion_ids

def _generate(self, prompts: list):
device = self.accelerator.device
Expand All @@ -1004,7 +997,7 @@ def _generate(self, prompts: list):
prompts = copy.deepcopy(prompts)

prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts)
prompt_ids, completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields)
completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields)

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
Expand Down
Loading