@@ -665,6 +665,126 @@ async def cancel_responses(response_id: str, raw_request: Request):
665665 return JSONResponse (content = response .model_dump ())
666666
667667
668+ if envs .VLLM_V1_SPANS_ENABLED :
669+ import spnl
670+ import time
671+ from fastapi import Body
672+ from vllm import SamplingParams
673+ from vllm .inputs import TokensPrompt
674+ from vllm .outputs import RequestOutput
675+ from vllm .entrypoints .openai .protocol import (ChatMessage ,ChatCompletionStreamResponse ,ChatCompletionResponseStreamChoice ,ChatCompletionResponseChoice ,DeltaMessage ,UsageInfo )
676+ spnl_state = spnl .init (10 )
677+ PAD_TOKEN = 27
678+ PLUS_TOKEN = envs .VLLM_V1_SPANS_TOKEN_PLUS if envs .VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
679+ CROSS_TOKEN = envs .VLLM_V1_SPANS_TOKEN_CROSS if envs .VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
680+ def wrap (prompt : str | list [str ]) -> TokensPrompt :
681+ if isinstance (prompt [0 ], list ):
682+ return [TokensPrompt (prompt_token_ids = p ) for p in prompt ]
683+ return TokensPrompt (prompt_token_ids = prompt )
684+ @router .post ("/v1/query/prepare" )
685+ @with_cancellation
686+ @load_aware_call
687+ async def prepare_query (raw_request : Request ,
688+ query : str = Body (..., media_type = "text/plain" )):
689+ docs = [wrap (doc ) for doc in spnl .tokenize_prepare (
690+ spnl_state ,
691+ query ,
692+ True , # we need to preload the prefix of the plus/independent spans
693+ PAD_TOKEN ,
694+ PLUS_TOKEN ,
695+ raw_request .app .state .vllm_config .cache_config .block_size
696+ )]
697+
698+ request_id = raw_request .headers .get (
699+ "X-Request-Id" ) or uuid .uuid4 ().hex
700+ client = engine_client (raw_request )
701+ generators = [client .generate (doc , SamplingParams (temperature = 0 ,max_tokens = 1 ), request_id ) for doc in docs ]
702+ for generator in generators :
703+ async for res in generator :
704+ final = res .outputs [0 ]
705+
706+ if isinstance (generator , ErrorResponse ):
707+ return JSONResponse (content = generator .model_dump (),
708+ status_code = generator .error .code )
709+ return JSONResponse (content = {"success" : True })
710+
711+ @router .post ("/v1/query/execute" )
712+ @with_cancellation
713+ @load_aware_call
714+ async def execute_query (raw_request : Request ,
715+ query : str = Body (..., media_type = "text/plain" ),
716+ stream : bool = False ):
717+ req = spnl .tokenize_query (
718+ spnl_state ,
719+ query ,
720+ PAD_TOKEN ,
721+ CROSS_TOKEN ,
722+ PLUS_TOKEN ,
723+ raw_request .app .state .vllm_config .cache_config .block_size
724+ )
725+
726+ request_id = raw_request .headers .get (
727+ "X-Request-Id" ) or uuid .uuid4 ().hex
728+ client = engine_client (raw_request )
729+ generator = client .generate (wrap (req .messages ), SamplingParams (n = 1 if req .n <= 0 else req .n ,temperature = req .temperature if req .temperature is not None else 0 ,max_tokens = req .max_tokens if req .max_tokens is not None and req .max_tokens != 0 else 2048 ), request_id )
730+
731+ if stream :
732+ async def sgen ():
733+ output_idx : List [int ] = [0 for _ in range (req .n )]
734+ async for res in generator :
735+ yield ChatCompletionStreamResponse (
736+ id = request_id ,
737+ object = "chat.completion.chunk" ,
738+ created = int (time .time ()),
739+ model = req .model ,
740+ choices = [
741+ ChatCompletionResponseStreamChoice (
742+ index = index ,
743+ delta = DeltaMessage (role = "assistant" , content = output .text [output_idx [index ]:]),
744+ logprobs = output .logprobs ,
745+ finish_reason = output .finish_reason ,
746+ stop_reason = output .stop_reason ,
747+ )
748+ for index , output in enumerate (res .outputs )
749+ ]
750+ ).model_dump_json (exclude_unset = True )
751+ for index , output in enumerate (res .outputs ):
752+ output_idx [index ] = len (output .text )
753+ return StreamingResponse (content = sgen (), media_type = "text/event-stream" )
754+
755+ outputs : List [Optional [CompletionOutput ]] = [None for _ in range (req .n )]
756+ async for res in generator :
757+ for output in res .outputs :
758+ outputs [output .index ] = output
759+ choices = [
760+ ChatCompletionResponseChoice (
761+ index = index ,
762+ message = ChatMessage (role = "assistant" , content = output .text ),
763+ logprobs = output .logprobs ,
764+ finish_reason = output .finish_reason ,
765+ stop_reason = output .stop_reason ,
766+ )
767+ for index , output in enumerate (outputs )
768+ ]
769+ num_prompt_tokens = 0 # TODO
770+ num_generated_tokens = 0 # TODO
771+ usage = UsageInfo (prompt_tokens = num_prompt_tokens ,
772+ completion_tokens = num_generated_tokens ,
773+ total_tokens = num_prompt_tokens +
774+ num_generated_tokens )
775+ response = ChatCompletionResponse (
776+ id = request_id ,
777+ created = int (time .time ()),
778+ model = req .model ,
779+ choices = choices ,
780+ usage = usage
781+ )
782+
783+ if isinstance (generator , ErrorResponse ):
784+ return JSONResponse (content = generator .model_dump (),
785+ status_code = generator .error .code )
786+ return JSONResponse (content = response .model_dump ())
787+
668788@router .post ("/v1/chat/completions" ,
669789 dependencies = [Depends (validate_json_request )],
670790 responses = {
0 commit comments