@@ -399,6 +399,48 @@ def _get_logprobs(self,
399399 logprobs [token_id ] = logprob .logprob
400400 return super ()._get_logprobs (logprobs_list , token_ids , top_logprobs )
401401
402+ def _get_prompt_logprobs (
403+ self ,
404+ prompt_logprobs : Optional [List [Optional [Dict ]]],
405+ prompt_token_ids : List [int ],
406+ ) -> Optional [List [Dict [str , Any ]]]:
407+ if prompt_logprobs is None or not prompt_token_ids :
408+ return None
409+
410+ result = []
411+ for pos_idx , (token_id , pos_logprobs ) in enumerate (zip (prompt_token_ids , prompt_logprobs )):
412+ token = self .tokenizer .decode (token_id )
413+ entry = {
414+ 'token_id' : token_id ,
415+ 'token' : token ,
416+ 'logprob' : None , # Will be filled if available
417+ 'top_logprobs' : [],
418+ }
419+
420+ if pos_logprobs is not None :
421+ # Get logprob for the actual token at this position
422+ if token_id in pos_logprobs :
423+ logprob_obj = pos_logprobs [token_id ]
424+ entry ['logprob' ] = logprob_obj .logprob if hasattr (logprob_obj , 'logprob' ) else logprob_obj
425+
426+ # Get top logprobs sorted by probability (descending)
427+ sorted_items = sorted (
428+ pos_logprobs .items (), key = lambda x : - (x [1 ].logprob if hasattr (x [1 ], 'logprob' ) else x [1 ]))
429+ for tid , logprob_obj in sorted_items :
430+ logprob_val = logprob_obj .logprob if hasattr (logprob_obj , 'logprob' ) else logprob_obj
431+ if logprob_val == float ('-inf' ):
432+ continue
433+ t = self .tokenizer .decode (tid )
434+ entry ['top_logprobs' ].append ({
435+ 'token_id' : tid ,
436+ 'token' : t ,
437+ 'logprob' : logprob_val ,
438+ })
439+
440+ result .append (entry )
441+
442+ return result
443+
402444 def _prepare_generation_config (self , request_config : RequestConfig ) -> SamplingParams :
403445 kwargs = {'max_tokens' : request_config .max_tokens }
404446 for key in ['temperature' , 'top_k' , 'top_p' , 'repetition_penalty' ]:
@@ -424,6 +466,10 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingP
424466 # Return only the sampled token's logprob
425467 kwargs ['logprobs' ] = 0
426468
469+ # Handle prompt_logprobs: return logprobs for prompt/input tokens
470+ if request_config .prompt_logprobs is not None :
471+ kwargs ['prompt_logprobs' ] = request_config .prompt_logprobs
472+
427473 # TODO: beam search
428474 for key in ['n' , 'best_of' , 'frequency_penalty' , 'presence_penalty' , 'seed' ]:
429475 if hasattr (SamplingParams , key ):
@@ -582,13 +628,21 @@ def _create_chat_completion_response(
582628 logprobs = self ._get_logprobs (output .logprobs , output .token_ids , request_config .top_logprobs )
583629 toolcall = self ._get_toolcall (content ) # Use content instead of response for tool calls
584630 token_ids = output .token_ids if request_config .return_details else None
631+
632+ # Get prompt logprobs if requested
633+ prompt_logprobs_result = None
634+ if request_config .prompt_logprobs is not None :
635+ prompt_logprobs_result = self ._get_prompt_logprobs (result .prompt_logprobs ,
636+ list (result .prompt_token_ids ))
637+
585638 choice = ChatCompletionResponseChoice (
586639 index = output .index ,
587640 message = ChatMessage (
588641 role = 'assistant' , content = content , reasoning_content = reasoning_content , tool_calls = toolcall ),
589642 finish_reason = output .finish_reason ,
590643 logprobs = logprobs ,
591- token_ids = token_ids )
644+ token_ids = token_ids ,
645+ prompt_logprobs = prompt_logprobs_result )
592646 choices .append (choice )
593647 prompt_token_ids = None
594648 images_size = None
0 commit comments