Skip to content

Commit aa971fa

Browse files
authored
fix(proxy): complete cache-locality fix for prompt cache hit rate restoration (#273)
* Implement cache locality fix (Tasks 1-5): model-class prefix, budget reallocation, settings columns, bridge TTL wiring, observability logs * test(migration): verify new dashboard_settings columns * feat(settings): expose cache-locality fields in dashboard API * refactor(test): eliminate stickiness test replica, import production code * feat(observability): populate bridge event cache_key_family and model_class - Enrich all 15 bridge event call-sites with cache_key_family and model_class kwargs - cache_key_family sourced from key.affinity_kind (always available) - model_class derived from _extract_model_class() with None fallback - No new log events created, only enrichment of existing ones - Updated test fixture with new settings fields - All tests passing (17 stickiness tests, bridge integration tests) * fix(lb): use strict > for budget_threshold_pct to match intent (>95% used) * fix(test): update idle_ttl_seconds assertion to >= after prompt-cache TTL wiring * feat(lb): wire configurable budget_threshold_pct through load_balancer * feat(observability): log reallocation-orphan bridge event * test(integration): add regression suite for cache-locality completion * fix(review): address final verification wave rejections - Fix A: Rename test wrapper _select_with_stickiness to _invoke_stickiness (22 call sites) - Fix B: Move AccountsRepository and StickySessionsRepository to TYPE_CHECKING block; inline get_settings import - Fix C: Remove explicit generic type from asyncio.Future annotation in bridge test - Fix D: Collapse multi-line getattr to single line for grep verification All verification tests pass. * fix(ci): resolve ruff lint and ty type-check errors * fix(proxy): preserve sticky grace path and selection budget * fix(types): align proxy and test typing contracts * fix(proxy): scope sticky reallocation and release retry lock * fix(proxy): restore entropy for empty cache-key fallback * fix(test): align proxy fixtures with settings wiring * fix(proxy): scope far-reset reallocation to prompt cache * fix(proxy): preserve bridge ttl and codex model bucketing * fix(test): speed up postgres fixture resets * fix(test): preserve postgres fixture semantics
1 parent c8b6c00 commit aa971fa

33 files changed

Lines changed: 2438 additions & 1396 deletions

app/core/clients/model_fetcher.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,23 @@ def _list_raw(data: dict[str, JsonValue], key: str) -> list[JsonValue]:
4848
return []
4949

5050

51+
def _parse_reasoning_level(value: JsonValue) -> ReasoningLevel | None:
52+
if not isinstance(value, dict):
53+
return None
54+
effort = value.get("effort")
55+
description = value.get("description")
56+
if not isinstance(effort, str) or not isinstance(description, str):
57+
return None
58+
return ReasoningLevel(effort=effort, description=description)
59+
60+
5161
def _parse_upstream_model(data: dict[str, JsonValue]) -> UpstreamModel:
5262
raw = {k: v for k, v in data.items() if k not in _FILTERED_FIELDS}
5363

5464
reasoning_levels = tuple(
55-
ReasoningLevel(effort=rl.get("effort", ""), description=rl.get("description", ""))
65+
parsed_level
5666
for rl in _list_raw(data, "supported_reasoning_levels")
57-
if isinstance(rl, dict)
67+
if (parsed_level := _parse_reasoning_level(rl)) is not None
5868
)
5969

6070
available_in_plans = frozenset(p for p in _list_raw(data, "available_in_plans") if isinstance(p, str))

app/core/clients/proxy.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@
1212
import socket
1313
import time
1414
from dataclasses import dataclass
15-
from typing import Any, AsyncContextManager, AsyncIterator, Awaitable, Mapping, Protocol, TypeAlias, cast
15+
from typing import Any, AsyncContextManager, AsyncIterator, Awaitable, Callable, Mapping, Protocol, TypeAlias, cast
1616
from urllib.parse import ParseResult, urlparse, urlunparse
1717

1818
import aiohttp
1919
from aiohttp import hdrs
2020
from aiohttp.client_ws import DEFAULT_WS_CLIENT_TIMEOUT
21+
from aiohttp.http_websocket import WS_KEY, WebSocketReader, WebSocketWriter
2122
from multidict import CIMultiDict
2223

2324
from app.core.clients.http import get_http_client
2425
from app.core.config.settings import get_settings
25-
from app.core.errors import OpenAIErrorEnvelope, ResponseFailedEvent, openai_error, response_failed_event
26+
from app.core.errors import (
27+
OpenAIErrorDetail,
28+
OpenAIErrorEnvelope,
29+
ResponseFailedEvent,
30+
openai_error,
31+
response_failed_event,
32+
)
2633
from app.core.openai.model_registry import get_model_registry
27-
from app.core.openai.models import CompactResponsePayload
34+
from app.core.openai.models import CompactResponsePayload, OpenAIError
2835
from app.core.openai.parsing import (
2936
parse_compact_response_payload,
3037
parse_error_payload,
@@ -408,7 +415,7 @@ def _error_payload_from_websocket_handshake_error(exc: aiohttp.WSServerHandshake
408415
if extracted is not None:
409416
error = parse_error_payload(extracted)
410417
if error is not None:
411-
return {"error": error.model_dump(exclude_none=True)}
418+
return {"error": _openai_error_detail(error)}
412419

413420
code = _infer_websocket_handshake_error_code(exc.status, message)
414421
if code == "invalid_api_key":
@@ -694,13 +701,32 @@ async def _error_payload_from_response(resp: ErrorResponse) -> OpenAIErrorEnvelo
694701
if isinstance(data, dict):
695702
error = parse_error_payload(data)
696703
if error:
697-
return {"error": error.model_dump(exclude_none=True)}
704+
return {"error": _openai_error_detail(error)}
698705
message = _extract_upstream_message(data)
699706
if message:
700707
return openai_error("upstream_error", message)
701708
return openai_error("upstream_error", fallback_message)
702709

703710

711+
def _openai_error_detail(error: OpenAIError) -> OpenAIErrorDetail:
712+
detail: OpenAIErrorDetail = {}
713+
if error.message is not None:
714+
detail["message"] = error.message
715+
if error.type is not None:
716+
detail["type"] = error.type
717+
if error.code is not None:
718+
detail["code"] = error.code
719+
if error.param is not None:
720+
detail["param"] = error.param
721+
if error.plan_type is not None:
722+
detail["plan_type"] = error.plan_type
723+
if error.resets_at is not None:
724+
detail["resets_at"] = error.resets_at
725+
if error.resets_in_seconds is not None:
726+
detail["resets_in_seconds"] = error.resets_in_seconds
727+
return detail
728+
729+
704730
def _extract_upstream_message(data: Mapping[str, object]) -> str | None:
705731
for key in ("message", "detail", "error"):
706732
value = data.get(key)
@@ -856,8 +882,8 @@ async def _open_upstream_websocket(
856882
connect_timeout_seconds: float,
857883
max_msg_size: int,
858884
) -> tuple[AsyncContextManager[aiohttp.ClientWebSocketResponse], aiohttp.ClientWebSocketResponse]:
859-
request = getattr(session, "request", None)
860-
if not callable(request):
885+
request_obj = getattr(session, "request", None)
886+
if not callable(request_obj):
861887
websocket_cm = session.ws_connect(
862888
url,
863889
headers=headers,
@@ -868,6 +894,7 @@ async def _open_upstream_websocket(
868894
)
869895
websocket = await asyncio.wait_for(websocket_cm.__aenter__(), timeout=connect_timeout_seconds)
870896
return websocket_cm, websocket
897+
request = cast(Callable[..., Awaitable[aiohttp.ClientResponse]], request_obj)
871898

872899
request_headers = CIMultiDict(headers)
873900
request_headers.setdefault(hdrs.UPGRADE, "websocket")
@@ -910,7 +937,7 @@ async def _raise_handshake_error(message: str) -> None:
910937
await _raise_handshake_error("Invalid connection header")
911938

912939
response_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, "")
913-
expected_key = base64.b64encode(hashlib.sha1(sec_key.encode() + aiohttp.client.WS_KEY).digest()).decode()
940+
expected_key = base64.b64encode(hashlib.sha1(sec_key.encode() + WS_KEY).digest()).decode()
914941
if response_key != expected_key:
915942
await _raise_handshake_error("Invalid challenge response")
916943

@@ -922,9 +949,10 @@ async def _raise_handshake_error(message: str) -> None:
922949

923950
transport = conn.transport
924951
assert transport is not None
925-
reader = aiohttp.client.WebSocketDataQueue(conn_proto, 2**16, loop=session._loop)
926-
conn_proto.set_parser(aiohttp.client.WebSocketReader(reader, max_msg_size), reader)
927-
writer = aiohttp.client.WebSocketWriter(conn_proto, transport, use_mask=True, compress=0, notakeover=False)
952+
web_socket_data_queue = cast(Callable[..., Any], getattr(aiohttp.client_ws, "WebSocketDataQueue"))
953+
reader = web_socket_data_queue(conn_proto, 2**16, loop=session._loop)
954+
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
955+
writer = WebSocketWriter(conn_proto, transport, use_mask=True, compress=0, notakeover=False)
928956
except BaseException:
929957
resp.close()
930958
raise

app/core/clients/proxy_websocket.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
from app.core.clients.proxy import ProxyResponseError, filter_inbound_headers
2222
from app.core.config.settings import get_settings
23-
from app.core.errors import OpenAIErrorEnvelope, openai_error
23+
from app.core.errors import OpenAIErrorDetail, OpenAIErrorEnvelope, openai_error
24+
from app.core.openai.models import OpenAIError
2425
from app.core.openai.parsing import parse_error_payload
2526
from app.core.utils.request_id import get_request_id
2627

@@ -262,4 +263,23 @@ def _try_parse_handshake_error_payload(
262263
error = parse_error_payload(payload)
263264
if error is None:
264265
return None
265-
return {"error": error.model_dump(exclude_none=True)}
266+
return {"error": _openai_error_detail(error)}
267+
268+
269+
def _openai_error_detail(error: OpenAIError) -> OpenAIErrorDetail:
270+
detail: OpenAIErrorDetail = {}
271+
if error.message is not None:
272+
detail["message"] = error.message
273+
if error.type is not None:
274+
detail["type"] = error.type
275+
if error.code is not None:
276+
detail["code"] = error.code
277+
if error.param is not None:
278+
detail["param"] = error.param
279+
if error.plan_type is not None:
280+
detail["plan_type"] = error.plan_type
281+
if error.resets_at is not None:
282+
detail["resets_at"] = error.resets_at
283+
if error.resets_in_seconds is not None:
284+
detail["resets_in_seconds"] = error.resets_in_seconds
285+
return detail

app/core/middleware/request_decompression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _replace_request_body(request: Request, body: bytes) -> None:
111111
headers.append((b"content-length", str(len(body)).encode("ascii")))
112112
request.scope["headers"] = headers
113113
# Ensure subsequent request.headers reflects the updated scope headers.
114-
request._headers = None
114+
request.__dict__.pop("_headers", None)
115115

116116

117117
def add_request_decompression_middleware(app: FastAPI) -> None:

app/core/openai/chat_requests.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,13 @@ def _text_format_from_parsed(parsed: ChatResponseFormat) -> ResponsesTextFormat:
265265
json_schema = parsed.json_schema
266266
if json_schema is None:
267267
raise ValueError("'response_format.json_schema' is required when type is 'json_schema'.")
268-
return ResponsesTextFormat(
269-
type=parsed.type,
270-
schema_=json_schema.schema_,
271-
name=json_schema.name,
272-
strict=json_schema.strict,
268+
return ResponsesTextFormat.model_validate(
269+
{
270+
"type": parsed.type,
271+
"schema": json_schema.schema_,
272+
"name": json_schema.name,
273+
"strict": json_schema.strict,
274+
}
273275
)
274276
if parsed.type in ("json_object", "text"):
275277
return ResponsesTextFormat(type=parsed.type)

app/db/alembic/versions/20260213_000000_base_schema.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,18 @@ def upgrade() -> None:
178178
nullable=False,
179179
server_default=sa.text("CURRENT_TIMESTAMP"),
180180
),
181+
sa.Column(
182+
"http_responses_session_bridge_prompt_cache_idle_ttl_seconds",
183+
sa.Integer(),
184+
nullable=False,
185+
server_default=sa.text("3600"),
186+
),
187+
sa.Column(
188+
"sticky_reallocation_budget_threshold_pct",
189+
sa.Float(),
190+
nullable=False,
191+
server_default=sa.text("95.0"),
192+
),
181193
)
182194

183195
created_api_keys = not _table_exists(bind, "api_keys")
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""add cache-locality settings fields to dashboard_settings
2+
3+
Revision ID: 20260330_000000_add_cache_locality_settings
4+
Revises: 20260325_000000_add_request_log_cost
5+
Create Date: 2026-03-30
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import sqlalchemy as sa
11+
from alembic import op
12+
from sqlalchemy.engine import Connection
13+
14+
# revision identifiers, used by Alembic.
15+
revision = "20260330_000000_add_cache_locality_settings"
16+
down_revision = "20260325_000000_add_request_log_cost"
17+
branch_labels = None
18+
depends_on = None
19+
20+
21+
def _columns(connection: Connection, table_name: str) -> set[str]:
22+
inspector = sa.inspect(connection)
23+
if not inspector.has_table(table_name):
24+
return set()
25+
return {column["name"] for column in inspector.get_columns(table_name)}
26+
27+
28+
def upgrade() -> None:
29+
bind = op.get_bind()
30+
columns = _columns(bind, "dashboard_settings")
31+
if not columns:
32+
return
33+
34+
if "http_responses_session_bridge_prompt_cache_idle_ttl_seconds" not in columns:
35+
with op.batch_alter_table("dashboard_settings") as batch_op:
36+
batch_op.add_column(
37+
sa.Column(
38+
"http_responses_session_bridge_prompt_cache_idle_ttl_seconds",
39+
sa.Integer(),
40+
nullable=False,
41+
server_default="3600",
42+
)
43+
)
44+
45+
if "sticky_reallocation_budget_threshold_pct" not in columns:
46+
with op.batch_alter_table("dashboard_settings") as batch_op:
47+
batch_op.add_column(
48+
sa.Column(
49+
"sticky_reallocation_budget_threshold_pct",
50+
sa.Float(),
51+
nullable=False,
52+
server_default="95.0",
53+
)
54+
)
55+
56+
57+
def downgrade() -> None:
58+
bind = op.get_bind()
59+
columns = _columns(bind, "dashboard_settings")
60+
if not columns:
61+
return
62+
63+
if "http_responses_session_bridge_prompt_cache_idle_ttl_seconds" in columns:
64+
with op.batch_alter_table("dashboard_settings") as batch_op:
65+
batch_op.drop_column("http_responses_session_bridge_prompt_cache_idle_ttl_seconds")
66+
67+
if "sticky_reallocation_budget_threshold_pct" in columns:
68+
with op.batch_alter_table("dashboard_settings") as batch_op:
69+
batch_op.drop_column("sticky_reallocation_budget_threshold_pct")

app/db/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,18 @@ class DashboardSettings(Base):
201201
)
202202
totp_secret_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
203203
totp_last_verified_step: Mapped[int | None] = mapped_column(Integer, nullable=True)
204+
http_responses_session_bridge_prompt_cache_idle_ttl_seconds: Mapped[int] = mapped_column(
205+
Integer,
206+
default=3600,
207+
server_default=text("3600"),
208+
nullable=False,
209+
)
210+
sticky_reallocation_budget_threshold_pct: Mapped[float] = mapped_column(
211+
Float,
212+
default=95.0,
213+
server_default=text("95.0"),
214+
nullable=False,
215+
)
204216
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)
205217
updated_at: Mapped[datetime] = mapped_column(
206218
DateTime,

app/modules/firewall/repository.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
4+
35
from sqlalchemy import select
46
from sqlalchemy.exc import IntegrityError
57
from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,7 +17,7 @@ class FirewallRepository:
1517
def __init__(self, session: AsyncSession) -> None:
1618
self._session = session
1719

18-
async def list_entries(self) -> list[ApiFirewallAllowlist]:
20+
async def list_entries(self) -> Sequence[ApiFirewallAllowlist]:
1921
result = await self._session.execute(
2022
select(ApiFirewallAllowlist).order_by(ApiFirewallAllowlist.created_at, ApiFirewallAllowlist.ip_address)
2123
)

0 commit comments

Comments
 (0)