Skip to content

Commit de493c5

Browse files
committed
feat: /v1/query/{prepare,execute} endpoints for span queries
This PR adds two endpoints - `/v1/query/prepare` - `/v1/query/execute` These add REST APIs around the spans core feature. They take as input a [span query](https://github.com/IBM/spnl). This PR also adds an example query under examples/offline-inference/spans/query-{ab,ba},json. See the [spans readme](examples/offline-inference/spans/README.md) for an example of usage. Signed-off-by: Nick Mitchell <[email protected]>
1 parent 1c3cd0e commit de493c5

File tree

5 files changed

+195
-0
lines changed

5 files changed

+195
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Span (a.k.a. Block Attention) Examples
2+
3+
## Span Queries
4+
5+
This directory contains a [span query](https://github.com/IBM/spnl#readme). To send a query, first prepare the query shape:
6+
7+
```bash
8+
curl -s -XPOST http://localhost:8000/v1/query/prepare --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
9+
1.504452
10+
```
11+
12+
And then you can execute the query in either order, and you should see millisecond-level TTFT:
13+
14+
```bash
15+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ba.json -o /dev/null -w "%{time_total}\n"
16+
0.077699
17+
```
18+
19+
```bash
20+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
21+
0.078419
22+
```

examples/offline_inference/spans/query-ab.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

examples/offline_inference/spans/query-ba.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ pybase64 # fast base64 implementation
4949
cbor2 # Required for cross-language serialization of hashable objects
5050
setproctitle # Used to set process names for better debugging and monitoring
5151
openai-harmony >= 0.0.3 # Required for gpt-oss
52+
spnl >= 0.8.0

vllm/entrypoints/openai/api_server.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,126 @@ async def cancel_responses(response_id: str, raw_request: Request):
665665
return JSONResponse(content=response.model_dump())
666666

667667

668+
if envs.VLLM_V1_SPANS_ENABLED:
669+
import spnl
670+
import time
671+
from fastapi import Body
672+
from vllm import SamplingParams
673+
from vllm.inputs import TokensPrompt
674+
from vllm.outputs import RequestOutput
675+
from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionStreamResponse,ChatCompletionResponseStreamChoice,ChatCompletionResponseChoice,DeltaMessage,UsageInfo)
676+
spnl_state = spnl.init(10)
677+
PAD_TOKEN = 27
678+
PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
679+
CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
680+
def wrap(prompt: str | list[str]) -> TokensPrompt:
681+
if isinstance(prompt[0], list):
682+
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
683+
return TokensPrompt(prompt_token_ids=prompt)
684+
@router.post("/v1/query/prepare")
685+
@with_cancellation
686+
@load_aware_call
687+
async def prepare_query(raw_request: Request,
688+
query: str = Body(..., media_type="text/plain")):
689+
docs = [wrap(doc) for doc in spnl.tokenize_prepare(
690+
spnl_state,
691+
query,
692+
True, # we need to preload the prefix of the plus/independent spans
693+
PAD_TOKEN,
694+
PLUS_TOKEN,
695+
raw_request.app.state.vllm_config.cache_config.block_size
696+
)]
697+
698+
request_id = raw_request.headers.get(
699+
"X-Request-Id") or uuid.uuid4().hex
700+
client = engine_client(raw_request)
701+
generators = [client.generate(doc, SamplingParams(temperature=0,max_tokens=1), request_id) for doc in docs]
702+
for generator in generators:
703+
async for res in generator:
704+
final = res.outputs[0]
705+
706+
if isinstance(generator, ErrorResponse):
707+
return JSONResponse(content=generator.model_dump(),
708+
status_code=generator.error.code)
709+
return JSONResponse(content={"success": True})
710+
711+
@router.post("/v1/query/execute")
712+
@with_cancellation
713+
@load_aware_call
714+
async def execute_query(raw_request: Request,
715+
query: str = Body(..., media_type="text/plain"),
716+
stream: bool = False):
717+
req = spnl.tokenize_query(
718+
spnl_state,
719+
query,
720+
PAD_TOKEN,
721+
CROSS_TOKEN,
722+
PLUS_TOKEN,
723+
raw_request.app.state.vllm_config.cache_config.block_size
724+
)
725+
726+
request_id = raw_request.headers.get(
727+
"X-Request-Id") or uuid.uuid4().hex
728+
client = engine_client(raw_request)
729+
generator = client.generate(wrap(req.messages), SamplingParams(n=1 if req.n <= 0 else req.n,temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id)
730+
731+
if stream:
732+
async def sgen():
733+
output_idx: List[int] = [0 for _ in range(req.n)]
734+
async for res in generator:
735+
yield ChatCompletionStreamResponse(
736+
id=request_id,
737+
object="chat.completion.chunk",
738+
created=int(time.time()),
739+
model=req.model,
740+
choices=[
741+
ChatCompletionResponseStreamChoice(
742+
index=index,
743+
delta=DeltaMessage(role="assistant", content=output.text[output_idx[index]:]),
744+
logprobs=output.logprobs,
745+
finish_reason=output.finish_reason,
746+
stop_reason=output.stop_reason,
747+
)
748+
for index, output in enumerate(res.outputs)
749+
]
750+
).model_dump_json(exclude_unset=True)
751+
for index, output in enumerate(res.outputs):
752+
output_idx[index] = len(output.text)
753+
return StreamingResponse(content=sgen(), media_type="text/event-stream")
754+
755+
outputs: List[Optional[CompletionOutput]] = [None for _ in range(req.n)]
756+
async for res in generator:
757+
for output in res.outputs:
758+
outputs[output.index] = output
759+
choices = [
760+
ChatCompletionResponseChoice(
761+
index=index,
762+
message=ChatMessage(role="assistant", content=output.text),
763+
logprobs=output.logprobs,
764+
finish_reason=output.finish_reason,
765+
stop_reason=output.stop_reason,
766+
)
767+
for index, output in enumerate(outputs)
768+
]
769+
num_prompt_tokens=0 # TODO
770+
num_generated_tokens=0 # TODO
771+
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
772+
completion_tokens=num_generated_tokens,
773+
total_tokens=num_prompt_tokens +
774+
num_generated_tokens)
775+
response = ChatCompletionResponse(
776+
id=request_id,
777+
created=int(time.time()),
778+
model=req.model,
779+
choices=choices,
780+
usage=usage
781+
)
782+
783+
if isinstance(generator, ErrorResponse):
784+
return JSONResponse(content=generator.model_dump(),
785+
status_code=generator.error.code)
786+
return JSONResponse(content=response.model_dump())
787+
668788
@router.post("/v1/chat/completions",
669789
dependencies=[Depends(validate_json_request)],
670790
responses={

0 commit comments

Comments
 (0)