Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 90 additions & 67 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,49 @@
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor

_SSE_DATA = "data: "
_SSE_NL = "\n\n"


def _pydantic_default(obj):
if hasattr(obj, "model_dump"):
return obj.model_dump()
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")


def _fast_sse_content(
chunk_id: str,
created: int,
model: str,
index: int,
content: str = None,
reasoning_content: str = None,
finish_reason=None,
logprobs=None,
usage=None,
) -> str:
delta = {}
if content is not None:
delta["content"] = content
if reasoning_content is not None:
delta["reasoning_content"] = reasoning_content
choice = {
"index": index,
"delta": delta,
"logprobs": logprobs,
"finish_reason": finish_reason,
}
resp = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [choice],
}
if usage is not None:
resp["usage"] = usage
return _SSE_DATA + orjson.dumps(resp, default=_pydantic_default).decode() + _SSE_NL
from sglang.srt.entrypoints.openai.utils import (
process_cached_tokens_details_from_ret,
process_hidden_states_from_ret,
Expand Down Expand Up @@ -721,20 +764,15 @@ async def _generate_chat_stream(
# First chunk with role
if is_firsts.get(index, True):
is_firsts[index] = False
delta = DeltaMessage(role="assistant", content="")
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=delta,
finish_reason=None,
logprobs=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
delta = {"role": "assistant", "content": ""}
resp = {
"id": content["meta_info"]["id"],
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [{"index": index, "delta": delta, "finish_reason": None, "logprobs": None}],
}
yield _SSE_DATA + orjson.dumps(resp).decode() + _SSE_NL
stream_started = True

stream_buffer = stream_buffers.get(index, "")
Expand All @@ -751,27 +789,22 @@ async def _generate_chat_stream(
index, delta, reasoning_parser_dict, content, request
)
if reasoning_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=reasoning_text),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)

# Add usage stats if continuous_usage_stats is enabled
usage = None
if continuous_usage_stats:
chunk.usage = UsageProcessor.calculate_token_usage(
usage = UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens.get(index, 0),
reasoning_tokens=reasoning_tokens.get(index, 0),
completion_tokens=completion_tokens.get(index, 0),
)

yield f"data: {chunk.model_dump_json()}\n\n"
yield _fast_sse_content(
chunk_id=content["meta_info"]["id"],
created=int(time.time()),
model=request.model,
index=index,
reasoning_content=reasoning_text,
usage=usage,
)

# Handle tool calls
if (
Expand Down Expand Up @@ -803,29 +836,23 @@ async def _generate_chat_stream(
else:
# Regular content
if delta:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta),
finish_reason=None,
matched_stop=None,
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)

# Add usage stats if continuous_usage_stats is enabled
usage = None
if continuous_usage_stats:
chunk.usage = UsageProcessor.calculate_token_usage(
usage = UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens.get(index, 0),
reasoning_tokens=reasoning_tokens.get(index, 0),
completion_tokens=completion_tokens.get(index, 0),
)

yield f"data: {chunk.model_dump_json()}\n\n"
yield _fast_sse_content(
chunk_id=content["meta_info"]["id"],
created=int(time.time()),
model=request.model,
index=index,
content=delta,
logprobs=choice_logprobs,
usage=usage,
)

# Send finish_reason chunks for each index that completed
for idx, finish_reason_data in finish_reasons.items():
Expand All @@ -836,27 +863,23 @@ async def _generate_chat_stream(
if has_tool_calls.get(idx, False) and finish_reason_type == "stop":
final_finish_reason = "tool_calls"

finish_reason_chunk = ChatCompletionStreamResponse(
id=content["meta_info"][
"id"
], # NOTE: openai uses the same chatcmpl-id for all indices
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=idx,
delta=DeltaMessage(),
finish_reason=final_finish_reason,
matched_stop=(
finish_reason_data["matched"]
if "matched" in finish_reason_data
else None
),
)
],
model=request.model,
usage=None,
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
matched_stop = finish_reason_data.get("matched")
fr_choice = {
"index": idx,
"delta": {},
"finish_reason": final_finish_reason,
"logprobs": None,
}
if matched_stop is not None:
fr_choice["matched_stop"] = matched_stop
fr_resp = {
"id": content["meta_info"]["id"],
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [fr_choice],
}
yield _SSE_DATA + orjson.dumps(fr_resp).decode() + _SSE_NL

# Send hidden states if requested
if request.return_hidden_states and hidden_states:
Expand Down
25 changes: 21 additions & 4 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,23 +1619,32 @@ def auto_create_handle_loop(self):
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)

_BATCH_NOTIFY_SIZE = 16

async def handle_loop(self):
"""The event loop that handles requests"""
while True:
with self.soft_watchdog.disable():
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
if isinstance(
recv_obj,
(BatchStrOutput, BatchEmbeddingOutput, BatchTokenIDOutput, BatchMultimodalOutput),
):
await self._handle_batch_output(recv_obj)
else:
self._result_dispatcher(recv_obj)
self.last_receive_tstamp = real_time()
self.soft_watchdog.feed()

def _handle_batch_output(
async def _handle_batch_output(
self,
recv_obj: Union[
BatchStrOutput,
BatchEmbeddingOutput,
BatchTokenIDOutput,
],
):
pending_notify: dict = {}
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
Expand Down Expand Up @@ -1828,16 +1837,24 @@ def _handle_batch_output(

if out_dict is not None:
state.out_list.append(out_dict)
state.event.set()
pending_notify[id(state)] = state

if len(pending_notify) >= self._BATCH_NOTIFY_SIZE:
for s in pending_notify.values():
s.event.set()
pending_notify = {}
await asyncio.sleep(0)

# Log metrics and dump
if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
self.dump_requests(state, out_dict)
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
self.record_request_for_crash_dump(state, out_dict)

for s in pending_notify.values():
s.event.set()

# When skip_tokenizer_init is enabled, tokensizer_manager receives
# BatchTokenIDOutput.
if (
Expand Down
Loading