-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[data][llm] Expose logprobs support in Ray Data LLM #58899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,6 +87,14 @@ class vLLMOutputData(BaseModel): | |
| # Metrics fields. | ||
| metrics: Optional[Dict[str, Any]] = None | ||
|
|
||
| # Logprobs fields. | ||
| # logprobs: List[Dict[int, Dict[str, Any]]] where each dict maps token_id to | ||
| # logprob info (logprob, rank, decoded_token) for each generated token. | ||
| logprobs: Optional[List[Dict[int, Dict[str, Any]]]] = None | ||
| # prompt_logprobs: List[Optional[Dict[int, Dict[str, Any]]]] where each dict | ||
| # (or None) maps token_id to logprob info for each prompt token. | ||
| prompt_logprobs: Optional[List[Optional[Dict[int, Dict[str, Any]]]]] = None | ||
|
|
||
| @classmethod | ||
| def from_vllm_engine_output(cls, output: Any) -> "vLLMOutputData": | ||
| """Create a vLLMOutputData from a vLLM engine output.""" | ||
|
|
@@ -111,6 +119,28 @@ def from_vllm_engine_output(cls, output: Any) -> "vLLMOutputData": | |
| data.generated_tokens = output.outputs[0].token_ids | ||
| data.generated_text = output.outputs[0].text | ||
| data.num_generated_tokens = len(output.outputs[0].token_ids) | ||
|
|
||
| # Extract logprobs | ||
| if output.outputs[0].logprobs is not None: | ||
| data.logprobs = [ | ||
| { | ||
| token_id: dataclasses.asdict(logprob) | ||
| for token_id, logprob in logprob_dict.items() | ||
| } | ||
| for logprob_dict in output.outputs[0].logprobs | ||
| ] | ||
|
|
||
| # Extract prompt_logprobs | ||
| if output.prompt_logprobs is not None: | ||
| data.prompt_logprobs = [ | ||
| { | ||
| token_id: dataclasses.asdict(logprob) | ||
| for token_id, logprob in logprob_dict.items() | ||
| } | ||
| if logprob_dict is not None | ||
| else None | ||
| for logprob_dict in output.prompt_logprobs | ||
| ] | ||
|
Comment on lines
+135
to
+143
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make this part a utility and reuse between prompt_logprobs and logprobs |
||
| elif isinstance(output, vllm.outputs.PoolingRequestOutput): | ||
| data.embeddings = output.outputs.data.cpu() | ||
| if ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| vLLMEngineStage, | ||
| vLLMEngineStageUDF, | ||
| vLLMEngineWrapper, | ||
| vLLMOutputData, | ||
| vLLMTaskType, | ||
| ) | ||
| from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy | ||
|
|
@@ -439,5 +440,105 @@ class AnswerModel(BaseModel): | |
| wrapper.shutdown() | ||
|
|
||
|
|
||
| def test_vllm_output_data_logprobs(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I understand correctly, this test does not test any code path in ray data llm stuff, it's only testing the data type logic in vllm. Is that intentional? |
||
| """Test that logprobs and prompt_logprobs are correctly extracted.""" | ||
| from vllm.logprobs import Logprob | ||
| from vllm.outputs import CompletionOutput, RequestOutput | ||
|
|
||
| logprobs = [ | ||
| { | ||
| 123: Logprob(logprob=-0.5, rank=1, decoded_token="hello"), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question, what is rank here? rank of the TP workers? or some other rank? |
||
| 456: Logprob(logprob=-1.2, rank=2, decoded_token="hi"), | ||
| }, | ||
| { | ||
| 789: Logprob(logprob=-0.3, rank=1, decoded_token="world"), | ||
| 999: Logprob(logprob=-1.5, rank=2, decoded_token="earth"), | ||
| }, | ||
| ] | ||
|
|
||
| prompt_logprobs = [ | ||
| None, | ||
| { | ||
| 111: Logprob(logprob=-0.1, rank=1, decoded_token="test"), | ||
| 222: Logprob(logprob=-0.8, rank=2, decoded_token="demo"), | ||
| }, | ||
| ] | ||
|
|
||
| request_output = RequestOutput( | ||
| request_id="test", | ||
| prompt="test prompt", | ||
| prompt_token_ids=[1, 2], | ||
| prompt_logprobs=prompt_logprobs, | ||
| outputs=[ | ||
| CompletionOutput( | ||
| index=0, | ||
| text="hello world", | ||
| token_ids=[123, 789], | ||
| cumulative_logprob=-0.8, | ||
| logprobs=logprobs, | ||
| ) | ||
| ], | ||
| finished=True, | ||
| ) | ||
|
|
||
| output_data = vLLMOutputData.from_vllm_engine_output(request_output) | ||
|
|
||
| expected_logprobs = [ | ||
| { | ||
| 123: {"logprob": -0.5, "rank": 1, "decoded_token": "hello"}, | ||
| 456: {"logprob": -1.2, "rank": 2, "decoded_token": "hi"}, | ||
| }, | ||
| { | ||
| 789: {"logprob": -0.3, "rank": 1, "decoded_token": "world"}, | ||
| 999: {"logprob": -1.5, "rank": 2, "decoded_token": "earth"}, | ||
| }, | ||
| ] | ||
| assert output_data.logprobs == expected_logprobs | ||
|
|
||
| expected_prompt_logprobs = [ | ||
| None, | ||
| { | ||
| 111: {"logprob": -0.1, "rank": 1, "decoded_token": "test"}, | ||
| 222: {"logprob": -0.8, "rank": 2, "decoded_token": "demo"}, | ||
| }, | ||
| ] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for this test. |
||
| assert output_data.prompt_logprobs == expected_prompt_logprobs | ||
|
|
||
| dumped = output_data.model_dump() | ||
| assert dumped["logprobs"] == expected_logprobs | ||
| assert dumped["prompt_logprobs"] == expected_prompt_logprobs | ||
|
|
||
|
|
||
| def test_vllm_output_data_no_logprobs(): | ||
| """Test that None logprobs are handled correctly when not requested.""" | ||
| from vllm.outputs import CompletionOutput, RequestOutput | ||
|
|
||
| request_output = RequestOutput( | ||
| request_id="test", | ||
| prompt="test prompt", | ||
| prompt_token_ids=[1, 2], | ||
| prompt_logprobs=None, | ||
| outputs=[ | ||
| CompletionOutput( | ||
| index=0, | ||
| text="test response", | ||
| token_ids=[4, 5, 6], | ||
| cumulative_logprob=None, | ||
| logprobs=None, | ||
| ) | ||
| ], | ||
| finished=True, | ||
| ) | ||
|
|
||
| output_data = vLLMOutputData.from_vllm_engine_output(request_output) | ||
|
|
||
| assert output_data.logprobs is None | ||
| assert output_data.prompt_logprobs is None | ||
|
|
||
| dumped = output_data.model_dump() | ||
| assert dumped["logprobs"] is None | ||
| assert dumped["prompt_logprobs"] is None | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(pytest.main(["-v", __file__])) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the possible types for logprob object? Can it be pydantic as well as dataclass (think about the future changes that could happen and the diff between sglang and vllm). I am afraid using dataclasses.asdict() might overfit to todays version