From f50b53b919c92190644f492310856e750e2d1b8d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 27 Feb 2026 09:21:37 -0500 Subject: [PATCH] fix logprobs handling --- trl/generation/vllm_client.py | 18 ++++++++++++------ trl/trainer/grpo_trainer.py | 4 +++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 4455decc08d..a2fa817f4af 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -256,7 +256,7 @@ def generate( List of lists of token IDs representing the tokenized input prompts. - `completion_ids` (`list[list[int]]`): List of lists of token IDs representing the model-generated completions for each prompt. - - `logprobs` (`list[list[list[float]]]`): + - `logprobs` (`list[list[list[float]]]` | `list[list[float]]`): Per-token logprobs of shape (num_sequences, seq_len, num_logprobs), sorted by descending probability. - `logprob_token_ids` (`list[list[list[int]]]`): @@ -287,12 +287,15 @@ def generate( ) if response.status_code == 200: json_response = response.json() - return { + result = { "prompt_ids": json_response["prompt_ids"], "completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"], - "logprob_token_ids": json_response["logprob_token_ids"], } + if "logprob_token_ids" in json_response: + # `logprob_token_ids` only appears in the response when `logprobs` is greater than 0 + result["logprob_token_ids"] = json_response["logprob_token_ids"] + return result else: raise Exception(f"Request failed: {response.status_code}, {response.text}") @@ -362,7 +365,7 @@ def chat( List of lists of token IDs representing the tokenized input messages. - `completion_ids` (`list[list[int]]`): List of lists of token IDs representing the model-generated completions for each message list. - - `logprobs` (`list[list[list[float]]]`): + - `logprobs` (`list[list[list[float]]]` | `list[list[float]]`): Per-token logprobs of shape (num_sequences, seq_len, num_logprobs), sorted by descending probability. - `logprob_token_ids` (`list[list[list[int]]]`): @@ -404,12 +407,15 @@ def chat( ) if response.status_code == 200: json_response = response.json() - return { + result = { "prompt_ids": json_response["prompt_ids"], "completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"], - "logprob_token_ids": json_response["logprob_token_ids"], } + if "logprob_token_ids" in json_response: + # `logprob_token_ids` only appears in the response when `logprobs` is greater than 0 + result["logprob_token_ids"] = json_response["logprob_token_ids"] + return result else: raise Exception(f"Request failed: {response.status_code}, {response.text}") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 5e44cee18d2..dfe54edfc9c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1234,7 +1234,9 @@ def _generate_single_turn(self, prompts: list): prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate") ) # vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob - logprobs = [[lp[0] for lp in seq] for seq in logprobs] + if isinstance(logprobs[0][0], list): + # reduce when we request logprobs > 0 from vllm and they are returned as a list per position + logprobs = [[lp[0] for lp in seq] for seq in logprobs] elif self.use_transformers_paged: if is_conversational({"prompt": prompts[0]}):