-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[data][llm] Expose logprobs support in Ray Data LLM #58899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Extract and surface logprobs from vLLM outputs. Previously, logprobs could be requested in sampling_params but were not returned in output rows. This adds logprobs and prompt_logprobs fields to vLLMOutputData and extracts them from vLLM's RequestOutput. Signed-off-by: Nikhil Ghosh <[email protected]>
Add unit tests for logprobs and prompt_logprobs extraction from vLLM outputs, including cases with multiple logprobs per token and None values. Signed-off-by: Nikhil Ghosh <[email protected]>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively exposes logprobs and prompt_logprobs from vLLM outputs in Ray Data LLM. The modifications to the vLLMOutputData model and its from_vllm_engine_output factory method are clear and correct. The new unit tests validate the functionality, and I've provided suggestions to enhance them by making assertions more comprehensive. This will improve test robustness and maintainability by verifying the entire data structure.
python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py
Outdated
Show resolved
Hide resolved
python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Nikhil Ghosh <[email protected]>
| if output.outputs[0].logprobs is not None: | ||
| data.logprobs = [ | ||
| { | ||
| token_id: dataclasses.asdict(logprob) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the possible types for logprob object? Can it be pydantic as well as dataclass (think about the future changes that could happen and the diff between sglang and vllm). I am afraid using dataclasses.asdict() might overfit to todays version
| data.prompt_logprobs = [ | ||
| { | ||
| token_id: dataclasses.asdict(logprob) | ||
| for token_id, logprob in logprob_dict.items() | ||
| } | ||
| if logprob_dict is not None | ||
| else None | ||
| for logprob_dict in output.prompt_logprobs | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this part a utility and reuse between prompt_logprobs and logprobs
|
|
||
| logprobs = [ | ||
| { | ||
| 123: Logprob(logprob=-0.5, rank=1, decoded_token="hello"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question, what is rank here? rank of the TP workers? or some other rank?
| wrapper.shutdown() | ||
|
|
||
|
|
||
| def test_vllm_output_data_logprobs(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I understand correctly, this test does not test any code path in ray data llm stuff, it's only testing the data type logic in vllm. Is that intentional?
| 111: {"logprob": -0.1, "rank": 1, "decoded_token": "test"}, | ||
| 222: {"logprob": -0.8, "rank": 2, "decoded_token": "demo"}, | ||
| }, | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this test.
Description
Exposes logprobs support in Ray Data LLM by extracting logprobs and prompt_logprobs from vLLM outputs and including them in the output rows.
Changes
logprobsandprompt_logprobsfields tovLLMOutputDatamodeloutput.outputs[0].logprobsinfrom_vllm_engine_outputoutput.prompt_logprobsinfrom_vllm_engine_outputLogprobdataclass instances to serializable dicts usingdataclasses.asdict()Testing
Added unit tests verifying:
Related issues
Closes #58894
Additional information