Skip to content

Commit 6c4dd70

Browse files
committed
feat: /v1/query/{prepare,execute} endpoints for span queries
This PR adds two endpoints - `/v1/query/prepare` - `/v1/query/execute` These add REST APIs around the spans core feature. They take as input a [span query](https://github.com/IBM/spnl). This PR also adds an example query under examples/offline-inference/spans/query-{ab,ba},json. See the [spans readme](examples/offline-inference/spans/README.md) for an example of usage. Signed-off-by: Nick Mitchell <[email protected]>
1 parent 1c3cd0e commit 6c4dd70

File tree

5 files changed

+164
-0
lines changed

5 files changed

+164
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Span (a.k.a. Block Attention) Examples
2+
3+
## Span Queries
4+
5+
This directory contains a [span query](https://github.com/IBM/spnl#readme). To send a query, first prepare the query shape:
6+
7+
```bash
8+
curl -s -XPOST http://localhost:8000/v1/query/prepare --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
9+
1.504452
10+
```
11+
12+
And then you can execute the query in either order, and you should see millisecond-level TTFT:
13+
14+
```bash
15+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ba.json -o /dev/null -w "%{time_total}\n"
16+
0.077699
17+
```
18+
19+
```bash
20+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
21+
0.078419
22+
```

examples/offline_inference/spans/query-ab.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

examples/offline_inference/spans/query-ba.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ pybase64 # fast base64 implementation
4949
cbor2 # Required for cross-language serialization of hashable objects
5050
setproctitle # Used to set process names for better debugging and monitoring
5151
openai-harmony >= 0.0.3 # Required for gpt-oss
52+
spnl >= 0.7.0

vllm/entrypoints/openai/api_server.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,95 @@ async def cancel_responses(response_id: str, raw_request: Request):
665665
return JSONResponse(content=response.model_dump())
666666

667667

668+
if envs.VLLM_V1_SPANS_ENABLED:
669+
import spnl
670+
import time
671+
from fastapi import Body
672+
from vllm import SamplingParams
673+
from vllm.inputs import TokensPrompt
674+
from vllm.outputs import RequestOutput
675+
from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionResponseChoice,UsageInfo)
676+
spnl_state = spnl.init(10)
677+
PAD_TOKEN = 27
678+
PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
679+
CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
680+
def wrap(prompt: str | list[str]) -> TokensPrompt:
681+
if isinstance(prompt[0], list):
682+
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
683+
return TokensPrompt(prompt_token_ids=prompt)
684+
@router.post("/v1/query/prepare")
685+
@with_cancellation
686+
@load_aware_call
687+
async def prepare_query(raw_request: Request,
688+
query: str = Body(..., media_type="text/plain")):
689+
docs = [wrap(doc) for doc in spnl.tokenize_prepare(
690+
spnl_state,
691+
query,
692+
True, # we need to preload the prefix of the plus/independent spans
693+
PAD_TOKEN,
694+
PLUS_TOKEN,
695+
raw_request.app.state.vllm_config.cache_config.block_size
696+
)]
697+
698+
request_id = raw_request.headers.get(
699+
"X-Request-Id") or uuid.uuid4().hex
700+
client = engine_client(raw_request)
701+
generators = [client.generate(doc, SamplingParams(temperature=0,max_tokens=1), request_id) for doc in docs]
702+
for generator in generators:
703+
async for res in generator:
704+
final = res.outputs[0]
705+
706+
return JSONResponse(content={"success": True})
707+
708+
@router.post("/v1/query/execute")
709+
@with_cancellation
710+
@load_aware_call
711+
async def execute_query(raw_request: Request,
712+
query: str = Body(..., media_type="text/plain")):
713+
req = spnl.tokenize_query(
714+
spnl_state,
715+
query,
716+
PAD_TOKEN,
717+
CROSS_TOKEN,
718+
PLUS_TOKEN,
719+
raw_request.app.state.vllm_config.cache_config.block_size
720+
)
721+
722+
request_id = raw_request.headers.get(
723+
"X-Request-Id") or uuid.uuid4().hex
724+
client = engine_client(raw_request)
725+
generator = client.generate(wrap(req.messages), SamplingParams(temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id)
726+
727+
# TODO streaming output...
728+
final_res: Optional[RequestOutput] = None
729+
async for res in generator:
730+
final_res = res
731+
final = final_res.outputs[0]
732+
choices = [
733+
ChatCompletionResponseChoice(
734+
index=0,
735+
message=ChatMessage(role="assistant", content=final.text),
736+
logprobs=final.logprobs,
737+
finish_reason=final.finish_reason,
738+
stop_reason=final.stop_reason,
739+
)
740+
]
741+
num_prompt_tokens=0 # TODO
742+
num_generated_tokens=0 # TODO
743+
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
744+
completion_tokens=num_generated_tokens,
745+
total_tokens=num_prompt_tokens +
746+
num_generated_tokens)
747+
response = ChatCompletionResponse(
748+
id=request_id,
749+
created=int(time.time()),
750+
model=req.model,
751+
choices=choices,
752+
usage=usage
753+
)
754+
755+
return JSONResponse(content=response.model_dump())
756+
668757
@router.post("/v1/chat/completions",
669758
dependencies=[Depends(validate_json_request)],
670759
responses={

0 commit comments

Comments
 (0)