diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index f85db58a0125..c667f6e3e70e 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -37,6 +37,49 @@ ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor + +_SSE_DATA = "data: " +_SSE_NL = "\n\n" + + +def _pydantic_default(obj): + if hasattr(obj, "model_dump"): + return obj.model_dump() + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def _fast_sse_content( + chunk_id: str, + created: int, + model: str, + index: int, + content: str = None, + reasoning_content: str = None, + finish_reason=None, + logprobs=None, + usage=None, +) -> str: + delta = {} + if content is not None: + delta["content"] = content + if reasoning_content is not None: + delta["reasoning_content"] = reasoning_content + choice = { + "index": index, + "delta": delta, + "logprobs": logprobs, + "finish_reason": finish_reason, + } + resp = { + "id": chunk_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [choice], + } + if usage is not None: + resp["usage"] = usage + return _SSE_DATA + orjson.dumps(resp, default=_pydantic_default).decode() + _SSE_NL from sglang.srt.entrypoints.openai.utils import ( process_cached_tokens_details_from_ret, process_hidden_states_from_ret, @@ -721,20 +764,15 @@ async def _generate_chat_stream( # First chunk with role if is_firsts.get(index, True): is_firsts[index] = False - delta = DeltaMessage(role="assistant", content="") - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=delta, - finish_reason=None, - logprobs=None, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=int(time.time()), - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" + delta = {"role": "assistant", "content": ""} + resp = { + "id": content["meta_info"]["id"], + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": request.model, + "choices": [{"index": index, "delta": delta, "finish_reason": None, "logprobs": None}], + } + yield _SSE_DATA + orjson.dumps(resp).decode() + _SSE_NL stream_started = True stream_buffer = stream_buffers.get(index, "") @@ -751,27 +789,22 @@ async def _generate_chat_stream( index, delta, reasoning_parser_dict, content, request ) if reasoning_text: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(reasoning_content=reasoning_text), - finish_reason=None, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=int(time.time()), - choices=[choice_data], - model=request.model, - ) - - # Add usage stats if continuous_usage_stats is enabled + usage = None if continuous_usage_stats: - chunk.usage = UsageProcessor.calculate_token_usage( + usage = UsageProcessor.calculate_token_usage( prompt_tokens=prompt_tokens.get(index, 0), reasoning_tokens=reasoning_tokens.get(index, 0), completion_tokens=completion_tokens.get(index, 0), ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield _fast_sse_content( + chunk_id=content["meta_info"]["id"], + created=int(time.time()), + model=request.model, + index=index, + reasoning_content=reasoning_text, + usage=usage, + ) # Handle tool calls if ( @@ -803,29 +836,23 @@ async def _generate_chat_stream( else: # Regular content if delta: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=delta), - finish_reason=None, - matched_stop=None, - logprobs=choice_logprobs, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=int(time.time()), - choices=[choice_data], - model=request.model, - ) - - # Add usage stats if continuous_usage_stats is enabled + usage = None if continuous_usage_stats: - chunk.usage = UsageProcessor.calculate_token_usage( + usage = UsageProcessor.calculate_token_usage( prompt_tokens=prompt_tokens.get(index, 0), reasoning_tokens=reasoning_tokens.get(index, 0), completion_tokens=completion_tokens.get(index, 0), ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield _fast_sse_content( + chunk_id=content["meta_info"]["id"], + created=int(time.time()), + model=request.model, + index=index, + content=delta, + logprobs=choice_logprobs, + usage=usage, + ) # Send finish_reason chunks for each index that completed for idx, finish_reason_data in finish_reasons.items(): @@ -836,27 +863,23 @@ async def _generate_chat_stream( if has_tool_calls.get(idx, False) and finish_reason_type == "stop": final_finish_reason = "tool_calls" - finish_reason_chunk = ChatCompletionStreamResponse( - id=content["meta_info"][ - "id" - ], # NOTE: openai uses the same chatcmpl-id for all indices - created=int(time.time()), - choices=[ - ChatCompletionResponseStreamChoice( - index=idx, - delta=DeltaMessage(), - finish_reason=final_finish_reason, - matched_stop=( - finish_reason_data["matched"] - if "matched" in finish_reason_data - else None - ), - ) - ], - model=request.model, - usage=None, - ) - yield f"data: {finish_reason_chunk.model_dump_json()}\n\n" + matched_stop = finish_reason_data.get("matched") + fr_choice = { + "index": idx, + "delta": {}, + "finish_reason": final_finish_reason, + "logprobs": None, + } + if matched_stop is not None: + fr_choice["matched_stop"] = matched_stop + fr_resp = { + "id": content["meta_info"]["id"], + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": request.model, + "choices": [fr_choice], + } + yield _SSE_DATA + orjson.dumps(fr_resp).decode() + _SSE_NL # Send hidden states if requested if request.return_hidden_states and hidden_states: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bfc5bef63aa5..ff16ec16240d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1619,16 +1619,24 @@ def auto_create_handle_loop(self): loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) ) + _BATCH_NOTIFY_SIZE = 16 + async def handle_loop(self): """The event loop that handles requests""" while True: with self.soft_watchdog.disable(): recv_obj = await self.recv_from_detokenizer.recv_pyobj() - self._result_dispatcher(recv_obj) + if isinstance( + recv_obj, + (BatchStrOutput, BatchEmbeddingOutput, BatchTokenIDOutput, BatchMultimodalOutput), + ): + await self._handle_batch_output(recv_obj) + else: + self._result_dispatcher(recv_obj) self.last_receive_tstamp = real_time() self.soft_watchdog.feed() - def _handle_batch_output( + async def _handle_batch_output( self, recv_obj: Union[ BatchStrOutput, @@ -1636,6 +1644,7 @@ def _handle_batch_output( BatchTokenIDOutput, ], ): + pending_notify: dict = {} for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: @@ -1828,9 +1837,14 @@ def _handle_batch_output( if out_dict is not None: state.out_list.append(out_dict) - state.event.set() + pending_notify[id(state)] = state + + if len(pending_notify) >= self._BATCH_NOTIFY_SIZE: + for s in pending_notify.values(): + s.event.set() + pending_notify = {} + await asyncio.sleep(0) - # Log metrics and dump if self.enable_metrics and state.obj.log_metrics: self.collect_metrics(state, recv_obj, i) if self.dump_requests_folder and state.finished and state.obj.log_metrics: @@ -1838,6 +1852,9 @@ def _handle_batch_output( if self.crash_dump_folder and state.finished and state.obj.log_metrics: self.record_request_for_crash_dump(state, out_dict) + for s in pending_notify.values(): + s.event.set() + # When skip_tokenizer_init is enabled, tokensizer_manager receives # BatchTokenIDOutput. if (