Skip to content

Commit 6486db3

Browse files
authored
fix: improve streaming proxy throughput by fixing middleware and logging bottlenecks (#21501)
* fix(middleware): replace BaseHTTPMiddleware with pure ASGI middleware BaseHTTPMiddleware wraps streaming responses with receive_or_disconnect per chunk, blocking the event loop and causing severe throughput degradation under concurrent streaming load (53% of CPU in profiling). Converts PrometheusAuthMiddleware to a pure ASGI middleware using the __call__(scope, receive, send) protocol. * fix(streaming): remove expensive debug logging and optimize usage stripping - Remove print_verbose calls that format chunk/response Pydantic objects, triggering millions of __repr__ calls (8% of CPU in profiling) - Guard remaining verbose_logger.debug with isEnabledFor(DEBUG) and use lazy %s formatting instead of f-strings - Replace usage stripping round-trip (model_dump + delete + reconstruct) with a _usage_stripped flag, deferring exclusion to serialization time * fix(proxy): remove per-chunk debug log and use _usage_stripped flag - Remove verbose_proxy_logger.debug that formatted every streaming chunk - Honor _usage_stripped flag from streaming handler to exclude usage during model_dump_json serialization instead of reconstructing objects * fix(proxy): remove per-chunk debug log in async_data_generator Remove verbose_proxy_logger.debug that formatted every streaming chunk, which triggered expensive Pydantic serialization on the hot path. * fix indentation and add clarifying comment for usage stripping * fix: guard calculate_total_usage against None usage in chunks * fix: store chunk copy to preserve usage for calculate_total_usage
1 parent a9058bb commit 6486db3

File tree

4 files changed

+70
-70
lines changed

4 files changed

+70
-70
lines changed

litellm/litellm_core_utils/streaming_handler.py

Lines changed: 24 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import collections.abc
33
import datetime
44
import json
5+
import logging
56
import threading
67
import time
78
import traceback
@@ -435,7 +436,7 @@ def handle_replicate_chunk(self, chunk):
435436

436437
def handle_openai_chat_completion_chunk(self, chunk):
437438
try:
438-
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
439+
439440
str_line = chunk
440441
text = ""
441442
is_finished = False
@@ -485,7 +486,7 @@ def handle_openai_chat_completion_chunk(self, chunk):
485486

486487
def handle_azure_text_completion_chunk(self, chunk):
487488
try:
488-
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
489+
489490
text = ""
490491
is_finished = False
491492
finish_reason = None
@@ -506,7 +507,7 @@ def handle_azure_text_completion_chunk(self, chunk):
506507

507508
def handle_openai_text_completion_chunk(self, chunk):
508509
try:
509-
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
510+
510511
text = ""
511512
is_finished = False
512513
finish_reason = None
@@ -870,9 +871,6 @@ def return_processed_chunk_logic( # noqa
870871
preserve_upstream_non_openai_attributes,
871872
)
872873

873-
print_verbose(
874-
f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}"
875-
)
876874
is_chunk_non_empty = self.is_chunk_non_empty(
877875
completion_obj, model_response, response_obj
878876
)
@@ -899,11 +897,9 @@ def return_processed_chunk_logic( # noqa
899897
choice_json.pop(
900898
"finish_reason", None
901899
) # for mistral etc. which return a value in their last chunk (not-openai compatible).
902-
print_verbose(f"choice_json: {choice_json}")
903900
choices.append(StreamingChoices(**choice_json))
904901
except Exception:
905902
choices.append(StreamingChoices())
906-
print_verbose(f"choices in streaming: {choices}")
907903
setattr(model_response, "choices", choices)
908904
else:
909905
return
@@ -921,9 +917,11 @@ def return_processed_chunk_logic( # noqa
921917
)
922918

923919
model_response = self.strip_role_from_delta(model_response)
924-
verbose_logger.debug(
925-
f"model_response.choices[0].delta inside is_chunk_non_empty: {model_response.choices[0].delta}"
926-
)
920+
if verbose_logger.isEnabledFor(logging.DEBUG):
921+
verbose_logger.debug(
922+
"model_response.choices[0].delta: %s",
923+
model_response.choices[0].delta,
924+
)
927925
else:
928926
## else
929927
completion_obj["content"] = model_response_str
@@ -1370,9 +1368,6 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
13701368
)
13711369

