Skip to content

Commit 7cdca1f

Browse files
committed
feat: /v1/query/{prepare,execute} endpoints for span queries
This PR adds two endpoints - `/v1/query/prepare` - `/v1/query/execute` These add REST APIs around the spans core feature. They take as input a [span query](https://github.com/IBM/spnl). This PR also adds an example query under examples/offline-inference/spans/query-{ab,ba},json. See the [spans readme](examples/offline-inference/spans/README.md) for an example of usage. Signed-off-by: Nick Mitchell <[email protected]>
1 parent 1c3cd0e commit 7cdca1f

File tree

5 files changed

+171
-0
lines changed

5 files changed

+171
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Span (a.k.a. Block Attention) Examples
2+
3+
## Span Queries
4+
5+
This directory contains a [span query](https://github.com/IBM/spnl#readme). To send a query, first prepare the query shape:
6+
7+
```bash
8+
curl -s -XPOST http://localhost:8000/v1/query/prepare --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
9+
1.504452
10+
```
11+
12+
And then you can execute the query in either order, and you should see millisecond-level TTFT:
13+
14+
```bash
15+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ba.json -o /dev/null -w "%{time_total}\n"
16+
0.077699
17+
```
18+
19+
```bash
20+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
21+
0.078419
22+
```

examples/offline_inference/spans/query-ab.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

examples/offline_inference/spans/query-ba.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ pybase64 # fast base64 implementation
4949
cbor2 # Required for cross-language serialization of hashable objects
5050
setproctitle # Used to set process names for better debugging and monitoring
5151
openai-harmony >= 0.0.3 # Required for gpt-oss
52+
spnl >= 0.7.0

vllm/entrypoints/openai/api_server.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,102 @@ async def cancel_responses(response_id: str, raw_request: Request):
665665
return JSONResponse(content=response.model_dump())
666666

667667

668+
if envs.VLLM_V1_SPANS_ENABLED:
669+
import spnl
670+
import time
671+
from fastapi import Body
672+
from vllm import SamplingParams
673+
from vllm.inputs import TokensPrompt
674+
from vllm.outputs import RequestOutput
675+
from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionResponseChoice,UsageInfo)
676+
spnl_state = spnl.init(10)
677+
PAD_TOKEN = 27
678+
PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
679+
CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
680+
def wrap(prompt: str | list[str]) -> TokensPrompt:
681+
if isinstance(prompt[0], list):
682+
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
683+
return TokensPrompt(prompt_token_ids=prompt)
684+
@router.post("/v1/query/prepare")
685+
@with_cancellation
686+
@load_aware_call
687+
async def prepare_query(raw_request: Request,
688+
query: str = Body(..., media_type="text/plain")):
689+
docs = [wrap(doc) for doc in spnl.tokenize_prepare(
690+
spnl_state,
691+
query,
692+
True, # we need to preload the prefix of the plus/independent spans
693+
PAD_TOKEN,
694+
PLUS_TOKEN,
695+
raw_request.app.state.vllm_config.cache_config.block_size
696+
)]
697+
698+
request_id = raw_request.headers.get(
699+
"X-Request-Id") or uuid.uuid4().hex
700+
client = engine_client(raw_request)
701+
generators = [client.generate(doc, SamplingParams(temperature=0,max_tokens=1), request_id) for doc in docs]
702+
for generator in generators:
703+
async for res in generator:
704+
final = res.outputs[0]
705+
706+
if isinstance(generator, ErrorResponse):
707+
return JSONResponse(content=generator.model_dump(),
708+
status_code=generator.error.code)
709+
return JSONResponse(content={"success": True})
710+
711+
@router.post("/v1/query/execute")
712+
@with_cancellation
713+
@load_aware_call
714+
async def execute_query(raw_request: Request,
715+
query: str = Body(..., media_type="text/plain")):
716+
req = spnl.tokenize_query(
717+
spnl_state,
718+
query,
719+
PAD_TOKEN,
720+
CROSS_TOKEN,
721+
PLUS_TOKEN,
722+
raw_request.app.state.vllm_config.cache_config.block_size
723+
)
724+
725+
request_id = raw_request.headers.get(
726+
"X-Request-Id") or uuid.uuid4().hex
727+
client = engine_client(raw_request)
728+
generator = client.generate(wrap(req.messages), SamplingParams(n=1 if req.n <= 0 else req.n,temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id)
729+
730+
# TODO streaming output...
731+
outputs: List[Optional[CompletionOutput]] = [None for _ in range(req.n)]
732+
async for res in generator:
733+
for output in res.outputs:
734+
outputs[output.index] = output
735+
choices = [
736+
ChatCompletionResponseChoice(
737+
index=index,
738+
message=ChatMessage(role="assistant", content=output.text),
739+
logprobs=output.logprobs,
740+
finish_reason=output.finish_reason,
741+
stop_reason=output.stop_reason,
742+
)
743+
for index, output in enumerate(outputs)
744+
]
745+
num_prompt_tokens=0 # TODO
746+
num_generated_tokens=0 # TODO
747+
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
748+
completion_tokens=num_generated_tokens,
749+
total_tokens=num_prompt_tokens +
750+
num_generated_tokens)
751+
response = ChatCompletionResponse(
752+
id=request_id,
753+
created=int(time.time()),
754+
model=req.model,
755+
choices=choices,
756+
usage=usage
757+
)
758+
759+
if isinstance(generator, ErrorResponse):
760+
return JSONResponse(content=generator.model_dump(),
761+
status_code=generator.error.code)
762+
return JSONResponse(content=response.model_dump())
763+
668764
@router.post("/v1/chat/completions",
669765
dependencies=[Depends(validate_json_request)],
670766
responses={

0 commit comments

Comments
 (0)