Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def generate(
List of lists of token IDs representing the tokenized input prompts.
- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each prompt.
- `logprobs` (`list[list[list[float]]]`):
- `logprobs` (`list[list[list[float]]]` | `list[list[float]]`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no so sure about this change, I think it's always list[list[list[float]]]:

>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=0)["logprobs"]
[[[-0.2079554945230484], [-5.021359443664551], [-4.506778717041016], [-0.16933049261569977], [-1.328775405883789], [-5.707622051239014], [-6.522100925445557], [-1.3067556619644165], [-6.344869136810303], [-0.061117831617593765], [-1.44622802734375], [-0.04607903212308884], [-0.00957468245178461], [-3.2259726524353027], [-3.274900436401367], [-3.4954776763916016]], [[-0.0422004759311676], [-5.590917587280273], [-1.9313716888427734], [-1.106265664100647], [-0.0110595328733325], [-0.0010186012368649244], [-1.6689160474925302e-05], [-1.3351351299206726e-05], [-0.11895198374986649], [-0.0006528153317049146], [-1.1920922133867862e-06], [-0.002003330737352371], [-0.003181754844263196], [-0.049776118248701096], [-0.011047743260860443], [-0.8712134957313538]]]
>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=0)["logprobs"]
[[[-0.2079554945230484], [-0.7713592648506165, -1.6463592052459717], [-0.4071059823036194], [-1.109581470489502, -1.609581470489502], [-0.06224556267261505], [-1.5437123775482178, -3.6687123775482178], [-1.0006335973739624, -1.5006335973739624], [-0.1264881044626236], [-0.006634233985096216], [-0.3543497920036316], [-0.8073388934135437], [-0.2931288480758667], [-0.2984274625778198], [-1.352670669555664, -1.977670669555664], [-0.2433444708585739, -1.9933444261550903], [-0.45821458101272583, -3.833214521408081]], [[-0.0422004759311676], [-2.3409173488616943, -6.840917587280273], [-0.7118824124336243], [-2.3060879707336426, -7.743587970733643], [-0.8871606588363647], [-1.029018759727478], [-0.04861651360988617], [-0.0002108589978888631], [-2.753696753643453e-05], [-1.4305012882687151e-05], [-0.05704395845532417], [-0.01012428104877472], [-0.00042632073746062815], [0.0], [-2.622600959512056e-06], [-0.004123874939978123]]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait a sec, isn't it a behavioural change from vLLM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if you don't pass logprobs at all?

Copy link
Member

@qgallouedec qgallouedec Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

server side error:

>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=None)["logprobs"]
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/uvicorn/protocols/http/httptools_impl.py", line 416, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        self.scope, self.receive, self.send
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/applications.py", line 1135, in __call__
    await super().__call__(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/applications.py", line 107, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/exceptions.py", line 63, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
    await self.app(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 716, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 736, in app
    await route.handle(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 290, in handle
    await self.app(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 115, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 101, in app
    response = await f(request)
               ^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 377, in app
    content = await serialize_response(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<10 lines>...
    )
    ^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 215, in serialize_response
    raise ResponseValidationError(
    ...<3 lines>...
    )
fastapi.exceptions.ResponseValidationError: 2 validation errors:
  {'type': 'list_type', 'loc': ('response', 'logprobs'), 'msg': 'Input should be a valid list', 'input': None}
  {'type': 'list_type', 'loc': ('response', 'logprob_token_ids'), 'msg': 'Input should be a valid list', 'input': None}

  File "/fsx/qgallouedec/trl/trl/scripts/vllm_serve.py", line 512, in generate
    POST /generate/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or

>>> client.generate(["Hello, AI!", "Tell me a joke"])["logprobs"]
[[[-0.2079554945230484], [-1.6463592052459717], [-0.4071059823036194], [-1.609581470489502], [-0.06224556267261505], [-1.5437123775482178], [-2.3698196411132812], [-0.29048386216163635], [-0.2877020239830017], [-2.652569532394409], [-0.4416201710700989], [-10.628130912780762], [-1.1008961200714111], [-0.37943634390830994], [-0.7756030559539795], [-14.819398880004883]], [[-0.0422004759311676], [-5.340917587280273], [-0.9934616088867188], [-1.7789040803909302], [-0.009760985150933266], [-0.00201974855735898], [-1.4781842764932662e-05], [-5.8412379075889476e-06], [-0.07062072306871414], [-0.0012309125158935785], [0.0], [-0.018214812502264977], [-0.001459129503928125], [-0.5814141035079956], [-0.01416344940662384], [-0.0002592465898487717]]]

Copy link
Collaborator

@LeonEricsson LeonEricsson Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, i think it's always list[list[list[float]]]. Or actually, it should be list[list[list[float]]] | None, and vllm_serve should correctly return None when passing logprobs=None, right now it crashes as Quentin noted.

Per-token logprobs of shape (num_sequences, seq_len, num_logprobs), sorted by descending
probability.
- `logprob_token_ids` (`list[list[list[int]]]`):
Expand Down Expand Up @@ -287,12 +287,15 @@ def generate(
)
if response.status_code == 200:
json_response = response.json()
return {
result = {
"prompt_ids": json_response["prompt_ids"],
"completion_ids": json_response["completion_ids"],
"logprobs": json_response["logprobs"],
"logprob_token_ids": json_response["logprob_token_ids"],
}
if "logprob_token_ids" in json_response:
# `logprob_token_ids` only appears in the response when `logprobs` is greater than 0
result["logprob_token_ids"] = json_response["logprob_token_ids"]
return result
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

Expand Down Expand Up @@ -362,7 +365,7 @@ def chat(
List of lists of token IDs representing the tokenized input messages.
- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each message list.
- `logprobs` (`list[list[list[float]]]`):
- `logprobs` (`list[list[list[float]]]` | `list[list[float]]`):
Per-token logprobs of shape (num_sequences, seq_len, num_logprobs), sorted by descending
probability.
- `logprob_token_ids` (`list[list[list[int]]]`):
Expand Down Expand Up @@ -404,12 +407,15 @@ def chat(
)
if response.status_code == 200:
json_response = response.json()
return {
result = {
"prompt_ids": json_response["prompt_ids"],
"completion_ids": json_response["completion_ids"],
"logprobs": json_response["logprobs"],
"logprob_token_ids": json_response["logprob_token_ids"],
}
if "logprob_token_ids" in json_response:
# `logprob_token_ids` only appears in the response when `logprobs` is greater than 0
result["logprob_token_ids"] = json_response["logprob_token_ids"]
return result
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,9 @@ def _generate_single_turn(self, prompts: list):
prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate")
)
# vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob
logprobs = [[lp[0] for lp in seq] for seq in logprobs]
if isinstance(logprobs[0][0], list):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't "really" like this check since it depends on a lot of nesting. Open to a better way to handle this guard.

# reduce when we request logprobs > 0 from vllm and they are returned as a list per position
logprobs = [[lp[0] for lp in seq] for seq in logprobs]

elif self.use_transformers_paged:
if is_conversational({"prompt": prompts[0]}):
Expand Down