13721370
model_response.model = self.model
1373-
print_verbose(
1374-
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
1375-
)
13761371
## FUNCTION CALL PARSING
13771372
original_chunk = (
13781373
response_obj.get("original_chunk") if response_obj is not None else None
@@ -1432,7 +1427,6 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
14321427
):
14331428
t.function.arguments = ""
14341429
_json_delta = delta.model_dump()
1435-
print_verbose(f"_json_delta: {_json_delta}")
14361430
if "role" not in _json_delta or _json_delta["role"] is None:
14371431
_json_delta[
14381432
"role"
@@ -1466,11 +1460,7 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
14661460
if original_chunk.choices[0].delta is None
14671461
else dict(original_chunk.choices[0].delta)
14681462
)
1469-
print_verbose(f"original delta: {delta}")
14701463
model_response.choices[0].delta = Delta(**delta)
1471-
print_verbose(
1472-
f"new delta: {model_response.choices[0].delta}"
1473-
)
14741464
except Exception:
14751465
model_response.choices[0].delta = Delta()
14761466
else:
@@ -1480,11 +1470,6 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
14801470
):
14811471
return model_response
14821472
return
1483-
print_verbose(
1484-
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
1485-
)
1486-
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
1487-
14881473
## CHECK FOR TOOL USE
14891474

14901475
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
@@ -1915,18 +1900,9 @@ async def __anext__(self): # noqa: PLR0915
19151900
and len(chunk.parts) == 0
19161901
):
19171902
continue
1918-
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
1919-
# __anext__ also calls async_success_handler, which does logging
1920-
verbose_logger.debug(
1921-
f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}"
1922-
)
1923-
19241903
processed_chunk: Optional[ModelResponseStream] = self.chunk_creator(
19251904
chunk=chunk
19261905
)
1927-
verbose_logger.debug(
1928-
f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}"
1929-
)
19301906
if processed_chunk is None:
19311907
continue
19321908

@@ -1943,31 +1919,28 @@ async def __anext__(self): # noqa: PLR0915
19431919
self.rules.post_call_rules(
19441920
input=self.response_uptil_now, model=self.model
19451921
)
1946-
self.chunks.append(processed_chunk)
1947-
1922+
# Store a shallow copy so usage stripping below
1923+
# does not mutate the stored chunk.
1924+
self.chunks.append(processed_chunk.model_copy())
1925+
19481926
# Add mcp_list_tools to first chunk if present
19491927
if not self.sent_first_chunk:
19501928
processed_chunk = self._add_mcp_list_tools_to_first_chunk(processed_chunk)
19511929
self.sent_first_chunk = True
1952-
if hasattr(
1953-
processed_chunk, "usage"
1954-
): # remove usage from chunk, only send on final chunk
1955-
# Convert the object to a dictionary
1956-
obj_dict = processed_chunk.model_dump()
1957-
1958-
# Remove an attribute (e.g., 'attr2')
1959-
if "usage" in obj_dict:
1960-
del obj_dict["usage"]
1961-
1962-
# Create a new object without the removed attribute
1963-
processed_chunk = self.model_response_creator(chunk=obj_dict)
1930+
if (
1931+
hasattr(processed_chunk, "usage")
1932+
and getattr(processed_chunk, "usage", None) is not None
1933+
):
1934+
# Strip usage from the outgoing chunk so
1935+
# model_dump_json(exclude_none=True) drops it.
1936+
# The copy in self.chunks retains usage for
1937+
# calculate_total_usage().
1938+
processed_chunk.usage = None # type: ignore
19641939
is_empty = is_model_response_stream_empty(
19651940
model_response=cast(ModelResponseStream, processed_chunk)
19661941
)
1967-
19681942
if is_empty:
19691943
continue
1970-
print_verbose(f"final returned processed chunk: {processed_chunk}")
19711944

19721945
# add usage as hidden param
19731946
if self.sent_last_chunk is True and self.stream_options is None:
@@ -1982,7 +1955,7 @@ async def __anext__(self): # noqa: PLR0915
19821955
)
19831956
)
19841957
# Add MCP metadata to final chunk if present (after hooks)
1985-
processed_chunk = self._add_mcp_metadata_to_final_chunk(processed_chunk)
1958+
processed_chunk = self._add_mcp_metadata_to_final_chunk(processed_chunk) # type: ignore[reportArgumentType]
19861959

19871960
return processed_chunk
19881961
raise StopAsyncIteration
@@ -1996,13 +1969,9 @@ async def __anext__(self): # noqa: PLR0915
19961969
else:
19971970
chunk = next(self.completion_stream)
19981971
if chunk is not None and chunk != b"":
1999-
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
20001972
processed_chunk: Optional[
20011973
ModelResponseStream
20021974
] = self.chunk_creator(chunk=chunk)
2003-
print_verbose(
2004-
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
2005-
)
20061975
if processed_chunk is None:
20071976
continue
20081977

@@ -2193,7 +2162,7 @@ def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
21932162
prompt_tokens: int = 0
21942163
completion_tokens: int = 0
21952164
for chunk in chunks:
2196-
if "usage" in chunk:
2165+
if "usage" in chunk and chunk["usage"] is not None:
21972166
if "prompt_tokens" in chunk["usage"]:
21982167
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
21992168
if "completion_tokens" in chunk["usage"]:

litellm/proxy/middleware/prometheus_auth_middleware.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
"""
22
Prometheus Auth Middleware
3+
4+
Pure ASGI middleware — avoids Starlette's BaseHTTPMiddleware which wraps
5+
streaming responses with receive_or_disconnect per chunk, blocking the
6+
event loop and causing severe throughput degradation under concurrent
7+
streaming load.
38
"""
4-
from fastapi import Request
5-
from fastapi.responses import JSONResponse
6-
from starlette.middleware.base import BaseHTTPMiddleware
9+
from starlette.requests import Request
10+
from starlette.responses import JSONResponse
11+
from starlette.types import ASGIApp, Receive, Scope, Send
712

813
import litellm
914
from litellm.proxy._types import SpecialHeaders
1015
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
1116

1217

13-
class PrometheusAuthMiddleware(BaseHTTPMiddleware):
18+
class PrometheusAuthMiddleware:
1419
"""
1520
Middleware to authenticate requests to the metrics endpoint
1621
@@ -24,8 +29,15 @@ class PrometheusAuthMiddleware(BaseHTTPMiddleware):
2429
```
2530
"""
2631

27-
async def dispatch(self, request: Request, call_next):
28-
# Check if this is a request to the metrics endpoint
32+
def __init__(self, app: ASGIApp) -> None:
33+
self.app = app
34+
35+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
36+
if scope["type"] not in ("http", "websocket"):
37+
await self.app(scope, receive, send)
38+
return
39+
40+
request = Request(scope, receive)
2941

3042
if self._is_prometheus_metrics_endpoint(request):
3143
if self._should_run_auth_on_metrics_endpoint() is True:
@@ -38,15 +50,14 @@ async def dispatch(self, request: Request, call_next):
3850
or "",
3951
)
4052
except Exception as e:
41-
return JSONResponse(
53+
response = JSONResponse(
4254
status_code=401,
4355
content=f"Unauthorized access to metrics endpoint: {getattr(e, 'message', str(e))}",
4456
)
57+
await response(scope, receive, send)
58+
return
4559

46-
# Process the request and get the response
47-
response = await call_next(request)
48-
49-
return response
60+
await self.app(scope, receive, send)
5061

5162
@staticmethod
5263
def _is_prometheus_metrics_endpoint(request: Request):

litellm/proxy/proxy_server.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5074,10 +5074,6 @@ async def async_data_generator(
50745074
response=response,
50755075
request_data=request_data,
50765076
):
5077-
verbose_proxy_logger.debug(
5078-
"async_data_generator: received streaming chunk - {}".format(chunk)
5079-
)
5080-
50815077
### CALL HOOKS ### - modify outgoing data
50825078
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
50835079
user_api_key_dict=user_api_key_dict,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Tests that PrometheusAuthMiddleware is a pure ASGI middleware (not BaseHTTPMiddleware).
3+
4+
BaseHTTPMiddleware wraps streaming responses with receive_or_disconnect per chunk,
5+
which blocks the event loop and causes severe throughput degradation.
6+
"""
7+
from starlette.middleware.base import BaseHTTPMiddleware
8+
9+
from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMiddleware
10+
11+
12+
def test_is_not_base_http_middleware():
13+
"""PrometheusAuthMiddleware must NOT inherit from BaseHTTPMiddleware."""
14+
assert not issubclass(PrometheusAuthMiddleware, BaseHTTPMiddleware), (
15+
"PrometheusAuthMiddleware should be a pure ASGI middleware, not BaseHTTPMiddleware. "
16+
"BaseHTTPMiddleware causes severe streaming performance degradation."
17+
)
18+
19+
20+
def test_has_asgi_call_protocol():
21+
"""PrometheusAuthMiddleware must implement the ASGI __call__ protocol."""
22+
assert "__call__" in PrometheusAuthMiddleware.__dict__, (
23+
"PrometheusAuthMiddleware must define __call__(self, scope, receive, send)"
24+
)

0 commit comments

Comments
 (0)